1use std::collections::HashMap;
20
21use ferrum_kernels::backend::{Backend, KvCache};
22use ferrum_quantization::{Linear, WeightLoader};
23use ferrum_types::Result;
24
25use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
26
27#[derive(Clone, Debug)]
30pub struct LlamaFamilyConfig {
31 pub hidden_size: usize,
32 pub intermediate_size: usize,
33 pub num_heads: usize,
34 pub num_kv_heads: usize,
35 pub head_dim: usize,
36 pub num_layers: usize,
37 pub vocab_size: usize,
38 pub max_seq_len: usize,
39 pub rms_norm_eps: f32,
40 pub rope_theta: f64,
41 pub has_qk_norm: bool,
44 pub sliding_window: usize,
47}
48
49impl LlamaFamilyConfig {
50 pub fn to_runtime(&self) -> LlmRuntimeConfig {
51 LlmRuntimeConfig {
52 hidden_size: self.hidden_size,
53 num_layers: self.num_layers,
54 num_kv_heads: self.num_kv_heads,
55 head_dim: self.head_dim,
56 vocab_size: self.vocab_size,
57 max_seq_len: self.max_seq_len,
58 }
59 }
60
61 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
65 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
66 let head_dim = def
67 .extra_params
68 .get("head_dim")
69 .and_then(|v| v.as_u64())
70 .map(|v| v as usize)
71 .unwrap_or(def.hidden_size / def.num_attention_heads);
72 let sliding_window = def
75 .extra_params
76 .get("sliding_window")
77 .and_then(|v| v.as_u64())
78 .map(|v| v as usize)
79 .unwrap_or(0);
80
81 LlamaFamilyConfigBase {
82 hidden_size: def.hidden_size,
83 intermediate_size: def.intermediate_size,
84 num_heads: def.num_attention_heads,
85 num_kv_heads,
86 head_dim,
87 num_layers: def.num_hidden_layers,
88 vocab_size: def.vocab_size,
89 max_seq_len: def.max_position_embeddings,
90 rms_norm_eps: def.norm_eps as f32,
91 rope_theta_opt: def.rope_theta,
92 sliding_window,
93 }
94 }
95
96 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
97 Self {
98 hidden_size: b.hidden_size,
99 intermediate_size: b.intermediate_size,
100 num_heads: b.num_heads,
101 num_kv_heads: b.num_kv_heads,
102 head_dim: b.head_dim,
103 num_layers: b.num_layers,
104 vocab_size: b.vocab_size,
105 max_seq_len: b.max_seq_len,
106 rms_norm_eps: b.rms_norm_eps,
107 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
108 has_qk_norm,
109 sliding_window: b.sliding_window,
110 }
111 }
112
113 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
115 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
116 }
117
118 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
122 Self::from_base(Self::from_def_base(def), 500_000.0, false)
123 }
124
125 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
127 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
128 }
129
130 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
134 Self::from_base(Self::from_def_base(def), 10_000.0, false)
135 }
136}
137
138struct LlamaFamilyConfigBase {
139 hidden_size: usize,
140 intermediate_size: usize,
141 num_heads: usize,
142 num_kv_heads: usize,
143 head_dim: usize,
144 num_layers: usize,
145 vocab_size: usize,
146 max_seq_len: usize,
147 rms_norm_eps: f32,
148 rope_theta_opt: Option<f64>,
149 sliding_window: usize,
150}
151
152pub struct LlamaFamilyLayer<B: Backend> {
155 pub input_ln_w: B::Buffer,
156 pub qkv_proj: Box<dyn Linear<B>>,
157 pub q_norm_w: Option<B::Buffer>,
159 pub k_norm_w: Option<B::Buffer>,
160 pub o_proj: Box<dyn Linear<B>>,
161 pub post_ln_w: B::Buffer,
162 pub gate_up_proj: Box<dyn Linear<B>>,
163 pub down_proj: Box<dyn Linear<B>>,
164}
165
166pub struct RopeCache<B: Backend> {
168 pub cos: B::Buffer,
169 pub sin: B::Buffer,
170}
171
172pub struct LlamaFamilyScratch<B: Backend> {
178 pub residual: Option<B::Buffer>,
189 pub norm_out: B::Buffer,
190 pub qkv_out: B::Buffer,
191 pub q_single: B::Buffer,
199 pub k_single: B::Buffer,
200 pub v_single: B::Buffer,
201 pub q_head_major_single: B::Buffer,
202 pub k_head_major_single: B::Buffer,
203 pub v_head_major_single: B::Buffer,
204 pub attn_head_major_single: B::Buffer,
205 pub attn_flat_single: B::Buffer,
206 pub batch_logits: B::Buffer,
209 pub q_buf: B::Buffer,
211 pub k_buf: B::Buffer,
212 pub v_buf: B::Buffer,
213 pub q_head_major: B::Buffer,
215 pub k_head_major: B::Buffer,
218 pub v_head_major: B::Buffer,
219 pub attn_head_major_out: B::Buffer,
221 pub attn_flat: B::Buffer,
223 pub o_proj_out: B::Buffer,
224 pub gate_up_out: B::Buffer,
225 pub silu_out: B::Buffer,
226 pub mlp_out: B::Buffer,
227 pub last_hidden: B::Buffer,
231 pub last_normed: B::Buffer,
233 pub logits: B::Buffer,
235 pub max_tokens: usize,
237}
238
239impl<B: Backend> LlamaFamilyScratch<B> {
240 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
241 let h = cfg.hidden_size;
242 let im = cfg.intermediate_size;
243 let q_dim = cfg.num_heads * cfg.head_dim;
244 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
245 let qkv_dim = q_dim + 2 * kv_dim;
246 let t = max_tokens;
247 Self {
248 residual: Some(B::alloc(t * h)),
249 norm_out: B::alloc(t * h),
250 qkv_out: B::alloc(t * qkv_dim),
251 q_buf: B::alloc(t * q_dim),
252 k_buf: B::alloc(t * kv_dim),
253 v_buf: B::alloc(t * kv_dim),
254 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
255 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
256 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
257 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
258 attn_flat: B::alloc(t * q_dim),
259 o_proj_out: B::alloc(t * h),
260 gate_up_out: B::alloc(t * 2 * im),
261 silu_out: B::alloc(t * im),
262 mlp_out: B::alloc(t * h),
263 last_hidden: B::alloc(h),
264 last_normed: B::alloc(h),
265 logits: B::alloc(cfg.vocab_size),
266 q_single: B::alloc(q_dim),
267 k_single: B::alloc(kv_dim),
268 v_single: B::alloc(kv_dim),
269 q_head_major_single: B::alloc(q_dim),
270 k_head_major_single: B::alloc(kv_dim),
271 v_head_major_single: B::alloc(kv_dim),
272 attn_head_major_single: B::alloc(q_dim),
273 attn_flat_single: B::alloc(q_dim),
274 batch_logits: B::alloc(t * cfg.vocab_size),
275 max_tokens: t,
276 }
277 }
278}
279
280pub struct LlamaFamilyModel<B: Backend> {
284 pub cfg: LlamaFamilyConfig,
285 pub runtime_cfg: LlmRuntimeConfig,
286
287 pub embed: Option<B::Buffer>,
291 pub layers: Vec<LlamaFamilyLayer<B>>,
292 pub final_norm_w: B::Buffer,
293 pub lm_head: Option<Box<dyn Linear<B>>>,
295
296 pub rope: RopeCache<B>,
297 pub scratch: LlamaFamilyScratch<B>,
298
299 pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
301 kv_free_pool: Vec<Vec<KvCache<B>>>,
306
307 graph_warmup: usize,
311 graph_capture_failed: bool,
314}
315
316impl<B: Backend> LlamaFamilyModel<B> {
317 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
322 {
327 let mut ctx = B::new_context();
328 B::reset_graph(&mut ctx);
329 }
330 let rope = build_rope_cache::<B>(&cfg);
331 let scratch = LlamaFamilyScratch::alloc(&cfg, 1); let embed = loader.load_tensor("model.embed_tokens.weight")?;
335
336 let mut layers = Vec::with_capacity(cfg.num_layers);
338 for li in 0..cfg.num_layers {
339 let prefix = format!("model.layers.{li}");
340 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
341 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
342 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
343 let post_ln_w =
344 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
345 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
346 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
347
348 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
349 let q = loader
350 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
351 .ok();
352 let k = loader
353 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
354 .ok();
355 (q, k)
356 } else {
357 (None, None)
358 };
359
360 layers.push(LlamaFamilyLayer {
361 input_ln_w,
362 qkv_proj,
363 q_norm_w,
364 k_norm_w,
365 o_proj,
366 post_ln_w,
367 gate_up_proj,
368 down_proj,
369 });
370 }
371
372 let final_norm_w = loader.load_tensor("model.norm.weight")?;
373
374 let lm_head = if loader.has_tensor("lm_head.weight") {
382 loader.load_linear("lm_head")?
383 } else {
384 tracing::info!(
385 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
386 );
387 let as_linear = loader.load_linear("model.embed_tokens")?;
388 if as_linear.out_features() != cfg.vocab_size
390 || as_linear.in_features() != cfg.hidden_size
391 {
392 return Err(ferrum_types::FerrumError::model(format!(
393 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
394 as_linear.out_features(),
395 as_linear.in_features(),
396 cfg.vocab_size,
397 cfg.hidden_size
398 )));
399 }
400 as_linear
401 };
402
403 let runtime_cfg = cfg.to_runtime();
404 Ok(Self {
405 cfg,
406 runtime_cfg,
407 embed: Some(embed),
408 layers,
409 final_norm_w,
410 lm_head: Some(lm_head),
411 rope,
412 scratch,
413 kv_caches: HashMap::new(),
414 kv_free_pool: Vec::new(),
415 graph_warmup: 0,
416 graph_capture_failed: false,
417 })
418 }
419
420 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
432 {
434 let mut ctx = B::new_context();
435 B::reset_graph(&mut ctx);
436 }
437 let rope = build_rope_cache::<B>(&cfg);
438 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
439
440 let mut layers = Vec::with_capacity(cfg.num_layers);
441 for li in 0..cfg.num_layers {
442 let prefix = format!("model.layers.{li}");
443 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
444 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
445 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
446 let post_ln_w =
447 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
448 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
449 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
450
451 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
452 let q = loader
453 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
454 .ok();
455 let k = loader
456 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
457 .ok();
458 (q, k)
459 } else {
460 (None, None)
461 };
462
463 layers.push(LlamaFamilyLayer {
464 input_ln_w,
465 qkv_proj,
466 q_norm_w,
467 k_norm_w,
468 o_proj,
469 post_ln_w,
470 gate_up_proj,
471 down_proj,
472 });
473 }
474
475 let final_norm_w = loader.load_tensor("model.norm.weight")?;
476
477 let runtime_cfg = cfg.to_runtime();
478 Ok(Self {
479 cfg,
480 runtime_cfg,
481 embed: None,
482 layers,
483 final_norm_w,
484 lm_head: None,
485 rope,
486 scratch,
487 kv_caches: HashMap::new(),
488 kv_free_pool: Vec::new(),
489 graph_warmup: 0,
490 graph_capture_failed: false,
491 })
492 }
493
494 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
496 if self.scratch.max_tokens < tokens {
497 {
502 let mut ctx = B::new_context();
503 B::reset_graph(&mut ctx);
504 }
505 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
506 self.graph_warmup = 0;
507 self.graph_capture_failed = false;
508 }
509 }
510
511 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
515 if self.kv_caches.contains_key(cache_id) {
516 return;
517 }
518 let nkv = self.cfg.num_kv_heads;
519 let hd = self.cfg.head_dim;
520 let max = self.cfg.max_seq_len;
521
522 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
525 (0..self.cfg.num_layers)
526 .map(|_| KvCache {
527 k: B::alloc(nkv * max * hd),
528 v: B::alloc(nkv * max * hd),
529 len: 0,
530 capacity: max,
531 num_kv_heads: nkv,
532 head_dim: hd,
533 })
534 .collect()
535 });
536 for c in caches.iter_mut() {
540 c.len = 0;
541 }
542 self.kv_caches.insert(cache_id.to_string(), caches);
543 }
544
545 #[allow(clippy::too_many_arguments)]
550 pub(crate) fn forward_layer(
551 &mut self,
552 ctx: &mut B::Context,
553 li: usize,
554 cache_id: &str,
555 residual: &mut B::Buffer,
556 pos_offset: usize,
557 tokens: usize,
558 ) {
559 let layer = &self.layers[li];
560 let cfg = &self.cfg;
561 let h = cfg.hidden_size;
562 let nh = cfg.num_heads;
563 let nkv = cfg.num_kv_heads;
564 let hd = cfg.head_dim;
565 let im = cfg.intermediate_size;
566 let eps = cfg.rms_norm_eps;
567 let q_dim = nh * hd;
568 let kv_dim = nkv * hd;
569
570 B::rms_norm(
572 ctx,
573 residual,
574 &layer.input_ln_w,
575 eps,
576 &mut self.scratch.norm_out,
577 tokens,
578 h,
579 );
580
581 layer.qkv_proj.forward(
583 ctx,
584 &self.scratch.norm_out,
585 &mut self.scratch.qkv_out,
586 tokens,
587 );
588
589 B::split_qkv(
591 ctx,
592 &self.scratch.qkv_out,
593 &mut self.scratch.q_buf,
594 &mut self.scratch.k_buf,
595 &mut self.scratch.v_buf,
596 tokens,
597 q_dim,
598 kv_dim,
599 );
600
601 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
605 let dummy = &layer.input_ln_w;
606 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
607 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
608
609 B::qk_norm_rope(
610 ctx,
611 &self.scratch.q_buf,
612 q_norm_w,
613 &self.rope.cos,
614 &self.rope.sin,
615 &mut self.scratch.q_head_major,
616 tokens,
617 nh,
618 hd,
619 pos_offset,
620 eps,
621 qk_mode,
622 );
623 B::qk_norm_rope(
624 ctx,
625 &self.scratch.k_buf,
626 k_norm_w,
627 &self.rope.cos,
628 &self.rope.sin,
629 &mut self.scratch.k_head_major,
630 tokens,
631 nkv,
632 hd,
633 pos_offset,
634 eps,
635 qk_mode,
636 );
637 B::qk_norm_rope(
638 ctx,
639 &self.scratch.v_buf,
640 dummy, &self.rope.cos,
642 &self.rope.sin,
643 &mut self.scratch.v_head_major,
644 tokens,
645 nkv,
646 hd,
647 pos_offset,
648 eps,
649 0, );
651
652 let caches = self
654 .kv_caches
655 .get_mut(cache_id)
656 .expect("ensure_kv must be called before forward_layer");
657 let cache = &mut caches[li];
658 B::kv_cache_append_head_major(
659 ctx,
660 &mut cache.k,
661 &mut cache.v,
662 cache.len,
663 cache.capacity,
664 &self.scratch.k_head_major,
665 &self.scratch.v_head_major,
666 tokens,
667 nkv,
668 hd,
669 );
670 cache.len += tokens;
671 let kv_len = cache.len;
672 let kv_stride = cache.capacity;
673
674 let attn_cfg = ferrum_kernels::backend::AttnConfig {
681 num_heads: nh,
682 num_kv_heads: nkv,
683 head_dim: hd,
684 causal: true,
685 scale: 1.0 / (hd as f32).sqrt(),
686 kv_seq_stride: kv_stride,
687 sliding_window: cfg.sliding_window,
688 };
689 B::flash_attention(
690 ctx,
691 &self.scratch.q_head_major,
692 &cache.k,
693 &cache.v,
694 &mut self.scratch.attn_head_major_out,
695 1,
696 tokens,
697 kv_len,
698 pos_offset,
699 &attn_cfg,
700 );
701
702 B::transpose_head_to_token(
704 ctx,
705 &self.scratch.attn_head_major_out,
706 &mut self.scratch.attn_flat,
707 tokens,
708 nh,
709 hd,
710 );
711
712 layer.o_proj.forward(
714 ctx,
715 &self.scratch.attn_flat,
716 &mut self.scratch.o_proj_out,
717 tokens,
718 );
719
720 B::fused_add_rms_norm(
724 ctx,
725 residual,
726 &self.scratch.o_proj_out,
727 &layer.post_ln_w,
728 eps,
729 &mut self.scratch.norm_out,
730 tokens,
731 h,
732 );
733
734 layer.gate_up_proj.forward(
736 ctx,
737 &self.scratch.norm_out,
738 &mut self.scratch.gate_up_out,
739 tokens,
740 );
741
742 B::fused_silu_mul_split(
744 ctx,
745 &self.scratch.gate_up_out,
746 &mut self.scratch.silu_out,
747 tokens,
748 im,
749 );
750
751 layer.down_proj.forward(
753 ctx,
754 &self.scratch.silu_out,
755 &mut self.scratch.mlp_out,
756 tokens,
757 );
758
759 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
761 }
762
763 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
766 let seq_len = tokens.len();
767 assert!(seq_len > 0, "prefill called with empty token list");
768 self.ensure_scratch(seq_len);
769 self.ensure_kv(cache_id);
770
771 let h = self.cfg.hidden_size;
772 let vocab = self.cfg.vocab_size;
773 let mut ctx = B::new_context();
774
775 let mut residual = self
782 .scratch
783 .residual
784 .take()
785 .expect("scratch residual missing (previous call didn't restore)");
786 let embed = self
787 .embed
788 .as_ref()
789 .expect("prefill_internal called on backbone-only model (no embed)");
790 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
791
792 for li in 0..self.cfg.num_layers {
793 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
794 }
795
796 B::copy_slice(
798 &mut ctx,
799 &residual,
800 (seq_len - 1) * h,
801 &mut self.scratch.last_hidden,
802 0,
803 h,
804 );
805
806 B::rms_norm(
808 &mut ctx,
809 &self.scratch.last_hidden,
810 &self.final_norm_w,
811 self.cfg.rms_norm_eps,
812 &mut self.scratch.last_normed,
813 1,
814 h,
815 );
816
817 let lm_head = self
819 .lm_head
820 .as_ref()
821 .expect("prefill_internal called on backbone-only model (no lm_head)");
822 lm_head.forward(
823 &mut ctx,
824 &self.scratch.last_normed,
825 &mut self.scratch.logits,
826 1,
827 );
828
829 B::sync(&mut ctx);
836
837 self.scratch.residual = Some(residual);
839
840 B::to_vec(&self.scratch.logits, vocab)
841 }
842
843 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
845 self.ensure_scratch(1);
846 self.ensure_kv(cache_id);
847
848 let h = self.cfg.hidden_size;
849 let vocab = self.cfg.vocab_size;
850
851 let mut ctx = B::new_context();
854
855 const GRAPH_WARMUP: usize = 3;
860 let graph_enabled = std::env::var("FERRUM_CUDA_GRAPH").is_ok();
861
862 if graph_enabled {
863 B::set_decode_state(&mut ctx, token, pos);
866
867 match B::replay_last_graph(&mut ctx) {
869 Ok(true) => {
870 B::sync(&mut ctx);
871 return B::to_vec(&self.scratch.logits, vocab);
872 }
873 Ok(false) => { }
874 Err(_) => { }
875 }
876 }
877
878 let should_capture =
879 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
880
881 if should_capture {
882 B::set_dev_state_mode(&mut ctx, true);
883 if B::begin_graph_capture(&mut ctx).is_err() {
884 self.graph_capture_failed = true;
885 B::set_dev_state_mode(&mut ctx, false);
886 }
887 }
888
889 let mut residual = self
895 .scratch
896 .residual
897 .take()
898 .expect("scratch residual missing (previous call didn't restore)");
899 let embed = self
900 .embed
901 .as_ref()
902 .expect("decode_internal called on backbone-only model (no embed)");
903 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
904
905 for li in 0..self.cfg.num_layers {
906 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
907 }
908
909 B::rms_norm(
910 &mut ctx,
911 &residual,
912 &self.final_norm_w,
913 self.cfg.rms_norm_eps,
914 &mut self.scratch.last_normed,
915 1,
916 h,
917 );
918
919 let lm_head = self
920 .lm_head
921 .as_ref()
922 .expect("decode_internal called on backbone-only model (no lm_head)");
923 lm_head.forward(
924 &mut ctx,
925 &self.scratch.last_normed,
926 &mut self.scratch.logits,
927 1,
928 );
929
930 if should_capture && !self.graph_capture_failed {
931 if B::end_graph_capture(&mut ctx).is_err() {
932 self.graph_capture_failed = true;
933 } else {
934 if B::replay_last_graph(&mut ctx).is_err() {
941 self.graph_capture_failed = true;
942 }
943 }
944 B::set_dev_state_mode(&mut ctx, false);
945 } else {
946 self.graph_warmup += 1;
947 }
948
949 B::sync(&mut ctx);
956 self.scratch.residual = Some(residual);
957
958 B::to_vec(&self.scratch.logits, vocab)
959 }
960
961 pub fn prefill_from_embeds(
970 &mut self,
971 cache_id: &str,
972 embeds: &[f32],
973 seq_len: usize,
974 ) -> Vec<f32> {
975 let h = self.cfg.hidden_size;
976 assert_eq!(
977 embeds.len(),
978 seq_len * h,
979 "embeds length {} != seq_len * hidden_size {}",
980 embeds.len(),
981 seq_len * h
982 );
983 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
984
985 self.ensure_scratch(seq_len);
986 self.ensure_kv(cache_id);
987
988 let mut ctx = B::new_context();
989 let mut residual = self
990 .scratch
991 .residual
992 .take()
993 .expect("scratch residual missing (previous call didn't restore)");
994
995 let embed_buf = B::from_slice(embeds);
997 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
998
999 for li in 0..self.cfg.num_layers {
1000 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1001 }
1002
1003 B::copy_slice(
1004 &mut ctx,
1005 &residual,
1006 (seq_len - 1) * h,
1007 &mut self.scratch.last_hidden,
1008 0,
1009 h,
1010 );
1011 B::sync(&mut ctx);
1012 self.scratch.residual = Some(residual);
1013 B::to_vec(&self.scratch.last_hidden, h)
1014 }
1015
1016 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1020 let h = self.cfg.hidden_size;
1021 assert_eq!(
1022 embed.len(),
1023 h,
1024 "embed length {} != hidden_size {}",
1025 embed.len(),
1026 h
1027 );
1028
1029 self.ensure_scratch(1);
1030 self.ensure_kv(cache_id);
1031
1032 let mut ctx = B::new_context();
1033 let mut residual = self
1034 .scratch
1035 .residual
1036 .take()
1037 .expect("scratch residual missing (previous call didn't restore)");
1038
1039 let embed_buf = B::from_slice(embed);
1040 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1041
1042 for li in 0..self.cfg.num_layers {
1043 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1044 }
1045
1046 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1047 B::sync(&mut ctx);
1048 self.scratch.residual = Some(residual);
1049 B::to_vec(&self.scratch.last_hidden, h)
1050 }
1051
1052 pub fn prefill_all_post_norm(
1063 &mut self,
1064 cache_id: &str,
1065 embeds: &[f32],
1066 seq_len: usize,
1067 pos_offset: usize,
1068 ) -> Vec<f32> {
1069 let h = self.cfg.hidden_size;
1070 assert_eq!(
1071 embeds.len(),
1072 seq_len * h,
1073 "embeds length {} != seq_len * hidden_size {}",
1074 embeds.len(),
1075 seq_len * h
1076 );
1077 assert!(seq_len > 0);
1078
1079 self.ensure_scratch(seq_len);
1080 self.ensure_kv(cache_id);
1081
1082 let mut ctx = B::new_context();
1083 let mut residual = self
1084 .scratch
1085 .residual
1086 .take()
1087 .expect("scratch residual missing (previous call didn't restore)");
1088
1089 let embed_buf = B::from_slice(embeds);
1090 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1091
1092 for li in 0..self.cfg.num_layers {
1093 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1094 }
1095
1096 B::rms_norm(
1098 &mut ctx,
1099 &residual,
1100 &self.final_norm_w,
1101 self.cfg.rms_norm_eps,
1102 &mut self.scratch.norm_out,
1103 seq_len,
1104 h,
1105 );
1106 B::sync(&mut ctx);
1107 self.scratch.residual = Some(residual);
1108 B::to_vec(&self.scratch.norm_out, seq_len * h)
1109 }
1110
1111 pub fn decode_post_norm_from_embed(
1115 &mut self,
1116 cache_id: &str,
1117 embed: &[f32],
1118 pos: u32,
1119 ) -> Vec<f32> {
1120 let h = self.cfg.hidden_size;
1121 assert_eq!(embed.len(), h);
1122
1123 self.ensure_scratch(1);
1124 self.ensure_kv(cache_id);
1125
1126 let mut ctx = B::new_context();
1127 let mut residual = self
1128 .scratch
1129 .residual
1130 .take()
1131 .expect("scratch residual missing (previous call didn't restore)");
1132
1133 let embed_buf = B::from_slice(embed);
1134 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1135
1136 for li in 0..self.cfg.num_layers {
1137 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1138 }
1139
1140 B::rms_norm(
1141 &mut ctx,
1142 &residual,
1143 &self.final_norm_w,
1144 self.cfg.rms_norm_eps,
1145 &mut self.scratch.last_normed,
1146 1,
1147 h,
1148 );
1149 B::sync(&mut ctx);
1150 self.scratch.residual = Some(residual);
1151 B::to_vec(&self.scratch.last_normed, h)
1152 }
1153
1154 pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1162 let m = batch.len();
1163 if m == 0 {
1164 return Vec::new();
1165 }
1166 if m == 1 {
1167 let (cid, tok, pos) = &batch[0];
1168 return vec![self.decode_internal(cid, *tok, *pos)];
1169 }
1170
1171 for (cid, _, _) in batch {
1173 self.ensure_kv(cid);
1174 }
1175 self.ensure_scratch(m);
1176
1177 let h = self.cfg.hidden_size;
1178 let vocab = self.cfg.vocab_size;
1179 let mut ctx = B::new_context();
1180
1181 let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1183 let mut residual = self
1184 .scratch
1185 .residual
1186 .take()
1187 .expect("scratch residual missing (previous call didn't restore)");
1188 let embed = self
1189 .embed
1190 .as_ref()
1191 .expect("decode_batch_internal called on backbone-only model (no embed)");
1192 B::embedding_lookup(&mut ctx, embed, &tokens, &mut residual, h);
1193
1194 for li in 0..self.cfg.num_layers {
1196 self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m);
1197 }
1198
1199 B::rms_norm(
1201 &mut ctx,
1202 &residual,
1203 &self.final_norm_w,
1204 self.cfg.rms_norm_eps,
1205 &mut self.scratch.norm_out,
1206 m,
1207 h,
1208 );
1209
1210 let lm_head = self
1212 .lm_head
1213 .as_ref()
1214 .expect("decode_batch_internal called on backbone-only model (no lm_head)");
1215 lm_head.forward(
1216 &mut ctx,
1217 &self.scratch.norm_out,
1218 &mut self.scratch.batch_logits,
1219 m,
1220 );
1221
1222 B::sync(&mut ctx);
1224 self.scratch.residual = Some(residual);
1225
1226 let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
1228 (0..m)
1229 .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
1230 .collect()
1231 }
1232
1233 fn forward_layer_batched_decode(
1235 &mut self,
1236 ctx: &mut B::Context,
1237 li: usize,
1238 batch: &[(String, u32, u32)],
1239 residual: &mut B::Buffer,
1240 m: usize,
1241 ) {
1242 let cfg = &self.cfg;
1243 let h = cfg.hidden_size;
1244 let nh = cfg.num_heads;
1245 let nkv = cfg.num_kv_heads;
1246 let hd = cfg.head_dim;
1247 let im = cfg.intermediate_size;
1248 let eps = cfg.rms_norm_eps;
1249 let q_dim = nh * hd;
1250 let kv_dim = nkv * hd;
1251
1252 let layer = &self.layers[li];
1253 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
1254 let dummy_w = &layer.input_ln_w;
1255 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy_w);
1256 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy_w);
1257
1258 B::rms_norm(
1260 ctx,
1261 residual,
1262 &layer.input_ln_w,
1263 eps,
1264 &mut self.scratch.norm_out,
1265 m,
1266 h,
1267 );
1268
1269 layer
1271 .qkv_proj
1272 .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
1273
1274 B::split_qkv(
1276 ctx,
1277 &self.scratch.qkv_out,
1278 &mut self.scratch.q_buf,
1279 &mut self.scratch.k_buf,
1280 &mut self.scratch.v_buf,
1281 m,
1282 q_dim,
1283 kv_dim,
1284 );
1285
1286 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1289 let pos_i = *pos as usize;
1290
1291 B::copy_slice(
1293 ctx,
1294 &self.scratch.q_buf,
1295 i * q_dim,
1296 &mut self.scratch.q_single,
1297 0,
1298 q_dim,
1299 );
1300 B::copy_slice(
1301 ctx,
1302 &self.scratch.k_buf,
1303 i * kv_dim,
1304 &mut self.scratch.k_single,
1305 0,
1306 kv_dim,
1307 );
1308 B::copy_slice(
1309 ctx,
1310 &self.scratch.v_buf,
1311 i * kv_dim,
1312 &mut self.scratch.v_single,
1313 0,
1314 kv_dim,
1315 );
1316
1317 B::qk_norm_rope(
1319 ctx,
1320 &self.scratch.q_single,
1321 q_norm_w,
1322 &self.rope.cos,
1323 &self.rope.sin,
1324 &mut self.scratch.q_head_major_single,
1325 1,
1326 nh,
1327 hd,
1328 pos_i,
1329 eps,
1330 qk_mode,
1331 );
1332 B::qk_norm_rope(
1333 ctx,
1334 &self.scratch.k_single,
1335 k_norm_w,
1336 &self.rope.cos,
1337 &self.rope.sin,
1338 &mut self.scratch.k_head_major_single,
1339 1,
1340 nkv,
1341 hd,
1342 pos_i,
1343 eps,
1344 qk_mode,
1345 );
1346 B::qk_norm_rope(
1347 ctx,
1348 &self.scratch.v_single,
1349 dummy_w,
1350 &self.rope.cos,
1351 &self.rope.sin,
1352 &mut self.scratch.v_head_major_single,
1353 1,
1354 nkv,
1355 hd,
1356 pos_i,
1357 eps,
1358 0,
1359 );
1360
1361 let caches = self
1363 .kv_caches
1364 .get_mut(cache_id)
1365 .expect("ensure_kv must be called before forward_layer_batched");
1366 let cache = &mut caches[li];
1367 B::kv_cache_append_head_major(
1368 ctx,
1369 &mut cache.k,
1370 &mut cache.v,
1371 cache.len,
1372 cache.capacity,
1373 &self.scratch.k_head_major_single,
1374 &self.scratch.v_head_major_single,
1375 1,
1376 nkv,
1377 hd,
1378 );
1379 cache.len += 1;
1380 let kv_len = cache.len;
1381 let kv_stride = cache.capacity;
1382
1383 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1384 num_heads: nh,
1385 num_kv_heads: nkv,
1386 head_dim: hd,
1387 causal: true,
1388 scale: 1.0 / (hd as f32).sqrt(),
1389 kv_seq_stride: kv_stride,
1390 sliding_window: cfg.sliding_window,
1391 };
1392 B::flash_attention(
1393 ctx,
1394 &self.scratch.q_head_major_single,
1395 &cache.k,
1396 &cache.v,
1397 &mut self.scratch.attn_head_major_single,
1398 1,
1399 1,
1400 kv_len,
1401 pos_i,
1402 &attn_cfg,
1403 );
1404
1405 B::transpose_head_to_token(
1407 ctx,
1408 &self.scratch.attn_head_major_single,
1409 &mut self.scratch.attn_flat_single,
1410 1,
1411 nh,
1412 hd,
1413 );
1414
1415 B::copy_slice(
1417 ctx,
1418 &self.scratch.attn_flat_single,
1419 0,
1420 &mut self.scratch.attn_flat,
1421 i * q_dim,
1422 q_dim,
1423 );
1424 }
1425
1426 layer.o_proj.forward(
1428 ctx,
1429 &self.scratch.attn_flat,
1430 &mut self.scratch.o_proj_out,
1431 m,
1432 );
1433
1434 B::fused_add_rms_norm(
1436 ctx,
1437 residual,
1438 &self.scratch.o_proj_out,
1439 &layer.post_ln_w,
1440 eps,
1441 &mut self.scratch.norm_out,
1442 m,
1443 h,
1444 );
1445
1446 layer.gate_up_proj.forward(
1448 ctx,
1449 &self.scratch.norm_out,
1450 &mut self.scratch.gate_up_out,
1451 m,
1452 );
1453
1454 B::fused_silu_mul_split(
1456 ctx,
1457 &self.scratch.gate_up_out,
1458 &mut self.scratch.silu_out,
1459 m,
1460 im,
1461 );
1462
1463 layer
1465 .down_proj
1466 .forward(ctx, &self.scratch.silu_out, &mut self.scratch.mlp_out, m);
1467
1468 B::add_inplace(ctx, residual, &self.scratch.mlp_out, m * h);
1470 }
1471}
1472
1473impl<B: Backend> DecoderOnlyLLM for LlamaFamilyModel<B> {
1474 fn config(&self) -> &LlmRuntimeConfig {
1475 &self.runtime_cfg
1476 }
1477
1478 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1479 self.prefill_internal(cache_id, tokens)
1480 }
1481
1482 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1483 self.decode_internal(cache_id, token, pos)
1484 }
1485
1486 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1487 self.decode_batch_internal(batch)
1488 }
1489
1490 fn release(&mut self, cache_id: &str) {
1491 let mut ctx = B::new_context();
1497 B::sync(&mut ctx);
1498 B::reset_graph(&mut ctx);
1499 B::sync(&mut ctx);
1500 self.graph_warmup = 0;
1501 self.graph_capture_failed = false;
1502
1503 if let Some(caches) = self.kv_caches.remove(cache_id) {
1506 self.kv_free_pool.push(caches);
1507 }
1508 }
1509
1510 fn reset(&mut self) {
1511 let mut ctx = B::new_context();
1513 B::sync(&mut ctx);
1514 B::reset_graph(&mut ctx);
1515 B::sync(&mut ctx);
1516 self.graph_warmup = 0;
1517 self.graph_capture_failed = false;
1518 self.kv_caches.clear();
1519 self.kv_free_pool.clear();
1520 }
1521}
1522
1523fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
1524 let hd = cfg.head_dim;
1525 let half = hd / 2;
1526 let max = cfg.max_seq_len;
1527 let mut cos = vec![0.0f32; max * half];
1528 let mut sin = vec![0.0f32; max * half];
1529 for pos in 0..max {
1530 for i in 0..half {
1531 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
1532 let angle = pos as f64 * freq;
1533 cos[pos * half + i] = angle.cos() as f32;
1534 sin[pos * half + i] = angle.sin() as f32;
1535 }
1536 }
1537 RopeCache {
1538 cos: B::from_slice(&cos),
1539 sin: B::from_slice(&sin),
1540 }
1541}