1use crate::builder::{
34 build_qwen3_decode_graph_sized, build_qwen3_decode_graph_sized_ext,
35 build_qwen3_graph_sized_last_logits,
36};
37use crate::capabilities::validate_device;
38use crate::config::Qwen3Config;
39use crate::profile::qwen3_profile_near_weights;
40use crate::sampling::{SampleOpts, sample_token};
41use anyhow::{Context, Result};
42use rlx_core::autoregressive::{
43 DecodeLogitsKv, KvCacheState, compile_cache_ensure_graph, kv_from_prefill_outputs,
44 prefill_cache_key, run_bucketed_kv_decode, split_decode_logits_kv,
45};
46use rlx_core::flow_bridge::compile_options_from_profile;
47use rlx_core::weight_loader::WeightLoader;
48use rlx_core::weight_map::WeightMap;
49use rlx_flow::CompileProfile;
50use rlx_ir::logical_kernel::KernelDispatchConfig;
51use rlx_runtime::attn_mask::bucket_decode_mask;
52use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
53use rlx_runtime::{CompileOptions, Device, Session};
54use std::collections::HashMap;
55use std::path::Path;
56
57pub struct Qwen3Generator {
63 cfg: Qwen3Config,
64 weights_cache: HashMap<String, (Vec<f32>, Vec<usize>)>,
69 tokens: Vec<u32>,
70 device: Device,
71 cache: Option<KvCacheState>,
75 prefill_compile_cache: Option<CompileCache>,
79 decode_compile_cache: Option<BucketedCompileCache>,
85 prefill_profile: CompileProfile,
86 decode_profile: CompileProfile,
87}
88
89impl Qwen3Generator {
90 pub fn from_loader(
93 cfg: Qwen3Config,
94 loader: &mut dyn WeightLoader,
95 device: Device,
96 ) -> Result<Self> {
97 validate_device(&cfg, device, false)?;
98 let keys = loader.remaining_keys();
99 let mut weights_cache = HashMap::with_capacity(keys.len());
100 for k in keys {
101 let v = loader
102 .take(&k)
103 .with_context(|| format!("draining weight {k}"))?;
104 let canonical =
110 rlx_core::weight_loader::gguf_to_hf_name(&k).unwrap_or_else(|| k.clone());
111 weights_cache.insert(canonical, v);
112 }
113 let max_past = cfg.max_position_embeddings.clamp(1, 4096);
114 Ok(Self {
115 cfg,
116 weights_cache,
117 tokens: Vec::new(),
118 device,
119 cache: None,
120 prefill_compile_cache: Some(CompileCache::new(device, 8)),
121 decode_compile_cache: Some(BucketedCompileCache::power_of_two_ladder(
122 device,
123 1,
124 max_past as u64,
125 )),
126 prefill_profile: CompileProfile::qwen3_prefill(),
127 decode_profile: CompileProfile::qwen3_decode(),
128 })
129 }
130
131 pub fn from_loader_at(
133 cfg: Qwen3Config,
134 loader: &mut dyn WeightLoader,
135 device: Device,
136 weights_path: &Path,
137 ) -> Result<Self> {
138 let mut g = Self::from_loader(cfg, loader, device)?;
139 g.prefill_profile = qwen3_profile_near_weights(weights_path, false);
140 g.decode_profile = qwen3_profile_near_weights(weights_path, true);
141 Ok(g)
142 }
143
144 pub fn with_compile_profiles(
145 mut self,
146 prefill: CompileProfile,
147 decode: CompileProfile,
148 ) -> Self {
149 self.prefill_profile = prefill;
150 self.decode_profile = decode;
151 self
152 }
153
154 pub fn prefill_profile(&self) -> &CompileProfile {
155 &self.prefill_profile
156 }
157
158 pub fn decode_profile(&self) -> &CompileProfile {
159 &self.decode_profile
160 }
161
162 fn profile_compile_options(&self, decode: bool) -> CompileOptions {
163 let profile = if decode {
164 &self.decode_profile
165 } else {
166 &self.prefill_profile
167 };
168 compile_options_from_profile(profile, self.device, KernelDispatchConfig::default())
169 }
170
171 pub fn with_prefill_cache(mut self, capacity: usize) -> Self {
176 self.prefill_compile_cache = Some(CompileCache::new(self.device, capacity));
177 self
178 }
179
180 pub fn with_decode_cache(mut self, max_past: usize) -> Self {
190 let cache = BucketedCompileCache::power_of_two_ladder(
191 self.device,
192 1,
193 max_past.max(1) as u64,
194 );
195 self.decode_compile_cache = Some(cache);
196 self
197 }
198
199 pub fn from_path(cfg: Qwen3Config, path: &str, device: Device) -> Result<Self> {
202 Self::from_path_at(cfg, path, device, Path::new("."))
203 }
204
205 pub fn from_path_at(
207 cfg: Qwen3Config,
208 path: &str,
209 device: Device,
210 weights_path: &Path,
211 ) -> Result<Self> {
212 let mut loader = rlx_core::weight_loader::load_from_path(path)?;
213 Self::from_loader_at(cfg, loader.as_mut(), device, weights_path)
214 }
215
216 pub fn from_path_with_mtp(
224 cfg: Qwen3Config,
225 path: &str,
226 device: Device,
227 include_mtp: bool,
228 ) -> Result<Self> {
229 if path.ends_with(".gguf") {
233 let mut gguf = rlx_core::weight_loader::GgufLoader::from_file(path)?;
234 gguf.include_mtp(include_mtp);
235 Self::from_loader(cfg, &mut gguf, device)
236 } else {
237 Self::from_path(cfg, path, device)
238 }
239 }
240
241 pub fn prefill(&mut self, prompt_ids: &[u32]) {
245 self.tokens.clear();
246 self.tokens.extend_from_slice(prompt_ids);
247 self.cache = None;
248 }
249
250 pub fn step(&mut self, opts: SampleOpts) -> Result<u32> {
254 if self.tokens.is_empty() {
255 anyhow::bail!("step() called with empty token history; call prefill() first");
256 }
257 let seq = self.tokens.len();
258 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
259 let (graph, params) = build_qwen3_graph_sized_last_logits(
260 &self.cfg, &mut wm, 1, seq, false,
261 )?;
262 let compile_opts = self.profile_compile_options(false);
263 let mut compiled = Session::new(self.device).compile_with(graph, &compile_opts);
264 for (name, data) in ¶ms {
265 compiled.set_param(name, data);
266 }
267 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
268 let outputs = compiled.run(&[("input_ids", ids_f32.as_slice())]);
269 let logits = outputs
270 .into_iter()
271 .next()
272 .context("compiled.run returned no outputs")?;
273
274 let vocab = self.cfg.vocab_size;
275 let expected = vocab;
276 if logits.len() < expected {
277 anyhow::bail!(
278 "logits length {} < expected {} (last logits, seq {seq}, vocab {vocab})",
279 logits.len(),
280 expected
281 );
282 }
283 let last_row = &logits[..vocab];
285 let tok = sample_token(last_row, opts) as u32;
286 self.tokens.push(tok);
287 Ok(tok)
288 }
289
290 pub fn generate(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
293 if self.decode_compile_cache.is_some() {
294 return self.generate_cached(n, opts);
295 }
296 let start = self.tokens.len();
297 for _ in 0..n {
298 self.step(opts)?;
299 }
300 Ok(self.tokens[start..].to_vec())
301 }
302
303 pub fn step_cached(&mut self, opts: SampleOpts) -> Result<u32> {
313 if self.tokens.is_empty() {
314 anyhow::bail!("step_cached() called with empty token history; call prefill() first");
315 }
316 if self.cache.is_none() {
317 let tok = self.seed_cache_from_prompt(opts)?;
321 return Ok(tok);
322 }
323 let cache = self.cache.as_ref().unwrap();
324 let past_seq = cache.past_len;
325 if self.tokens.len() <= past_seq {
329 anyhow::bail!(
330 "cache invariant violated: tokens.len() {} <= past_seq {}",
331 self.tokens.len(),
332 past_seq
333 );
334 }
335 let input_tok = self.tokens[past_seq];
336
337 let (logits, new_k, new_v) = if self.decode_compile_cache.is_some()
339 && self
340 .decode_compile_cache
341 .as_ref()
342 .unwrap()
343 .bucket_for(past_seq as u64)
344 .is_some()
345 {
346 self.decode_step_bucketed(past_seq, input_tok)?
347 } else {
348 self.decode_step_oneshot(past_seq, input_tok)?
349 };
350
351 let cache_mut = self.cache.as_mut().unwrap();
352 cache_mut.past_len = past_seq + 1;
353 cache_mut.layers_k = new_k;
354 cache_mut.layers_v = new_v;
355
356 let vocab = self.cfg.vocab_size;
357 if logits.len() != vocab {
358 anyhow::bail!("decode logits length {} != vocab {}", logits.len(), vocab);
359 }
360 let tok = sample_token(&logits, opts) as u32;
361 self.tokens.push(tok);
362 Ok(tok)
363 }
364
365 fn decode_step_oneshot(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
368 let cache = self.cache.as_ref().unwrap();
369
370 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
371 let (graph, params) =
372 build_qwen3_decode_graph_sized(&self.cfg, &mut wm, 1, past_seq)?;
373 let opts = self.profile_compile_options(true);
374 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
375 for (name, data) in ¶ms {
376 compiled.set_param(name, data);
377 }
378
379 let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
380 let input_ids_f32 = [input_tok as f32];
381 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
382 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
383 .collect();
384 let mut inputs: Vec<(&str, &[f32])> =
385 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
386 inputs.push(("input_ids", input_ids_f32.as_slice()));
387 inputs.push(("rope_cos", cos.as_slice()));
388 inputs.push(("rope_sin", sin.as_slice()));
389 for i in 0..self.cfg.num_hidden_layers {
390 inputs.push((&key_strs[2 * i], cache.layers_k[i].as_slice()));
391 inputs.push((&key_strs[2 * i + 1], cache.layers_v[i].as_slice()));
392 }
393
394 let outputs = compiled.run(&inputs);
395 split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)
396 }
397
398 fn decode_step_bucketed(&mut self, past_seq: usize, input_tok: u32) -> Result<DecodeLogitsKv> {
399 let kv = self.cache.as_ref().unwrap().clone();
400 let kv_dim = self.cfg.kv_proj_dim();
401 let n_layers = self.cfg.num_hidden_layers;
402 let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
403 let input_ids_f32 = [input_tok as f32];
404 let decode_opts = self.profile_compile_options(true);
405 let upper = self
406 .decode_compile_cache
407 .as_ref()
408 .and_then(|cache_dec| {
409 cache_dec.bucket_for(past_seq as u64).map(|idx| {
410 cache_dec
411 .buckets()
412 .nth(idx)
413 .map(|r| (r.end - 1) as usize)
414 .unwrap_or(past_seq)
415 })
416 })
417 .unwrap_or(past_seq);
418 let mask = bucket_decode_mask(past_seq, upper);
419 let fixed = [
420 CacheRunInput {
421 name: "input_ids",
422 data: &input_ids_f32,
423 row_inner: None,
424 },
425 CacheRunInput {
426 name: "rope_cos",
427 data: &cos,
428 row_inner: None,
429 },
430 CacheRunInput {
431 name: "rope_sin",
432 data: &sin,
433 row_inner: None,
434 },
435 CacheRunInput {
436 name: "mask",
437 data: &mask,
438 row_inner: None,
439 },
440 ];
441 let cfg = self.cfg.clone();
442 let weights = self.weights_cache.clone();
443 let cache_dec = self.decode_compile_cache.as_mut().unwrap();
444 run_bucketed_kv_decode(
445 cache_dec,
446 past_seq,
447 &kv,
448 kv_dim,
449 n_layers,
450 &fixed,
451 |upper| {
452 let mut wm = WeightMap::from_tensors(weights.clone());
453 build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, upper as usize, true)
454 .expect("qwen3 bucketed decode graph")
455 },
456 &decode_opts,
457 )
458 }
459
460 fn run_prefill_with_cache(
464 &mut self,
465 batch: usize,
466 seq: usize,
467 ids_f32: &[f32],
468 ) -> Result<Vec<Vec<f32>>> {
469 let prefill_opts = self.profile_compile_options(false);
470 if let Some(cache) = &mut self.prefill_compile_cache {
471 let key = prefill_cache_key(batch, seq);
472 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
473 let (graph, params) = build_qwen3_graph_sized_last_logits(
474 &self.cfg, &mut wm, batch, seq, true,
475 )?;
476 let compiled = compile_cache_ensure_graph(cache, key, graph, params, &prefill_opts);
477 Ok(compiled.run(&[("input_ids", ids_f32)]))
478 } else {
479 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
480 let (graph, params) = build_qwen3_graph_sized_last_logits(
481 &self.cfg, &mut wm, batch, seq, true,
482 )?;
483 let opts = self.profile_compile_options(false);
484 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
485 for (name, data) in ¶ms {
486 compiled.set_param(name, data);
487 }
488 Ok(compiled.run(&[("input_ids", ids_f32)]))
489 }
490 }
491
492 pub fn generate_cached(&mut self, n: usize, opts: SampleOpts) -> Result<Vec<u32>> {
494 self.generate_cached_with(n, opts, |_| {})
495 }
496
497 pub fn generate_cached_with(
504 &mut self,
505 n: usize,
506 opts: SampleOpts,
507 mut on_token: impl FnMut(u32),
508 ) -> Result<Vec<u32>> {
509 let start = self.tokens.len();
510 for _ in 0..n {
511 let tok = self.step_cached(opts)?;
512 on_token(tok);
513 }
514 Ok(self.tokens[start..].to_vec())
515 }
516
517 fn seed_cache_from_prompt(&mut self, opts: SampleOpts) -> Result<u32> {
522 let seq = self.tokens.len();
523 let batch = 1usize;
524 let kv_dim = self.cfg.kv_proj_dim();
525
526 let ids_f32: Vec<f32> = self.tokens.iter().map(|&i| i as f32).collect();
527 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
528 let (logits, kv) =
529 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
530 self.cache = Some(kv);
531
532 let vocab = self.cfg.vocab_size;
533 let needed = vocab;
534 if logits.len() < needed {
535 anyhow::bail!("prefill logits length {} < {}", logits.len(), needed);
536 }
537 let last_row = &logits[..vocab];
538 let tok = sample_token(last_row, opts) as u32;
539 self.tokens.push(tok);
540 Ok(tok)
541 }
542
543 pub fn tokens(&self) -> &[u32] {
545 &self.tokens
546 }
547
548 pub fn config(&self) -> &Qwen3Config {
549 &self.cfg
550 }
551
552 pub fn prefill_get_last_logits(&mut self, context: &[u32]) -> Result<Vec<f32>> {
562 if context.is_empty() {
563 anyhow::bail!("prefill_get_last_logits: empty context");
564 }
565 self.tokens.clear();
566 self.tokens.extend_from_slice(context);
567 self.cache = None;
568
569 let seq = context.len();
570 let batch = 1usize;
571 let kv_dim = self.cfg.kv_proj_dim();
572
573 let ids_f32: Vec<f32> = context.iter().map(|&i| i as f32).collect();
574 let outputs = self.run_prefill_with_cache(batch, seq, &ids_f32)?;
575 let (logits, kv) =
576 kv_from_prefill_outputs(outputs, batch, seq, kv_dim, self.cfg.num_hidden_layers)?;
577 self.cache = Some(kv);
578
579 let vocab = self.cfg.vocab_size;
580 let needed = vocab;
581 if logits.len() < needed {
582 anyhow::bail!("logits short: {} < {}", logits.len(), needed);
583 }
584 Ok(logits[..vocab].to_vec())
585 }
586
587 pub fn decode_get_logits(&mut self, input: u32) -> Result<Vec<f32>> {
595 let cache = self.cache.as_ref().ok_or_else(|| {
596 anyhow::anyhow!(
597 "decode_get_logits: cache not seeded; call prefill_get_last_logits first"
598 )
599 })?;
600 let past_seq = cache.past_len;
601
602 let mut wm = WeightMap::from_tensors(self.weights_cache.clone());
603 let (graph, params) =
604 build_qwen3_decode_graph_sized(&self.cfg, &mut wm, 1, past_seq)?;
605 let opts = self.profile_compile_options(true);
606 let mut compiled = Session::new(self.device).compile_with(graph, &opts);
607 for (name, data) in ¶ms {
608 compiled.set_param(name, data);
609 }
610
611 let (cos, sin) = compute_rope_slice(&self.cfg, past_seq);
612 let input_ids_f32 = [input as f32];
613 let key_strs: Vec<String> = (0..self.cfg.num_hidden_layers)
614 .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
615 .collect();
616 let mut inputs: Vec<(&str, &[f32])> =
617 Vec::with_capacity(3 + 2 * self.cfg.num_hidden_layers);
618 inputs.push(("input_ids", input_ids_f32.as_slice()));
619 inputs.push(("rope_cos", cos.as_slice()));
620 inputs.push(("rope_sin", sin.as_slice()));
621 for i in 0..self.cfg.num_hidden_layers {
622 let pk = &cache.layers_k[i];
623 let pv = &cache.layers_v[i];
624 inputs.push((&key_strs[2 * i], pk.as_slice()));
625 inputs.push((&key_strs[2 * i + 1], pv.as_slice()));
626 }
627
628 let outputs = compiled.run(&inputs);
629 let (logits, new_k, new_v) = split_decode_logits_kv(outputs, self.cfg.num_hidden_layers)?;
630
631 let cache_mut = self.cache.as_mut().unwrap();
632 cache_mut.past_len = past_seq + 1;
633 cache_mut.layers_k = new_k;
634 cache_mut.layers_v = new_v;
635 self.tokens.push(input);
636
637 Ok(logits)
638 }
639}
640
641fn compute_rope_slice(cfg: &Qwen3Config, pos: usize) -> (Vec<f32>, Vec<f32>) {
645 let dh = cfg.head_dim;
646 let half = dh / 2;
647 let mut cos = vec![0f32; half];
648 let mut sin = vec![0f32; half];
649 for i in 0..half {
650 let freq = 1.0 / cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
651 let angle = pos as f64 * freq;
652 let (s, c) = angle.sin_cos();
653 cos[i] = c as f32;
654 sin[i] = s as f32;
655 }
656 (cos, sin)
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use crate::config::Qwen3Config;
663
664 fn tiny_cfg() -> Qwen3Config {
665 Qwen3Config {
666 vocab_size: 16,
667 hidden_size: 16,
668 intermediate_size: 32,
669 num_hidden_layers: 2,
670 num_attention_heads: 4,
671 num_key_value_heads: 2,
672 head_dim: 8,
673 max_position_embeddings: 16,
674 rms_norm_eps: 1e-6,
675 rope_theta: 1_000_000.0,
676 hidden_act: "silu".into(),
677 tie_word_embeddings: false,
678 attention_bias: false,
679 qk_norm: true,
680 sliding_window: None,
681 max_window_layers: usize::MAX,
682 use_sliding_window: false,
683 num_experts: 0,
684 num_experts_used: 0,
685 expert_ffn_size: 0,
686 shared_expert_ffn_size: 0,
687 expert_weights_scale: 1.0,
688 }
689 }
690
691 fn synthetic_weights(cfg: &Qwen3Config) -> WeightMap {
692 let h = cfg.hidden_size;
693 let q_dim = cfg.q_proj_dim();
694 let kv_dim = cfg.kv_proj_dim();
695 let int_dim = cfg.intermediate_size;
696 let dh = cfg.head_dim;
697 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
698 let pat = |n: usize, salt: u32| -> Vec<f32> {
701 (0..n)
702 .map(|i| {
703 let x = ((i as u32).wrapping_mul(2654435761).wrapping_add(salt)) >> 8;
704 (x as f32 / (1u32 << 24) as f32) - 0.5
705 })
706 .collect()
707 };
708 t.insert(
709 "model.embed_tokens.weight".into(),
710 (pat(cfg.vocab_size * h, 1), vec![cfg.vocab_size, h]),
711 );
712 for i in 0..cfg.num_hidden_layers {
713 let lp = format!("model.layers.{i}");
714 t.insert(
715 format!("{lp}.input_layernorm.weight"),
716 (pat(h, 100 + i as u32), vec![h]),
717 );
718 t.insert(
719 format!("{lp}.post_attention_layernorm.weight"),
720 (pat(h, 200 + i as u32), vec![h]),
721 );
722 t.insert(
723 format!("{lp}.self_attn.q_proj.weight"),
724 (pat(q_dim * h, 300 + i as u32), vec![q_dim, h]),
725 );
726 t.insert(
727 format!("{lp}.self_attn.k_proj.weight"),
728 (pat(kv_dim * h, 400 + i as u32), vec![kv_dim, h]),
729 );
730 t.insert(
731 format!("{lp}.self_attn.v_proj.weight"),
732 (pat(kv_dim * h, 500 + i as u32), vec![kv_dim, h]),
733 );
734 t.insert(
735 format!("{lp}.self_attn.o_proj.weight"),
736 (pat(h * q_dim, 600 + i as u32), vec![h, q_dim]),
737 );
738 t.insert(
739 format!("{lp}.self_attn.q_norm.weight"),
740 (pat(dh, 700 + i as u32), vec![dh]),
741 );
742 t.insert(
743 format!("{lp}.self_attn.k_norm.weight"),
744 (pat(dh, 800 + i as u32), vec![dh]),
745 );
746 t.insert(
747 format!("{lp}.mlp.gate_proj.weight"),
748 (pat(int_dim * h, 900 + i as u32), vec![int_dim, h]),
749 );
750 t.insert(
751 format!("{lp}.mlp.up_proj.weight"),
752 (pat(int_dim * h, 1000 + i as u32), vec![int_dim, h]),
753 );
754 t.insert(
755 format!("{lp}.mlp.down_proj.weight"),
756 (pat(h * int_dim, 1100 + i as u32), vec![h, int_dim]),
757 );
758 }
759 t.insert("model.norm.weight".into(), (pat(h, 2000), vec![h]));
760 t.insert(
761 "lm_head.weight".into(),
762 (pat(cfg.vocab_size * h, 3000), vec![cfg.vocab_size, h]),
763 );
764 WeightMap::from_tensors(t)
765 }
766
767 #[test]
768 fn generator_drains_loader_and_runs_one_step() {
769 let cfg = tiny_cfg();
770 let mut wm = synthetic_weights(&cfg);
771 let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
772 assert_eq!(wm.len(), 0, "loader should be drained");
773 gn.prefill(&[1, 2, 3]);
774 let t = gn.step(SampleOpts::greedy()).unwrap();
775 assert!((t as usize) < cfg.vocab_size);
776 assert_eq!(gn.tokens().len(), 4);
777 }
778
779 #[test]
780 fn generate_n_appends_n_tokens() {
781 let cfg = tiny_cfg();
782 let mut wm = synthetic_weights(&cfg);
783 let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
784 gn.prefill(&[5, 6]);
785 let new_tokens = gn.generate(3, SampleOpts::greedy()).unwrap();
786 assert_eq!(new_tokens.len(), 3);
787 assert_eq!(gn.tokens().len(), 5);
788 for t in &new_tokens {
789 assert!((*t as usize) < cfg.vocab_size);
790 }
791 }
792
793 #[test]
794 fn step_without_prefill_errors() {
795 let cfg = tiny_cfg();
796 let mut wm = synthetic_weights(&cfg);
797 let mut gn = Qwen3Generator::from_loader(cfg, &mut wm, Device::Cpu).unwrap();
798 let r = gn.step(SampleOpts::greedy());
799 assert!(r.is_err());
800 }
801
802 #[test]
803 fn cached_matches_naive_on_greedy() {
804 let cfg = tiny_cfg();
811 let prompt: Vec<u32> = vec![1, 2, 3, 5];
812 let steps = 4;
813
814 let mut wm_n = synthetic_weights(&cfg);
815 let mut gn_naive =
816 Qwen3Generator::from_loader(cfg.clone(), &mut wm_n, Device::Cpu).unwrap();
817 gn_naive.prefill_compile_cache = None;
818 gn_naive.decode_compile_cache = None;
819 gn_naive.prefill(&prompt);
820 let naive_tokens = gn_naive.generate(steps, SampleOpts::greedy()).unwrap();
821
822 let mut wm_c = synthetic_weights(&cfg);
823 let mut gn_cached =
824 Qwen3Generator::from_loader(cfg.clone(), &mut wm_c, Device::Cpu).unwrap();
825 gn_cached.prefill(&prompt);
826 let cached_tokens = gn_cached
827 .generate_cached(steps, SampleOpts::greedy())
828 .unwrap();
829
830 assert_eq!(
831 cached_tokens, naive_tokens,
832 "cached vs naive token mismatch — KV cache or kernel-Lq!=Lk bug"
833 );
834 }
835
836 #[test]
837 fn cached_step_advances_cache_invariant() {
838 let cfg = tiny_cfg();
839 let mut wm = synthetic_weights(&cfg);
840 let mut gn = Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap();
841 gn.prefill(&[1, 2, 3]);
842 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
843 assert_eq!(gn.tokens().len(), 4);
845 assert_eq!(gn.cache.as_ref().unwrap().past_len, 3);
846 let _ = gn.step_cached(SampleOpts::greedy()).unwrap();
847 assert_eq!(gn.tokens().len(), 5);
849 assert_eq!(gn.cache.as_ref().unwrap().past_len, 4);
850 }
851
852 #[test]
853 fn bucketed_decode_matches_oneshot() {
854 let cfg = tiny_cfg();
860 let prompt: Vec<u32> = vec![1, 2, 3, 5];
861 let steps = 6;
862
863 let mut wm_one = synthetic_weights(&cfg);
864 let mut gn_one =
865 Qwen3Generator::from_loader(cfg.clone(), &mut wm_one, Device::Cpu).unwrap();
866 gn_one.prefill(&prompt);
867 let oneshot_tokens = gn_one.generate_cached(steps, SampleOpts::greedy()).unwrap();
868
869 let mut wm_buc = synthetic_weights(&cfg);
870 let mut gn_buc = Qwen3Generator::from_loader(cfg.clone(), &mut wm_buc, Device::Cpu)
871 .unwrap()
872 .with_decode_cache(32);
873 gn_buc.prefill(&prompt);
874 let bucketed_tokens = gn_buc.generate_cached(steps, SampleOpts::greedy()).unwrap();
875
876 assert_eq!(
877 bucketed_tokens, oneshot_tokens,
878 "bucketed-cache decode diverged from one-shot decode — \
879 mask, padding, or output-slice bug"
880 );
881 }
882
883 #[test]
884 fn bucketed_decode_q_proj_seq_is_one() {
885 use rlx_ir::Op;
886
887 let cfg = tiny_cfg();
888 let mut wm = synthetic_weights(&cfg);
889 let (graph, _) = build_qwen3_decode_graph_sized_ext(&cfg, &mut wm, 1, 4, true).unwrap();
890 for node in graph.nodes() {
891 if let Op::MatMul = &node.op {
892 let sh = graph.shape(node.id);
893 if sh.rank() == 3 && sh.dim(2).unwrap_static() == cfg.q_proj_dim() {
894 assert_eq!(
895 sh.dim(1).unwrap_static(),
896 1,
897 "decode q_proj matmul seq dim must be 1, got {sh} on node {}",
898 node.id
899 );
900 }
901 }
902 }
903
904 let fused = rlx_opt::CompilePipeline::new(rlx_opt::FusionTarget::Metal)
905 .with_assert_fusion_clean(false)
906 .compile_graph(graph)
907 .lir
908 .into_graph();
909 for node in fused.nodes() {
910 if let Op::Narrow { len, .. } = &node.op {
911 let sh = fused.shape(node.id);
912 if sh.rank() == 3 && *len == cfg.q_proj_dim() {
913 assert_eq!(
914 sh.dim(1).unwrap_static(),
915 1,
916 "fused decode q narrow seq dim must be 1, got {sh} on node {}",
917 node.id
918 );
919 }
920 }
921 }
922 }
923
924 #[test]
925 fn prefill_compile_cache_does_not_change_output() {
926 let cfg = tiny_cfg();
927 let prompt: Vec<u32> = vec![1, 2, 3, 5];
928 let mut wm_a = synthetic_weights(&cfg);
929 let mut gn_a = Qwen3Generator::from_loader(cfg.clone(), &mut wm_a, Device::Cpu).unwrap();
930 gn_a.prefill(&prompt);
931 let a = gn_a.generate_cached(4, SampleOpts::greedy()).unwrap();
932
933 let mut wm_b = synthetic_weights(&cfg);
934 let mut gn_b = Qwen3Generator::from_loader(cfg.clone(), &mut wm_b, Device::Cpu)
935 .unwrap()
936 .with_prefill_cache(4);
937 gn_b.prefill(&prompt);
938 let b = gn_b.generate_cached(4, SampleOpts::greedy()).unwrap();
939
940 assert_eq!(a, b, "enabling prefill_cache must not change output");
941 }
942
943 #[test]
944 fn greedy_is_deterministic_across_runs() {
945 let cfg = tiny_cfg();
946 let weights = synthetic_weights(&cfg);
947 let mk = || {
948 let mut wm = WeightMap::from_tensors(weights_as_hashmap(&weights));
949 Qwen3Generator::from_loader(cfg.clone(), &mut wm, Device::Cpu).unwrap()
950 };
951 let mut a = mk();
952 let mut b = mk();
953 a.prefill(&[1, 2, 3]);
954 b.prefill(&[1, 2, 3]);
955 let ta = a.generate(4, SampleOpts::greedy()).unwrap();
956 let tb = b.generate(4, SampleOpts::greedy()).unwrap();
957 assert_eq!(ta, tb);
958 }
959
960 fn weights_as_hashmap(wm: &WeightMap) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
961 let _ = wm; let cfg = tiny_cfg();
967 let mut new = synthetic_weights(&cfg);
968 let keys: Vec<String> = new.keys().map(|s| s.to_string()).collect();
969 let mut out = HashMap::new();
970 for k in keys {
971 out.insert(k.clone(), new.take(&k).unwrap());
972 }
973 out
974 }
975}