1use crate::config::Qwen35Config;
19
20#[derive(Debug, Clone)]
22pub enum Qwen35LayerState {
23 Linear {
25 conv_state: Vec<f32>,
27 ssm_state: Vec<f32>,
29 },
30 FullAttn {
32 past_k: Vec<f32>,
34 past_v: Vec<f32>,
36 },
37}
38
39#[derive(Debug, Clone)]
41pub struct Qwen35DecodeCache {
42 pub batch: usize,
43 pub past_seq: usize,
44 pub prompt_lens: Vec<usize>,
46 pub layers: Vec<Qwen35LayerState>,
47}
48
49impl Qwen35DecodeCache {
50 pub fn n_trunk(&self) -> usize {
51 self.layers.len()
52 }
53}
54
55pub fn trunk_layer_kinds(cfg: &Qwen35Config) -> Vec<bool> {
57 let n_main = cfg.num_hidden_layers - cfg.nextn_predict_layers;
58 let interval = cfg.full_attention_interval.max(1);
59 (0..n_main).map(|il| ((il + 1) % interval) == 0).collect()
60}
61
62pub fn recurrent_output_count(cfg: &Qwen35Config) -> usize {
64 trunk_layer_kinds(cfg).len() * 2
65}
66
67pub fn logit_output_count(with_mtp: bool) -> usize {
69 1 + usize::from(with_mtp)
70}
71
72fn truncate_logits_row(_cfg: &Qwen35Config, logits: Vec<f32>, _batch: usize) -> Vec<f32> {
73 logits
76}
77
78fn parse_mtp_logits(cfg: &Qwen35Config, batch: usize, mtp: Vec<f32>) -> anyhow::Result<Vec<f32>> {
79 use anyhow::bail;
80 let lm_vocab = mtp.len() / batch.max(1);
81 let expected = batch * lm_vocab;
82 if mtp.len() != expected {
83 bail!(
84 "mtp logits: len={} expected batch*lm_vocab={expected}",
85 mtp.len()
86 );
87 }
88 Ok(truncate_logits_row(cfg, mtp, batch))
89}
90
91pub fn zero_recurrent_inputs(cfg: &Qwen35Config, batch: usize) -> Vec<(String, Vec<f32>)> {
93 let n_state = cfg.ssm_state_size;
94 let n_v_heads = cfg.ssm_time_step_rank;
95 let conv_channels = linear_conv_channels(cfg);
96 let k_conv = cfg.ssm_conv_kernel;
97 let head_dim = cfg.key_length;
98 let kv_cols = cfg.num_key_value_heads * head_dim;
99
100 let mut out = Vec::new();
101 for (il, is_full) in trunk_layer_kinds(cfg).into_iter().enumerate() {
102 if is_full {
103 let _ = kv_cols;
104 let _ = head_dim;
105 } else {
107 out.push((
108 format!("conv_state_l{il}"),
109 vec![0f32; batch * (k_conv - 1) * conv_channels],
110 ));
111 out.push((
112 format!("ssm_state_l{il}"),
113 vec![0f32; batch * n_v_heads * n_state * n_state],
114 ));
115 }
116 }
117 out
118}
119
120fn linear_conv_channels(cfg: &Qwen35Config) -> usize {
121 let n_state = cfg.ssm_state_size;
122 let n_k_heads = cfg.ssm_group_count;
123 let n_v_heads = cfg.ssm_time_step_rank;
124 let key_dim = n_state * n_k_heads;
125 let value_dim = n_state * n_v_heads;
126 key_dim * 2 + value_dim
127}
128
129pub fn build_decode_attention_mask(
132 batch: usize,
133 past_seq: usize,
134 bucket_upper: usize,
135 prompt_lens: &[usize],
136 generated_per_row: &[usize],
137) -> Vec<f32> {
138 let mask_len = bucket_upper + 1;
139 let mut mask = vec![0f32; batch * mask_len];
140 for b in 0..batch {
141 let valid = prompt_lens.get(b).copied().unwrap_or(past_seq)
142 + generated_per_row.get(b).copied().unwrap_or(0);
143 let base = b * mask_len;
144 for t in 0..=past_seq.min(bucket_upper) {
145 if t < valid {
146 mask[base + t] = 1.0;
147 }
148 }
149 }
150 mask
151}
152
153pub fn pad_kv_to_bucket(
155 src: &[f32],
156 batch: usize,
157 actual_past: usize,
158 bucket_upper: usize,
159 kv_cols: usize,
160) -> Vec<f32> {
161 let mut out = vec![0f32; batch * bucket_upper * kv_cols];
162 for b in 0..batch {
163 let src_base = b * actual_past * kv_cols;
164 let dst_base = b * bucket_upper * kv_cols;
165 let copy_len = actual_past * kv_cols;
166 out[dst_base..dst_base + copy_len].copy_from_slice(&src[src_base..src_base + copy_len]);
167 }
168 out
169}
170
171pub fn slice_kv_from_bucket(
173 src: &[f32],
174 batch: usize,
175 actual_past: usize,
176 bucket_upper: usize,
177 kv_cols: usize,
178) -> anyhow::Result<Vec<f32>> {
179 use anyhow::bail;
180 let out_seq = bucket_upper.saturating_add(1);
183 let mut out = vec![0f32; batch * actual_past * kv_cols];
184 for b in 0..batch {
185 let src_base = b * out_seq * kv_cols;
186 let dst_base = b * actual_past * kv_cols;
187 let copy_len = actual_past * kv_cols;
188 let end = src_base + copy_len;
189 if end > src.len() {
190 bail!(
191 "slice_kv_from_bucket: need {end} floats in bucket output, got {} \
192 (batch={batch}, actual_past={actual_past}, bucket_upper={bucket_upper})",
193 src.len()
194 );
195 }
196 out[dst_base..dst_base + copy_len].copy_from_slice(&src[src_base..end]);
197 }
198 Ok(out)
199}
200
201pub fn zero_prompt_padding_kv(
203 cfg: &Qwen35Config,
204 cache: &mut Qwen35DecodeCache,
205 padded_seq: usize,
206) {
207 let head_dim = cfg.key_length;
208 let kv_cols = cfg.num_key_value_heads * head_dim;
209 let kinds = trunk_layer_kinds(cfg);
210 for (il, layer) in cache.layers.iter_mut().enumerate() {
211 if !kinds[il] {
212 continue;
213 }
214 if let Qwen35LayerState::FullAttn { past_k, past_v } = layer {
215 for b in 0..cache.batch {
216 let prompt_len = cache.prompt_lens.get(b).copied().unwrap_or(padded_seq);
217 if prompt_len >= padded_seq {
218 continue;
219 }
220 for t in prompt_len..padded_seq {
221 let start = b * padded_seq * kv_cols + t * kv_cols;
222 past_k[start..start + kv_cols].fill(0.0);
223 past_v[start..start + kv_cols].fill(0.0);
224 }
225 }
226 }
227 }
228}
229
230pub fn decode_step_feeds(
235 cfg: &Qwen35Config,
236 cache: &Qwen35DecodeCache,
237 tokens: &[u32],
238 rope_cos: &[f32],
239 rope_sin: &[f32],
240 bucket_upper: Option<usize>,
241 generated_per_row: &[usize],
242) -> anyhow::Result<Vec<(String, Vec<f32>)>> {
243 use anyhow::bail;
244
245 if tokens.len() != cache.batch {
246 bail!(
247 "decode_step_feeds: expected {} tokens, got {}",
248 cache.batch,
249 tokens.len()
250 );
251 }
252 let mut feeds = vec![
253 (
254 "input_ids".into(),
255 tokens.iter().map(|&t| t as f32).collect(),
256 ),
257 ("rope_cos".into(), rope_cos.to_vec()),
258 ("rope_sin".into(), rope_sin.to_vec()),
259 ];
260 if let Some(upper) = bucket_upper {
261 let mask = build_decode_attention_mask(
262 cache.batch,
263 cache.past_seq,
264 upper,
265 &cache.prompt_lens,
266 generated_per_row,
267 );
268 feeds.push(("mask".into(), mask));
269 }
270 let head_dim = cfg.key_length;
271 let kv_cols = cfg.num_key_value_heads * head_dim;
272 let kinds = trunk_layer_kinds(cfg);
273 for (il, layer) in cache.layers.iter().enumerate() {
274 let is_full = kinds[il];
275 match (layer, is_full) {
276 (
277 Qwen35LayerState::Linear {
278 conv_state,
279 ssm_state,
280 },
281 false,
282 ) => {
283 feeds.push((format!("conv_state_l{il}"), conv_state.clone()));
284 feeds.push((format!("ssm_state_l{il}"), ssm_state.clone()));
285 }
286 (Qwen35LayerState::FullAttn { past_k, past_v }, true) => {
287 if let Some(upper) = bucket_upper {
288 feeds.push((
289 format!("past_k_l{il}"),
290 pad_kv_to_bucket(past_k, cache.batch, cache.past_seq, upper, kv_cols),
291 ));
292 feeds.push((
293 format!("past_v_l{il}"),
294 pad_kv_to_bucket(past_v, cache.batch, cache.past_seq, upper, kv_cols),
295 ));
296 } else {
297 feeds.push((format!("past_k_l{il}"), past_k.clone()));
298 feeds.push((format!("past_v_l{il}"), past_v.clone()));
299 }
300 }
301 _ => {}
302 }
303 }
304 Ok(feeds)
305}
306
307pub fn seed_cache_from_outputs(
310 cfg: &Qwen35Config,
311 batch: usize,
312 seq: usize,
313 prompt_lens: &[usize],
314 outputs: Vec<Vec<f32>>,
315 with_mtp: bool,
316 trunk_is_hidden: bool,
317) -> anyhow::Result<(Vec<f32>, Qwen35DecodeCache, Option<Vec<f32>>)> {
318 use anyhow::{Context, bail};
319 let n_head = logit_output_count(with_mtp);
320 let n_extra = recurrent_output_count(cfg);
321 if outputs.len() != n_head + n_extra {
322 bail!(
323 "prefill-cache: expected {} outputs, got {}",
324 n_head + n_extra,
325 outputs.len()
326 );
327 }
328 let mut iter = outputs.into_iter();
329 let trunk = iter.next().context("trunk head output missing")?;
330 let head_dim = cfg.key_length;
331 let kv_cols = cfg.num_key_value_heads * head_dim;
332 let logits = if trunk_is_hidden {
333 let n = cfg.hidden_size;
334 let expected_last = batch * n;
335 let expected_full = batch * seq * n;
336 if trunk.len() == expected_last {
337 trunk
338 } else if trunk.len() == expected_full
339 || (trunk.len().is_multiple_of(n)
340 && trunk.len() >= batch.max(1) * n
341 && trunk.len() % (batch.max(1) * n) == 0)
342 {
343 let row_stride = trunk.len() / batch.max(1);
344 let seq_dim = row_stride / n;
345 if batch > 1 && !prompt_lens.is_empty() {
346 let mut out = Vec::with_capacity(batch * n);
347 for b in 0..batch {
348 let pl = prompt_lens.get(b).copied().unwrap_or(seq).min(seq_dim);
349 let idx = pl.saturating_sub(1);
350 let off = b * row_stride + idx * n;
351 out.extend_from_slice(&trunk[off..off + n]);
352 }
353 out
354 } else if !prompt_lens.is_empty() {
355 let last_pl = *prompt_lens.iter().max().unwrap_or(&seq);
356 let idx = last_pl.saturating_sub(1).min(seq_dim.saturating_sub(1));
357 let off = idx * n;
358 trunk[off..off + n].to_vec()
359 } else {
360 trunk[expected_full.saturating_sub(n)..].to_vec()
361 }
362 } else {
363 bail!(
364 "prefill-cache hidden: len={} expected batch*hidden={expected_last} \
365 or batch*seq*hidden={expected_full} (or padded max_seq layout)",
366 trunk.len()
367 );
368 }
369 } else {
370 let lm_vocab = trunk.len() / batch.max(1);
371 let expected_logits = batch * lm_vocab;
372 if trunk.len() != expected_logits {
373 bail!(
374 "prefill-cache logits: len={} expected batch*lm_vocab={expected_logits} \
375 (batch={batch}, lm_vocab={lm_vocab})",
376 trunk.len()
377 );
378 }
379 truncate_logits_row(cfg, trunk, batch)
380 };
381 let mtp_logits = if with_mtp {
382 Some(parse_mtp_logits(
383 cfg,
384 batch,
385 iter.next().context("mtp logits missing")?,
386 )?)
387 } else {
388 None
389 };
390
391 let mut layers = Vec::with_capacity(trunk_layer_kinds(cfg).len());
392 for (il, is_full) in trunk_layer_kinds(cfg).into_iter().enumerate() {
393 if is_full {
394 let k = iter.next().context("past_k missing")?;
395 let v = iter.next().context("past_v missing")?;
396 let expected = batch * seq * kv_cols;
397 let (past_k, past_v) = if k.len() == expected && v.len() == expected {
398 (k, v)
399 } else if k.len() % kv_cols == 0 && v.len() % kv_cols == 0 {
400 let k_bucket = k.len() / (batch.max(1) * kv_cols);
401 let v_bucket = v.len() / (batch.max(1) * kv_cols);
402 if k_bucket >= seq && v_bucket >= seq {
403 (
404 slice_kv_from_bucket(&k, batch, seq, k_bucket, kv_cols)?,
405 slice_kv_from_bucket(&v, batch, seq, v_bucket, kv_cols)?,
406 )
407 } else {
408 bail!(
409 "layer {il} kv: k.len={} v.len={} expected {expected} \
410 (k_bucket={k_bucket} v_bucket={v_bucket} < seq={seq})",
411 k.len(),
412 v.len()
413 );
414 }
415 } else {
416 bail!(
417 "layer {il} kv: k.len={} v.len={} expected {expected}",
418 k.len(),
419 v.len()
420 );
421 };
422 layers.push(Qwen35LayerState::FullAttn { past_k, past_v });
423 } else {
424 let conv = iter.next().context("conv_state missing")?;
425 let ssm = iter.next().context("ssm_state missing")?;
426 let conv_ring =
427 batch * (cfg.ssm_conv_kernel.saturating_sub(1)) * linear_conv_channels(cfg);
428 let conv_state = if conv.len() == conv_ring {
429 conv
430 } else {
431 bail!(
432 "layer {il} conv_state: len={} expected {conv_ring}",
433 conv.len()
434 );
435 };
436 layers.push(Qwen35LayerState::Linear {
437 conv_state,
438 ssm_state: ssm,
439 });
440 }
441 }
442 Ok((
443 logits,
444 Qwen35DecodeCache {
445 batch,
446 past_seq: seq,
447 prompt_lens: prompt_lens.to_vec(),
448 layers,
449 },
450 mtp_logits,
451 ))
452}
453
454pub fn advance_cache_from_decode_outputs(
457 cfg: &Qwen35Config,
458 cache: &mut Qwen35DecodeCache,
459 outputs: Vec<Vec<f32>>,
460 bucket_upper: Option<usize>,
461 mtp_logits_path: bool,
462 want_mtp: bool,
463 trunk_is_hidden: bool,
464) -> anyhow::Result<(Vec<f32>, Option<Vec<f32>>)> {
465 use anyhow::{Context, bail};
466 let n_head = logit_output_count(mtp_logits_path);
467 let n_extra = recurrent_output_count(cfg);
468 if outputs.len() != n_head + n_extra {
469 bail!(
470 "decode: expected {} outputs, got {}",
471 n_head + n_extra,
472 outputs.len()
473 );
474 }
475 let mut iter = outputs.into_iter();
476 let trunk = iter.next().context("trunk head output missing")?;
477 let new_past = cache.past_seq + 1;
478 let head_dim = cfg.key_length;
479 let kv_cols = cfg.num_key_value_heads * head_dim;
480 let batch = cache.batch;
481
482 let trunk_out = if trunk_is_hidden {
483 let expected = batch * cfg.hidden_size;
484 if trunk.len() != expected {
485 bail!(
486 "decode hidden: len={} expected batch*hidden={expected}",
487 trunk.len()
488 );
489 }
490 trunk
491 } else {
492 let lm_vocab = trunk.len() / batch.max(1);
493 let expected_logits = batch * lm_vocab;
494 if trunk.len() != expected_logits {
495 bail!(
496 "decode logits: len={} expected batch*lm_vocab={expected_logits}",
497 trunk.len()
498 );
499 }
500 truncate_logits_row(cfg, trunk, batch)
501 };
502 let mtp_logits = if mtp_logits_path {
503 let raw = iter.next().context("mtp logits missing")?;
504 if want_mtp {
505 Some(parse_mtp_logits(cfg, batch, raw)?)
506 } else {
507 None
508 }
509 } else {
510 None
511 };
512
513 let mut new_layers = Vec::with_capacity(cache.layers.len());
514 let kinds = trunk_layer_kinds(cfg);
515 for (il, layer) in cache.layers.iter().enumerate() {
516 let is_full = kinds[il];
517 if is_full {
518 let k = iter.next().context("new_k missing")?;
519 let v = iter.next().context("new_v missing")?;
520 let (k, v) = if let Some(upper) = bucket_upper {
521 (
522 slice_kv_from_bucket(&k, batch, new_past, upper, kv_cols)?,
523 slice_kv_from_bucket(&v, batch, new_past, upper, kv_cols)?,
524 )
525 } else {
526 (k, v)
527 };
528 let expected = batch * new_past * kv_cols;
529 if k.len() != expected || v.len() != expected {
530 bail!(
531 "layer {il} kv: k.len={} v.len={} expected {expected}",
532 k.len(),
533 v.len()
534 );
535 }
536 new_layers.push(Qwen35LayerState::FullAttn {
537 past_k: k,
538 past_v: v,
539 });
540 let _ = layer;
541 } else {
542 let conv = iter.next().context("conv_state missing")?;
543 let ssm = iter.next().context("ssm_state missing")?;
544 new_layers.push(Qwen35LayerState::Linear {
545 conv_state: conv,
546 ssm_state: ssm,
547 });
548 }
549 }
550 cache.past_seq = new_past;
551 cache.layers = new_layers;
552 Ok((trunk_out, mtp_logits))
553}
554
555#[allow(dead_code)]
557pub fn trunk_layer_state_sizes(cfg: &Qwen35Config) -> Vec<(bool, usize, usize)> {
558 let n_main = cfg.num_hidden_layers - cfg.nextn_predict_layers;
559 let interval = cfg.full_attention_interval.max(1);
560 let n_state = cfg.ssm_state_size;
561 let n_v_heads = cfg.ssm_time_step_rank;
562 let conv_channels = linear_conv_channels(cfg);
563 let k_conv = cfg.ssm_conv_kernel;
564
565 let mut out = Vec::with_capacity(n_main);
566 for il in 0..n_main {
567 let is_full_attn = ((il + 1) % interval) == 0;
568 if is_full_attn {
569 out.push((true, 0, 0));
570 } else {
571 out.push((
572 false,
573 (k_conv - 1) * conv_channels,
574 n_v_heads * n_state * n_state,
575 ));
576 }
577 }
578 out
579}
580
581pub fn pack_input_ids(batch_prompts: &[Vec<u32>], max_seq: usize) -> anyhow::Result<Vec<f32>> {
583 use anyhow::bail;
584 if batch_prompts.is_empty() {
585 bail!("pack_input_ids: batch must be non-empty");
586 }
587 let batch = batch_prompts.len();
588 let mut out = vec![0f32; batch * max_seq];
589 for (b, prompt) in batch_prompts.iter().enumerate() {
590 if prompt.len() > max_seq {
591 bail!(
592 "pack_input_ids: row {b} length {} exceeds max_seq={max_seq}",
593 prompt.len()
594 );
595 }
596 let base = b * max_seq;
597 for (i, &t) in prompt.iter().enumerate() {
598 out[base + i] = t as f32;
599 }
600 }
601 Ok(out)
602}
603
604pub fn last_token_indices(prompt_lens: &[usize]) -> Vec<f32> {
606 prompt_lens
607 .iter()
608 .map(|&l| l.saturating_sub(1) as f32)
609 .collect()
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 fn one_full_attn_cfg() -> Qwen35Config {
617 Qwen35Config {
618 vocab_size: 16,
619 hidden_size: 4,
620 intermediate_size: 8,
621 num_hidden_layers: 1,
622 nextn_predict_layers: 0,
623 num_attention_heads: 2,
624 num_key_value_heads: 2,
625 key_length: 2,
626 value_length: 2,
627 max_position_embeddings: 64,
628 rms_norm_eps: 1e-6,
629 rope_theta: 10_000.0,
630 rope_dim_count: 2,
631 rope_dim_sections: vec![],
632 full_attention_interval: 1,
633 ssm_conv_kernel: 4,
634 ssm_group_count: 2,
635 ssm_inner_size: 8,
636 ssm_state_size: 4,
637 ssm_time_step_rank: 2,
638 tie_word_embeddings: true,
639 num_experts: 0,
640 num_experts_used: 0,
641 expert_ffn_size: 0,
642 shared_expert_ffn_size: 0,
643 expert_weights_scale: 1.0,
644 }
645 }
646
647 #[test]
648 fn advance_decode_consumes_mtp_before_kv_states() {
649 let cfg = one_full_attn_cfg();
650 let batch = 1;
651 let past_seq = 1;
652 let kv_cols = cfg.num_key_value_heads * cfg.key_length;
653 let new_past = past_seq + 1;
654 let kv_len = batch * new_past * kv_cols;
655
656 let mut cache = Qwen35DecodeCache {
657 batch,
658 past_seq,
659 prompt_lens: vec![past_seq],
660 layers: vec![Qwen35LayerState::FullAttn {
661 past_k: vec![0.0; batch * past_seq * kv_cols],
662 past_v: vec![0.0; batch * past_seq * kv_cols],
663 }],
664 };
665
666 let trunk_logits = vec![1.0; batch * cfg.vocab_size];
667 let mtp_logits = vec![2.0; batch * cfg.vocab_size];
668 assert_ne!(
669 mtp_logits.len(),
670 kv_len,
671 "test needs distinct mtp vs kv lengths"
672 );
673 let new_k = vec![3.0; kv_len];
674 let new_v = vec![4.0; kv_len];
675
676 let outputs = vec![
677 trunk_logits.clone(),
678 mtp_logits.clone(),
679 new_k.clone(),
680 new_v.clone(),
681 ];
682 let (trunk_out, mtp) =
683 advance_cache_from_decode_outputs(&cfg, &mut cache, outputs, None, true, true, false)
684 .unwrap();
685 assert_eq!(trunk_out, trunk_logits);
686 assert_eq!(mtp.unwrap(), mtp_logits);
687 assert_eq!(cache.past_seq, new_past);
688 match &cache.layers[0] {
689 Qwen35LayerState::FullAttn { past_k, past_v } => {
690 assert_eq!(past_k, &new_k);
691 assert_eq!(past_v, &new_v);
692 }
693 _ => panic!("expected full-attn layer"),
694 }
695
696 let mut cache2 = cache.clone();
697 cache2.past_seq = past_seq;
698 let bad = vec![trunk_logits, new_k, new_v, mtp_logits];
699 assert!(
700 advance_cache_from_decode_outputs(&cfg, &mut cache2, bad, None, true, true, false)
701 .is_err()
702 );
703 }
704
705 #[test]
706 fn advance_decode_discards_mtp_when_not_wanted() {
707 let cfg = one_full_attn_cfg();
708 let batch = 1;
709 let kv_cols = cfg.num_key_value_heads * cfg.key_length;
710 let kv_len = batch * 2 * kv_cols;
711
712 let mut cache = Qwen35DecodeCache {
713 batch,
714 past_seq: 1,
715 prompt_lens: vec![1],
716 layers: vec![Qwen35LayerState::FullAttn {
717 past_k: vec![0.0; batch * kv_cols],
718 past_v: vec![0.0; batch * kv_cols],
719 }],
720 };
721
722 let outputs = vec![
723 vec![0.0; batch * cfg.vocab_size],
724 vec![1.0; batch * cfg.vocab_size],
725 vec![2.0; kv_len],
726 vec![3.0; kv_len],
727 ];
728 let (_, mtp) =
729 advance_cache_from_decode_outputs(&cfg, &mut cache, outputs, None, true, false, false)
730 .unwrap();
731 assert!(mtp.is_none());
732 }
733}