1use crate::builder::{
34 build_qwen3_decode_graph_sized, build_qwen3_decode_graph_sized_ext,
35 build_qwen3_graph_sized_last_logits,
36};
37use crate::capabilities::validate_device;
38use crate::config::Qwen3Config;
39use crate::profile::qwen3_profile_near_weights;
40use crate::sampling::{SampleOpts, sample_token};
41use anyhow::{Context, Result};
42use rlx_core::autoregressive::{
43 DecodeLogitsKv, KvCacheState, compile_cache_ensure_graph, kv_from_prefill_outputs,
44 prefill_cache_key, run_bucketed_kv_decode, split_decode_logits_kv,
45};
46use rlx_core::flow_bridge::compile_options_from_profile;
47use rlx_core::weight_loader::WeightLoader;
48use rlx_core::weight_map::WeightMap;
49use rlx_flow::CompileProfile;
50use rlx_ir::logical_kernel::KernelDispatchConfig;
51use rlx_runtime::attn_mask::bucket_decode_mask;
52use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
53use rlx_runtime::{CompileOptions, Device, Session};
54use std::collections::HashMap;
55use std::path::Path;
56
57pub struct Qwen3Generator {
63 cfg: Qwen3Config,
64 weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
69 tokens: Vec<u32>,
70 device: Device,
71 cache: Option<KvCacheState>,
75 prefill_compile_cache: Option<CompileCache>,
79 decode_compile_cache: Option<BucketedCompileCache>,
85 prefill_profile: CompileProfile,
86 decode_profile: CompileProfile,
87}
88
89impl Qwen3Generator {
90 pub fn from_loader(
93 cfg: Qwen3Config,
94 loader: &mut dyn WeightLoader,
95 device: Device,
96 ) -> Result<Self> {
97 validate_device(&cfg, device, false)?;
98 let keys = loader.remaining_keys();
99 let mut weights_cache = HashMap::with_capacity(keys.len());
100 for k in keys {
101 let v = loader
102 .take(&k)
103 .with_context(|| format!("draining weight {k}"))?;
104 let canonical =
110 rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
111 weights_cache.insert(canonical, v);
112 }
113 let max_past = cfg.max_position_embeddings.clamp(1, 4096);
114 Ok(Self {
115 cfg,
116 weights_cache,
117 tokens: Vec::new(),
118 device,
119 cache: None,
120 prefill_compile_cache: Some(CompileCache::new(device, 8)),
121 decode_compile_cache: Some(BucketedCompileCache::power_of_two_ladder(
122 device,
123 1,
124 max_past as u64,
125 )),
126 prefill_profile: CompileProfile::qwen3_prefill(),
127 decode_profile: CompileProfile::qwen3_decode(),
128 })
129 }
130
131 pub fn from_loader_at(
133 cfg: Qwen3Config,
134 loader: &mut dyn WeightLoader,
135 device: Device,
136 weights_path: &Path,
137 ) -> Result<Self> {
138 let mut g = Self::from_loader(cfg, loader, device)?;
139 g.prefill_profile = qwen3_profile_near_weights(weights_path, false);
140 g.decode_profile = qwen3_profile_near_weights(weights_path, true);
141 Ok(g)
142 }
143
144 pub fn with_compile_profiles(
145 mut self,
146 prefill: CompileProfile,
147 decode: CompileProfile,
148 ) -> Self {
149 self.prefill_profile = prefill;
150 self.decode_profile = decode;
151 self
152 }
153
154 pub fn prefill_profile(&self) -> &CompileProfile {
155 &self.prefill_profile
156 }
157
158 pub fn decode_profile(&self) -> &CompileProfile {
159 &self.decode_profile
160 }
161
162 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
163 let profile = if decode {
164 &self.decode_profile
165 } else {
166 &self.prefill_profile
167 };
168 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
169 }
170
171 pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
176 self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
177 self
178 }
179
180 pub fn set_decode_compile_cache(&mut self, cache: Option<BucketedCompileCache>) {
192 self.decode_compile_cache = cache;
193 }
194
195 pub fn with_decode_cache(mut self, max_past: usize) -> Self {
196 let cache = BucketedCompileCache::power_of_two_ladder(
197 self.device,
198 1,
199 max_past.max(1) as u64,
200 );
201 self.decode_compile_cache = Some(cache);
202 self
203 }
204
205 pub fn from_path(cfg: Qwen3Config, path: &str, device: Device) -> Result<Self> {
208 Self::from_path_at(cfg, path, device, Path::new("."))
209 }
210
211 pub fn from_path_at(
213 cfg: Qwen3Config,
214 path: &str,
215 device: Device,
216 weights_path: &Path,
217 ) -> Result<Self> {
218 let mut loader = rlx_core::weight_loader::load_from_path(path)?;
219 Self::from_loader_at(cfg, loader.as_mut(), device, weights_path)
220 }
221
222 pub fn from_path_with_mtp(
230 cfg: Qwen3Config,
231 path: &str,
232 device: Device,
233 include_mtp: bool,
234 ) -> Result<Self> {
235 if path.ends_with(".gguf") {
239 let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
240 gguf.include_mtp(include_mtp);
241 Self::from_loader(cfg, &mut gguf, device)
242 } else {
243 Self::from_path(cfg, path, device)
244 }
245 }
246
247 pub fn prefill(&mut self, prompt_ids: &[u32]) {
251 self.tokens.clear();
252 self.tokens.extend_from_slice(prompt_ids);
253 self.cache = None;
254 }
255
256 pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
260 if self.tokens.is_empty() {
261 anyhow::bail!("step() called with empty token history; call prefill() first");
262 }
263 let seq = self.tokens.len();
264 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
265 let (graph, params) = build_qwen3_graph_sized_last_logits(
266 &self.cfg, &mut wm, 1, seq, false,
267 )?;
268 let compile_opts = self.profile_compile_options(false);
269 let mut compiled = Session::new(self.device).compile_with(graph, &compile_opts);
270 for (name, data) in ¶ms {
271 compiled.set_param(name, data);
272 }
273 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
274 let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
275 let logits = outputs
276 .into_iter()
277 .next()
278 .context("compiled.run returned no outputs")?;
279
280 let vocab = self.cfg.vocab_size;
281 let expected = vocab;
282 if logits.len() < expected {
283 anyhow::bail!(
284 "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
285 logits.len(),
286 expected
287 );
288 }
289 let last_row = &logits[..vocab];
291 let tok = sample_token(last_row, opts) as u32;
292 self.tokens.push(tok);
293 Ok(tok)
294 }
295
296 pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
299 if self.decode_compile_cache.is_some() {
300 return self.generate_cached(n, opts);
301 }
302 let start = self.tokens.len();
303 for _ in 0..n {
304 self.step(opts)?;
305 }
306 Ok(self.tokens[start..].to_vec())
307 }
308
309 pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
319 if self.tokens.is_empty() {
320 anyhow::bail!("step_cached() called with empty token history; call prefill() first");
321 }
322 if self.cache.is_none() {
323 let tok = self.seed_cache_from_prompt(opts)?;
327 return Ok(tok);
328 }
329 let cache = self.cache.as_ref().unwrap();
330 let past_seq = cache.past_len;
331 if self.tokens.len() <= past_seq {
335 anyhow::bail!(
336 "cache invariant violated: tokens.len() {} <= past_seq {}",
337 self.tokens.len(),
338 past_seq
339 );
340 }
341 let input_tok = self.tokens[past_seq];
342
343 let (logits, new_k, new_v) = if self.decode_compile_cache.is_some()
345 && self
346 .decode_compile_cache
347 .as_ref()
348 .unwrap()
349 .bucket_for(past_seq as u64)
350 .is_some()
351 {
352 self.decode_step_bucketed(past_seq, input_tok)?
353 } else {
354 self.decode_step_oneshot(past_seq, input_tok)?
355 };
356
357 let cache_mut = self.cache.as_mut().unwrap();
358 cache_mut.past_len = past_seq + 1;
359 cache_mut.layers_k = new_k;
360 cache_mut.layers_v = new_v;
361
362 let vocab = self.cfg.vocab_size;
363 if logits.len() != vocab {
364 anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
365 }
366 let tok = sample_token(&logits, opts) as u32;
367 self.tokens.push(tok);
368 Ok(tok)
369 }
370
371 fn decode_step_oneshot(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
374 let cache = self.cache.as_ref().unwrap();
375
376 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
377 let (graph, params) =
378 build_qwen3_decode_graph_sized(&self.cfg, &mut wm, 1, past_seq)?;
379 let opts = self.profile_compile_options(true);
380 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
381 for (name, data) in ¶ms {
382 compiled.set_param(name, data);
383 }
384
385 let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
386 let input_ids_f32 = [input_tok as f32];
387 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
388 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
389 .collect();
390 let mut inputs: Vec<(&str, &[f32])> =
391 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
392 inputs.push(("input_ids", input_ids_f32.as_slice()));
393 inputs.push(("rope_cos", cos.as_slice()));
394 inputs.push(("rope_sin", sin.as_slice()));
395 for i in 0..self.cfg.num_hidden_layers {
396 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
397 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
398 }
399
400 let outputs = compiled.run(&inputs);
401 split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
402 }
403
404 fn decode_step_bucketed(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
405 let kv = self.cache.as_ref().unwrap().clone();
406 let kv_dim = self.cfg.kv_proj_dim();
407 let n_layers = self.cfg.num_hidden_layers;
408 let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
409 let input_ids_f32 = [input_tok as f32];
410 let decode_opts = self.profile_compile_options(true);
411 let upper = self
412 .decode_compile_cache
413 .as_ref()
414 .and_then(|cache_dec| {
415 cache_dec.bucket_for(past_seq as u64).map(|idx| {
416 cache_dec
417 .buckets()
418 .nth(idx)
419 .map(|r| (r.end - 1) as usize)
420 .unwrap_or(past_seq)
421 })
422 })
423 .unwrap_or(past_seq);
424 let mask = bucket_decode_mask(past_seq, upper);
425 let fixed = [
426 CacheRunInput {
427 name: "input_ids",
428 data: &input_ids_f32,
429 row_inner: None,
430 },
431 CacheRunInput {
432 name: "rope_cos",
433 data: &cos,
434 row_inner: None,
435 },
436 CacheRunInput {
437 name: "rope_sin",
438 data: &sin,
439 row_inner: None,
440 },
441 CacheRunInput {
442 name: "mask",
443 data: &mask,
444 row_inner: None,
445 },
446 ];
447 let cfg = self.cfg.clone();
448 let weights = self.weights_cache.clone();
449 let cache_dec = self.decode_compile_cache.as_mut().unwrap();
450 run_bucketed_kv_decode(
451 cache_dec,
452 past_seq,
453 &kv,
454 kv_dim,
455 n_layers,
456 &fixed,
457 |upper| {
458 let mut wm = WeightMap::from_tensors(weights.clone());
459 build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
460 .expect("qwen3 bucketed decode graph")
461 },
462 &decode_opts,
463 )
464 }
465
466 fn run_prefill_with_cache(
470 &mut self,
471 batch: usize,
472 seq: usize,
473 ids_f32: &[f32],
474 ) -> Result<Vec<Vec<f32>>> {
475 let prefill_opts = self.profile_compile_options(false);
476 if let Some(cache) = &mut self.prefill_compile_cache {
477 let key = prefill_cache_key(batch, seq);
478 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
479 let (graph, params) = build_qwen3_graph_sized_last_logits(
480 &self.cfg, &mut wm, batch, seq, true,
481 )?;
482 let compiled = compile_cache_ensure_graph(cache, key, graph, params, &prefill_opts);
483 Ok(compiled.run(&[("input_ids", ids_f32)]))
484 } else {
485 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
486 let (graph, params) = build_qwen3_graph_sized_last_logits(
487 &self.cfg, &mut wm, batch, seq, true,
488 )?;
489 let opts = self.profile_compile_options(false);
490 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
491 for (name, data) in ¶ms {
492 compiled.set_param(name, data);
493 }
494 Ok(compiled.run(&[("input_ids", ids_f32)]))
495 }
496 }
497
498 pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
500 self.generate_cached_with(n, opts, |_| {})
501 }
502
503 pub fn generate_cached_with(
510 &mut self,
511 n: usize,
512 opts: SampleOpts,
513 on_token: impl FnMut(u32),
514 ) -> Result<Vec<u32>> {
515 self.generate_cached_until(n, opts, |_| true, on_token)
516 }
517
518 pub fn generate_cached_until(
521 &mut self,
522 n: usize,
523 opts: SampleOpts,
524 mut should_continue: impl FnMut(u32) -> bool,
525 mut on_token: impl FnMut(u32),
526 ) -> Result<Vec<u32>> {
527 let start = self.tokens.len();
528 for _ in 0..n {
529 let tok = self.step_cached(opts)?;
530 on_token(tok);
531 if !should_continue(tok) {
532 break;
533 }
534 }
535 Ok(self.tokens[start..].to_vec())
536 }
537
538 fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
543 let seq = self.tokens.len();
544 let batch = 1usize;
545 let kv_dim = self.cfg.kv_proj_dim();
546
547 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
548 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
549 let (logits, kv) =
550 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
551 self.cache = Some(kv);
552
553 let vocab = self.cfg.vocab_size;
554 let needed = vocab;
555 if logits.len() < needed {
556 anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
557 }
558 let last_row = &logits[..vocab];
559 let tok = sample_token(last_row, opts) as u32;
560 self.tokens.push(tok);
561 Ok(tok)
562 }
563
564 pub fn tokens(&self) -> &[u32] {
566 &self.tokens
567 }
568
569 pub fn config(&self) -> &Qwen3Config {
570 &self.cfg
571 }
572
573 pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
583 if context.is_empty() {
584 anyhow::bail!("prefill_get_last_logits: empty context");
585 }
586 self.tokens.clear();
587 self.tokens.extend_from_slice(context);
588 self.cache = None;
589
590 let seq = context.len();
591 let batch = 1usize;
592 let kv_dim = self.cfg.kv_proj_dim();
593
594 let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
595 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
596 let (logits, kv) =
597 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
598 self.cache = Some(kv);
599
600 let vocab = self.cfg.vocab_size;
601 let needed = vocab;
602 if logits.len() < needed {
603 anyhow::bail!("logits short: {} < {}", logits.len(), needed);
604 }
605 Ok(logits[..vocab].to_vec())
606 }
607
608 pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
616 let cache = self.cache.as_ref().ok_or_else(|| {
617 anyhow::anyhow!(
618 "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
619 )
620 })?;
621 let past_seq = cache.past_len;
622
623 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
624 let (graph, params) =
625 build_qwen3_decode_graph_sized(&self.cfg, &mut wm, 1, past_seq)?;
626 let opts = self.profile_compile_options(true);
627 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
628 for (name, data) in ¶ms {
629 compiled.set_param(name, data);
630 }
631
632 let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
633 let input_ids_f32 = [input as f32];
634 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
635 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
636 .collect();
637 let mut inputs: Vec<(&str, &[f32])> =
638 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
639 inputs.push(("input_ids", input_ids_f32.as_slice()));
640 inputs.push(("rope_cos", cos.as_slice()));
641 inputs.push(("rope_sin", sin.as_slice()));
642 for i in 0..self.cfg.num_hidden_layers {
643 let pk = &cache.layers_k[i];
644 let pv = &cache.layers_v[i];
645 inputs.push((&key_strs[2 * i], pk.as_slice()));
646 inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
647 }
648
649 let outputs = compiled.run(&inputs);
650 let (logits, new_k, new_v) = split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)?;
651
652 let cache_mut = self.cache.as_mut().unwrap();
653 cache_mut.past_len = past_seq + 1;
654 cache_mut.layers_k = new_k;
655 cache_mut.layers_v = new_v;
656 self.tokens.push(input);
657
658 Ok(logits)
659 }
660}
661
662fn compute_rope_slice(cfg: &Qwen3Config, pos: usize) -> (Vec<f32>, Vec<f32>) {
666 let dh = cfg.head_dim;
667 let half = dh / 2;
668 let mut cos = vec![0f32; half];
669 let mut sin = vec![0f32; half];
670 for i in 0..half {
671 let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
672 let angle = pos as f64 * freq;
673 let (s, c) = angle.sin_cos();
674 cos[i] = c as f32;
675 sin[i] = s as f32;
676 }
677 (cos, sin)
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use crate::config::Qwen3Config;
684
685 fn tiny_cfg() -> Qwen3Config {
686 Qwen3Config {
687 vocab_size: 16,
688 hidden_size: 16,
689 intermediate_size: 32,
690 num_hidden_layers: 2,
691 num_attention_heads: 4,
692 num_key_value_heads: 2,
693 head_dim: 8,
694 max_position_embeddings: 16,
695 rms_norm_eps: 1e-6,
696 rope_theta: 1_000_000.0,
697 hidden_act: "silu".into(),
698 tie_word_embeddings: false,
699 attention_bias: false,
700 qk_norm: true,
701 sliding_window: None,
702 max_window_layers: usize::MAX,
703 use_sliding_window: false,
704 num_experts: 0,
705 num_experts_used: 0,
706 expert_ffn_size: 0,
707 shared_expert_ffn_size: 0,
708 expert_weights_scale: 1.0,
709 }
710 }
711
712 fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
713 let h = cfg.hidden_size;
714 let q_dim = cfg.q_proj_dim();
715 let kv_dim = cfg.kv_proj_dim();
716 let int_dim = cfg.intermediate_size;
717 let dh = cfg.head_dim;
718 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
719 let pat = |n: usize, salt: u32| -> Vec<f32> {
722 (0..n)
723 .map(|i| {
724 let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
725 (x as f32 / (1u32 << 24) as f32) - 0.5
726 })
727 .collect()
728 };
729 t.insert(
730 "model.embed_tokens.weight".into(),
731 (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
732 );
733 for i in 0..cfg.num_hidden_layers {
734 let lp = format!("model.layers.{i}");
735 t.insert(
736 format!("{lp}.input_layernorm.weight"),
737 (pat(h, 100 + i as u32), vec![h]),
738 );
739 t.insert(
740 format!("{lp}.post_attention_layernorm.weight"),
741 (pat(h, 200 + i as u32), vec![h]),
742 );
743 t.insert(
744 format!("{lp}.self_attn.q_proj.weight"),
745 (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
746 );
747 t.insert(
748 format!("{lp}.self_attn.k_proj.weight"),
749 (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
750 );
751 t.insert(
752 format!("{lp}.self_attn.v_proj.weight"),
753 (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
754 );
755 t.insert(
756 format!("{lp}.self_attn.o_proj.weight"),
757 (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
758 );
759 t.insert(
760 format!("{lp}.self_attn.q_norm.weight"),
761 (pat(dh, 700 + i as u32), vec![dh]),
762 );
763 t.insert(
764 format!("{lp}.self_attn.k_norm.weight"),
765 (pat(dh, 800 + i as u32), vec![dh]),
766 );
767 t.insert(
768 format!("{lp}.mlp.gate_proj.weight"),
769 (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
770 );
771 t.insert(
772 format!("{lp}.mlp.up_proj.weight"),
773 (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
774 );
775 t.insert(
776 format!("{lp}.mlp.down_proj.weight"),
777 (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
778 );
779 }
780 t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
781 t.insert(
782 "lm_head.weight".into(),
783 (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
784 );
785 WeightMap::from_tensors(t)
786 }
787
788 #[test]
789 fn generator_drains_loader_and_runs_one_step() {
790 let cfg = tiny_cfg();
791 let mut wm = synthetic_weights(&cfg);
792 let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
793 assert_eq!(wm.len(), 0, "loader should be drained");
794 gn.prefill(&[1, 2, 3]);
795 let t = gn.step(SampleOpts::greedy()).unwrap();
796 assert!((t as usize) < cfg.vocab_size);
797 assert_eq!(gn.tokens().len(), 4);
798 }
799
800 #[test]
801 fn generate_n_appends_n_tokens() {
802 let cfg = tiny_cfg();
803 let mut wm = synthetic_weights(&cfg);
804 let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
805 gn.prefill(&[5, 6]);
806 let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
807 assert_eq!(new_tokens.len(), 3);
808 assert_eq!(gn.tokens().len(), 5);
809 for t in &new_tokens {
810 assert!((*t as usize) < cfg.vocab_size);
811 }
812 }
813
814 #[test]
815 fn step_without_prefill_errors() {
816 let cfg = tiny_cfg();
817 let mut wm = synthetic_weights(&cfg);
818 let mut gn = Qwen3Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
819 let r = gn.step(SampleOpts::greedy());
820 assert!(r.is_err());
821 }
822
823 #[test]
824 fn cached_matches_naive_on_greedy() {
825 let cfg = tiny_cfg();
832 let prompt: Vec<u32> = vec![1, 2, 3, 5];
833 let steps = 4;
834
835 let mut wm_n = synthetic_weights(&cfg);
836 let mut gn_naive =
837 Qwen3Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
838 gn_naive.prefill_compile_cache = None;
839 gn_naive.decode_compile_cache = None;
840 gn_naive.prefill(&prompt);
841 let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
842
843 let mut wm_c = synthetic_weights(&cfg);
844 let mut gn_cached =
845 Qwen3Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
846 gn_cached.prefill(&prompt);
847 let cached_tokens = gn_cached
848 .generate_cached(steps, SampleOpts::greedy())
849 .unwrap();
850
851 assert_eq!(
852 cached_tokens, naive_tokens,
853 "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
854 );
855 }
856
857 #[test]
858 fn cached_step_advances_cache_invariant() {
859 let cfg = tiny_cfg();
860 let mut wm = synthetic_weights(&cfg);
861 let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
862 gn.prefill(&[1, 2, 3]);
863 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
864 assert_eq!(gn.tokens().len(), 4);
866 assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
867 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
868 assert_eq!(gn.tokens().len(), 5);
870 assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
871 }
872
873 #[test]
874 fn bucketed_decode_matches_oneshot() {
875 let cfg = tiny_cfg();
881 let prompt: Vec<u32> = vec![1, 2, 3, 5];
882 let steps = 6;
883
884 let mut wm_one = synthetic_weights(&cfg);
885 let mut gn_one =
886 Qwen3Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
887 gn_one.prefill(&prompt);
888 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
889
890 let mut wm_buc = synthetic_weights(&cfg);
891 let mut gn_buc = Qwen3Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
892 .unwrap()
893 .with_decode_cache(32);
894 gn_buc.prefill(&prompt);
895 let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
896
897 assert_eq!(
898 bucketed_tokens, oneshot_tokens,
899 "bucketed-cache decode diverged from one-shot decode — \
900 mask, padding, or output-slice bug"
901 );
902 }
903
904 #[test]
905 fn bucketed_decode_q_proj_seq_is_one() {
906 use rlx_ir::Op;
907
908 let cfg = tiny_cfg();
909 let mut wm = synthetic_weights(&cfg);
910 let (graph, _) = build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, 4, true).unwrap();
911 for node in graph.nodes() {
912 if let Op::MatMul = &node.op {
913 let sh = graph.shape(node.id);
914 if sh.rank() == 3 && sh.dim(2).unwrap_static() == cfg.q_proj_dim() {
915 assert_eq!(
916 sh.dim(1).unwrap_static(),
917 1,
918 "decode q_proj matmul seq dim must be 1, got {sh} on node {}",
919 node.id
920 );
921 }
922 }
923 }
924
925 let fused = rlx_opt::CompilePipeline::new(rlx_opt::FusionTarget::Metal)
926 .with_assert_fusion_clean(false)
927 .compile_graph(graph)
928 .lir
929 .into_graph();
930 for node in fused.nodes() {
931 if let Op::Narrow { len, .. } = &node.op {
932 let sh = fused.shape(node.id);
933 if sh.rank() == 3 && *len == cfg.q_proj_dim() {
934 assert_eq!(
935 sh.dim(1).unwrap_static(),
936 1,
937 "fused decode q narrow seq dim must be 1, got {sh} on node {}",
938 node.id
939 );
940 }
941 }
942 }
943 }
944
945 #[test]
946 fn prefill_compile_cache_does_not_change_output() {
947 let cfg = tiny_cfg();
948 let prompt: Vec<u32> = vec![1, 2, 3, 5];
949 let mut wm_a = synthetic_weights(&cfg);
950 let mut gn_a = Qwen3Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
951 gn_a.prefill(&prompt);
952 let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
953
954 let mut wm_b = synthetic_weights(&cfg);
955 let mut gn_b = Qwen3Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
956 .unwrap()
957 .with_prefill_cache(4);
958 gn_b.prefill(&prompt);
959 let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
960
961 assert_eq!(a, b, "enabling prefill_cache must not change output");
962 }
963
964 #[test]
965 fn greedy_is_deterministic_across_runs() {
966 let cfg = tiny_cfg();
967 let weights = synthetic_weights(&cfg);
968 let mk = || {
969 let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
970 Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
971 };
972 let mut a = mk();
973 let mut b = mk();
974 a.prefill(&[1, 2, 3]);
975 b.prefill(&[1, 2, 3]);
976 let ta = a.generate(4, SampleOpts::greedy()).unwrap();
977 let tb = b.generate(4, SampleOpts::greedy()).unwrap();
978 assert_eq!(ta, tb);
979 }
980
981 fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
982 let _ = wm; let cfg = tiny_cfg();
988 let mut new = synthetic_weights(&cfg);
989 let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
990 let mut out = HashMap::new();
991 for k in keys {
992 out.insert(k.clone(), new.take(&k).unwrap());
993 }
994 out
995 }
996}