Expand description
Shared helpers for decoder-only unified mixed-batch forward.
The Llama / Qwen3-MoE / future decoder families all share the same
outer scaffolding for unified forward: cu_seqlens construction,
block-table stacking, final-token index lookup, graph-cache keying.
These are pure functions — no kernel calls, no model state — extracted
here so each family’s unified_forward_internal reads as
“scaffolding + family-specific layer loop”, not “scaffolding +
700 lines of scaffolding clone”.
Per docs/decoder-unified-runner-abstraction.md. Phase 2A.
Functions§
- compute_
cu_ seqlens_ q - Cumulative q-token counts:
cu_seqlens_q[i+1] - cu_seqlens_q[i] = items[i].q_tokens.len(). The varlen attention + paged-KV-write kernels read this to find each sequence’s slice of the flat[M_total, *]tensor. - compute_
final_ indices - For each
is_final_chunk = trueitem, return(orig_index, global_token_index)whereglobal_token_indexis the position in the flat[M_total, hidden]residual buffer of that item’s LAST q-token. The final-norm + lm_head stages slice these rows out for sampling. - compute_
max_ kv_ len - Causal max over
(pos_offset + q_len)— needed for the varlen attention kernel’s shared-mem sizing (must fit the longest reachablekv_posacross all items in the batch). - compute_
pos_ offsets - Per-item starting absolute KV position for the FIRST q-token in
items[i].q_tokens. Zero for fresh prefill, priorkv_lenfor chunked-prefill continuations or decode steps. Returned asu32to match the device-side index buffers the varlen kernels read. - concat_
q_ tokens - Flatten all items’ q-tokens into one concatenated
[M_total]vec. Caller passes this toembedding_lookupso the entire batch’s embeddings end up contiguous in the unified residual buffer. - stack_
block_ tables - Pack per-(seq, layer-0) page indices into the dense
[num_seqs, max_blocks_per_seq]layout that the varlen attention kernel reads. Layer indexing is “first layer’s block table” because in ferrum’s paged-KV layout every layer shares the same block-table list (the layer-specific data lives inside each KV pool; the table itself is per-sequence). - unified_
graph_ key - Graph cache key for a unified mixed-batch forward. High bit set so we
never collide with legacy decode/batched keys (which use the low 63
bits for
m_padded/SINGLE_ITEM = 0).