1use std::collections::HashMap;
20use std::sync::atomic::AtomicU64;
21
22use ferrum_kernels::backend::{Backend, KvCache};
23
24static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
25static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
26static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
27static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
28static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
29static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
30static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
31static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
32static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
33static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
34use ferrum_quantization::{Linear, WeightLoader};
35use ferrum_types::Result;
36
37use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
38
39#[derive(Clone, Debug, PartialEq)]
42pub struct LlamaFamilyConfig {
43 pub hidden_size: usize,
44 pub intermediate_size: usize,
45 pub num_heads: usize,
46 pub num_kv_heads: usize,
47 pub head_dim: usize,
48 pub num_layers: usize,
49 pub vocab_size: usize,
50 pub max_seq_len: usize,
51 pub rms_norm_eps: f32,
52 pub rope_theta: f64,
53 pub has_qk_norm: bool,
56 pub sliding_window: usize,
59}
60
61impl LlamaFamilyConfig {
62 pub fn to_runtime(&self) -> LlmRuntimeConfig {
63 LlmRuntimeConfig {
64 hidden_size: self.hidden_size,
65 num_layers: self.num_layers,
66 num_kv_heads: self.num_kv_heads,
67 head_dim: self.head_dim,
68 vocab_size: self.vocab_size,
69 max_seq_len: self.max_seq_len,
70 }
71 }
72
73 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
77 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
78 let head_dim = def
79 .extra_params
80 .get("head_dim")
81 .and_then(|v| v.as_u64())
82 .map(|v| v as usize)
83 .unwrap_or(def.hidden_size / def.num_attention_heads);
84 let sliding_window = def
87 .extra_params
88 .get("sliding_window")
89 .and_then(|v| v.as_u64())
90 .map(|v| v as usize)
91 .unwrap_or(0);
92
93 LlamaFamilyConfigBase {
94 hidden_size: def.hidden_size,
95 intermediate_size: def.intermediate_size,
96 num_heads: def.num_attention_heads,
97 num_kv_heads,
98 head_dim,
99 num_layers: def.num_hidden_layers,
100 vocab_size: def.vocab_size,
101 max_seq_len: def.max_position_embeddings,
102 rms_norm_eps: def.norm_eps as f32,
103 rope_theta_opt: def.rope_theta,
104 sliding_window,
105 }
106 }
107
108 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
109 Self {
110 hidden_size: b.hidden_size,
111 intermediate_size: b.intermediate_size,
112 num_heads: b.num_heads,
113 num_kv_heads: b.num_kv_heads,
114 head_dim: b.head_dim,
115 num_layers: b.num_layers,
116 vocab_size: b.vocab_size,
117 max_seq_len: b.max_seq_len,
118 rms_norm_eps: b.rms_norm_eps,
119 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
120 has_qk_norm,
121 sliding_window: b.sliding_window,
122 }
123 }
124
125 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
127 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
128 }
129
130 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
134 Self::from_base(Self::from_def_base(def), 500_000.0, false)
135 }
136
137 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
139 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
140 }
141
142 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
146 Self::from_base(Self::from_def_base(def), 10_000.0, false)
147 }
148}
149
150struct LlamaFamilyConfigBase {
151 hidden_size: usize,
152 intermediate_size: usize,
153 num_heads: usize,
154 num_kv_heads: usize,
155 head_dim: usize,
156 num_layers: usize,
157 vocab_size: usize,
158 max_seq_len: usize,
159 rms_norm_eps: f32,
160 rope_theta_opt: Option<f64>,
161 sliding_window: usize,
162}
163
164pub struct LlamaFamilyLayer<B: Backend> {
167 pub input_ln_w: B::Buffer,
168 pub qkv_proj: Box<dyn Linear<B>>,
169 pub q_norm_w: Option<B::Buffer>,
171 pub k_norm_w: Option<B::Buffer>,
172 pub o_proj: Box<dyn Linear<B>>,
173 pub post_ln_w: B::Buffer,
174 pub gate_up_proj: Box<dyn Linear<B>>,
175 pub down_proj: Box<dyn Linear<B>>,
176}
177
178pub struct RopeCache<B: Backend> {
180 pub cos: B::Buffer,
181 pub sin: B::Buffer,
182}
183
184pub struct LlamaFamilyScratch<B: Backend> {
190 pub residual: Option<B::Buffer>,
201 pub norm_out: B::Buffer,
202 pub qkv_out: B::Buffer,
203 pub q_single: B::Buffer,
211 pub k_single: B::Buffer,
212 pub v_single: B::Buffer,
213 pub q_head_major_single: B::Buffer,
214 pub k_head_major_single: B::Buffer,
215 pub v_head_major_single: B::Buffer,
216 pub attn_head_major_single: B::Buffer,
217 pub attn_flat_single: B::Buffer,
218 pub batch_logits: B::Buffer,
221 pub q_buf: B::Buffer,
223 pub k_buf: B::Buffer,
224 pub v_buf: B::Buffer,
225 pub q_head_major: B::Buffer,
227 pub k_head_major: B::Buffer,
230 pub v_head_major: B::Buffer,
231 pub attn_head_major_out: B::Buffer,
233 pub attn_flat: B::Buffer,
235 pub o_proj_out: B::Buffer,
236 pub gate_up_out: B::Buffer,
237 pub silu_out: B::Buffer,
238 pub mlp_out: B::Buffer,
239 pub paged_batch_q: Option<B::Buffer>,
245 pub paged_batch_o: Option<B::Buffer>,
246 pub paged_batch_block_tables: Option<B::Buffer>,
250 pub paged_batch_context_lens: Option<B::Buffer>,
253 pub paged_max_blocks_per_seq: usize,
256 pub last_hidden: B::Buffer,
260 pub last_normed: B::Buffer,
262 pub logits: B::Buffer,
264 pub max_tokens: usize,
266}
267
268impl<B: Backend> LlamaFamilyScratch<B> {
269 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
270 let h = cfg.hidden_size;
271 let im = cfg.intermediate_size;
272 let q_dim = cfg.num_heads * cfg.head_dim;
273 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
274 let qkv_dim = q_dim + 2 * kv_dim;
275 let t = max_tokens;
276 Self {
277 residual: Some(B::alloc(t * h)),
278 norm_out: B::alloc(t * h),
279 qkv_out: B::alloc(t * qkv_dim),
280 q_buf: B::alloc(t * q_dim),
281 k_buf: B::alloc(t * kv_dim),
282 v_buf: B::alloc(t * kv_dim),
283 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
284 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
285 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
286 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
287 attn_flat: B::alloc(t * q_dim),
288 o_proj_out: B::alloc(t * h),
289 gate_up_out: B::alloc(t * 2 * im),
290 silu_out: B::alloc(t * im),
291 mlp_out: B::alloc(t * h),
292 last_hidden: B::alloc(h),
293 last_normed: B::alloc(h),
294 logits: B::alloc(cfg.vocab_size),
295 q_single: B::alloc(q_dim),
296 k_single: B::alloc(kv_dim),
297 v_single: B::alloc(kv_dim),
298 q_head_major_single: B::alloc(q_dim),
299 k_head_major_single: B::alloc(kv_dim),
300 v_head_major_single: B::alloc(kv_dim),
301 attn_head_major_single: B::alloc(q_dim),
302 attn_flat_single: B::alloc(q_dim),
303 batch_logits: B::alloc(t * cfg.vocab_size),
304 paged_batch_q: None,
309 paged_batch_o: None,
310 paged_batch_block_tables: None,
311 paged_batch_context_lens: None,
312 paged_max_blocks_per_seq: 0,
313 max_tokens: t,
314 }
315 }
316
317 fn enable_paged_batch(
321 &mut self,
322 cfg: &LlamaFamilyConfig,
323 max_seqs: usize,
324 max_blocks_per_seq: usize,
325 ) {
326 if self.paged_batch_q.is_some() {
327 return;
328 }
329 let q_dim = cfg.num_heads * cfg.head_dim;
330 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
331 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
332 self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
333 self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
334 self.paged_max_blocks_per_seq = max_blocks_per_seq;
335 }
336}
337
338pub struct LlamaFamilyModel<B: Backend> {
342 pub cfg: LlamaFamilyConfig,
343 pub runtime_cfg: LlmRuntimeConfig,
344
345 pub embed: Option<B::Buffer>,
349 pub layers: Vec<LlamaFamilyLayer<B>>,
350 pub final_norm_w: B::Buffer,
351 pub lm_head: Option<Box<dyn Linear<B>>>,
353
354 pub rope: RopeCache<B>,
355 pub scratch: LlamaFamilyScratch<B>,
356
357 pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
366 kv_free_pool: Vec<Vec<KvCache<B>>>,
371
372 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
384 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
388
389 graph_warmup: usize,
393 graph_capture_failed: bool,
396}
397
398impl<B: Backend> LlamaFamilyModel<B> {
399 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
404 {
409 let mut ctx = B::new_context();
410 B::reset_graph(&mut ctx);
411 }
412 let rope = build_rope_cache::<B>(&cfg);
413 let scratch = LlamaFamilyScratch::alloc(&cfg, 1); let embed = loader.load_tensor("model.embed_tokens.weight")?;
417
418 let mut layers = Vec::with_capacity(cfg.num_layers);
420 for li in 0..cfg.num_layers {
421 let prefix = format!("model.layers.{li}");
422 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
423 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
424 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
425 let post_ln_w =
426 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
427 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
428 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
429
430 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
431 let q = loader
432 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
433 .ok();
434 let k = loader
435 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
436 .ok();
437 (q, k)
438 } else {
439 (None, None)
440 };
441
442 layers.push(LlamaFamilyLayer {
443 input_ln_w,
444 qkv_proj,
445 q_norm_w,
446 k_norm_w,
447 o_proj,
448 post_ln_w,
449 gate_up_proj,
450 down_proj,
451 });
452 }
453
454 let final_norm_w = loader.load_tensor("model.norm.weight")?;
455
456 let lm_head = if loader.has_tensor("lm_head.weight") {
464 loader.load_linear("lm_head")?
465 } else {
466 tracing::info!(
467 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
468 );
469 let as_linear = loader.load_linear("model.embed_tokens")?;
470 if as_linear.out_features() != cfg.vocab_size
472 || as_linear.in_features() != cfg.hidden_size
473 {
474 return Err(ferrum_types::FerrumError::model(format!(
475 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
476 as_linear.out_features(),
477 as_linear.in_features(),
478 cfg.vocab_size,
479 cfg.hidden_size
480 )));
481 }
482 as_linear
483 };
484
485 let runtime_cfg = cfg.to_runtime();
486 Ok(Self {
487 cfg,
488 runtime_cfg,
489 embed: Some(embed),
490 layers,
491 final_norm_w,
492 lm_head: Some(lm_head),
493 rope,
494 scratch,
495 kv_caches: HashMap::new(),
496 kv_free_pool: Vec::new(),
497 paged_pools: None,
498 paged_block_alloc: None,
499 graph_warmup: 0,
500 graph_capture_failed: false,
501 })
502 }
503
504 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
516 {
518 let mut ctx = B::new_context();
519 B::reset_graph(&mut ctx);
520 }
521 let rope = build_rope_cache::<B>(&cfg);
522 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
523
524 let mut layers = Vec::with_capacity(cfg.num_layers);
525 for li in 0..cfg.num_layers {
526 let prefix = format!("model.layers.{li}");
527 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
528 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
529 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
530 let post_ln_w =
531 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
532 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
533 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
534
535 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
536 let q = loader
537 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
538 .ok();
539 let k = loader
540 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
541 .ok();
542 (q, k)
543 } else {
544 (None, None)
545 };
546
547 layers.push(LlamaFamilyLayer {
548 input_ln_w,
549 qkv_proj,
550 q_norm_w,
551 k_norm_w,
552 o_proj,
553 post_ln_w,
554 gate_up_proj,
555 down_proj,
556 });
557 }
558
559 let final_norm_w = loader.load_tensor("model.norm.weight")?;
560
561 let runtime_cfg = cfg.to_runtime();
562 Ok(Self {
563 cfg,
564 runtime_cfg,
565 embed: None,
566 layers,
567 final_norm_w,
568 lm_head: None,
569 rope,
570 scratch,
571 kv_caches: HashMap::new(),
572 kv_free_pool: Vec::new(),
573 paged_pools: None,
574 paged_block_alloc: None,
575 graph_warmup: 0,
576 graph_capture_failed: false,
577 })
578 }
579
580 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
582 if self.scratch.max_tokens < tokens {
583 {
588 let mut ctx = B::new_context();
589 B::reset_graph(&mut ctx);
590 }
591 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
592 self.graph_warmup = 0;
593 self.graph_capture_failed = false;
594 }
595 }
596
597 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
601 if self.kv_caches.contains_key(cache_id) {
602 return;
603 }
604 let nkv = self.cfg.num_kv_heads;
605 let hd = self.cfg.head_dim;
606 let model_max = self.cfg.max_seq_len;
613 const DEFAULT_KV_CAPACITY: usize = 512;
622 let max = std::env::var("FERRUM_KV_CAPACITY")
623 .ok()
624 .and_then(|s| s.parse::<usize>().ok())
625 .map(|cap| cap.min(model_max))
626 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
627
628 let paged = std::env::var("FERRUM_METAL_PAGED_KV")
644 .map(|v| v != "0")
645 .unwrap_or_else(|_| B::supports_paged_kv());
646 const PAGED_BLOCK_SIZE: usize = 16;
647
648 let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
656 .ok()
657 .and_then(|s| s.parse::<usize>().ok())
658 .unwrap_or(32);
659 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
660 let total_pool_blocks = max_seqs * max_blocks_per_seq;
661
662 if paged && self.paged_pools.is_none() {
669 let mut pools = Vec::with_capacity(self.cfg.num_layers);
670 for _ in 0..self.cfg.num_layers {
671 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
672 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
673 }
674 self.paged_pools = Some(pools);
675 self.paged_block_alloc = Some(std::sync::Mutex::new(
676 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
677 ));
678 }
679 if paged {
685 self.scratch
686 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
687 }
688
689 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
692 (0..self.cfg.num_layers)
693 .map(|_| {
694 if paged {
695 let mut block_table = B::alloc_u32(max_blocks_per_seq);
701 let mut context_lens = B::alloc_u32(1);
702 let mut bt_ctx = B::new_context();
703 B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
704 B::sync(&mut bt_ctx);
705 KvCache {
706 k: B::alloc(1),
707 v: B::alloc(1),
708 len: 0,
709 capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
710 num_kv_heads: nkv,
711 head_dim: hd,
712 block_size: PAGED_BLOCK_SIZE,
713 block_table: Some(block_table),
714 context_lens: Some(context_lens),
715 paged_block_indices: Vec::new(),
716 }
717 } else {
718 KvCache {
719 k: B::alloc(nkv * max * hd),
720 v: B::alloc(nkv * max * hd),
721 len: 0,
722 capacity: max,
723 num_kv_heads: nkv,
724 head_dim: hd,
725 block_size: 0,
726 block_table: None,
727 context_lens: None,
728 paged_block_indices: Vec::new(),
729 }
730 }
731 })
732 .collect()
733 });
734
735 if paged {
741 let alloc_arc = self
742 .paged_block_alloc
743 .as_ref()
744 .expect("paged_block_alloc must be initialised when paged=true");
745 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
749 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
750 Ok(idx) => idx,
751 Err(e) => {
752 drop(alloc);
759 self.kv_free_pool.push(caches);
760 eprintln!(
761 "[ferrum] paged KV pool exhausted on ensure_kv for \
762 cache_id={cache_id:?}: {e}. Increase \
763 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
764 throttle concurrent requests.",
765 );
766 return;
767 }
768 };
769 let mut padded = block_indices.clone();
774 padded.resize(max_blocks_per_seq, 0);
775 let mut ctx_tmp = B::new_context();
776 for c in caches.iter_mut() {
777 if let Some(bt) = c.block_table.as_mut() {
778 B::write_u32(&mut ctx_tmp, bt, &padded);
779 }
780 c.paged_block_indices = block_indices.clone();
781 }
782 B::sync(&mut ctx_tmp);
783 }
784
785 for c in caches.iter_mut() {
789 c.len = 0;
790 if let Some(cl) = c.context_lens.as_mut() {
791 let mut ctx_tmp = B::new_context();
792 B::write_u32(&mut ctx_tmp, cl, &[0u32]);
793 B::sync(&mut ctx_tmp);
794 }
795 }
796 self.kv_caches.insert(cache_id.to_string(), caches);
797 }
798
799 #[allow(clippy::too_many_arguments)]
804 pub(crate) fn forward_layer(
805 &mut self,
806 ctx: &mut B::Context,
807 li: usize,
808 cache_id: &str,
809 residual: &mut B::Buffer,
810 pos_offset: usize,
811 tokens: usize,
812 ) {
813 let layer = &self.layers[li];
814 let cfg = &self.cfg;
815 let h = cfg.hidden_size;
816 let nh = cfg.num_heads;
817 let nkv = cfg.num_kv_heads;
818 let hd = cfg.head_dim;
819 let im = cfg.intermediate_size;
820 let eps = cfg.rms_norm_eps;
821 let q_dim = nh * hd;
822 let kv_dim = nkv * hd;
823
824 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
826 B::sync(ctx);
827 Some(std::time::Instant::now())
828 } else {
829 None
830 };
831 B::rms_norm(
832 ctx,
833 residual,
834 &layer.input_ln_w,
835 eps,
836 &mut self.scratch.norm_out,
837 tokens,
838 h,
839 );
840 if let Some(t0) = _t0 {
841 B::sync(ctx);
842 NORM_TIME_US.fetch_add(
843 t0.elapsed().as_micros() as u64,
844 std::sync::atomic::Ordering::Relaxed,
845 );
846 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
847 }
848
849 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
851 B::sync(ctx);
852 Some(std::time::Instant::now())
853 } else {
854 None
855 };
856 layer.qkv_proj.forward(
857 ctx,
858 &self.scratch.norm_out,
859 &mut self.scratch.qkv_out,
860 tokens,
861 );
862 if let Some(t0) = _t0 {
863 B::sync(ctx);
864 MATMUL_TIME_US.fetch_add(
865 t0.elapsed().as_micros() as u64,
866 std::sync::atomic::Ordering::Relaxed,
867 );
868 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
869 }
870
871 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
884 let dummy = &layer.input_ln_w;
885 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
886 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
887
888 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
899 if let Some(pools) = self.paged_pools.as_mut() {
900 let pool = &mut pools[li];
901 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
902 } else {
903 None
904 };
905 let caches = self
906 .kv_caches
907 .get_mut(cache_id)
908 .expect("ensure_kv must be called before forward_layer");
909 let cache = &mut caches[li];
910 let cache_len_before = cache.len;
911 let cache_capacity = cache.capacity;
912
913 if cache_len_before + tokens > cache_capacity {
919 panic!(
920 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
921 cache_len_before + tokens
922 );
923 }
924
925 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
926 B::sync(ctx);
927 Some(std::time::Instant::now())
928 } else {
929 None
930 };
931 let used_qkv_into_cache = if cache.block_size > 0 {
936 let bt = cache
937 .block_table
938 .as_ref()
939 .expect("paged cache missing block_table");
940 let num_blocks_per_seq = cache.capacity / cache.block_size;
941 let (pool_k_ptr, pool_v_ptr) =
943 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
944 let pool_k = unsafe { &mut *pool_k_ptr };
947 let pool_v = unsafe { &mut *pool_v_ptr };
948 B::split_qkv_norm_rope_into_paged_cache(
949 ctx,
950 &self.scratch.qkv_out,
951 0, q_norm_w,
953 k_norm_w,
954 &self.rope.cos,
955 &self.rope.sin,
956 &mut self.scratch.q_head_major,
957 0, pool_k,
959 pool_v,
960 bt,
961 tokens,
962 nh,
963 nkv,
964 hd,
965 pos_offset,
966 eps,
967 qk_mode,
968 cache_len_before,
969 cache.block_size,
970 num_blocks_per_seq,
971 )
972 .is_ok()
973 } else {
974 B::split_qkv_norm_rope_into_cache(
975 ctx,
976 &self.scratch.qkv_out,
977 q_norm_w,
978 k_norm_w,
979 &self.rope.cos,
980 &self.rope.sin,
981 &mut self.scratch.q_head_major,
982 &mut cache.k,
983 &mut cache.v,
984 tokens,
985 nh,
986 nkv,
987 hd,
988 pos_offset,
989 eps,
990 qk_mode,
991 cache_len_before,
992 cache_capacity,
993 )
994 .is_ok()
995 };
996 if !used_qkv_into_cache {
997 let used_fused_qkv = B::split_qkv_norm_rope(
1000 ctx,
1001 &self.scratch.qkv_out,
1002 q_norm_w,
1003 k_norm_w,
1004 &self.rope.cos,
1005 &self.rope.sin,
1006 &mut self.scratch.q_head_major,
1007 &mut self.scratch.k_head_major,
1008 &mut self.scratch.v_head_major,
1009 tokens,
1010 nh,
1011 nkv,
1012 hd,
1013 pos_offset,
1014 eps,
1015 qk_mode,
1016 )
1017 .is_ok();
1018 if !used_fused_qkv {
1019 B::split_qkv(
1021 ctx,
1022 &self.scratch.qkv_out,
1023 &mut self.scratch.q_buf,
1024 &mut self.scratch.k_buf,
1025 &mut self.scratch.v_buf,
1026 tokens,
1027 q_dim,
1028 kv_dim,
1029 );
1030 B::qk_norm_rope(
1031 ctx,
1032 &self.scratch.q_buf,
1033 q_norm_w,
1034 &self.rope.cos,
1035 &self.rope.sin,
1036 &mut self.scratch.q_head_major,
1037 tokens,
1038 nh,
1039 hd,
1040 pos_offset,
1041 eps,
1042 qk_mode,
1043 );
1044 B::qk_norm_rope(
1045 ctx,
1046 &self.scratch.k_buf,
1047 k_norm_w,
1048 &self.rope.cos,
1049 &self.rope.sin,
1050 &mut self.scratch.k_head_major,
1051 tokens,
1052 nkv,
1053 hd,
1054 pos_offset,
1055 eps,
1056 qk_mode,
1057 );
1058 B::qk_norm_rope(
1059 ctx,
1060 &self.scratch.v_buf,
1061 dummy,
1062 &self.rope.cos,
1063 &self.rope.sin,
1064 &mut self.scratch.v_head_major,
1065 tokens,
1066 nkv,
1067 hd,
1068 pos_offset,
1069 eps,
1070 0,
1071 );
1072 }
1073 B::kv_cache_append_head_major(
1074 ctx,
1075 &mut cache.k,
1076 &mut cache.v,
1077 cache.len,
1078 cache.capacity,
1079 &self.scratch.k_head_major,
1080 &self.scratch.v_head_major,
1081 tokens,
1082 nkv,
1083 hd,
1084 );
1085 }
1086 if let Some(t0) = _t0 {
1087 B::sync(ctx);
1088 QKR_TIME_US.fetch_add(
1089 t0.elapsed().as_micros() as u64,
1090 std::sync::atomic::Ordering::Relaxed,
1091 );
1092 QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1093 }
1094 cache.len += tokens;
1095 let kv_len = cache.len;
1096 let kv_stride = cache.capacity;
1097
1098 let _attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1105 B::sync(ctx);
1106 Some(std::time::Instant::now())
1107 } else {
1108 None
1109 };
1110 if cache.block_size > 0 {
1111 let bt = cache
1112 .block_table
1113 .as_ref()
1114 .expect("paged cache missing block_table");
1115 let cl_buf = cache
1116 .context_lens
1117 .as_mut()
1118 .expect("paged cache missing context_lens");
1119 let num_blocks_per_seq = cache.capacity / cache.block_size;
1120 let (pool_k_ptr, pool_v_ptr) =
1122 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1123 let pool_k = unsafe { &*pool_k_ptr };
1126 let pool_v = unsafe { &*pool_v_ptr };
1127 let final_kv_len = cache.len as u32;
1133 B::write_u32(ctx, cl_buf, &[final_kv_len]);
1134 B::paged_decode_attention(
1135 ctx,
1136 &self.scratch.q_head_major,
1137 pool_k,
1138 pool_v,
1139 &mut self.scratch.attn_head_major_out,
1140 bt,
1141 cl_buf,
1142 1, nh,
1144 nkv,
1145 hd,
1146 cache.block_size,
1147 num_blocks_per_seq,
1148 tokens, )
1150 .expect("paged_decode_attention");
1151 } else {
1152 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1156 num_heads: nh,
1157 num_kv_heads: nkv,
1158 head_dim: hd,
1159 causal: true,
1160 scale: 1.0 / (hd as f32).sqrt(),
1161 kv_seq_stride: kv_stride,
1162 sliding_window: cfg.sliding_window,
1163 };
1164 B::flash_attention(
1165 ctx,
1166 &self.scratch.q_head_major,
1167 &cache.k,
1168 &cache.v,
1169 &mut self.scratch.attn_head_major_out,
1170 1,
1171 tokens,
1172 kv_len,
1173 pos_offset,
1174 &attn_cfg,
1175 );
1176 }
1177 if let Some(t0) = _attn_t0 {
1178 B::sync(ctx);
1179 ATTN_TIME_US.fetch_add(
1180 t0.elapsed().as_micros() as u64,
1181 std::sync::atomic::Ordering::Relaxed,
1182 );
1183 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1184 }
1185
1186 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1193 B::sync(ctx);
1194 Some(std::time::Instant::now())
1195 } else {
1196 None
1197 };
1198 let attn_token_major = if tokens == 1 {
1199 &self.scratch.attn_head_major_out
1200 } else {
1201 B::transpose_head_to_token(
1202 ctx,
1203 &self.scratch.attn_head_major_out,
1204 &mut self.scratch.attn_flat,
1205 tokens,
1206 nh,
1207 hd,
1208 );
1209 &self.scratch.attn_flat
1210 };
1211 if let Some(t0) = _t0 {
1212 B::sync(ctx);
1213 OTHER_TIME_US.fetch_add(
1214 t0.elapsed().as_micros() as u64,
1215 std::sync::atomic::Ordering::Relaxed,
1216 );
1217 OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1218 }
1219
1220 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1222 B::sync(ctx);
1223 Some(std::time::Instant::now())
1224 } else {
1225 None
1226 };
1227 layer
1228 .o_proj
1229 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1230 if let Some(t0) = _t0 {
1231 B::sync(ctx);
1232 MATMUL_TIME_US.fetch_add(
1233 t0.elapsed().as_micros() as u64,
1234 std::sync::atomic::Ordering::Relaxed,
1235 );
1236 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1237 }
1238
1239 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1243 B::sync(ctx);
1244 Some(std::time::Instant::now())
1245 } else {
1246 None
1247 };
1248 B::fused_add_rms_norm(
1249 ctx,
1250 residual,
1251 &self.scratch.o_proj_out,
1252 &layer.post_ln_w,
1253 eps,
1254 &mut self.scratch.norm_out,
1255 tokens,
1256 h,
1257 );
1258 if let Some(t0) = _t0 {
1259 B::sync(ctx);
1260 NORM_TIME_US.fetch_add(
1261 t0.elapsed().as_micros() as u64,
1262 std::sync::atomic::Ordering::Relaxed,
1263 );
1264 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1265 }
1266
1267 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1269 B::sync(ctx);
1270 Some(std::time::Instant::now())
1271 } else {
1272 None
1273 };
1274 layer.gate_up_proj.forward(
1275 ctx,
1276 &self.scratch.norm_out,
1277 &mut self.scratch.gate_up_out,
1278 tokens,
1279 );
1280 if let Some(t0) = _t0 {
1281 B::sync(ctx);
1282 MATMUL_TIME_US.fetch_add(
1283 t0.elapsed().as_micros() as u64,
1284 std::sync::atomic::Ordering::Relaxed,
1285 );
1286 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1287 }
1288
1289 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1291 B::sync(ctx);
1292 Some(std::time::Instant::now())
1293 } else {
1294 None
1295 };
1296 B::fused_silu_mul_split(
1297 ctx,
1298 &self.scratch.gate_up_out,
1299 &mut self.scratch.silu_out,
1300 tokens,
1301 im,
1302 );
1303 if let Some(t0) = _t0 {
1304 B::sync(ctx);
1305 OTHER_TIME_US.fetch_add(
1306 t0.elapsed().as_micros() as u64,
1307 std::sync::atomic::Ordering::Relaxed,
1308 );
1309 OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1310 }
1311
1312 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1314 B::sync(ctx);
1315 Some(std::time::Instant::now())
1316 } else {
1317 None
1318 };
1319 layer.down_proj.forward(
1320 ctx,
1321 &self.scratch.silu_out,
1322 &mut self.scratch.mlp_out,
1323 tokens,
1324 );
1325 if let Some(t0) = _t0 {
1326 B::sync(ctx);
1327 MATMUL_TIME_US.fetch_add(
1328 t0.elapsed().as_micros() as u64,
1329 std::sync::atomic::Ordering::Relaxed,
1330 );
1331 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1332 }
1333
1334 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1336 B::sync(ctx);
1337 Some(std::time::Instant::now())
1338 } else {
1339 None
1340 };
1341 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1342 if let Some(t0) = _t0 {
1343 B::sync(ctx);
1344 OTHER_TIME_US.fetch_add(
1345 t0.elapsed().as_micros() as u64,
1346 std::sync::atomic::Ordering::Relaxed,
1347 );
1348 OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1349 }
1350 }
1351
1352 pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1364 let seq_len = tokens.len();
1365 assert!(seq_len > 0, "forward_verify called with empty tokens");
1366 self.ensure_scratch(seq_len);
1367 self.ensure_kv(cache_id);
1368
1369 let h = self.cfg.hidden_size;
1370 let vocab = self.cfg.vocab_size;
1371
1372 let pos_offset = self
1373 .kv_caches
1374 .get(cache_id)
1375 .and_then(|layers| layers.first())
1376 .map(|c| c.len)
1377 .unwrap_or(0);
1378
1379 let mut ctx = B::new_context();
1380 let mut residual = self
1381 .scratch
1382 .residual
1383 .take()
1384 .expect("scratch residual missing (previous call didn't restore)");
1385
1386 let embed = self
1387 .embed
1388 .as_ref()
1389 .expect("forward_verify called on backbone-only model (no embed)");
1390 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1391
1392 for li in 0..self.cfg.num_layers {
1393 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1394 }
1395
1396 B::rms_norm(
1399 &mut ctx,
1400 &residual,
1401 &self.final_norm_w,
1402 self.cfg.rms_norm_eps,
1403 &mut self.scratch.norm_out,
1404 seq_len,
1405 h,
1406 );
1407
1408 let lm_head = self
1412 .lm_head
1413 .as_ref()
1414 .expect("forward_verify called on backbone-only model (no lm_head)");
1415 lm_head.forward(
1416 &mut ctx,
1417 &self.scratch.norm_out,
1418 &mut self.scratch.batch_logits,
1419 seq_len,
1420 );
1421
1422 B::sync(&mut ctx);
1423 self.scratch.residual = Some(residual);
1424 B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1425 }
1426
1427 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1435 let seq_len = tokens.len();
1436 assert!(seq_len > 0, "prefill called with empty token list");
1437 self.ensure_scratch(seq_len);
1438 self.ensure_kv(cache_id);
1439
1440 let pos_offset = self
1443 .kv_caches
1444 .get(cache_id)
1445 .and_then(|layers| layers.first())
1446 .map(|c| c.len)
1447 .unwrap_or(0);
1448
1449 let h = self.cfg.hidden_size;
1450 let vocab = self.cfg.vocab_size;
1451 let mut ctx = B::new_context();
1452
1453 let mut residual = self
1460 .scratch
1461 .residual
1462 .take()
1463 .expect("scratch residual missing (previous call didn't restore)");
1464 let embed = self
1465 .embed
1466 .as_ref()
1467 .expect("prefill_internal called on backbone-only model (no embed)");
1468 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1469
1470 let prefill_profile = std::env::var("FERRUM_PREFILL_OP_PROFILE").is_ok();
1471 let prefill_t0 = if prefill_profile {
1472 B::sync(&mut ctx);
1473 Some(std::time::Instant::now())
1474 } else {
1475 None
1476 };
1477
1478 for li in 0..self.cfg.num_layers {
1479 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1480 }
1481
1482 if let Some(t0) = prefill_t0 {
1483 B::sync(&mut ctx);
1484 let total_us = t0.elapsed().as_micros() as u64;
1485 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1486 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1487 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1488 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1489 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1490 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1491 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1492 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1493 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1494 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1495 eprintln!(
1496 "[prefill-profile] tokens={} layers total={} ms",
1497 seq_len,
1498 total_us / 1000
1499 );
1500 let bucket = |label: &str, n: u64, us: u64| {
1501 if n > 0 {
1502 eprintln!(
1503 "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1504 n,
1505 us / 1000,
1506 us / n
1507 );
1508 }
1509 };
1510 bucket("flash_attn", attn_n, attn_us);
1511 bucket("qk_norm_rope", qkr_n, qkr_us);
1512 bucket("matmuls", mm_n, mm_us);
1513 bucket("norms", norm_n, norm_us);
1514 bucket("other", other_n, other_us);
1515 }
1516
1517 B::copy_slice(
1519 &mut ctx,
1520 &residual,
1521 (seq_len - 1) * h,
1522 &mut self.scratch.last_hidden,
1523 0,
1524 h,
1525 );
1526
1527 B::rms_norm(
1529 &mut ctx,
1530 &self.scratch.last_hidden,
1531 &self.final_norm_w,
1532 self.cfg.rms_norm_eps,
1533 &mut self.scratch.last_normed,
1534 1,
1535 h,
1536 );
1537
1538 let lm_head = self
1540 .lm_head
1541 .as_ref()
1542 .expect("prefill_internal called on backbone-only model (no lm_head)");
1543 lm_head.forward(
1544 &mut ctx,
1545 &self.scratch.last_normed,
1546 &mut self.scratch.logits,
1547 1,
1548 );
1549
1550 B::sync(&mut ctx);
1557
1558 self.scratch.residual = Some(residual);
1560
1561 B::to_vec(&self.scratch.logits, vocab)
1562 }
1563
1564 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1566 self.ensure_scratch(1);
1567 self.ensure_kv(cache_id);
1568
1569 let h = self.cfg.hidden_size;
1570 let vocab = self.cfg.vocab_size;
1571
1572 let mut ctx = B::new_context();
1575
1576 const GRAPH_WARMUP: usize = 3;
1581 let graph_enabled = std::env::var("FERRUM_CUDA_GRAPH").is_ok();
1582
1583 if graph_enabled {
1584 B::set_decode_state(&mut ctx, token, pos);
1587
1588 match B::replay_last_graph(&mut ctx) {
1590 Ok(true) => {
1591 B::sync(&mut ctx);
1592 return B::to_vec(&self.scratch.logits, vocab);
1593 }
1594 Ok(false) => { }
1595 Err(_) => { }
1596 }
1597 }
1598
1599 let should_capture =
1600 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1601
1602 if should_capture {
1603 B::set_dev_state_mode(&mut ctx, true);
1604 if B::begin_graph_capture(&mut ctx).is_err() {
1605 self.graph_capture_failed = true;
1606 B::set_dev_state_mode(&mut ctx, false);
1607 }
1608 }
1609
1610 let mut residual = self
1616 .scratch
1617 .residual
1618 .take()
1619 .expect("scratch residual missing (previous call didn't restore)");
1620 let embed = self
1621 .embed
1622 .as_ref()
1623 .expect("decode_internal called on backbone-only model (no embed)");
1624 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1625
1626 let layer_profile = std::env::var("FERRUM_DECODE_LAYER_PROFILE").is_ok();
1630 let mut layer_times = if layer_profile {
1631 Some(Vec::with_capacity(self.cfg.num_layers))
1632 } else {
1633 None
1634 };
1635
1636 for li in 0..self.cfg.num_layers {
1637 if layer_profile {
1638 let t0 = std::time::Instant::now();
1639 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1640 B::sync(&mut ctx);
1641 let elapsed_us = t0.elapsed().as_micros() as u64;
1642 if let Some(v) = layer_times.as_mut() {
1643 v.push(elapsed_us);
1644 }
1645 } else {
1646 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1647 }
1648 }
1649 if let Some(times) = layer_times.take() {
1650 let sum: u64 = times.iter().sum();
1651 let avg = sum / times.len() as u64;
1652 let mn = *times.iter().min().unwrap_or(&0);
1653 let mx = *times.iter().max().unwrap_or(&0);
1654 eprintln!(
1655 "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1656 times.len(),
1657 sum / 1000,
1658 avg,
1659 mn,
1660 mx
1661 );
1662 for (i, t) in times.iter().enumerate() {
1663 eprint!("L{i}={}ms ", t / 1000);
1664 if (i + 1) % 6 == 0 {
1665 eprintln!();
1666 }
1667 }
1668 eprintln!();
1669 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1670 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1671 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1672 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1673 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1674 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1675 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1676 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1677 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1678 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1679 eprintln!(
1680 "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1681 attn_n,
1682 attn_us / 1000,
1683 if attn_n > 0 { attn_us / attn_n } else { 0 }
1684 );
1685 eprintln!(
1686 "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1687 qkr_n,
1688 qkr_us / 1000,
1689 if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1690 );
1691 eprintln!(
1692 "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1693 mm_n,
1694 mm_us / 1000,
1695 if mm_n > 0 { mm_us / mm_n } else { 0 }
1696 );
1697 eprintln!(
1698 "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1699 norm_n,
1700 norm_us / 1000,
1701 if norm_n > 0 { norm_us / norm_n } else { 0 }
1702 );
1703 eprintln!(
1704 "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1705 other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1706 );
1707 }
1708
1709 B::rms_norm(
1710 &mut ctx,
1711 &residual,
1712 &self.final_norm_w,
1713 self.cfg.rms_norm_eps,
1714 &mut self.scratch.last_normed,
1715 1,
1716 h,
1717 );
1718
1719 let lm_head = self
1720 .lm_head
1721 .as_ref()
1722 .expect("decode_internal called on backbone-only model (no lm_head)");
1723 lm_head.forward(
1724 &mut ctx,
1725 &self.scratch.last_normed,
1726 &mut self.scratch.logits,
1727 1,
1728 );
1729
1730 if should_capture && !self.graph_capture_failed {
1731 if B::end_graph_capture(&mut ctx).is_err() {
1732 self.graph_capture_failed = true;
1733 } else {
1734 if B::replay_last_graph(&mut ctx).is_err() {
1741 self.graph_capture_failed = true;
1742 }
1743 }
1744 B::set_dev_state_mode(&mut ctx, false);
1745 } else {
1746 self.graph_warmup += 1;
1747 }
1748
1749 B::sync(&mut ctx);
1756 self.scratch.residual = Some(residual);
1757
1758 B::to_vec(&self.scratch.logits, vocab)
1759 }
1760
1761 pub fn prefill_from_embeds(
1770 &mut self,
1771 cache_id: &str,
1772 embeds: &[f32],
1773 seq_len: usize,
1774 ) -> Vec<f32> {
1775 let h = self.cfg.hidden_size;
1776 assert_eq!(
1777 embeds.len(),
1778 seq_len * h,
1779 "embeds length {} != seq_len * hidden_size {}",
1780 embeds.len(),
1781 seq_len * h
1782 );
1783 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1784
1785 self.ensure_scratch(seq_len);
1786 self.ensure_kv(cache_id);
1787
1788 let mut ctx = B::new_context();
1789 let mut residual = self
1790 .scratch
1791 .residual
1792 .take()
1793 .expect("scratch residual missing (previous call didn't restore)");
1794
1795 let embed_buf = B::from_slice(embeds);
1797 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1798
1799 for li in 0..self.cfg.num_layers {
1800 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1801 }
1802
1803 B::copy_slice(
1804 &mut ctx,
1805 &residual,
1806 (seq_len - 1) * h,
1807 &mut self.scratch.last_hidden,
1808 0,
1809 h,
1810 );
1811 B::sync(&mut ctx);
1812 self.scratch.residual = Some(residual);
1813 B::to_vec(&self.scratch.last_hidden, h)
1814 }
1815
1816 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1820 let h = self.cfg.hidden_size;
1821 assert_eq!(
1822 embed.len(),
1823 h,
1824 "embed length {} != hidden_size {}",
1825 embed.len(),
1826 h
1827 );
1828
1829 self.ensure_scratch(1);
1830 self.ensure_kv(cache_id);
1831
1832 let mut ctx = B::new_context();
1833 let mut residual = self
1834 .scratch
1835 .residual
1836 .take()
1837 .expect("scratch residual missing (previous call didn't restore)");
1838
1839 let embed_buf = B::from_slice(embed);
1840 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1841
1842 for li in 0..self.cfg.num_layers {
1843 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1844 }
1845
1846 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1847 B::sync(&mut ctx);
1848 self.scratch.residual = Some(residual);
1849 B::to_vec(&self.scratch.last_hidden, h)
1850 }
1851
1852 pub fn prefill_all_post_norm(
1863 &mut self,
1864 cache_id: &str,
1865 embeds: &[f32],
1866 seq_len: usize,
1867 pos_offset: usize,
1868 ) -> Vec<f32> {
1869 let h = self.cfg.hidden_size;
1870 assert_eq!(
1871 embeds.len(),
1872 seq_len * h,
1873 "embeds length {} != seq_len * hidden_size {}",
1874 embeds.len(),
1875 seq_len * h
1876 );
1877 assert!(seq_len > 0);
1878
1879 self.ensure_scratch(seq_len);
1880 self.ensure_kv(cache_id);
1881
1882 let mut ctx = B::new_context();
1883 let mut residual = self
1884 .scratch
1885 .residual
1886 .take()
1887 .expect("scratch residual missing (previous call didn't restore)");
1888
1889 let embed_buf = B::from_slice(embeds);
1890 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1891
1892 for li in 0..self.cfg.num_layers {
1893 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1894 }
1895
1896 B::rms_norm(
1898 &mut ctx,
1899 &residual,
1900 &self.final_norm_w,
1901 self.cfg.rms_norm_eps,
1902 &mut self.scratch.norm_out,
1903 seq_len,
1904 h,
1905 );
1906 B::sync(&mut ctx);
1907 self.scratch.residual = Some(residual);
1908 B::to_vec(&self.scratch.norm_out, seq_len * h)
1909 }
1910
1911 pub fn decode_post_norm_from_embed(
1915 &mut self,
1916 cache_id: &str,
1917 embed: &[f32],
1918 pos: u32,
1919 ) -> Vec<f32> {
1920 let h = self.cfg.hidden_size;
1921 assert_eq!(embed.len(), h);
1922
1923 self.ensure_scratch(1);
1924 self.ensure_kv(cache_id);
1925
1926 let mut ctx = B::new_context();
1927 let mut residual = self
1928 .scratch
1929 .residual
1930 .take()
1931 .expect("scratch residual missing (previous call didn't restore)");
1932
1933 let embed_buf = B::from_slice(embed);
1934 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1935
1936 for li in 0..self.cfg.num_layers {
1937 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1938 }
1939
1940 B::rms_norm(
1941 &mut ctx,
1942 &residual,
1943 &self.final_norm_w,
1944 self.cfg.rms_norm_eps,
1945 &mut self.scratch.last_normed,
1946 1,
1947 h,
1948 );
1949 B::sync(&mut ctx);
1950 self.scratch.residual = Some(residual);
1951 B::to_vec(&self.scratch.last_normed, h)
1952 }
1953
1954 pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1962 let m = batch.len();
1963 if m == 0 {
1964 return Vec::new();
1965 }
1966 if m == 1 {
1967 let (cid, tok, pos) = &batch[0];
1968 return vec![self.decode_internal(cid, *tok, *pos)];
1969 }
1970
1971 for (cid, _, _) in batch {
1973 self.ensure_kv(cid);
1974 }
1975 self.ensure_scratch(m);
1976 let h = self.cfg.hidden_size;
1982 let vocab = self.cfg.vocab_size;
1983 let mut ctx = B::new_context();
1984
1985 let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1987 let mut residual = self
1988 .scratch
1989 .residual
1990 .take()
1991 .expect("scratch residual missing (previous call didn't restore)");
1992 let embed = self
1993 .embed
1994 .as_ref()
1995 .expect("decode_batch_internal called on backbone-only model (no embed)");
1996 B::embedding_lookup(&mut ctx, embed, &tokens, &mut residual, h);
1997
1998 for li in 0..self.cfg.num_layers {
2000 self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m);
2001 }
2002
2003 B::rms_norm(
2005 &mut ctx,
2006 &residual,
2007 &self.final_norm_w,
2008 self.cfg.rms_norm_eps,
2009 &mut self.scratch.norm_out,
2010 m,
2011 h,
2012 );
2013
2014 let lm_head = self
2016 .lm_head
2017 .as_ref()
2018 .expect("decode_batch_internal called on backbone-only model (no lm_head)");
2019 lm_head.forward(
2020 &mut ctx,
2021 &self.scratch.norm_out,
2022 &mut self.scratch.batch_logits,
2023 m,
2024 );
2025
2026 B::sync(&mut ctx);
2028 self.scratch.residual = Some(residual);
2029
2030 let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
2032 (0..m)
2033 .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
2034 .collect()
2035 }
2036
2037 fn forward_layer_batched_decode(
2039 &mut self,
2040 ctx: &mut B::Context,
2041 li: usize,
2042 batch: &[(String, u32, u32)],
2043 residual: &mut B::Buffer,
2044 m: usize,
2045 ) {
2046 let cfg = &self.cfg;
2047 let h = cfg.hidden_size;
2048 let nh = cfg.num_heads;
2049 let nkv = cfg.num_kv_heads;
2050 let hd = cfg.head_dim;
2051 let im = cfg.intermediate_size;
2052 let eps = cfg.rms_norm_eps;
2053 let q_dim = nh * hd;
2054 let kv_dim = nkv * hd;
2055
2056 let layer = &self.layers[li];
2057 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
2058 let dummy_w = &layer.input_ln_w;
2059 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy_w);
2060 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy_w);
2061
2062 B::rms_norm(
2064 ctx,
2065 residual,
2066 &layer.input_ln_w,
2067 eps,
2068 &mut self.scratch.norm_out,
2069 m,
2070 h,
2071 );
2072
2073 layer
2075 .qkv_proj
2076 .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
2077
2078 if let Some(pools) = self.paged_pools.as_mut() {
2098 let pool_ptr = (
2099 &mut pools[li].0 as *mut B::Buffer,
2100 &mut pools[li].1 as *mut B::Buffer,
2101 );
2102 let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
2104
2105 let qkv_stride = q_dim + 2 * kv_dim;
2106 let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
2107 let block_size = 16; let mut item_state: Vec<(u32, Vec<u32>)> = Vec::with_capacity(m);
2114 for (cache_id, _, _) in batch.iter() {
2115 let caches = self
2116 .kv_caches
2117 .get(cache_id)
2118 .expect("ensure_kv must be called before forward_layer_batched");
2119 let cache = &caches[li];
2120 item_state.push((cache.len as u32, cache.paged_block_indices.clone()));
2121 }
2122
2123 let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
2127 let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
2128 for (i, (cache_id, _, pos)) in batch.iter().enumerate() {
2129 let pos_i = *pos as usize;
2130 let caches = self
2131 .kv_caches
2132 .get(cache_id)
2133 .expect("paged batched: cache not present");
2134 let cache = &caches[li];
2135 let bt = cache
2136 .block_table
2137 .as_ref()
2138 .expect("paged batched: block_table missing");
2139 let cache_len_before = cache.len;
2140 let block_table_ref = bt as *const B::Buffer;
2141 let bt_safe: &B::Buffer = unsafe { &*block_table_ref };
2145 B::split_qkv_norm_rope_into_paged_cache(
2146 ctx,
2147 &self.scratch.qkv_out,
2148 (i as u64) * qkv_stride_bytes,
2149 q_norm_w,
2150 k_norm_w,
2151 &self.rope.cos,
2152 &self.rope.sin,
2153 self.scratch
2154 .paged_batch_q
2155 .as_mut()
2156 .expect("paged_batch_q missing"),
2157 (i as u64) * q_head_major_size_bytes,
2158 pool_k,
2159 pool_v,
2160 bt_safe,
2161 1,
2162 nh,
2163 nkv,
2164 hd,
2165 pos_i,
2166 eps,
2167 qk_mode,
2168 cache_len_before,
2169 block_size,
2170 max_blocks_per_seq,
2171 )
2172 .expect("paged batched write");
2173 }
2174
2175 let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
2178 let mut stacked_cl: Vec<u32> = vec![0u32; m];
2179 for (i, (cache_id, _, _)) in batch.iter().enumerate() {
2180 let caches = self
2181 .kv_caches
2182 .get_mut(cache_id)
2183 .expect("paged batched: cache not present");
2184 let cache = &mut caches[li];
2185 cache.len += 1;
2186 let len = cache.len as u32;
2187 stacked_cl[i] = len;
2188 let blocks = &cache.paged_block_indices;
2189 let n_to_copy = blocks.len().min(max_blocks_per_seq);
2190 stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
2191 .copy_from_slice(&blocks[..n_to_copy]);
2192 }
2193 let bt_buf = self
2194 .scratch
2195 .paged_batch_block_tables
2196 .as_mut()
2197 .expect("paged_batch_block_tables missing");
2198 B::write_u32(ctx, bt_buf, &stacked_bt);
2199 let cl_buf = self
2200 .scratch
2201 .paged_batch_context_lens
2202 .as_mut()
2203 .expect("paged_batch_context_lens missing");
2204 B::write_u32(ctx, cl_buf, &stacked_cl);
2205
2206 let bt_ptr =
2208 self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
2209 let cl_ptr =
2210 self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
2211 let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
2212 let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
2213 let bt_safe = unsafe { &*bt_ptr };
2216 let cl_safe = unsafe { &*cl_ptr };
2217 let q_safe = unsafe { &*q_ptr };
2218 let o_safe = unsafe { &mut *o_ptr };
2219 B::paged_decode_attention(
2220 ctx,
2221 q_safe,
2222 pool_k,
2223 pool_v,
2224 o_safe,
2225 bt_safe,
2226 cl_safe,
2227 m,
2228 nh,
2229 nkv,
2230 hd,
2231 block_size,
2232 max_blocks_per_seq,
2233 1, )
2235 .expect("paged batched decode");
2236
2237 for i in 0..m {
2241 B::copy_slice(
2242 ctx,
2243 self.scratch.paged_batch_o.as_ref().unwrap(),
2244 i * q_dim,
2245 &mut self.scratch.attn_flat,
2246 i * q_dim,
2247 q_dim,
2248 );
2249 }
2250
2251 return self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2253 }
2254
2255 B::split_qkv(
2257 ctx,
2258 &self.scratch.qkv_out,
2259 &mut self.scratch.q_buf,
2260 &mut self.scratch.k_buf,
2261 &mut self.scratch.v_buf,
2262 m,
2263 q_dim,
2264 kv_dim,
2265 );
2266
2267 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
2270 let pos_i = *pos as usize;
2271
2272 B::copy_slice(
2274 ctx,
2275 &self.scratch.q_buf,
2276 i * q_dim,
2277 &mut self.scratch.q_single,
2278 0,
2279 q_dim,
2280 );
2281 B::copy_slice(
2282 ctx,
2283 &self.scratch.k_buf,
2284 i * kv_dim,
2285 &mut self.scratch.k_single,
2286 0,
2287 kv_dim,
2288 );
2289 B::copy_slice(
2290 ctx,
2291 &self.scratch.v_buf,
2292 i * kv_dim,
2293 &mut self.scratch.v_single,
2294 0,
2295 kv_dim,
2296 );
2297
2298 B::qk_norm_rope(
2300 ctx,
2301 &self.scratch.q_single,
2302 q_norm_w,
2303 &self.rope.cos,
2304 &self.rope.sin,
2305 &mut self.scratch.q_head_major_single,
2306 1,
2307 nh,
2308 hd,
2309 pos_i,
2310 eps,
2311 qk_mode,
2312 );
2313 B::qk_norm_rope(
2314 ctx,
2315 &self.scratch.k_single,
2316 k_norm_w,
2317 &self.rope.cos,
2318 &self.rope.sin,
2319 &mut self.scratch.k_head_major_single,
2320 1,
2321 nkv,
2322 hd,
2323 pos_i,
2324 eps,
2325 qk_mode,
2326 );
2327 B::qk_norm_rope(
2328 ctx,
2329 &self.scratch.v_single,
2330 dummy_w,
2331 &self.rope.cos,
2332 &self.rope.sin,
2333 &mut self.scratch.v_head_major_single,
2334 1,
2335 nkv,
2336 hd,
2337 pos_i,
2338 eps,
2339 0,
2340 );
2341
2342 let caches = self
2344 .kv_caches
2345 .get_mut(cache_id)
2346 .expect("ensure_kv must be called before forward_layer_batched");
2347 let cache = &mut caches[li];
2348 B::kv_cache_append_head_major(
2349 ctx,
2350 &mut cache.k,
2351 &mut cache.v,
2352 cache.len,
2353 cache.capacity,
2354 &self.scratch.k_head_major_single,
2355 &self.scratch.v_head_major_single,
2356 1,
2357 nkv,
2358 hd,
2359 );
2360 cache.len += 1;
2361 let kv_len = cache.len;
2362 let kv_stride = cache.capacity;
2363
2364 let attn_cfg = ferrum_kernels::backend::AttnConfig {
2365 num_heads: nh,
2366 num_kv_heads: nkv,
2367 head_dim: hd,
2368 causal: true,
2369 scale: 1.0 / (hd as f32).sqrt(),
2370 kv_seq_stride: kv_stride,
2371 sliding_window: cfg.sliding_window,
2372 };
2373 B::flash_attention(
2374 ctx,
2375 &self.scratch.q_head_major_single,
2376 &cache.k,
2377 &cache.v,
2378 &mut self.scratch.attn_head_major_single,
2379 1,
2380 1,
2381 kv_len,
2382 pos_i,
2383 &attn_cfg,
2384 );
2385
2386 B::copy_slice(
2394 ctx,
2395 &self.scratch.attn_head_major_single,
2396 0,
2397 &mut self.scratch.attn_flat,
2398 i * q_dim,
2399 q_dim,
2400 );
2401 }
2402
2403 self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2404 }
2405
2406 fn forward_layer_batched_decode_post_attn(
2407 &mut self,
2408 ctx: &mut B::Context,
2409 li: usize,
2410 residual: &mut B::Buffer,
2411 m: usize,
2412 ) {
2413 let cfg = &self.cfg;
2414 let h = cfg.hidden_size;
2415 let im = cfg.intermediate_size;
2416 let eps = cfg.rms_norm_eps;
2417 let layer = &self.layers[li];
2418
2419 layer.o_proj.forward(
2421 ctx,
2422 &self.scratch.attn_flat,
2423 &mut self.scratch.o_proj_out,
2424 m,
2425 );
2426
2427 B::fused_add_rms_norm(
2429 ctx,
2430 residual,
2431 &self.scratch.o_proj_out,
2432 &layer.post_ln_w,
2433 eps,
2434 &mut self.scratch.norm_out,
2435 m,
2436 h,
2437 );
2438
2439 layer.gate_up_proj.forward(
2441 ctx,
2442 &self.scratch.norm_out,
2443 &mut self.scratch.gate_up_out,
2444 m,
2445 );
2446
2447 B::fused_silu_mul_split(
2449 ctx,
2450 &self.scratch.gate_up_out,
2451 &mut self.scratch.silu_out,
2452 m,
2453 im,
2454 );
2455
2456 layer
2458 .down_proj
2459 .forward(ctx, &self.scratch.silu_out, &mut self.scratch.mlp_out, m);
2460
2461 B::add_inplace(ctx, residual, &self.scratch.mlp_out, m * h);
2463 }
2464}
2465
2466impl<B: Backend> DecoderOnlyLLM for LlamaFamilyModel<B> {
2467 fn config(&self) -> &LlmRuntimeConfig {
2468 &self.runtime_cfg
2469 }
2470
2471 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2472 self.ensure_scratch(max_tokens);
2477 self.ensure_kv(cache_id);
2478
2479 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2480 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2481 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2485 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2486 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2487 if let Some(c0) = caches.first() {
2488 if !c0.paged_block_indices.is_empty() {
2489 alloc.free(&c0.paged_block_indices);
2490 }
2491 }
2492 for c in caches.iter_mut() {
2493 c.paged_block_indices.clear();
2494 }
2495 }
2496 self.kv_free_pool.push(caches);
2497 }
2498 }
2499
2500 fn kv_capacity(&self) -> usize {
2501 let model_max = self.cfg.max_seq_len;
2503 const DEFAULT_KV_CAPACITY: usize = 512;
2504 std::env::var("FERRUM_KV_CAPACITY")
2505 .ok()
2506 .and_then(|s| s.parse::<usize>().ok())
2507 .map(|cap| cap.min(model_max))
2508 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2509 }
2510
2511 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2512 self.prefill_internal(cache_id, tokens)
2513 }
2514
2515 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2516 self.decode_internal(cache_id, token, pos)
2517 }
2518
2519 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2520 self.decode_batch_internal(batch)
2521 }
2522
2523 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2524 LlamaFamilyModel::<B>::forward_verify(self, cache_id, tokens)
2526 }
2527
2528 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2529 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2530 for c in caches.iter_mut() {
2531 if new_len < c.len {
2532 c.len = new_len;
2533 }
2534 }
2535 }
2536 let mut ctx = B::new_context();
2538 B::reset_graph(&mut ctx);
2539 self.graph_warmup = 0;
2540 self.graph_capture_failed = false;
2541 }
2542
2543 fn release(&mut self, cache_id: &str) {
2544 let mut ctx = B::new_context();
2550 B::sync(&mut ctx);
2551 B::reset_graph(&mut ctx);
2552 B::sync(&mut ctx);
2553 self.graph_warmup = 0;
2554 self.graph_capture_failed = false;
2555
2556 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2561 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2562 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2563 if let Some(c0) = caches.first() {
2567 if !c0.paged_block_indices.is_empty() {
2568 alloc.free(&c0.paged_block_indices);
2569 }
2570 }
2571 for c in caches.iter_mut() {
2574 c.paged_block_indices.clear();
2575 }
2576 }
2577 self.kv_free_pool.push(caches);
2578 }
2579 }
2580
2581 fn reset(&mut self) {
2582 let mut ctx = B::new_context();
2584 B::sync(&mut ctx);
2585 B::reset_graph(&mut ctx);
2586 B::sync(&mut ctx);
2587 self.graph_warmup = 0;
2588 self.graph_capture_failed = false;
2589 self.kv_caches.clear();
2590 self.kv_free_pool.clear();
2591 }
2592}
2593
2594fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2595 let hd = cfg.head_dim;
2596 let half = hd / 2;
2597 let max = cfg.max_seq_len;
2598 let mut cos = vec![0.0f32; max * half];
2599 let mut sin = vec![0.0f32; max * half];
2600 for pos in 0..max {
2601 for i in 0..half {
2602 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
2603 let angle = pos as f64 * freq;
2604 cos[pos * half + i] = angle.cos() as f32;
2605 sin[pos * half + i] = angle.sin() as f32;
2606 }
2607 }
2608 RopeCache {
2609 cos: B::from_slice(&cos),
2610 sin: B::from_slice(&sin),
2611 }
2612}