1use std::collections::HashMap;
20use std::sync::atomic::AtomicU64;
21
22use ferrum_kernels::backend::{Backend, KvCache};
23
24static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
25static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
26static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
27static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
28static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
29static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
30static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
31static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
32static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
33static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
34use ferrum_quantization::{Linear, WeightLoader};
35use ferrum_types::Result;
36
37use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
38
39#[derive(Clone, Debug, PartialEq)]
42pub struct LlamaFamilyConfig {
43 pub hidden_size: usize,
44 pub intermediate_size: usize,
45 pub num_heads: usize,
46 pub num_kv_heads: usize,
47 pub head_dim: usize,
48 pub num_layers: usize,
49 pub vocab_size: usize,
50 pub max_seq_len: usize,
51 pub rms_norm_eps: f32,
52 pub rope_theta: f64,
53 pub has_qk_norm: bool,
56 pub sliding_window: usize,
59}
60
61impl LlamaFamilyConfig {
62 pub fn to_runtime(&self) -> LlmRuntimeConfig {
63 LlmRuntimeConfig {
64 hidden_size: self.hidden_size,
65 num_layers: self.num_layers,
66 num_kv_heads: self.num_kv_heads,
67 head_dim: self.head_dim,
68 vocab_size: self.vocab_size,
69 max_seq_len: self.max_seq_len,
70 }
71 }
72
73 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
77 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
78 let head_dim = def
79 .extra_params
80 .get("head_dim")
81 .and_then(|v| v.as_u64())
82 .map(|v| v as usize)
83 .unwrap_or(def.hidden_size / def.num_attention_heads);
84 let sliding_window = def
87 .extra_params
88 .get("sliding_window")
89 .and_then(|v| v.as_u64())
90 .map(|v| v as usize)
91 .unwrap_or(0);
92
93 LlamaFamilyConfigBase {
94 hidden_size: def.hidden_size,
95 intermediate_size: def.intermediate_size,
96 num_heads: def.num_attention_heads,
97 num_kv_heads,
98 head_dim,
99 num_layers: def.num_hidden_layers,
100 vocab_size: def.vocab_size,
101 max_seq_len: def.max_position_embeddings,
102 rms_norm_eps: def.norm_eps as f32,
103 rope_theta_opt: def.rope_theta,
104 sliding_window,
105 }
106 }
107
108 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
109 Self {
110 hidden_size: b.hidden_size,
111 intermediate_size: b.intermediate_size,
112 num_heads: b.num_heads,
113 num_kv_heads: b.num_kv_heads,
114 head_dim: b.head_dim,
115 num_layers: b.num_layers,
116 vocab_size: b.vocab_size,
117 max_seq_len: b.max_seq_len,
118 rms_norm_eps: b.rms_norm_eps,
119 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
120 has_qk_norm,
121 sliding_window: b.sliding_window,
122 }
123 }
124
125 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
127 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
128 }
129
130 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
134 Self::from_base(Self::from_def_base(def), 500_000.0, false)
135 }
136
137 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
139 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
140 }
141
142 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
146 Self::from_base(Self::from_def_base(def), 10_000.0, false)
147 }
148}
149
150struct LlamaFamilyConfigBase {
151 hidden_size: usize,
152 intermediate_size: usize,
153 num_heads: usize,
154 num_kv_heads: usize,
155 head_dim: usize,
156 num_layers: usize,
157 vocab_size: usize,
158 max_seq_len: usize,
159 rms_norm_eps: f32,
160 rope_theta_opt: Option<f64>,
161 sliding_window: usize,
162}
163
164pub struct LlamaFamilyLayer<B: Backend> {
167 pub input_ln_w: B::Buffer,
168 pub qkv_proj: Box<dyn Linear<B>>,
169 pub q_norm_w: Option<B::Buffer>,
171 pub k_norm_w: Option<B::Buffer>,
172 pub o_proj: Box<dyn Linear<B>>,
173 pub post_ln_w: B::Buffer,
174 pub gate_up_proj: Box<dyn Linear<B>>,
175 pub down_proj: Box<dyn Linear<B>>,
176}
177
178pub struct RopeCache<B: Backend> {
180 pub cos: B::Buffer,
181 pub sin: B::Buffer,
182}
183
184pub struct LlamaFamilyScratch<B: Backend> {
190 pub residual: Option<B::Buffer>,
201 pub norm_out: B::Buffer,
202 pub qkv_out: B::Buffer,
203 pub q_single: B::Buffer,
211 pub k_single: B::Buffer,
212 pub v_single: B::Buffer,
213 pub q_head_major_single: B::Buffer,
214 pub k_head_major_single: B::Buffer,
215 pub v_head_major_single: B::Buffer,
216 pub attn_head_major_single: B::Buffer,
217 pub attn_flat_single: B::Buffer,
218 pub batch_logits: B::Buffer,
221 pub q_buf: B::Buffer,
223 pub k_buf: B::Buffer,
224 pub v_buf: B::Buffer,
225 pub q_head_major: B::Buffer,
227 pub k_head_major: B::Buffer,
230 pub v_head_major: B::Buffer,
231 pub attn_head_major_out: B::Buffer,
233 pub attn_flat: B::Buffer,
235 pub o_proj_out: B::Buffer,
236 pub gate_up_out: B::Buffer,
237 pub silu_out: B::Buffer,
238 pub mlp_out: B::Buffer,
239 pub paged_batch_q: Option<B::Buffer>,
245 pub paged_batch_o: Option<B::Buffer>,
246 pub paged_batch_block_tables: Option<B::Buffer>,
250 pub paged_batch_context_lens: Option<B::Buffer>,
253 pub paged_max_blocks_per_seq: usize,
256 pub last_hidden: B::Buffer,
260 pub last_normed: B::Buffer,
262 pub logits: B::Buffer,
264 pub max_tokens: usize,
266}
267
268impl<B: Backend> LlamaFamilyScratch<B> {
269 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
270 let h = cfg.hidden_size;
271 let im = cfg.intermediate_size;
272 let q_dim = cfg.num_heads * cfg.head_dim;
273 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
274 let qkv_dim = q_dim + 2 * kv_dim;
275 let t = max_tokens;
276 Self {
277 residual: Some(B::alloc(t * h)),
278 norm_out: B::alloc(t * h),
279 qkv_out: B::alloc(t * qkv_dim),
280 q_buf: B::alloc(t * q_dim),
281 k_buf: B::alloc(t * kv_dim),
282 v_buf: B::alloc(t * kv_dim),
283 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
284 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
285 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
286 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
287 attn_flat: B::alloc(t * q_dim),
288 o_proj_out: B::alloc(t * h),
289 gate_up_out: B::alloc(t * 2 * im),
290 silu_out: B::alloc(t * im),
291 mlp_out: B::alloc(t * h),
292 last_hidden: B::alloc(h),
293 last_normed: B::alloc(h),
294 logits: B::alloc(cfg.vocab_size),
295 q_single: B::alloc(q_dim),
296 k_single: B::alloc(kv_dim),
297 v_single: B::alloc(kv_dim),
298 q_head_major_single: B::alloc(q_dim),
299 k_head_major_single: B::alloc(kv_dim),
300 v_head_major_single: B::alloc(kv_dim),
301 attn_head_major_single: B::alloc(q_dim),
302 attn_flat_single: B::alloc(q_dim),
303 batch_logits: B::alloc(t * cfg.vocab_size),
304 paged_batch_q: None,
309 paged_batch_o: None,
310 paged_batch_block_tables: None,
311 paged_batch_context_lens: None,
312 paged_max_blocks_per_seq: 0,
313 max_tokens: t,
314 }
315 }
316
317 fn enable_paged_batch(
321 &mut self,
322 cfg: &LlamaFamilyConfig,
323 max_seqs: usize,
324 max_blocks_per_seq: usize,
325 ) {
326 if self.paged_batch_q.is_some() {
327 return;
328 }
329 let q_dim = cfg.num_heads * cfg.head_dim;
330 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
331 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
332 self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
333 self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
334 self.paged_max_blocks_per_seq = max_blocks_per_seq;
335 }
336}
337
338pub struct LlamaFamilyModel<B: Backend> {
342 pub cfg: LlamaFamilyConfig,
343 pub runtime_cfg: LlmRuntimeConfig,
344
345 pub embed: Option<B::Buffer>,
349 pub layers: Vec<LlamaFamilyLayer<B>>,
350 pub final_norm_w: B::Buffer,
351 pub lm_head: Option<Box<dyn Linear<B>>>,
353
354 pub rope: RopeCache<B>,
355 pub scratch: LlamaFamilyScratch<B>,
356
357 pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
366 kv_free_pool: Vec<Vec<KvCache<B>>>,
371
372 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
384 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
388
389 graph_warmup: usize,
393 graph_capture_failed: bool,
396}
397
398impl<B: Backend> LlamaFamilyModel<B> {
399 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
404 {
409 let mut ctx = B::new_context();
410 B::reset_graph(&mut ctx);
411 }
412 let rope = build_rope_cache::<B>(&cfg);
413 let scratch = LlamaFamilyScratch::alloc(&cfg, 1); let embed = loader.load_tensor("model.embed_tokens.weight")?;
417
418 let mut layers = Vec::with_capacity(cfg.num_layers);
420 for li in 0..cfg.num_layers {
421 let prefix = format!("model.layers.{li}");
422 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
423 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
424 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
425 let post_ln_w =
426 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
427 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
428 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
429
430 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
431 let q = loader
432 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
433 .ok();
434 let k = loader
435 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
436 .ok();
437 (q, k)
438 } else {
439 (None, None)
440 };
441
442 layers.push(LlamaFamilyLayer {
443 input_ln_w,
444 qkv_proj,
445 q_norm_w,
446 k_norm_w,
447 o_proj,
448 post_ln_w,
449 gate_up_proj,
450 down_proj,
451 });
452 }
453
454 let final_norm_w = loader.load_tensor("model.norm.weight")?;
455
456 let lm_head = if loader.has_tensor("lm_head.weight") {
464 loader.load_linear("lm_head")?
465 } else {
466 tracing::info!(
467 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
468 );
469 let as_linear = loader.load_linear("model.embed_tokens")?;
470 if as_linear.out_features() != cfg.vocab_size
472 || as_linear.in_features() != cfg.hidden_size
473 {
474 return Err(ferrum_types::FerrumError::model(format!(
475 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
476 as_linear.out_features(),
477 as_linear.in_features(),
478 cfg.vocab_size,
479 cfg.hidden_size
480 )));
481 }
482 as_linear
483 };
484
485 let runtime_cfg = cfg.to_runtime();
486 Ok(Self {
487 cfg,
488 runtime_cfg,
489 embed: Some(embed),
490 layers,
491 final_norm_w,
492 lm_head: Some(lm_head),
493 rope,
494 scratch,
495 kv_caches: HashMap::new(),
496 kv_free_pool: Vec::new(),
497 paged_pools: None,
498 paged_block_alloc: None,
499 graph_warmup: 0,
500 graph_capture_failed: false,
501 })
502 }
503
504 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
516 {
518 let mut ctx = B::new_context();
519 B::reset_graph(&mut ctx);
520 }
521 let rope = build_rope_cache::<B>(&cfg);
522 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
523
524 let mut layers = Vec::with_capacity(cfg.num_layers);
525 for li in 0..cfg.num_layers {
526 let prefix = format!("model.layers.{li}");
527 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
528 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
529 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
530 let post_ln_w =
531 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
532 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
533 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
534
535 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
536 let q = loader
537 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
538 .ok();
539 let k = loader
540 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
541 .ok();
542 (q, k)
543 } else {
544 (None, None)
545 };
546
547 layers.push(LlamaFamilyLayer {
548 input_ln_w,
549 qkv_proj,
550 q_norm_w,
551 k_norm_w,
552 o_proj,
553 post_ln_w,
554 gate_up_proj,
555 down_proj,
556 });
557 }
558
559 let final_norm_w = loader.load_tensor("model.norm.weight")?;
560
561 let runtime_cfg = cfg.to_runtime();
562 Ok(Self {
563 cfg,
564 runtime_cfg,
565 embed: None,
566 layers,
567 final_norm_w,
568 lm_head: None,
569 rope,
570 scratch,
571 kv_caches: HashMap::new(),
572 kv_free_pool: Vec::new(),
573 paged_pools: None,
574 paged_block_alloc: None,
575 graph_warmup: 0,
576 graph_capture_failed: false,
577 })
578 }
579
580 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
582 if self.scratch.max_tokens < tokens {
583 {
588 let mut ctx = B::new_context();
589 B::reset_graph(&mut ctx);
590 }
591 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
592 self.graph_warmup = 0;
593 self.graph_capture_failed = false;
594 }
595 }
596
597 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
601 if self.kv_caches.contains_key(cache_id) {
602 return;
603 }
604 let nkv = self.cfg.num_kv_heads;
605 let hd = self.cfg.head_dim;
606 let model_max = self.cfg.max_seq_len;
613 const DEFAULT_KV_CAPACITY: usize = 4096;
614 let max = std::env::var("FERRUM_KV_CAPACITY")
615 .ok()
616 .and_then(|s| s.parse::<usize>().ok())
617 .map(|cap| cap.min(model_max))
618 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
619
620 let paged = std::env::var("FERRUM_METAL_PAGED_KV")
631 .map(|v| v == "1")
632 .unwrap_or(false);
633 const PAGED_BLOCK_SIZE: usize = 16;
634
635 let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
638 .ok()
639 .and_then(|s| s.parse::<usize>().ok())
640 .unwrap_or(16);
641 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
642 let total_pool_blocks = max_seqs * max_blocks_per_seq;
643
644 if paged && self.paged_pools.is_none() {
651 let mut pools = Vec::with_capacity(self.cfg.num_layers);
652 for _ in 0..self.cfg.num_layers {
653 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
654 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
655 }
656 self.paged_pools = Some(pools);
657 self.paged_block_alloc = Some(std::sync::Mutex::new(
658 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
659 ));
660 }
661 if paged {
667 self.scratch
668 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
669 }
670
671 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
674 (0..self.cfg.num_layers)
675 .map(|_| {
676 if paged {
677 let mut block_table = B::alloc_u32(max_blocks_per_seq);
683 let mut context_lens = B::alloc_u32(1);
684 let mut bt_ctx = B::new_context();
685 B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
686 B::sync(&mut bt_ctx);
687 KvCache {
688 k: B::alloc(1),
689 v: B::alloc(1),
690 len: 0,
691 capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
692 num_kv_heads: nkv,
693 head_dim: hd,
694 block_size: PAGED_BLOCK_SIZE,
695 block_table: Some(block_table),
696 context_lens: Some(context_lens),
697 paged_block_indices: Vec::new(),
698 }
699 } else {
700 KvCache {
701 k: B::alloc(nkv * max * hd),
702 v: B::alloc(nkv * max * hd),
703 len: 0,
704 capacity: max,
705 num_kv_heads: nkv,
706 head_dim: hd,
707 block_size: 0,
708 block_table: None,
709 context_lens: None,
710 paged_block_indices: Vec::new(),
711 }
712 }
713 })
714 .collect()
715 });
716
717 if paged {
723 let alloc_arc = self
724 .paged_block_alloc
725 .as_ref()
726 .expect("paged_block_alloc must be initialised when paged=true");
727 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
731 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
732 Ok(idx) => idx,
733 Err(e) => {
734 drop(alloc);
741 self.kv_free_pool.push(caches);
742 eprintln!(
743 "[ferrum] paged KV pool exhausted on ensure_kv for \
744 cache_id={cache_id:?}: {e}. Increase \
745 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
746 throttle concurrent requests.",
747 );
748 return;
749 }
750 };
751 let mut padded = block_indices.clone();
756 padded.resize(max_blocks_per_seq, 0);
757 let mut ctx_tmp = B::new_context();
758 for c in caches.iter_mut() {
759 if let Some(bt) = c.block_table.as_mut() {
760 B::write_u32(&mut ctx_tmp, bt, &padded);
761 }
762 c.paged_block_indices = block_indices.clone();
763 }
764 B::sync(&mut ctx_tmp);
765 }
766
767 for c in caches.iter_mut() {
771 c.len = 0;
772 if let Some(cl) = c.context_lens.as_mut() {
773 let mut ctx_tmp = B::new_context();
774 B::write_u32(&mut ctx_tmp, cl, &[0u32]);
775 B::sync(&mut ctx_tmp);
776 }
777 }
778 self.kv_caches.insert(cache_id.to_string(), caches);
779 }
780
781 #[allow(clippy::too_many_arguments)]
786 pub(crate) fn forward_layer(
787 &mut self,
788 ctx: &mut B::Context,
789 li: usize,
790 cache_id: &str,
791 residual: &mut B::Buffer,
792 pos_offset: usize,
793 tokens: usize,
794 ) {
795 let layer = &self.layers[li];
796 let cfg = &self.cfg;
797 let h = cfg.hidden_size;
798 let nh = cfg.num_heads;
799 let nkv = cfg.num_kv_heads;
800 let hd = cfg.head_dim;
801 let im = cfg.intermediate_size;
802 let eps = cfg.rms_norm_eps;
803 let q_dim = nh * hd;
804 let kv_dim = nkv * hd;
805
806 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
808 B::sync(ctx);
809 Some(std::time::Instant::now())
810 } else {
811 None
812 };
813 B::rms_norm(
814 ctx,
815 residual,
816 &layer.input_ln_w,
817 eps,
818 &mut self.scratch.norm_out,
819 tokens,
820 h,
821 );
822 if let Some(t0) = _t0 {
823 B::sync(ctx);
824 NORM_TIME_US.fetch_add(
825 t0.elapsed().as_micros() as u64,
826 std::sync::atomic::Ordering::Relaxed,
827 );
828 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
829 }
830
831 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
833 B::sync(ctx);
834 Some(std::time::Instant::now())
835 } else {
836 None
837 };
838 layer.qkv_proj.forward(
839 ctx,
840 &self.scratch.norm_out,
841 &mut self.scratch.qkv_out,
842 tokens,
843 );
844 if let Some(t0) = _t0 {
845 B::sync(ctx);
846 MATMUL_TIME_US.fetch_add(
847 t0.elapsed().as_micros() as u64,
848 std::sync::atomic::Ordering::Relaxed,
849 );
850 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
851 }
852
853 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
866 let dummy = &layer.input_ln_w;
867 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
868 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
869
870 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
881 if let Some(pools) = self.paged_pools.as_mut() {
882 let pool = &mut pools[li];
883 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
884 } else {
885 None
886 };
887 let caches = self
888 .kv_caches
889 .get_mut(cache_id)
890 .expect("ensure_kv must be called before forward_layer");
891 let cache = &mut caches[li];
892 let cache_len_before = cache.len;
893 let cache_capacity = cache.capacity;
894
895 if cache_len_before + tokens > cache_capacity {
901 panic!(
902 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
903 cache_len_before + tokens
904 );
905 }
906
907 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
908 B::sync(ctx);
909 Some(std::time::Instant::now())
910 } else {
911 None
912 };
913 let used_qkv_into_cache = if cache.block_size > 0 {
918 let bt = cache
919 .block_table
920 .as_ref()
921 .expect("paged cache missing block_table");
922 let num_blocks_per_seq = cache.capacity / cache.block_size;
923 let (pool_k_ptr, pool_v_ptr) =
925 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
926 let pool_k = unsafe { &mut *pool_k_ptr };
929 let pool_v = unsafe { &mut *pool_v_ptr };
930 B::split_qkv_norm_rope_into_paged_cache(
931 ctx,
932 &self.scratch.qkv_out,
933 0, q_norm_w,
935 k_norm_w,
936 &self.rope.cos,
937 &self.rope.sin,
938 &mut self.scratch.q_head_major,
939 0, pool_k,
941 pool_v,
942 bt,
943 tokens,
944 nh,
945 nkv,
946 hd,
947 pos_offset,
948 eps,
949 qk_mode,
950 cache_len_before,
951 cache.block_size,
952 num_blocks_per_seq,
953 )
954 .is_ok()
955 } else {
956 B::split_qkv_norm_rope_into_cache(
957 ctx,
958 &self.scratch.qkv_out,
959 q_norm_w,
960 k_norm_w,
961 &self.rope.cos,
962 &self.rope.sin,
963 &mut self.scratch.q_head_major,
964 &mut cache.k,
965 &mut cache.v,
966 tokens,
967 nh,
968 nkv,
969 hd,
970 pos_offset,
971 eps,
972 qk_mode,
973 cache_len_before,
974 cache_capacity,
975 )
976 .is_ok()
977 };
978 if !used_qkv_into_cache {
979 let used_fused_qkv = B::split_qkv_norm_rope(
982 ctx,
983 &self.scratch.qkv_out,
984 q_norm_w,
985 k_norm_w,
986 &self.rope.cos,
987 &self.rope.sin,
988 &mut self.scratch.q_head_major,
989 &mut self.scratch.k_head_major,
990 &mut self.scratch.v_head_major,
991 tokens,
992 nh,
993 nkv,
994 hd,
995 pos_offset,
996 eps,
997 qk_mode,
998 )
999 .is_ok();
1000 if !used_fused_qkv {
1001 B::split_qkv(
1003 ctx,
1004 &self.scratch.qkv_out,
1005 &mut self.scratch.q_buf,
1006 &mut self.scratch.k_buf,
1007 &mut self.scratch.v_buf,
1008 tokens,
1009 q_dim,
1010 kv_dim,
1011 );
1012 B::qk_norm_rope(
1013 ctx,
1014 &self.scratch.q_buf,
1015 q_norm_w,
1016 &self.rope.cos,
1017 &self.rope.sin,
1018 &mut self.scratch.q_head_major,
1019 tokens,
1020 nh,
1021 hd,
1022 pos_offset,
1023 eps,
1024 qk_mode,
1025 );
1026 B::qk_norm_rope(
1027 ctx,
1028 &self.scratch.k_buf,
1029 k_norm_w,
1030 &self.rope.cos,
1031 &self.rope.sin,
1032 &mut self.scratch.k_head_major,
1033 tokens,
1034 nkv,
1035 hd,
1036 pos_offset,
1037 eps,
1038 qk_mode,
1039 );
1040 B::qk_norm_rope(
1041 ctx,
1042 &self.scratch.v_buf,
1043 dummy,
1044 &self.rope.cos,
1045 &self.rope.sin,
1046 &mut self.scratch.v_head_major,
1047 tokens,
1048 nkv,
1049 hd,
1050 pos_offset,
1051 eps,
1052 0,
1053 );
1054 }
1055 B::kv_cache_append_head_major(
1056 ctx,
1057 &mut cache.k,
1058 &mut cache.v,
1059 cache.len,
1060 cache.capacity,
1061 &self.scratch.k_head_major,
1062 &self.scratch.v_head_major,
1063 tokens,
1064 nkv,
1065 hd,
1066 );
1067 }
1068 if let Some(t0) = _t0 {
1069 B::sync(ctx);
1070 QKR_TIME_US.fetch_add(
1071 t0.elapsed().as_micros() as u64,
1072 std::sync::atomic::Ordering::Relaxed,
1073 );
1074 QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1075 }
1076 cache.len += tokens;
1077 let kv_len = cache.len;
1078 let kv_stride = cache.capacity;
1079
1080 let _attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1087 B::sync(ctx);
1088 Some(std::time::Instant::now())
1089 } else {
1090 None
1091 };
1092 if cache.block_size > 0 {
1093 let bt = cache
1094 .block_table
1095 .as_ref()
1096 .expect("paged cache missing block_table");
1097 let cl_buf = cache
1098 .context_lens
1099 .as_mut()
1100 .expect("paged cache missing context_lens");
1101 let num_blocks_per_seq = cache.capacity / cache.block_size;
1102 let (pool_k_ptr, pool_v_ptr) =
1104 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1105 let pool_k = unsafe { &*pool_k_ptr };
1108 let pool_v = unsafe { &*pool_v_ptr };
1109 let final_kv_len = cache.len as u32;
1115 B::write_u32(ctx, cl_buf, &[final_kv_len]);
1116 B::paged_decode_attention(
1117 ctx,
1118 &self.scratch.q_head_major,
1119 pool_k,
1120 pool_v,
1121 &mut self.scratch.attn_head_major_out,
1122 bt,
1123 cl_buf,
1124 1, nh,
1126 nkv,
1127 hd,
1128 cache.block_size,
1129 num_blocks_per_seq,
1130 tokens, )
1132 .expect("paged_decode_attention");
1133 } else {
1134 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1138 num_heads: nh,
1139 num_kv_heads: nkv,
1140 head_dim: hd,
1141 causal: true,
1142 scale: 1.0 / (hd as f32).sqrt(),
1143 kv_seq_stride: kv_stride,
1144 sliding_window: cfg.sliding_window,
1145 };
1146 B::flash_attention(
1147 ctx,
1148 &self.scratch.q_head_major,
1149 &cache.k,
1150 &cache.v,
1151 &mut self.scratch.attn_head_major_out,
1152 1,
1153 tokens,
1154 kv_len,
1155 pos_offset,
1156 &attn_cfg,
1157 );
1158 }
1159 if let Some(t0) = _attn_t0 {
1160 B::sync(ctx);
1161 ATTN_TIME_US.fetch_add(
1162 t0.elapsed().as_micros() as u64,
1163 std::sync::atomic::Ordering::Relaxed,
1164 );
1165 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1166 }
1167
1168 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1175 B::sync(ctx);
1176 Some(std::time::Instant::now())
1177 } else {
1178 None
1179 };
1180 let attn_token_major = if tokens == 1 {
1181 &self.scratch.attn_head_major_out
1182 } else {
1183 B::transpose_head_to_token(
1184 ctx,
1185 &self.scratch.attn_head_major_out,
1186 &mut self.scratch.attn_flat,
1187 tokens,
1188 nh,
1189 hd,
1190 );
1191 &self.scratch.attn_flat
1192 };
1193 if let Some(t0) = _t0 {
1194 B::sync(ctx);
1195 OTHER_TIME_US.fetch_add(
1196 t0.elapsed().as_micros() as u64,
1197 std::sync::atomic::Ordering::Relaxed,
1198 );
1199 OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1200 }
1201
1202 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1204 B::sync(ctx);
1205 Some(std::time::Instant::now())
1206 } else {
1207 None
1208 };
1209 layer
1210 .o_proj
1211 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1212 if let Some(t0) = _t0 {
1213 B::sync(ctx);
1214 MATMUL_TIME_US.fetch_add(
1215 t0.elapsed().as_micros() as u64,
1216 std::sync::atomic::Ordering::Relaxed,
1217 );
1218 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1219 }
1220
1221 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1225 B::sync(ctx);
1226 Some(std::time::Instant::now())
1227 } else {
1228 None
1229 };
1230 B::fused_add_rms_norm(
1231 ctx,
1232 residual,
1233 &self.scratch.o_proj_out,
1234 &layer.post_ln_w,
1235 eps,
1236 &mut self.scratch.norm_out,
1237 tokens,
1238 h,
1239 );
1240 if let Some(t0) = _t0 {
1241 B::sync(ctx);
1242 NORM_TIME_US.fetch_add(
1243 t0.elapsed().as_micros() as u64,
1244 std::sync::atomic::Ordering::Relaxed,
1245 );
1246 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1247 }
1248
1249 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1251 B::sync(ctx);
1252 Some(std::time::Instant::now())
1253 } else {
1254 None
1255 };
1256 layer.gate_up_proj.forward(
1257 ctx,
1258 &self.scratch.norm_out,
1259 &mut self.scratch.gate_up_out,
1260 tokens,
1261 );
1262 if let Some(t0) = _t0 {
1263 B::sync(ctx);
1264 MATMUL_TIME_US.fetch_add(
1265 t0.elapsed().as_micros() as u64,
1266 std::sync::atomic::Ordering::Relaxed,
1267 );
1268 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1269 }
1270
1271 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1273 B::sync(ctx);
1274 Some(std::time::Instant::now())
1275 } else {
1276 None
1277 };
1278 B::fused_silu_mul_split(
1279 ctx,
1280 &self.scratch.gate_up_out,
1281 &mut self.scratch.silu_out,
1282 tokens,
1283 im,
1284 );
1285 if let Some(t0) = _t0 {
1286 B::sync(ctx);
1287 OTHER_TIME_US.fetch_add(
1288 t0.elapsed().as_micros() as u64,
1289 std::sync::atomic::Ordering::Relaxed,
1290 );
1291 OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1292 }
1293
1294 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1296 B::sync(ctx);
1297 Some(std::time::Instant::now())
1298 } else {
1299 None
1300 };
1301 layer.down_proj.forward(
1302 ctx,
1303 &self.scratch.silu_out,
1304 &mut self.scratch.mlp_out,
1305 tokens,
1306 );
1307 if let Some(t0) = _t0 {
1308 B::sync(ctx);
1309 MATMUL_TIME_US.fetch_add(
1310 t0.elapsed().as_micros() as u64,
1311 std::sync::atomic::Ordering::Relaxed,
1312 );
1313 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1314 }
1315
1316 let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1318 B::sync(ctx);
1319 Some(std::time::Instant::now())
1320 } else {
1321 None
1322 };
1323 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1324 if let Some(t0) = _t0 {
1325 B::sync(ctx);
1326 OTHER_TIME_US.fetch_add(
1327 t0.elapsed().as_micros() as u64,
1328 std::sync::atomic::Ordering::Relaxed,
1329 );
1330 OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1331 }
1332 }
1333
1334 pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1346 let seq_len = tokens.len();
1347 assert!(seq_len > 0, "forward_verify called with empty tokens");
1348 self.ensure_scratch(seq_len);
1349 self.ensure_kv(cache_id);
1350
1351 let h = self.cfg.hidden_size;
1352 let vocab = self.cfg.vocab_size;
1353
1354 let pos_offset = self
1355 .kv_caches
1356 .get(cache_id)
1357 .and_then(|layers| layers.first())
1358 .map(|c| c.len)
1359 .unwrap_or(0);
1360
1361 let mut ctx = B::new_context();
1362 let mut residual = self
1363 .scratch
1364 .residual
1365 .take()
1366 .expect("scratch residual missing (previous call didn't restore)");
1367
1368 let embed = self
1369 .embed
1370 .as_ref()
1371 .expect("forward_verify called on backbone-only model (no embed)");
1372 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1373
1374 for li in 0..self.cfg.num_layers {
1375 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1376 }
1377
1378 B::rms_norm(
1381 &mut ctx,
1382 &residual,
1383 &self.final_norm_w,
1384 self.cfg.rms_norm_eps,
1385 &mut self.scratch.norm_out,
1386 seq_len,
1387 h,
1388 );
1389
1390 let lm_head = self
1394 .lm_head
1395 .as_ref()
1396 .expect("forward_verify called on backbone-only model (no lm_head)");
1397 lm_head.forward(
1398 &mut ctx,
1399 &self.scratch.norm_out,
1400 &mut self.scratch.batch_logits,
1401 seq_len,
1402 );
1403
1404 B::sync(&mut ctx);
1405 self.scratch.residual = Some(residual);
1406 B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1407 }
1408
1409 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1417 let seq_len = tokens.len();
1418 assert!(seq_len > 0, "prefill called with empty token list");
1419 self.ensure_scratch(seq_len);
1420 self.ensure_kv(cache_id);
1421
1422 let pos_offset = self
1425 .kv_caches
1426 .get(cache_id)
1427 .and_then(|layers| layers.first())
1428 .map(|c| c.len)
1429 .unwrap_or(0);
1430
1431 let h = self.cfg.hidden_size;
1432 let vocab = self.cfg.vocab_size;
1433 let mut ctx = B::new_context();
1434
1435 let mut residual = self
1442 .scratch
1443 .residual
1444 .take()
1445 .expect("scratch residual missing (previous call didn't restore)");
1446 let embed = self
1447 .embed
1448 .as_ref()
1449 .expect("prefill_internal called on backbone-only model (no embed)");
1450 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1451
1452 let prefill_profile = std::env::var("FERRUM_PREFILL_OP_PROFILE").is_ok();
1453 let prefill_t0 = if prefill_profile {
1454 B::sync(&mut ctx);
1455 Some(std::time::Instant::now())
1456 } else {
1457 None
1458 };
1459
1460 for li in 0..self.cfg.num_layers {
1461 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1462 }
1463
1464 if let Some(t0) = prefill_t0 {
1465 B::sync(&mut ctx);
1466 let total_us = t0.elapsed().as_micros() as u64;
1467 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1468 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1469 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1470 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1471 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1472 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1473 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1474 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1475 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1476 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1477 eprintln!(
1478 "[prefill-profile] tokens={} layers total={} ms",
1479 seq_len,
1480 total_us / 1000
1481 );
1482 let bucket = |label: &str, n: u64, us: u64| {
1483 if n > 0 {
1484 eprintln!(
1485 "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1486 n,
1487 us / 1000,
1488 us / n
1489 );
1490 }
1491 };
1492 bucket("flash_attn", attn_n, attn_us);
1493 bucket("qk_norm_rope", qkr_n, qkr_us);
1494 bucket("matmuls", mm_n, mm_us);
1495 bucket("norms", norm_n, norm_us);
1496 bucket("other", other_n, other_us);
1497 }
1498
1499 B::copy_slice(
1501 &mut ctx,
1502 &residual,
1503 (seq_len - 1) * h,
1504 &mut self.scratch.last_hidden,
1505 0,
1506 h,
1507 );
1508
1509 B::rms_norm(
1511 &mut ctx,
1512 &self.scratch.last_hidden,
1513 &self.final_norm_w,
1514 self.cfg.rms_norm_eps,
1515 &mut self.scratch.last_normed,
1516 1,
1517 h,
1518 );
1519
1520 let lm_head = self
1522 .lm_head
1523 .as_ref()
1524 .expect("prefill_internal called on backbone-only model (no lm_head)");
1525 lm_head.forward(
1526 &mut ctx,
1527 &self.scratch.last_normed,
1528 &mut self.scratch.logits,
1529 1,
1530 );
1531
1532 B::sync(&mut ctx);
1539
1540 self.scratch.residual = Some(residual);
1542
1543 B::to_vec(&self.scratch.logits, vocab)
1544 }
1545
1546 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1548 self.ensure_scratch(1);
1549 self.ensure_kv(cache_id);
1550
1551 let h = self.cfg.hidden_size;
1552 let vocab = self.cfg.vocab_size;
1553
1554 let mut ctx = B::new_context();
1557
1558 const GRAPH_WARMUP: usize = 3;
1563 let graph_enabled = std::env::var("FERRUM_CUDA_GRAPH").is_ok();
1564
1565 if graph_enabled {
1566 B::set_decode_state(&mut ctx, token, pos);
1569
1570 match B::replay_last_graph(&mut ctx) {
1572 Ok(true) => {
1573 B::sync(&mut ctx);
1574 return B::to_vec(&self.scratch.logits, vocab);
1575 }
1576 Ok(false) => { }
1577 Err(_) => { }
1578 }
1579 }
1580
1581 let should_capture =
1582 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1583
1584 if should_capture {
1585 B::set_dev_state_mode(&mut ctx, true);
1586 if B::begin_graph_capture(&mut ctx).is_err() {
1587 self.graph_capture_failed = true;
1588 B::set_dev_state_mode(&mut ctx, false);
1589 }
1590 }
1591
1592 let mut residual = self
1598 .scratch
1599 .residual
1600 .take()
1601 .expect("scratch residual missing (previous call didn't restore)");
1602 let embed = self
1603 .embed
1604 .as_ref()
1605 .expect("decode_internal called on backbone-only model (no embed)");
1606 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1607
1608 let layer_profile = std::env::var("FERRUM_DECODE_LAYER_PROFILE").is_ok();
1612 let mut layer_times = if layer_profile {
1613 Some(Vec::with_capacity(self.cfg.num_layers))
1614 } else {
1615 None
1616 };
1617
1618 for li in 0..self.cfg.num_layers {
1619 if layer_profile {
1620 let t0 = std::time::Instant::now();
1621 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1622 B::sync(&mut ctx);
1623 let elapsed_us = t0.elapsed().as_micros() as u64;
1624 if let Some(v) = layer_times.as_mut() {
1625 v.push(elapsed_us);
1626 }
1627 } else {
1628 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1629 }
1630 }
1631 if let Some(times) = layer_times.take() {
1632 let sum: u64 = times.iter().sum();
1633 let avg = sum / times.len() as u64;
1634 let mn = *times.iter().min().unwrap_or(&0);
1635 let mx = *times.iter().max().unwrap_or(&0);
1636 eprintln!(
1637 "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1638 times.len(),
1639 sum / 1000,
1640 avg,
1641 mn,
1642 mx
1643 );
1644 for (i, t) in times.iter().enumerate() {
1645 eprint!("L{i}={}ms ", t / 1000);
1646 if (i + 1) % 6 == 0 {
1647 eprintln!();
1648 }
1649 }
1650 eprintln!();
1651 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1652 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1653 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1654 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1655 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1656 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1657 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1658 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1659 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1660 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1661 eprintln!(
1662 "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1663 attn_n,
1664 attn_us / 1000,
1665 if attn_n > 0 { attn_us / attn_n } else { 0 }
1666 );
1667 eprintln!(
1668 "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1669 qkr_n,
1670 qkr_us / 1000,
1671 if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1672 );
1673 eprintln!(
1674 "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1675 mm_n,
1676 mm_us / 1000,
1677 if mm_n > 0 { mm_us / mm_n } else { 0 }
1678 );
1679 eprintln!(
1680 "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1681 norm_n,
1682 norm_us / 1000,
1683 if norm_n > 0 { norm_us / norm_n } else { 0 }
1684 );
1685 eprintln!(
1686 "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1687 other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1688 );
1689 }
1690
1691 B::rms_norm(
1692 &mut ctx,
1693 &residual,
1694 &self.final_norm_w,
1695 self.cfg.rms_norm_eps,
1696 &mut self.scratch.last_normed,
1697 1,
1698 h,
1699 );
1700
1701 let lm_head = self
1702 .lm_head
1703 .as_ref()
1704 .expect("decode_internal called on backbone-only model (no lm_head)");
1705 lm_head.forward(
1706 &mut ctx,
1707 &self.scratch.last_normed,
1708 &mut self.scratch.logits,
1709 1,
1710 );
1711
1712 if should_capture && !self.graph_capture_failed {
1713 if B::end_graph_capture(&mut ctx).is_err() {
1714 self.graph_capture_failed = true;
1715 } else {
1716 if B::replay_last_graph(&mut ctx).is_err() {
1723 self.graph_capture_failed = true;
1724 }
1725 }
1726 B::set_dev_state_mode(&mut ctx, false);
1727 } else {
1728 self.graph_warmup += 1;
1729 }
1730
1731 B::sync(&mut ctx);
1738 self.scratch.residual = Some(residual);
1739
1740 B::to_vec(&self.scratch.logits, vocab)
1741 }
1742
1743 pub fn prefill_from_embeds(
1752 &mut self,
1753 cache_id: &str,
1754 embeds: &[f32],
1755 seq_len: usize,
1756 ) -> Vec<f32> {
1757 let h = self.cfg.hidden_size;
1758 assert_eq!(
1759 embeds.len(),
1760 seq_len * h,
1761 "embeds length {} != seq_len * hidden_size {}",
1762 embeds.len(),
1763 seq_len * h
1764 );
1765 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1766
1767 self.ensure_scratch(seq_len);
1768 self.ensure_kv(cache_id);
1769
1770 let mut ctx = B::new_context();
1771 let mut residual = self
1772 .scratch
1773 .residual
1774 .take()
1775 .expect("scratch residual missing (previous call didn't restore)");
1776
1777 let embed_buf = B::from_slice(embeds);
1779 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1780
1781 for li in 0..self.cfg.num_layers {
1782 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1783 }
1784
1785 B::copy_slice(
1786 &mut ctx,
1787 &residual,
1788 (seq_len - 1) * h,
1789 &mut self.scratch.last_hidden,
1790 0,
1791 h,
1792 );
1793 B::sync(&mut ctx);
1794 self.scratch.residual = Some(residual);
1795 B::to_vec(&self.scratch.last_hidden, h)
1796 }
1797
1798 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1802 let h = self.cfg.hidden_size;
1803 assert_eq!(
1804 embed.len(),
1805 h,
1806 "embed length {} != hidden_size {}",
1807 embed.len(),
1808 h
1809 );
1810
1811 self.ensure_scratch(1);
1812 self.ensure_kv(cache_id);
1813
1814 let mut ctx = B::new_context();
1815 let mut residual = self
1816 .scratch
1817 .residual
1818 .take()
1819 .expect("scratch residual missing (previous call didn't restore)");
1820
1821 let embed_buf = B::from_slice(embed);
1822 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1823
1824 for li in 0..self.cfg.num_layers {
1825 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1826 }
1827
1828 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1829 B::sync(&mut ctx);
1830 self.scratch.residual = Some(residual);
1831 B::to_vec(&self.scratch.last_hidden, h)
1832 }
1833
1834 pub fn prefill_all_post_norm(
1845 &mut self,
1846 cache_id: &str,
1847 embeds: &[f32],
1848 seq_len: usize,
1849 pos_offset: usize,
1850 ) -> Vec<f32> {
1851 let h = self.cfg.hidden_size;
1852 assert_eq!(
1853 embeds.len(),
1854 seq_len * h,
1855 "embeds length {} != seq_len * hidden_size {}",
1856 embeds.len(),
1857 seq_len * h
1858 );
1859 assert!(seq_len > 0);
1860
1861 self.ensure_scratch(seq_len);
1862 self.ensure_kv(cache_id);
1863
1864 let mut ctx = B::new_context();
1865 let mut residual = self
1866 .scratch
1867 .residual
1868 .take()
1869 .expect("scratch residual missing (previous call didn't restore)");
1870
1871 let embed_buf = B::from_slice(embeds);
1872 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1873
1874 for li in 0..self.cfg.num_layers {
1875 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1876 }
1877
1878 B::rms_norm(
1880 &mut ctx,
1881 &residual,
1882 &self.final_norm_w,
1883 self.cfg.rms_norm_eps,
1884 &mut self.scratch.norm_out,
1885 seq_len,
1886 h,
1887 );
1888 B::sync(&mut ctx);
1889 self.scratch.residual = Some(residual);
1890 B::to_vec(&self.scratch.norm_out, seq_len * h)
1891 }
1892
1893 pub fn decode_post_norm_from_embed(
1897 &mut self,
1898 cache_id: &str,
1899 embed: &[f32],
1900 pos: u32,
1901 ) -> Vec<f32> {
1902 let h = self.cfg.hidden_size;
1903 assert_eq!(embed.len(), h);
1904
1905 self.ensure_scratch(1);
1906 self.ensure_kv(cache_id);
1907
1908 let mut ctx = B::new_context();
1909 let mut residual = self
1910 .scratch
1911 .residual
1912 .take()
1913 .expect("scratch residual missing (previous call didn't restore)");
1914
1915 let embed_buf = B::from_slice(embed);
1916 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1917
1918 for li in 0..self.cfg.num_layers {
1919 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1920 }
1921
1922 B::rms_norm(
1923 &mut ctx,
1924 &residual,
1925 &self.final_norm_w,
1926 self.cfg.rms_norm_eps,
1927 &mut self.scratch.last_normed,
1928 1,
1929 h,
1930 );
1931 B::sync(&mut ctx);
1932 self.scratch.residual = Some(residual);
1933 B::to_vec(&self.scratch.last_normed, h)
1934 }
1935
1936 pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1944 let m = batch.len();
1945 if m == 0 {
1946 return Vec::new();
1947 }
1948 if m == 1 {
1949 let (cid, tok, pos) = &batch[0];
1950 return vec![self.decode_internal(cid, *tok, *pos)];
1951 }
1952
1953 for (cid, _, _) in batch {
1955 self.ensure_kv(cid);
1956 }
1957 self.ensure_scratch(m);
1958 let h = self.cfg.hidden_size;
1964 let vocab = self.cfg.vocab_size;
1965 let mut ctx = B::new_context();
1966
1967 let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1969 let mut residual = self
1970 .scratch
1971 .residual
1972 .take()
1973 .expect("scratch residual missing (previous call didn't restore)");
1974 let embed = self
1975 .embed
1976 .as_ref()
1977 .expect("decode_batch_internal called on backbone-only model (no embed)");
1978 B::embedding_lookup(&mut ctx, embed, &tokens, &mut residual, h);
1979
1980 for li in 0..self.cfg.num_layers {
1982 self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m);
1983 }
1984
1985 B::rms_norm(
1987 &mut ctx,
1988 &residual,
1989 &self.final_norm_w,
1990 self.cfg.rms_norm_eps,
1991 &mut self.scratch.norm_out,
1992 m,
1993 h,
1994 );
1995
1996 let lm_head = self
1998 .lm_head
1999 .as_ref()
2000 .expect("decode_batch_internal called on backbone-only model (no lm_head)");
2001 lm_head.forward(
2002 &mut ctx,
2003 &self.scratch.norm_out,
2004 &mut self.scratch.batch_logits,
2005 m,
2006 );
2007
2008 B::sync(&mut ctx);
2010 self.scratch.residual = Some(residual);
2011
2012 let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
2014 (0..m)
2015 .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
2016 .collect()
2017 }
2018
2019 fn forward_layer_batched_decode(
2021 &mut self,
2022 ctx: &mut B::Context,
2023 li: usize,
2024 batch: &[(String, u32, u32)],
2025 residual: &mut B::Buffer,
2026 m: usize,
2027 ) {
2028 let cfg = &self.cfg;
2029 let h = cfg.hidden_size;
2030 let nh = cfg.num_heads;
2031 let nkv = cfg.num_kv_heads;
2032 let hd = cfg.head_dim;
2033 let im = cfg.intermediate_size;
2034 let eps = cfg.rms_norm_eps;
2035 let q_dim = nh * hd;
2036 let kv_dim = nkv * hd;
2037
2038 let layer = &self.layers[li];
2039 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
2040 let dummy_w = &layer.input_ln_w;
2041 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy_w);
2042 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy_w);
2043
2044 B::rms_norm(
2046 ctx,
2047 residual,
2048 &layer.input_ln_w,
2049 eps,
2050 &mut self.scratch.norm_out,
2051 m,
2052 h,
2053 );
2054
2055 layer
2057 .qkv_proj
2058 .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
2059
2060 if let Some(pools) = self.paged_pools.as_mut() {
2080 let pool_ptr = (
2081 &mut pools[li].0 as *mut B::Buffer,
2082 &mut pools[li].1 as *mut B::Buffer,
2083 );
2084 let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
2086
2087 let qkv_stride = q_dim + 2 * kv_dim;
2088 let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
2089 let block_size = 16; let mut item_state: Vec<(u32, Vec<u32>)> = Vec::with_capacity(m);
2096 for (cache_id, _, _) in batch.iter() {
2097 let caches = self
2098 .kv_caches
2099 .get(cache_id)
2100 .expect("ensure_kv must be called before forward_layer_batched");
2101 let cache = &caches[li];
2102 item_state.push((cache.len as u32, cache.paged_block_indices.clone()));
2103 }
2104
2105 let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
2109 let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
2110 for (i, (cache_id, _, pos)) in batch.iter().enumerate() {
2111 let pos_i = *pos as usize;
2112 let caches = self
2113 .kv_caches
2114 .get(cache_id)
2115 .expect("paged batched: cache not present");
2116 let cache = &caches[li];
2117 let bt = cache
2118 .block_table
2119 .as_ref()
2120 .expect("paged batched: block_table missing");
2121 let cache_len_before = cache.len;
2122 let block_table_ref = bt as *const B::Buffer;
2123 let bt_safe: &B::Buffer = unsafe { &*block_table_ref };
2127 B::split_qkv_norm_rope_into_paged_cache(
2128 ctx,
2129 &self.scratch.qkv_out,
2130 (i as u64) * qkv_stride_bytes,
2131 q_norm_w,
2132 k_norm_w,
2133 &self.rope.cos,
2134 &self.rope.sin,
2135 self.scratch
2136 .paged_batch_q
2137 .as_mut()
2138 .expect("paged_batch_q missing"),
2139 (i as u64) * q_head_major_size_bytes,
2140 pool_k,
2141 pool_v,
2142 bt_safe,
2143 1,
2144 nh,
2145 nkv,
2146 hd,
2147 pos_i,
2148 eps,
2149 qk_mode,
2150 cache_len_before,
2151 block_size,
2152 max_blocks_per_seq,
2153 )
2154 .expect("paged batched write");
2155 }
2156
2157 let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
2160 let mut stacked_cl: Vec<u32> = vec![0u32; m];
2161 for (i, (cache_id, _, _)) in batch.iter().enumerate() {
2162 let caches = self
2163 .kv_caches
2164 .get_mut(cache_id)
2165 .expect("paged batched: cache not present");
2166 let cache = &mut caches[li];
2167 cache.len += 1;
2168 let len = cache.len as u32;
2169 stacked_cl[i] = len;
2170 let blocks = &cache.paged_block_indices;
2171 let n_to_copy = blocks.len().min(max_blocks_per_seq);
2172 stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
2173 .copy_from_slice(&blocks[..n_to_copy]);
2174 }
2175 let bt_buf = self
2176 .scratch
2177 .paged_batch_block_tables
2178 .as_mut()
2179 .expect("paged_batch_block_tables missing");
2180 B::write_u32(ctx, bt_buf, &stacked_bt);
2181 let cl_buf = self
2182 .scratch
2183 .paged_batch_context_lens
2184 .as_mut()
2185 .expect("paged_batch_context_lens missing");
2186 B::write_u32(ctx, cl_buf, &stacked_cl);
2187
2188 let bt_ptr =
2190 self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
2191 let cl_ptr =
2192 self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
2193 let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
2194 let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
2195 let bt_safe = unsafe { &*bt_ptr };
2198 let cl_safe = unsafe { &*cl_ptr };
2199 let q_safe = unsafe { &*q_ptr };
2200 let o_safe = unsafe { &mut *o_ptr };
2201 B::paged_decode_attention(
2202 ctx,
2203 q_safe,
2204 pool_k,
2205 pool_v,
2206 o_safe,
2207 bt_safe,
2208 cl_safe,
2209 m,
2210 nh,
2211 nkv,
2212 hd,
2213 block_size,
2214 max_blocks_per_seq,
2215 1, )
2217 .expect("paged batched decode");
2218
2219 for i in 0..m {
2223 B::copy_slice(
2224 ctx,
2225 self.scratch.paged_batch_o.as_ref().unwrap(),
2226 i * q_dim,
2227 &mut self.scratch.attn_flat,
2228 i * q_dim,
2229 q_dim,
2230 );
2231 }
2232
2233 return self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2235 }
2236
2237 B::split_qkv(
2239 ctx,
2240 &self.scratch.qkv_out,
2241 &mut self.scratch.q_buf,
2242 &mut self.scratch.k_buf,
2243 &mut self.scratch.v_buf,
2244 m,
2245 q_dim,
2246 kv_dim,
2247 );
2248
2249 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
2252 let pos_i = *pos as usize;
2253
2254 B::copy_slice(
2256 ctx,
2257 &self.scratch.q_buf,
2258 i * q_dim,
2259 &mut self.scratch.q_single,
2260 0,
2261 q_dim,
2262 );
2263 B::copy_slice(
2264 ctx,
2265 &self.scratch.k_buf,
2266 i * kv_dim,
2267 &mut self.scratch.k_single,
2268 0,
2269 kv_dim,
2270 );
2271 B::copy_slice(
2272 ctx,
2273 &self.scratch.v_buf,
2274 i * kv_dim,
2275 &mut self.scratch.v_single,
2276 0,
2277 kv_dim,
2278 );
2279
2280 B::qk_norm_rope(
2282 ctx,
2283 &self.scratch.q_single,
2284 q_norm_w,
2285 &self.rope.cos,
2286 &self.rope.sin,
2287 &mut self.scratch.q_head_major_single,
2288 1,
2289 nh,
2290 hd,
2291 pos_i,
2292 eps,
2293 qk_mode,
2294 );
2295 B::qk_norm_rope(
2296 ctx,
2297 &self.scratch.k_single,
2298 k_norm_w,
2299 &self.rope.cos,
2300 &self.rope.sin,
2301 &mut self.scratch.k_head_major_single,
2302 1,
2303 nkv,
2304 hd,
2305 pos_i,
2306 eps,
2307 qk_mode,
2308 );
2309 B::qk_norm_rope(
2310 ctx,
2311 &self.scratch.v_single,
2312 dummy_w,
2313 &self.rope.cos,
2314 &self.rope.sin,
2315 &mut self.scratch.v_head_major_single,
2316 1,
2317 nkv,
2318 hd,
2319 pos_i,
2320 eps,
2321 0,
2322 );
2323
2324 let caches = self
2326 .kv_caches
2327 .get_mut(cache_id)
2328 .expect("ensure_kv must be called before forward_layer_batched");
2329 let cache = &mut caches[li];
2330 B::kv_cache_append_head_major(
2331 ctx,
2332 &mut cache.k,
2333 &mut cache.v,
2334 cache.len,
2335 cache.capacity,
2336 &self.scratch.k_head_major_single,
2337 &self.scratch.v_head_major_single,
2338 1,
2339 nkv,
2340 hd,
2341 );
2342 cache.len += 1;
2343 let kv_len = cache.len;
2344 let kv_stride = cache.capacity;
2345
2346 let attn_cfg = ferrum_kernels::backend::AttnConfig {
2347 num_heads: nh,
2348 num_kv_heads: nkv,
2349 head_dim: hd,
2350 causal: true,
2351 scale: 1.0 / (hd as f32).sqrt(),
2352 kv_seq_stride: kv_stride,
2353 sliding_window: cfg.sliding_window,
2354 };
2355 B::flash_attention(
2356 ctx,
2357 &self.scratch.q_head_major_single,
2358 &cache.k,
2359 &cache.v,
2360 &mut self.scratch.attn_head_major_single,
2361 1,
2362 1,
2363 kv_len,
2364 pos_i,
2365 &attn_cfg,
2366 );
2367
2368 B::copy_slice(
2376 ctx,
2377 &self.scratch.attn_head_major_single,
2378 0,
2379 &mut self.scratch.attn_flat,
2380 i * q_dim,
2381 q_dim,
2382 );
2383 }
2384
2385 self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2386 }
2387
2388 fn forward_layer_batched_decode_post_attn(
2389 &mut self,
2390 ctx: &mut B::Context,
2391 li: usize,
2392 residual: &mut B::Buffer,
2393 m: usize,
2394 ) {
2395 let cfg = &self.cfg;
2396 let h = cfg.hidden_size;
2397 let im = cfg.intermediate_size;
2398 let eps = cfg.rms_norm_eps;
2399 let layer = &self.layers[li];
2400
2401 layer.o_proj.forward(
2403 ctx,
2404 &self.scratch.attn_flat,
2405 &mut self.scratch.o_proj_out,
2406 m,
2407 );
2408
2409 B::fused_add_rms_norm(
2411 ctx,
2412 residual,
2413 &self.scratch.o_proj_out,
2414 &layer.post_ln_w,
2415 eps,
2416 &mut self.scratch.norm_out,
2417 m,
2418 h,
2419 );
2420
2421 layer.gate_up_proj.forward(
2423 ctx,
2424 &self.scratch.norm_out,
2425 &mut self.scratch.gate_up_out,
2426 m,
2427 );
2428
2429 B::fused_silu_mul_split(
2431 ctx,
2432 &self.scratch.gate_up_out,
2433 &mut self.scratch.silu_out,
2434 m,
2435 im,
2436 );
2437
2438 layer
2440 .down_proj
2441 .forward(ctx, &self.scratch.silu_out, &mut self.scratch.mlp_out, m);
2442
2443 B::add_inplace(ctx, residual, &self.scratch.mlp_out, m * h);
2445 }
2446}
2447
2448impl<B: Backend> DecoderOnlyLLM for LlamaFamilyModel<B> {
2449 fn config(&self) -> &LlmRuntimeConfig {
2450 &self.runtime_cfg
2451 }
2452
2453 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2454 self.ensure_scratch(max_tokens);
2459 self.ensure_kv(cache_id);
2460
2461 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2462 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2463 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2467 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2468 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2469 if let Some(c0) = caches.first() {
2470 if !c0.paged_block_indices.is_empty() {
2471 alloc.free(&c0.paged_block_indices);
2472 }
2473 }
2474 for c in caches.iter_mut() {
2475 c.paged_block_indices.clear();
2476 }
2477 }
2478 self.kv_free_pool.push(caches);
2479 }
2480 }
2481
2482 fn kv_capacity(&self) -> usize {
2483 let model_max = self.cfg.max_seq_len;
2485 const DEFAULT_KV_CAPACITY: usize = 4096;
2486 std::env::var("FERRUM_KV_CAPACITY")
2487 .ok()
2488 .and_then(|s| s.parse::<usize>().ok())
2489 .map(|cap| cap.min(model_max))
2490 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2491 }
2492
2493 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2494 self.prefill_internal(cache_id, tokens)
2495 }
2496
2497 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2498 self.decode_internal(cache_id, token, pos)
2499 }
2500
2501 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2502 self.decode_batch_internal(batch)
2503 }
2504
2505 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2506 LlamaFamilyModel::<B>::forward_verify(self, cache_id, tokens)
2508 }
2509
2510 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2511 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2512 for c in caches.iter_mut() {
2513 if new_len < c.len {
2514 c.len = new_len;
2515 }
2516 }
2517 }
2518 let mut ctx = B::new_context();
2520 B::reset_graph(&mut ctx);
2521 self.graph_warmup = 0;
2522 self.graph_capture_failed = false;
2523 }
2524
2525 fn release(&mut self, cache_id: &str) {
2526 let mut ctx = B::new_context();
2532 B::sync(&mut ctx);
2533 B::reset_graph(&mut ctx);
2534 B::sync(&mut ctx);
2535 self.graph_warmup = 0;
2536 self.graph_capture_failed = false;
2537
2538 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2543 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2544 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2545 if let Some(c0) = caches.first() {
2549 if !c0.paged_block_indices.is_empty() {
2550 alloc.free(&c0.paged_block_indices);
2551 }
2552 }
2553 for c in caches.iter_mut() {
2556 c.paged_block_indices.clear();
2557 }
2558 }
2559 self.kv_free_pool.push(caches);
2560 }
2561 }
2562
2563 fn reset(&mut self) {
2564 let mut ctx = B::new_context();
2566 B::sync(&mut ctx);
2567 B::reset_graph(&mut ctx);
2568 B::sync(&mut ctx);
2569 self.graph_warmup = 0;
2570 self.graph_capture_failed = false;
2571 self.kv_caches.clear();
2572 self.kv_free_pool.clear();
2573 }
2574}
2575
2576fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2577 let hd = cfg.head_dim;
2578 let half = hd / 2;
2579 let max = cfg.max_seq_len;
2580 let mut cos = vec![0.0f32; max * half];
2581 let mut sin = vec![0.0f32; max * half];
2582 for pos in 0..max {
2583 for i in 0..half {
2584 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
2585 let angle = pos as f64 * freq;
2586 cos[pos * half + i] = angle.cos() as f32;
2587 sin[pos * half + i] = angle.sin() as f32;
2588 }
2589 }
2590 RopeCache {
2591 cos: B::from_slice(&cos),
2592 sin: B::from_slice(&sin),
2593 }
2594}