Skip to main content

rlx_qwen35/
cache.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Decode-time recurrent state for the qwen35 hybrid trunk.
17
18use crate::config::Qwen35Config;
19
20/// Per-trunk-layer recurrent payload carried across decode steps.
21#[derive(Debug, Clone)]
22pub enum Qwen35LayerState {
23    /// Gated-DeltaNet + depthwise conv block.
24    Linear {
25        /// `[batch, k-1, conv_channels]` causal conv ring.
26        conv_state: Vec<f32>,
27        /// `[batch, n_v_heads, n_state, n_state]` SSM matrix per head.
28        ssm_state: Vec<f32>,
29    },
30    /// Standard attention block — pre-GQA K/V cache.
31    FullAttn {
32        /// `[batch, past_seq, n_kv * head_dim]`, post-RoPE K.
33        past_k: Vec<f32>,
34        /// `[batch, past_seq, n_kv * head_dim]`, pre-GQA V.
35        past_v: Vec<f32>,
36    },
37}
38
39/// Host-side decode cache seeded from a prefill-with-states forward.
40#[derive(Debug, Clone)]
41pub struct Qwen35DecodeCache {
42    pub batch: usize,
43    pub past_seq: usize,
44    /// Actual prompt length per batch row (before generation).
45    pub prompt_lens: Vec<usize>,
46    pub layers: Vec<Qwen35LayerState>,
47}
48
49impl Qwen35DecodeCache {
50    pub fn n_trunk(&self) -> usize {
51        self.layers.len()
52    }
53}
54
55/// Trunk layer kinds in declaration order (excludes MTP).
56pub fn trunk_layer_kinds(cfg: &Qwen35Config) -> Vec<bool> {
57    let n_main = cfg.num_hidden_layers - cfg.nextn_predict_layers;
58    let interval = cfg.full_attention_interval.max(1);
59    (0..n_main).map(|il| ((il + 1) % interval) == 0).collect()
60}
61
62/// Number of extra graph outputs after logits (and optional MTP).
63pub fn recurrent_output_count(cfg: &Qwen35Config) -> usize {
64    trunk_layer_kinds(cfg).len() * 2
65}
66
67/// Logit outputs before recurrent state exports: `[trunk, (optional mtp)]`.
68pub fn logit_output_count(with_mtp: bool) -> usize {
69    1 + usize::from(with_mtp)
70}
71
72fn truncate_logits_row(_cfg: &Qwen35Config, logits: Vec<f32>, _batch: usize) -> Vec<f32> {
73    // Graph LM head width comes from the embedding table (`lm_vocab_size`);
74    // do not clip to `cfg.vocab_size` (metadata can under-report vs GGUF).
75    logits
76}
77
78fn parse_mtp_logits(cfg: &Qwen35Config, batch: usize, mtp: Vec<f32>) -> anyhow::Result<Vec<f32>> {
79    use anyhow::bail;
80    let lm_vocab = mtp.len() / batch.max(1);
81    let expected = batch * lm_vocab;
82    if mtp.len() != expected {
83        bail!(
84            "mtp logits: len={} expected batch*lm_vocab={expected}",
85            mtp.len()
86        );
87    }
88    Ok(truncate_logits_row(cfg, mtp, batch))
89}
90
91/// Zero-initialized recurrent inputs for a prefill-cache seed graph.
92pub fn zero_recurrent_inputs(cfg: &Qwen35Config, batch: usize) -> Vec<(String, Vec<f32>)> {
93    let n_state = cfg.ssm_state_size;
94    let n_v_heads = cfg.ssm_time_step_rank;
95    let conv_channels = linear_conv_channels(cfg);
96    let k_conv = cfg.ssm_conv_kernel;
97    let head_dim = cfg.key_length;
98    let kv_cols = cfg.num_key_value_heads * head_dim;
99
100    let mut out = Vec::new();
101    for (il, is_full) in trunk_layer_kinds(cfg).into_iter().enumerate() {
102        if is_full {
103            let _ = kv_cols;
104            let _ = head_dim;
105            // Full-attn layers have no recurrent *inputs* on prefill seed.
106        } else {
107            out.push((
108                format!("conv_state_l{il}"),
109                vec![0f32; batch * (k_conv - 1) * conv_channels],
110            ));
111            out.push((
112                format!("ssm_state_l{il}"),
113                vec![0f32; batch * n_v_heads * n_state * n_state],
114            ));
115        }
116    }
117    out
118}
119
120fn linear_conv_channels(cfg: &Qwen35Config) -> usize {
121    let n_state = cfg.ssm_state_size;
122    let n_k_heads = cfg.ssm_group_count;
123    let n_v_heads = cfg.ssm_time_step_rank;
124    let key_dim = n_state * n_k_heads;
125    let value_dim = n_state * n_v_heads;
126    key_dim * 2 + value_dim
127}
128
129/// Build `[batch, bucket_upper + 1]` attention mask for bucketed decode.
130/// Positions before each row's valid prefix (prompt + generated) are 1.0.
131pub fn build_decode_attention_mask(
132    batch: usize,
133    past_seq: usize,
134    bucket_upper: usize,
135    prompt_lens: &[usize],
136    generated_per_row: &[usize],
137) -> Vec<f32> {
138    let mask_len = bucket_upper + 1;
139    let mut mask = vec![0f32; batch * mask_len];
140    for b in 0..batch {
141        let valid = prompt_lens.get(b).copied().unwrap_or(past_seq)
142            + generated_per_row.get(b).copied().unwrap_or(0);
143        let base = b * mask_len;
144        for t in 0..=past_seq.min(bucket_upper) {
145            if t < valid {
146                mask[base + t] = 1.0;
147            }
148        }
149    }
150    mask
151}
152
153/// Pad `[batch, actual, kv_cols]` K/V to `[batch, bucket_upper, kv_cols]`.
154pub fn pad_kv_to_bucket(
155    src: &[f32],
156    batch: usize,
157    actual_past: usize,
158    bucket_upper: usize,
159    kv_cols: usize,
160) -> Vec<f32> {
161    let mut out = vec![0f32; batch * bucket_upper * kv_cols];
162    for b in 0..batch {
163        let src_base = b * actual_past * kv_cols;
164        let dst_base = b * bucket_upper * kv_cols;
165        let copy_len = actual_past * kv_cols;
166        out[dst_base..dst_base + copy_len].copy_from_slice(&src[src_base..src_base + copy_len]);
167    }
168    out
169}
170
171/// Slice bucketed K/V outputs back to `[batch, actual_past, kv_cols]`.
172pub fn slice_kv_from_bucket(
173    src: &[f32],
174    batch: usize,
175    actual_past: usize,
176    bucket_upper: usize,
177    kv_cols: usize,
178) -> anyhow::Result<Vec<f32>> {
179    use anyhow::bail;
180    // Decode graphs concat padded `past_k` `[batch, bucket_upper, kv]` with the
181    // new token → `[batch, bucket_upper + 1, kv]` row-major layout.
182    let out_seq = bucket_upper.saturating_add(1);
183    let mut out = vec![0f32; batch * actual_past * kv_cols];
184    for b in 0..batch {
185        let src_base = b * out_seq * kv_cols;
186        let dst_base = b * actual_past * kv_cols;
187        let copy_len = actual_past * kv_cols;
188        let end = src_base + copy_len;
189        if end > src.len() {
190            bail!(
191                "slice_kv_from_bucket: need {end} floats in bucket output, got {} \
192                 (batch={batch}, actual_past={actual_past}, bucket_upper={bucket_upper})",
193                src.len()
194            );
195        }
196        out[dst_base..dst_base + copy_len].copy_from_slice(&src[src_base..end]);
197    }
198    Ok(out)
199}
200
201/// Zero padded prompt positions in full-attention KV (variable-length batch).
202pub fn zero_prompt_padding_kv(
203    cfg: &Qwen35Config,
204    cache: &mut Qwen35DecodeCache,
205    padded_seq: usize,
206) {
207    let head_dim = cfg.key_length;
208    let kv_cols = cfg.num_key_value_heads * head_dim;
209    let kinds = trunk_layer_kinds(cfg);
210    for (il, layer) in cache.layers.iter_mut().enumerate() {
211        if !kinds[il] {
212            continue;
213        }
214        if let Qwen35LayerState::FullAttn { past_k, past_v } = layer {
215            for b in 0..cache.batch {
216                let prompt_len = cache.prompt_lens.get(b).copied().unwrap_or(padded_seq);
217                if prompt_len >= padded_seq {
218                    continue;
219                }
220                for t in prompt_len..padded_seq {
221                    let start = b * padded_seq * kv_cols + t * kv_cols;
222                    past_k[start..start + kv_cols].fill(0.0);
223                    past_v[start..start + kv_cols].fill(0.0);
224                }
225            }
226        }
227    }
228}
229
230/// Build host feeds for a single decode step from `cache`.
231///
232/// `tokens` must have length `cache.batch` — one next-token id per row.
233/// When `bucket_upper` is `Some`, pads K/V and supplies a custom mask.
234pub fn decode_step_feeds(
235    cfg: &Qwen35Config,
236    cache: &Qwen35DecodeCache,
237    tokens: &[u32],
238    rope_cos: &[f32],
239    rope_sin: &[f32],
240    bucket_upper: Option<usize>,
241    generated_per_row: &[usize],
242) -> anyhow::Result<Vec<(String, Vec<f32>)>> {
243    use anyhow::bail;
244
245    if tokens.len() != cache.batch {
246        bail!(
247            "decode_step_feeds: expected {} tokens, got {}",
248            cache.batch,
249            tokens.len()
250        );
251    }
252    let mut feeds = vec![
253        (
254            "input_ids".into(),
255            tokens.iter().map(|&t| t as f32).collect(),
256        ),
257        ("rope_cos".into(), rope_cos.to_vec()),
258        ("rope_sin".into(), rope_sin.to_vec()),
259    ];
260    if let Some(upper) = bucket_upper {
261        let mask = build_decode_attention_mask(
262            cache.batch,
263            cache.past_seq,
264            upper,
265            &cache.prompt_lens,
266            generated_per_row,
267        );
268        feeds.push(("mask".into(), mask));
269    }
270    let head_dim = cfg.key_length;
271    let kv_cols = cfg.num_key_value_heads * head_dim;
272    let kinds = trunk_layer_kinds(cfg);
273    for (il, layer) in cache.layers.iter().enumerate() {
274        let is_full = kinds[il];
275        match (layer, is_full) {
276            (
277                Qwen35LayerState::Linear {
278                    conv_state,
279                    ssm_state,
280                },
281                false,
282            ) => {
283                feeds.push((format!("conv_state_l{il}"), conv_state.clone()));
284                feeds.push((format!("ssm_state_l{il}"), ssm_state.clone()));
285            }
286            (Qwen35LayerState::FullAttn { past_k, past_v }, true) => {
287                if let Some(upper) = bucket_upper {
288                    feeds.push((
289                        format!("past_k_l{il}"),
290                        pad_kv_to_bucket(past_k, cache.batch, cache.past_seq, upper, kv_cols),
291                    ));
292                    feeds.push((
293                        format!("past_v_l{il}"),
294                        pad_kv_to_bucket(past_v, cache.batch, cache.past_seq, upper, kv_cols),
295                    ));
296                } else {
297                    feeds.push((format!("past_k_l{il}"), past_k.clone()));
298                    feeds.push((format!("past_v_l{il}"), past_v.clone()));
299                }
300            }
301            _ => {}
302        }
303    }
304    Ok(feeds)
305}
306
307/// Parse prefill-cache graph outputs into logits/hidden + [`Qwen35DecodeCache`].
308/// When `trunk_is_hidden`, the first output is `[batch × hidden_size]` not logits.
309pub fn seed_cache_from_outputs(
310    cfg: &Qwen35Config,
311    batch: usize,
312    seq: usize,
313    prompt_lens: &[usize],
314    outputs: Vec<Vec<f32>>,
315    with_mtp: bool,
316    trunk_is_hidden: bool,
317) -> anyhow::Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
318    use anyhow::{Context, bail};
319    let n_head = logit_output_count(with_mtp);
320    let n_extra = recurrent_output_count(cfg);
321    if outputs.len() != n_head + n_extra {
322        bail!(
323            "prefill-cache: expected {} outputs, got {}",
324            n_head + n_extra,
325            outputs.len()
326        );
327    }
328    let mut iter = outputs.into_iter();
329    let trunk = iter.next().context("trunk head output missing")?;
330    let head_dim = cfg.key_length;
331    let kv_cols = cfg.num_key_value_heads * head_dim;
332    let logits = if trunk_is_hidden {
333        let n = cfg.hidden_size;
334        let expected_last = batch * n;
335        let expected_full = batch * seq * n;
336        if trunk.len() == expected_last {
337            trunk
338        } else if trunk.len() == expected_full
339            || (trunk.len().is_multiple_of(n)
340                && trunk.len() >= batch.max(1) * n
341                && trunk.len() % (batch.max(1) * n) == 0)
342        {
343            let row_stride = trunk.len() / batch.max(1);
344            let seq_dim = row_stride / n;
345            if batch > 1 && !prompt_lens.is_empty() {
346                let mut out = Vec::with_capacity(batch * n);
347                for b in 0..batch {
348                    let pl = prompt_lens.get(b).copied().unwrap_or(seq).min(seq_dim);
349                    let idx = pl.saturating_sub(1);
350                    let off = b * row_stride + idx * n;
351                    out.extend_from_slice(&trunk[off..off + n]);
352                }
353                out
354            } else if !prompt_lens.is_empty() {
355                let last_pl = *prompt_lens.iter().max().unwrap_or(&seq);
356                let idx = last_pl.saturating_sub(1).min(seq_dim.saturating_sub(1));
357                let off = idx * n;
358                trunk[off..off + n].to_vec()
359            } else {
360                trunk[expected_full.saturating_sub(n)..].to_vec()
361            }
362        } else {
363            bail!(
364                "prefill-cache hidden: len={} expected batch*hidden={expected_last} \
365                 or batch*seq*hidden={expected_full} (or padded max_seq layout)",
366                trunk.len()
367            );
368        }
369    } else {
370        let lm_vocab = trunk.len() / batch.max(1);
371        let expected_logits = batch * lm_vocab;
372        if trunk.len() != expected_logits {
373            bail!(
374                "prefill-cache logits: len={} expected batch*lm_vocab={expected_logits} \
375                 (batch={batch}, lm_vocab={lm_vocab})",
376                trunk.len()
377            );
378        }
379        truncate_logits_row(cfg, trunk, batch)
380    };
381    let mtp_logits = if with_mtp {
382        Some(parse_mtp_logits(
383            cfg,
384            batch,
385            iter.next().context("mtp logits missing")?,
386        )?)
387    } else {
388        None
389    };
390
391    let mut layers = Vec::with_capacity(trunk_layer_kinds(cfg).len());
392    for (il, is_full) in trunk_layer_kinds(cfg).into_iter().enumerate() {
393        if is_full {
394            let k = iter.next().context("past_k missing")?;
395            let v = iter.next().context("past_v missing")?;
396            let expected = batch * seq * kv_cols;
397            let (past_k, past_v) = if k.len() == expected && v.len() == expected {
398                (k, v)
399            } else if k.len() % kv_cols == 0 && v.len() % kv_cols == 0 {
400                let k_bucket = k.len() / (batch.max(1) * kv_cols);
401                let v_bucket = v.len() / (batch.max(1) * kv_cols);
402                if k_bucket >= seq && v_bucket >= seq {
403                    (
404                        slice_kv_from_bucket(&k, batch, seq, k_bucket, kv_cols)?,
405                        slice_kv_from_bucket(&v, batch, seq, v_bucket, kv_cols)?,
406                    )
407                } else {
408                    bail!(
409                        "layer {il} kv: k.len={} v.len={} expected {expected} \
410                         (k_bucket={k_bucket} v_bucket={v_bucket} < seq={seq})",
411                        k.len(),
412                        v.len()
413                    );
414                }
415            } else {
416                bail!(
417                    "layer {il} kv: k.len={} v.len={} expected {expected}",
418                    k.len(),
419                    v.len()
420                );
421            };
422            layers.push(Qwen35LayerState::FullAttn { past_k, past_v });
423        } else {
424            let conv = iter.next().context("conv_state missing")?;
425            let ssm = iter.next().context("ssm_state missing")?;
426            let conv_ring =
427                batch * (cfg.ssm_conv_kernel.saturating_sub(1)) * linear_conv_channels(cfg);
428            let conv_state = if conv.len() == conv_ring {
429                conv
430            } else {
431                bail!(
432                    "layer {il} conv_state: len={} expected {conv_ring}",
433                    conv.len()
434                );
435            };
436            layers.push(Qwen35LayerState::Linear {
437                conv_state,
438                ssm_state: ssm,
439            });
440        }
441    }
442    Ok((
443        logits,
444        Qwen35DecodeCache {
445            batch,
446            past_seq: seq,
447            prompt_lens: prompt_lens.to_vec(),
448            layers,
449        },
450        mtp_logits,
451    ))
452}
453
454/// Advance `cache` from decode-graph outputs (logits or normed hidden + states).
455/// When `trunk_is_hidden`, the first output is `[batch × hidden_size]` not logits.
456pub fn advance_cache_from_decode_outputs(
457    cfg: &Qwen35Config,
458    cache: &mut Qwen35DecodeCache,
459    outputs: Vec<Vec<f32>>,
460    bucket_upper: Option<usize>,
461    mtp_logits_path: bool,
462    want_mtp: bool,
463    trunk_is_hidden: bool,
464) -> anyhow::Result<(Vec<f32>, Option<Vec<f32>>)> {
465    use anyhow::{Context, bail};
466    let n_head = logit_output_count(mtp_logits_path);
467    let n_extra = recurrent_output_count(cfg);
468    if outputs.len() != n_head + n_extra {
469        bail!(
470            "decode: expected {} outputs, got {}",
471            n_head + n_extra,
472            outputs.len()
473        );
474    }
475    let mut iter = outputs.into_iter();
476    let trunk = iter.next().context("trunk head output missing")?;
477    let new_past = cache.past_seq + 1;
478    let head_dim = cfg.key_length;
479    let kv_cols = cfg.num_key_value_heads * head_dim;
480    let batch = cache.batch;
481
482    let trunk_out = if trunk_is_hidden {
483        let expected = batch * cfg.hidden_size;
484        if trunk.len() != expected {
485            bail!(
486                "decode hidden: len={} expected batch*hidden={expected}",
487                trunk.len()
488            );
489        }
490        trunk
491    } else {
492        let lm_vocab = trunk.len() / batch.max(1);
493        let expected_logits = batch * lm_vocab;
494        if trunk.len() != expected_logits {
495            bail!(
496                "decode logits: len={} expected batch*lm_vocab={expected_logits}",
497                trunk.len()
498            );
499        }
500        truncate_logits_row(cfg, trunk, batch)
501    };
502    let mtp_logits = if mtp_logits_path {
503        let raw = iter.next().context("mtp logits missing")?;
504        if want_mtp {
505            Some(parse_mtp_logits(cfg, batch, raw)?)
506        } else {
507            None
508        }
509    } else {
510        None
511    };
512
513    let mut new_layers = Vec::with_capacity(cache.layers.len());
514    let kinds = trunk_layer_kinds(cfg);
515    for (il, layer) in cache.layers.iter().enumerate() {
516        let is_full = kinds[il];
517        if is_full {
518            let k = iter.next().context("new_k missing")?;
519            let v = iter.next().context("new_v missing")?;
520            let (k, v) = if let Some(upper) = bucket_upper {
521                (
522                    slice_kv_from_bucket(&k, batch, new_past, upper, kv_cols)?,
523                    slice_kv_from_bucket(&v, batch, new_past, upper, kv_cols)?,
524                )
525            } else {
526                (k, v)
527            };
528            let expected = batch * new_past * kv_cols;
529            if k.len() != expected || v.len() != expected {
530                bail!(
531                    "layer {il} kv: k.len={} v.len={} expected {expected}",
532                    k.len(),
533                    v.len()
534                );
535            }
536            new_layers.push(Qwen35LayerState::FullAttn {
537                past_k: k,
538                past_v: v,
539            });
540            let _ = layer;
541        } else {
542            let conv = iter.next().context("conv_state missing")?;
543            let ssm = iter.next().context("ssm_state missing")?;
544            new_layers.push(Qwen35LayerState::Linear {
545                conv_state: conv,
546                ssm_state: ssm,
547            });
548        }
549    }
550    cache.past_seq = new_past;
551    cache.layers = new_layers;
552    Ok((trunk_out, mtp_logits))
553}
554
555/// Describe per-layer buffer sizes for a config (trunk only).
556#[allow(dead_code)]
557pub fn trunk_layer_state_sizes(cfg: &Qwen35Config) -> Vec<(bool, usize, usize)> {
558    let n_main = cfg.num_hidden_layers - cfg.nextn_predict_layers;
559    let interval = cfg.full_attention_interval.max(1);
560    let n_state = cfg.ssm_state_size;
561    let n_v_heads = cfg.ssm_time_step_rank;
562    let conv_channels = linear_conv_channels(cfg);
563    let k_conv = cfg.ssm_conv_kernel;
564
565    let mut out = Vec::with_capacity(n_main);
566    for il in 0..n_main {
567        let is_full_attn = ((il + 1) % interval) == 0;
568        if is_full_attn {
569            out.push((true, 0, 0));
570        } else {
571            out.push((
572                false,
573                (k_conv - 1) * conv_channels,
574                n_v_heads * n_state * n_state,
575            ));
576        }
577    }
578    out
579}
580
581/// Pack per-row prompts into `[batch, max_seq]` row-major F32 ids (zero-pad).
582pub fn pack_input_ids(batch_prompts: &[Vec<u32>], max_seq: usize) -> anyhow::Result<Vec<f32>> {
583    use anyhow::bail;
584    if batch_prompts.is_empty() {
585        bail!("pack_input_ids: batch must be non-empty");
586    }
587    let batch = batch_prompts.len();
588    let mut out = vec![0f32; batch * max_seq];
589    for (b, prompt) in batch_prompts.iter().enumerate() {
590        if prompt.len() > max_seq {
591            bail!(
592                "pack_input_ids: row {b} length {} exceeds max_seq={max_seq}",
593                prompt.len()
594            );
595        }
596        let base = b * max_seq;
597        for (i, &t) in prompt.iter().enumerate() {
598            out[base + i] = t as f32;
599        }
600    }
601    Ok(out)
602}
603
604/// Per-row index of the last real prompt token (0-based).
605pub fn last_token_indices(prompt_lens: &[usize]) -> Vec<f32> {
606    prompt_lens
607        .iter()
608        .map(|&l| l.saturating_sub(1) as f32)
609        .collect()
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615
616    fn one_full_attn_cfg() -> Qwen35Config {
617        Qwen35Config {
618            vocab_size: 16,
619            hidden_size: 4,
620            intermediate_size: 8,
621            num_hidden_layers: 1,
622            nextn_predict_layers: 0,
623            num_attention_heads: 2,
624            num_key_value_heads: 2,
625            key_length: 2,
626            value_length: 2,
627            max_position_embeddings: 64,
628            rms_norm_eps: 1e-6,
629            rope_theta: 10_000.0,
630            rope_dim_count: 2,
631            rope_dim_sections: vec![],
632            full_attention_interval: 1,
633            ssm_conv_kernel: 4,
634            ssm_group_count: 2,
635            ssm_inner_size: 8,
636            ssm_state_size: 4,
637            ssm_time_step_rank: 2,
638            tie_word_embeddings: true,
639            num_experts: 0,
640            num_experts_used: 0,
641            expert_ffn_size: 0,
642            shared_expert_ffn_size: 0,
643            expert_weights_scale: 1.0,
644        }
645    }
646
647    #[test]
648    fn advance_decode_consumes_mtp_before_kv_states() {
649        let cfg = one_full_attn_cfg();
650        let batch = 1;
651        let past_seq = 1;
652        let kv_cols = cfg.num_key_value_heads * cfg.key_length;
653        let new_past = past_seq + 1;
654        let kv_len = batch * new_past * kv_cols;
655
656        let mut cache = Qwen35DecodeCache {
657            batch,
658            past_seq,
659            prompt_lens: vec![past_seq],
660            layers: vec![Qwen35LayerState::FullAttn {
661                past_k: vec![0.0; batch * past_seq * kv_cols],
662                past_v: vec![0.0; batch * past_seq * kv_cols],
663            }],
664        };
665
666        let trunk_logits = vec![1.0; batch * cfg.vocab_size];
667        let mtp_logits = vec![2.0; batch * cfg.vocab_size];
668        assert_ne!(
669            mtp_logits.len(),
670            kv_len,
671            "test needs distinct mtp vs kv lengths"
672        );
673        let new_k = vec![3.0; kv_len];
674        let new_v = vec![4.0; kv_len];
675
676        let outputs = vec![
677            trunk_logits.clone(),
678            mtp_logits.clone(),
679            new_k.clone(),
680            new_v.clone(),
681        ];
682        let (trunk_out, mtp) =
683            advance_cache_from_decode_outputs(&cfg, &mut cache, outputs, None, true, true, false)
684                .unwrap();
685        assert_eq!(trunk_out, trunk_logits);
686        assert_eq!(mtp.unwrap(), mtp_logits);
687        assert_eq!(cache.past_seq, new_past);
688        match &cache.layers[0] {
689            Qwen35LayerState::FullAttn { past_k, past_v } => {
690                assert_eq!(past_k, &new_k);
691                assert_eq!(past_v, &new_v);
692            }
693            _ => panic!("expected full-attn layer"),
694        }
695
696        let mut cache2 = cache.clone();
697        cache2.past_seq = past_seq;
698        let bad = vec![trunk_logits, new_k, new_v, mtp_logits];
699        assert!(
700            advance_cache_from_decode_outputs(&cfg, &mut cache2, bad, None, true, true, false)
701                .is_err()
702        );
703    }
704
705    #[test]
706    fn advance_decode_discards_mtp_when_not_wanted() {
707        let cfg = one_full_attn_cfg();
708        let batch = 1;
709        let kv_cols = cfg.num_key_value_heads * cfg.key_length;
710        let kv_len = batch * 2 * kv_cols;
711
712        let mut cache = Qwen35DecodeCache {
713            batch,
714            past_seq: 1,
715            prompt_lens: vec![1],
716            layers: vec![Qwen35LayerState::FullAttn {
717                past_k: vec![0.0; batch * kv_cols],
718                past_v: vec![0.0; batch * kv_cols],
719            }],
720        };
721
722        let outputs = vec![
723            vec![0.0; batch * cfg.vocab_size],
724            vec![1.0; batch * cfg.vocab_size],
725            vec![2.0; kv_len],
726            vec![3.0; kv_len],
727        ];
728        let (_, mtp) =
729            advance_cache_from_decode_outputs(&cfg, &mut cache, outputs, None, true, false, false)
730                .unwrap();
731        assert!(mtp.is_none());
732    }
733}