Skip to main content

ferrum_models/common/
decoder_unified.rs

1//! Shared helpers for decoder-only unified mixed-batch forward.
2//!
3//! The Llama / Qwen3-MoE / future decoder families all share the same
4//! outer scaffolding for unified forward: cu_seqlens construction,
5//! block-table stacking, final-token index lookup, graph-cache keying.
6//! These are pure functions — no kernel calls, no model state — extracted
7//! here so each family's `unified_forward_internal` reads as
8//! "scaffolding + family-specific layer loop", not "scaffolding +
9//! 700 lines of scaffolding clone".
10//!
11//! Per `docs/decoder-unified-runner-abstraction.md`. Phase 2A.
12
13/// Cumulative q-token counts: `cu_seqlens_q[i+1] - cu_seqlens_q[i] =
14/// items[i].q_tokens.len()`. The varlen attention + paged-KV-write
15/// kernels read this to find each sequence's slice of the flat
16/// `[M_total, *]` tensor.
17///
18/// Also returns the flat `q_lens[i] = items[i].q_tokens.len()` and
19/// `m_total = sum(q_lens)`.
20pub fn compute_cu_seqlens_q(
21    items: &[(String, Vec<u32>, usize, bool)],
22) -> (Vec<usize>, Vec<u32>, usize) {
23    let q_lens: Vec<usize> = items.iter().map(|it| it.1.len()).collect();
24    let mut cu_seqlens_q: Vec<u32> = Vec::with_capacity(items.len() + 1);
25    cu_seqlens_q.push(0);
26    for &l in &q_lens {
27        let prev = *cu_seqlens_q.last().unwrap();
28        cu_seqlens_q.push(prev + l as u32);
29    }
30    let m_total = *cu_seqlens_q.last().unwrap() as usize;
31    (q_lens, cu_seqlens_q, m_total)
32}
33
34/// Per-item starting absolute KV position for the FIRST q-token in
35/// `items[i].q_tokens`. Zero for fresh prefill, prior `kv_len` for
36/// chunked-prefill continuations or decode steps. Returned as `u32`
37/// to match the device-side index buffers the varlen kernels read.
38pub fn compute_pos_offsets(items: &[(String, Vec<u32>, usize, bool)]) -> Vec<u32> {
39    items.iter().map(|it| it.2 as u32).collect()
40}
41
42/// Causal max over `(pos_offset + q_len)` — needed for the varlen
43/// attention kernel's shared-mem sizing (must fit the longest reachable
44/// `kv_pos` across all items in the batch).
45pub fn compute_max_kv_len(items: &[(String, Vec<u32>, usize, bool)]) -> usize {
46    items.iter().map(|it| it.2 + it.1.len()).max().unwrap_or(0)
47}
48
49/// Flatten all items' q-tokens into one concatenated `[M_total]` vec.
50/// Caller passes this to `embedding_lookup` so the entire batch's
51/// embeddings end up contiguous in the unified residual buffer.
52pub fn concat_q_tokens(items: &[(String, Vec<u32>, usize, bool)]) -> Vec<u32> {
53    items.iter().flat_map(|it| it.1.iter().copied()).collect()
54}
55
56/// Pack per-(seq, layer-0) page indices into the dense
57/// `[num_seqs, max_blocks_per_seq]` layout that the varlen attention
58/// kernel reads. Layer indexing is "first layer's block table"
59/// because in ferrum's paged-KV layout every layer shares the same
60/// block-table list (the layer-specific data lives inside each KV
61/// pool; the table itself is per-sequence).
62///
63/// `lookup` returns the block-indices slice for each item's cache_id;
64/// the caller wires this to its model's `kv_caches.get(cid)`.
65pub fn stack_block_tables<F: Fn(&str) -> Vec<u32>>(
66    items: &[(String, Vec<u32>, usize, bool)],
67    max_blocks_per_seq: usize,
68    lookup: F,
69) -> Vec<u32> {
70    let mut stacked: Vec<u32> = vec![0u32; items.len() * max_blocks_per_seq];
71    for (i, (cid, _, _, _)) in items.iter().enumerate() {
72        let blocks = lookup(cid);
73        let n_to_copy = blocks.len().min(max_blocks_per_seq);
74        stacked[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
75            .copy_from_slice(&blocks[..n_to_copy]);
76    }
77    stacked
78}
79
80/// For each `is_final_chunk = true` item, return `(orig_index, global_token_index)`
81/// where `global_token_index` is the position in the flat `[M_total, hidden]`
82/// residual buffer of that item's LAST q-token. The final-norm + lm_head
83/// stages slice these rows out for sampling.
84pub fn compute_final_indices(
85    items: &[(String, Vec<u32>, usize, bool)],
86    cu_seqlens_q: &[u32],
87) -> Vec<(usize, usize)> {
88    items
89        .iter()
90        .enumerate()
91        .filter(|(_, it)| it.3)
92        .map(|(orig_idx, it)| {
93            let last_token_local = it.1.len() - 1;
94            let global = (cu_seqlens_q[orig_idx] as usize) + last_token_local;
95            (orig_idx, global)
96        })
97        .collect()
98}
99
100/// Graph cache key for a unified mixed-batch forward. High bit set so we
101/// never collide with legacy decode/batched keys (which use the low 63
102/// bits for `m_padded` / `SINGLE_ITEM = 0`).
103///
104/// Keyed by `(m_total, num_seqs)` because the captured kernel launches
105/// bake in grid_dim / per-seq scratch indexing for that specific shape;
106/// reusing a graph for a different shape leads to wrong-shape memory
107/// access. (See memory `project_moe_phase3_graph_bug.md` for the same
108/// rationale in the legacy MoE path.)
109pub const fn unified_graph_key(m_total: usize, num_seqs: usize) -> u64 {
110    (1u64 << 63) | ((m_total as u64) << 32) | (num_seqs as u64)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    fn item(
118        cid: &str,
119        q_len: usize,
120        pos: usize,
121        final_chunk: bool,
122    ) -> (String, Vec<u32>, usize, bool) {
123        (cid.to_string(), vec![0u32; q_len], pos, final_chunk)
124    }
125
126    #[test]
127    fn cu_seqlens_q_mixed_lengths() {
128        let items = vec![
129            item("a", 5, 0, true),
130            item("b", 1, 100, true),
131            item("c", 3, 10, false),
132        ];
133        let (q_lens, cu, m_total) = compute_cu_seqlens_q(&items);
134        assert_eq!(q_lens, vec![5, 1, 3]);
135        assert_eq!(cu, vec![0, 5, 6, 9]);
136        assert_eq!(m_total, 9);
137    }
138
139    #[test]
140    fn pos_offsets_and_max_kv_len() {
141        let items = vec![
142            item("a", 5, 0, true),
143            item("b", 1, 100, true),
144            item("c", 3, 10, false),
145        ];
146        assert_eq!(compute_pos_offsets(&items), vec![0u32, 100, 10]);
147        assert_eq!(compute_max_kv_len(&items), 101); // b: 100 + 1
148    }
149
150    #[test]
151    fn final_indices_only_final_chunks() {
152        let items = vec![
153            item("a", 5, 0, true),   // last token at global 4
154            item("b", 1, 100, true), // last at global 5
155            item("c", 3, 10, false), // not final
156        ];
157        let (_, cu, _) = compute_cu_seqlens_q(&items);
158        let fi = compute_final_indices(&items, &cu);
159        assert_eq!(fi, vec![(0, 4), (1, 5)]);
160    }
161
162    #[test]
163    fn graph_key_high_bit_set() {
164        let k = unified_graph_key(32, 4);
165        assert!(k & (1u64 << 63) != 0, "high bit must be set");
166        // Legacy key with same low bits should differ.
167        let legacy = ((32u64) << 32) | 4u64;
168        assert_ne!(k, legacy);
169    }
170
171    #[test]
172    fn stack_block_tables_pads_and_truncates() {
173        let items = vec![item("a", 1, 0, true), item("b", 1, 0, true)];
174        // Item a has 2 blocks; b has 5 but max_blocks_per_seq=3
175        let stacked = stack_block_tables(&items, 3, |cid| match cid {
176            "a" => vec![10u32, 11u32],
177            "b" => vec![20u32, 21u32, 22u32, 23u32, 24u32],
178            _ => unreachable!(),
179        });
180        // a: [10, 11, 0]  (padded with 0)
181        // b: [20, 21, 22] (truncated to 3)
182        assert_eq!(stacked, vec![10, 11, 0, 20, 21, 22]);
183    }
184
185    #[test]
186    fn empty_items() {
187        let items: Vec<(String, Vec<u32>, usize, bool)> = Vec::new();
188        let (q_lens, cu, m_total) = compute_cu_seqlens_q(&items);
189        assert!(q_lens.is_empty());
190        assert_eq!(cu, vec![0]);
191        assert_eq!(m_total, 0);
192    }
193}