Skip to main content

Module decoder_unified

Module decoder_unified 

Source
Expand description

Shared helpers for decoder-only unified mixed-batch forward.

The Llama / Qwen3-MoE / future decoder families all share the same outer scaffolding for unified forward: cu_seqlens construction, block-table stacking, final-token index lookup, graph-cache keying. These are pure functions — no kernel calls, no model state — extracted here so each family’s unified_forward_internal reads as “scaffolding + family-specific layer loop”, not “scaffolding + 700 lines of scaffolding clone”.

Per docs/decoder-unified-runner-abstraction.md. Phase 2A.

Functions§

compute_cu_seqlens_q
Cumulative q-token counts: cu_seqlens_q[i+1] - cu_seqlens_q[i] = items[i].q_tokens.len(). The varlen attention + paged-KV-write kernels read this to find each sequence’s slice of the flat [M_total, *] tensor.
compute_final_indices
For each is_final_chunk = true item, return (orig_index, global_token_index) where global_token_index is the position in the flat [M_total, hidden] residual buffer of that item’s LAST q-token. The final-norm + lm_head stages slice these rows out for sampling.
compute_max_kv_len
Causal max over (pos_offset + q_len) — needed for the varlen attention kernel’s shared-mem sizing (must fit the longest reachable kv_pos across all items in the batch).
compute_pos_offsets
Per-item starting absolute KV position for the FIRST q-token in items[i].q_tokens. Zero for fresh prefill, prior kv_len for chunked-prefill continuations or decode steps. Returned as u32 to match the device-side index buffers the varlen kernels read.
concat_q_tokens
Flatten all items’ q-tokens into one concatenated [M_total] vec. Caller passes this to embedding_lookup so the entire batch’s embeddings end up contiguous in the unified residual buffer.
stack_block_tables
Pack per-(seq, layer-0) page indices into the dense [num_seqs, max_blocks_per_seq] layout that the varlen attention kernel reads. Layer indexing is “first layer’s block table” because in ferrum’s paged-KV layout every layer shares the same block-table list (the layer-specific data lives inside each KV pool; the table itself is per-sequence).
unified_graph_key
Graph cache key for a unified mixed-batch forward. High bit set so we never collide with legacy decode/batched keys (which use the low 63 bits for m_padded / SINGLE_ITEM = 0).