1use std::collections::HashMap;
20use std::sync::{atomic::AtomicU64, OnceLock};
21
22use ferrum_interfaces::kv_dtype::{KvFp16, KvInt8};
23use ferrum_kernels::backend::{
24 Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
25 BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
26 MAX_LAYERS_FOR_GRAPH,
27};
28
29pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
32
33pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
35pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
36
37pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
38pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
39pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
41pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
43pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
45pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
47use ferrum_quantization::{Linear, WeightLoader};
48use ferrum_types::Result;
49
50use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
51
52const DEFAULT_KV_CAPACITY: usize = 512;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55struct LlamaFamilyRuntimeEnv {
56 kv_capacity: Option<usize>,
57 metal_paged_kv: Option<bool>,
58 paged_max_seqs: usize,
59 decode_op_profile: bool,
60 prefill_op_profile: bool,
61 cuda_graph: bool,
62 decode_layer_profile: bool,
63}
64
65impl LlamaFamilyRuntimeEnv {
66 fn from_env() -> Self {
67 Self::from_env_vars(std::env::vars())
68 }
69
70 fn from_env_vars<I, K, V>(vars: I) -> Self
71 where
72 I: IntoIterator<Item = (K, V)>,
73 K: AsRef<str>,
74 V: AsRef<str>,
75 {
76 let mut config = Self {
77 kv_capacity: None,
78 metal_paged_kv: None,
79 paged_max_seqs: 32,
80 decode_op_profile: false,
81 prefill_op_profile: false,
82 cuda_graph: false,
83 decode_layer_profile: false,
84 };
85 for (name, value) in vars {
86 let value = value.as_ref();
87 match name.as_ref() {
88 "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
89 "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
90 "FERRUM_PAGED_MAX_SEQS" => {
91 if let Ok(max_seqs) = value.parse::<usize>() {
92 config.paged_max_seqs = max_seqs;
93 }
94 }
95 "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
96 "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
97 "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
98 "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
99 _ => {}
100 }
101 }
102 config
103 }
104
105 fn kv_capacity_for_model(&self, model_max: usize) -> usize {
106 self.kv_capacity
107 .map(|cap| cap.min(model_max))
108 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
109 }
110
111 fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
112 self.metal_paged_kv
113 .unwrap_or_else(|| B::supports_paged_kv())
114 }
115}
116
117fn llama_family_runtime_env() -> &'static LlamaFamilyRuntimeEnv {
118 static CONFIG: OnceLock<LlamaFamilyRuntimeEnv> = OnceLock::new();
119 CONFIG.get_or_init(LlamaFamilyRuntimeEnv::from_env)
120}
121
122#[derive(Clone, Debug, PartialEq)]
125pub struct LlamaFamilyConfig {
126 pub hidden_size: usize,
127 pub intermediate_size: usize,
128 pub num_heads: usize,
129 pub num_kv_heads: usize,
130 pub head_dim: usize,
131 pub num_layers: usize,
132 pub vocab_size: usize,
133 pub max_seq_len: usize,
134 pub rms_norm_eps: f32,
135 pub rope_theta: f64,
136 pub has_qk_norm: bool,
139 pub sliding_window: usize,
142}
143
144impl LlamaFamilyConfig {
145 pub fn to_runtime(&self) -> LlmRuntimeConfig {
146 LlmRuntimeConfig {
147 hidden_size: self.hidden_size,
148 num_layers: self.num_layers,
149 num_kv_heads: self.num_kv_heads,
150 head_dim: self.head_dim,
151 vocab_size: self.vocab_size,
152 max_seq_len: self.max_seq_len,
153 }
154 }
155
156 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
160 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
161 let head_dim = def
162 .extra_params
163 .get("head_dim")
164 .and_then(|v| v.as_u64())
165 .map(|v| v as usize)
166 .unwrap_or(def.hidden_size / def.num_attention_heads);
167 let sliding_window = def
170 .extra_params
171 .get("sliding_window")
172 .and_then(|v| v.as_u64())
173 .map(|v| v as usize)
174 .unwrap_or(0);
175
176 LlamaFamilyConfigBase {
177 hidden_size: def.hidden_size,
178 intermediate_size: def.intermediate_size,
179 num_heads: def.num_attention_heads,
180 num_kv_heads,
181 head_dim,
182 num_layers: def.num_hidden_layers,
183 vocab_size: def.vocab_size,
184 max_seq_len: def.max_position_embeddings,
185 rms_norm_eps: def.norm_eps as f32,
186 rope_theta_opt: def.rope_theta,
187 sliding_window,
188 }
189 }
190
191 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
192 Self {
193 hidden_size: b.hidden_size,
194 intermediate_size: b.intermediate_size,
195 num_heads: b.num_heads,
196 num_kv_heads: b.num_kv_heads,
197 head_dim: b.head_dim,
198 num_layers: b.num_layers,
199 vocab_size: b.vocab_size,
200 max_seq_len: b.max_seq_len,
201 rms_norm_eps: b.rms_norm_eps,
202 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
203 has_qk_norm,
204 sliding_window: b.sliding_window,
205 }
206 }
207
208 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
210 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
211 }
212
213 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
217 Self::from_base(Self::from_def_base(def), 500_000.0, false)
218 }
219
220 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
222 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
223 }
224
225 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
229 Self::from_base(Self::from_def_base(def), 10_000.0, false)
230 }
231}
232
233struct LlamaFamilyConfigBase {
234 hidden_size: usize,
235 intermediate_size: usize,
236 num_heads: usize,
237 num_kv_heads: usize,
238 head_dim: usize,
239 num_layers: usize,
240 vocab_size: usize,
241 max_seq_len: usize,
242 rms_norm_eps: f32,
243 rope_theta_opt: Option<f64>,
244 sliding_window: usize,
245}
246
247pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
250 pub input_ln_w: B::Buffer,
251 pub qkv_proj: Box<dyn Linear<B>>,
252 pub q_norm_w: Option<B::Buffer>,
254 pub k_norm_w: Option<B::Buffer>,
255 pub o_proj: Box<dyn Linear<B>>,
256 pub post_ln_w: B::Buffer,
257 pub gate_up_proj: Box<dyn Linear<B>>,
258 pub down_proj: Box<dyn Linear<B>>,
259}
260
261pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
263 pub cos: B::Buffer,
264 pub sin: B::Buffer,
265}
266
267pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
273 pub residual: Option<B::Buffer>,
284 pub norm_out: B::Buffer,
285 pub qkv_out: B::Buffer,
286 pub q_single: B::Buffer,
294 pub k_single: B::Buffer,
295 pub v_single: B::Buffer,
296 pub q_head_major_single: B::Buffer,
297 pub k_head_major_single: B::Buffer,
298 pub v_head_major_single: B::Buffer,
299 pub attn_head_major_single: B::Buffer,
300 pub attn_flat_single: B::Buffer,
301 pub batch_logits: B::Buffer,
304 pub q_buf: B::Buffer,
306 pub k_buf: B::Buffer,
307 pub v_buf: B::Buffer,
308 pub q_head_major: B::Buffer,
310 pub k_head_major: B::Buffer,
313 pub v_head_major: B::Buffer,
314 pub attn_head_major_out: B::Buffer,
316 pub attn_flat: B::Buffer,
318 pub o_proj_out: B::Buffer,
319 pub gate_up_out: B::Buffer,
320 pub silu_out: B::Buffer,
321 pub mlp_out: B::Buffer,
322 pub paged_batch_q: Option<B::Buffer>,
328 pub paged_batch_o: Option<B::Buffer>,
329 pub paged_batch_block_tables: Option<B::Buffer>,
333 pub paged_batch_context_lens: Option<B::Buffer>,
336 pub paged_max_blocks_per_seq: usize,
339 pub paged_max_seqs: usize,
345 pub batch_positions: B::Buffer,
350 pub batch_tokens: B::Buffer,
354 pub batch_kv_lens_pre: B::Buffer,
358 pub batch_kv_lens_post: B::Buffer,
363 pub q_normed_batched: B::Buffer,
367 pub k_normed_batched: B::Buffer,
368 pub v_normed_batched: B::Buffer,
369
370 pub unified_capacity: usize, pub unified_residual: Option<B::Buffer>,
377 pub unified_norm_out: Option<B::Buffer>,
378 pub unified_qkv_out: Option<B::Buffer>,
379 pub unified_packed_q: Option<B::Buffer>,
380 pub unified_attn_out: Option<B::Buffer>,
381 pub unified_o_proj_out: Option<B::Buffer>,
382 pub unified_gate_up_out: Option<B::Buffer>,
383 pub unified_silu_out: Option<B::Buffer>,
384 pub unified_mlp_out: Option<B::Buffer>,
385 pub unified_cu_seqlens_q: Option<B::Buffer>,
390 pub unified_pos_offsets: Option<B::Buffer>,
391 pub unified_block_tables: Option<B::Buffer>,
392 pub unified_packed_normed: Option<B::Buffer>,
395 pub unified_packed_logits: Option<B::Buffer>,
397 pub last_hidden: B::Buffer,
401 pub last_normed: B::Buffer,
403 pub logits: B::Buffer,
405 pub max_tokens: usize,
407}
408
409impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
410 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
411 let h = cfg.hidden_size;
412 let im = cfg.intermediate_size;
413 let q_dim = cfg.num_heads * cfg.head_dim;
414 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
415 let qkv_dim = q_dim + 2 * kv_dim;
416 let t = max_tokens;
417 Self {
418 residual: Some(B::alloc(t * h)),
419 norm_out: B::alloc(t * h),
420 qkv_out: B::alloc(t * qkv_dim),
421 q_buf: B::alloc(t * q_dim),
422 k_buf: B::alloc(t * kv_dim),
423 v_buf: B::alloc(t * kv_dim),
424 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
425 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
426 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
427 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
428 attn_flat: B::alloc(t * q_dim),
429 o_proj_out: B::alloc(t * h),
430 gate_up_out: B::alloc(t * 2 * im),
431 silu_out: B::alloc(t * im),
432 mlp_out: B::alloc(t * h),
433 last_hidden: B::alloc(h),
434 last_normed: B::alloc(h),
435 logits: B::alloc(cfg.vocab_size),
436 q_single: B::alloc(q_dim),
437 k_single: B::alloc(kv_dim),
438 v_single: B::alloc(kv_dim),
439 q_head_major_single: B::alloc(q_dim),
440 k_head_major_single: B::alloc(kv_dim),
441 v_head_major_single: B::alloc(kv_dim),
442 attn_head_major_single: B::alloc(q_dim),
443 attn_flat_single: B::alloc(q_dim),
444 batch_logits: B::alloc(t * cfg.vocab_size),
445 paged_batch_q: None,
450 paged_batch_o: None,
451 paged_batch_block_tables: None,
452 paged_batch_context_lens: None,
453 paged_max_blocks_per_seq: 0,
454 paged_max_seqs: 0,
455 batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
456 batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
457 batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
458 batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
459 q_normed_batched: B::alloc(t * q_dim),
460 k_normed_batched: B::alloc(t * kv_dim),
461 v_normed_batched: B::alloc(t * kv_dim),
462 unified_capacity: 0,
463 unified_residual: None,
464 unified_norm_out: None,
465 unified_qkv_out: None,
466 unified_packed_q: None,
467 unified_attn_out: None,
468 unified_o_proj_out: None,
469 unified_gate_up_out: None,
470 unified_silu_out: None,
471 unified_mlp_out: None,
472 unified_cu_seqlens_q: None,
473 unified_pos_offsets: None,
474 unified_block_tables: None,
475 unified_packed_normed: None,
476 unified_packed_logits: None,
477 max_tokens: t,
478 }
479 }
480
481 pub(crate) fn ensure_unified_scratch(
485 &mut self,
486 cfg: &LlamaFamilyConfig,
487 m_total: usize,
488 max_seqs: usize,
489 max_blocks_per_seq: usize,
490 ) {
491 if m_total <= self.unified_capacity
492 && self.unified_residual.is_some()
493 && self.unified_cu_seqlens_q.is_some()
494 {
495 return;
496 }
497 let cap = m_total.max(self.unified_capacity).max(1);
498 let h = cfg.hidden_size;
499 let q_dim = cfg.num_heads * cfg.head_dim;
500 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
501 let qkv_dim = q_dim + 2 * kv_dim;
502 let im = cfg.intermediate_size;
503 let v = cfg.vocab_size;
504 self.unified_residual = Some(B::alloc(cap * h));
505 self.unified_norm_out = Some(B::alloc(cap * h));
506 self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
507 self.unified_packed_q = Some(B::alloc(cap * q_dim));
508 self.unified_attn_out = Some(B::alloc(cap * q_dim));
509 self.unified_o_proj_out = Some(B::alloc(cap * h));
510 self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
511 self.unified_silu_out = Some(B::alloc(cap * im));
512 self.unified_mlp_out = Some(B::alloc(cap * h));
513 if self.unified_cu_seqlens_q.is_none() {
514 self.unified_cu_seqlens_q = Some(B::alloc_typed(
515 ferrum_kernels::backend::Dtype::U32,
516 max_seqs + 1,
517 ));
518 self.unified_pos_offsets = Some(B::alloc_typed(
519 ferrum_kernels::backend::Dtype::U32,
520 max_seqs,
521 ));
522 self.unified_block_tables = Some(B::alloc_typed(
523 ferrum_kernels::backend::Dtype::U32,
524 max_seqs * max_blocks_per_seq,
525 ));
526 self.unified_packed_normed = Some(B::alloc(max_seqs * h));
527 self.unified_packed_logits = Some(B::alloc(max_seqs * v));
528 }
529 self.unified_capacity = cap;
530 }
531
532 fn enable_paged_batch(
536 &mut self,
537 cfg: &LlamaFamilyConfig,
538 max_seqs: usize,
539 max_blocks_per_seq: usize,
540 ) {
541 if self.paged_batch_q.is_some() {
542 return;
543 }
544 let q_dim = cfg.num_heads * cfg.head_dim;
545 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
546 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
547 self.paged_batch_block_tables = Some(B::alloc_typed(
548 ferrum_kernels::backend::Dtype::U32,
549 max_seqs * max_blocks_per_seq,
550 ));
551 self.paged_batch_context_lens = Some(B::alloc_typed(
552 ferrum_kernels::backend::Dtype::U32,
553 max_seqs,
554 ));
555 self.paged_max_blocks_per_seq = max_blocks_per_seq;
556 self.paged_max_seqs = max_seqs;
557 }
558}
559
560pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
575 pub cfg: LlamaFamilyConfig,
576 pub runtime_cfg: LlmRuntimeConfig,
577
578 pub embed: Option<B::Buffer>,
582 pub layers: Vec<LlamaFamilyLayer<B>>,
583 pub final_norm_w: B::Buffer,
584 pub lm_head: Option<Box<dyn Linear<B>>>,
586
587 pub rope: RopeCache<B>,
588 pub scratch: LlamaFamilyScratch<B>,
589
590 pub kv_caches: HashMap<String, Vec<K::Layer>>,
599 kv_free_pool: Vec<Vec<K::Layer>>,
604
605 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
617 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
621 pub paged_dims: Option<(usize, usize)>,
627
628 pub(crate) graph_warmup: usize,
632 pub(crate) graph_capture_failed: bool,
635 pub(crate) batched_graph_warmup: usize,
637 pub(crate) batched_graph_failed: bool,
639 pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
643 pub(crate) batched_pointers_for: Option<Vec<String>>,
648 pub(crate) unified_graph_warmup: usize,
653 pub(crate) unified_graph_failed: bool,
654 pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
655}
656
657impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
658 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
663 {
668 let mut ctx = B::new_context();
669 B::reset_all_graphs(&mut ctx);
670 }
671 let rope = build_rope_cache::<B>(&cfg);
672 let scratch = LlamaFamilyScratch::alloc(&cfg, 1); let embed = loader.load_tensor("model.embed_tokens.weight")?;
676
677 let mut layers = Vec::with_capacity(cfg.num_layers);
679 for li in 0..cfg.num_layers {
680 let prefix = format!("model.layers.{li}");
681 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
682 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
683 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
684 let post_ln_w =
685 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
686 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
687 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
688
689 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
690 let q = loader
691 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
692 .ok();
693 let k = loader
694 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
695 .ok();
696 (q, k)
697 } else {
698 (None, None)
699 };
700
701 layers.push(LlamaFamilyLayer {
702 input_ln_w,
703 qkv_proj,
704 q_norm_w,
705 k_norm_w,
706 o_proj,
707 post_ln_w,
708 gate_up_proj,
709 down_proj,
710 });
711 }
712
713 let final_norm_w = loader.load_tensor("model.norm.weight")?;
714
715 let lm_head = if loader.has_tensor("lm_head.weight") {
723 loader.load_linear("lm_head")?
724 } else {
725 tracing::info!(
726 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
727 );
728 let as_linear = loader.load_linear("model.embed_tokens")?;
729 if as_linear.out_features() != cfg.vocab_size
731 || as_linear.in_features() != cfg.hidden_size
732 {
733 return Err(ferrum_types::FerrumError::model(format!(
734 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
735 as_linear.out_features(),
736 as_linear.in_features(),
737 cfg.vocab_size,
738 cfg.hidden_size
739 )));
740 }
741 as_linear
742 };
743
744 let runtime_cfg = cfg.to_runtime();
745 Ok(Self {
746 cfg,
747 runtime_cfg,
748 embed: Some(embed),
749 layers,
750 final_norm_w,
751 lm_head: Some(lm_head),
752 rope,
753 scratch,
754 kv_caches: HashMap::new(),
755 kv_free_pool: Vec::new(),
756 paged_pools: None,
757 paged_block_alloc: None,
758 paged_dims: None,
759 graph_warmup: 0,
760 graph_capture_failed: false,
761 batched_graph_warmup: 0,
762 batched_graph_failed: false,
763 batched_graph_keys_seen: std::collections::HashSet::new(),
764 batched_pointers_for: None,
765 unified_graph_warmup: 0,
766 unified_graph_failed: false,
767 unified_graph_keys_seen: std::collections::HashSet::new(),
768 })
769 }
770
771 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
783 {
785 let mut ctx = B::new_context();
786 B::reset_all_graphs(&mut ctx);
787 }
788 let rope = build_rope_cache::<B>(&cfg);
789 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
790
791 let mut layers = Vec::with_capacity(cfg.num_layers);
792 for li in 0..cfg.num_layers {
793 let prefix = format!("model.layers.{li}");
794 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
795 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
796 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
797 let post_ln_w =
798 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
799 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
800 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
801
802 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
803 let q = loader
804 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
805 .ok();
806 let k = loader
807 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
808 .ok();
809 (q, k)
810 } else {
811 (None, None)
812 };
813
814 layers.push(LlamaFamilyLayer {
815 input_ln_w,
816 qkv_proj,
817 q_norm_w,
818 k_norm_w,
819 o_proj,
820 post_ln_w,
821 gate_up_proj,
822 down_proj,
823 });
824 }
825
826 let final_norm_w = loader.load_tensor("model.norm.weight")?;
827
828 let runtime_cfg = cfg.to_runtime();
829 Ok(Self {
830 cfg,
831 runtime_cfg,
832 embed: None,
833 layers,
834 final_norm_w,
835 lm_head: None,
836 rope,
837 scratch,
838 kv_caches: HashMap::new(),
839 kv_free_pool: Vec::new(),
840 paged_pools: None,
841 paged_block_alloc: None,
842 paged_dims: None,
843 graph_warmup: 0,
844 graph_capture_failed: false,
845 batched_graph_warmup: 0,
846 batched_graph_failed: false,
847 batched_graph_keys_seen: std::collections::HashSet::new(),
848 batched_pointers_for: None,
849 unified_graph_warmup: 0,
850 unified_graph_failed: false,
851 unified_graph_keys_seen: std::collections::HashSet::new(),
852 })
853 }
854
855 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
857 if self.scratch.max_tokens < tokens {
858 {
863 let mut ctx = B::new_context();
864 B::reset_all_graphs(&mut ctx);
865 }
866 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
867 self.graph_warmup = 0;
868 self.graph_capture_failed = false;
869 self.batched_graph_keys_seen.clear();
870 self.batched_graph_warmup = 0;
871 self.batched_graph_failed = false;
872 self.unified_graph_keys_seen.clear();
873 self.unified_graph_warmup = 0;
874 self.unified_graph_failed = false;
875 if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
880 self.scratch
881 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
882 }
883 }
884 }
885
886 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
890 if self.kv_caches.contains_key(cache_id) {
891 return;
892 }
893 let nkv = self.cfg.num_kv_heads;
894 let hd = self.cfg.head_dim;
895 let model_max = self.cfg.max_seq_len;
902 let runtime_env = llama_family_runtime_env();
911 let max = runtime_env.kv_capacity_for_model(model_max);
912
913 let paged = runtime_env.paged_kv_enabled::<B>();
929 const PAGED_BLOCK_SIZE: usize = 16;
930
931 let max_seqs = runtime_env.paged_max_seqs;
939 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
940 let total_pool_blocks = max_seqs * max_blocks_per_seq;
941
942 if paged && self.paged_pools.is_none() {
949 let mut pools = Vec::with_capacity(self.cfg.num_layers);
950 for _ in 0..self.cfg.num_layers {
951 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
952 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
953 }
954 self.paged_pools = Some(pools);
955 self.paged_block_alloc = Some(std::sync::Mutex::new(
956 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
957 ));
958 }
959 if paged {
965 self.scratch
966 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
967 self.paged_dims = Some((max_seqs, max_blocks_per_seq));
970 }
971
972 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
980 (0..self.cfg.num_layers)
981 .map(|_| {
982 if paged {
983 K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
984 } else {
985 K::alloc_contig(max, nkv, hd)
986 }
987 })
988 .collect()
989 });
990
991 if paged {
997 let alloc_arc = self
998 .paged_block_alloc
999 .as_ref()
1000 .expect("paged_block_alloc must be initialised when paged=true");
1001 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1005 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1006 Ok(idx) => idx,
1007 Err(e) => {
1008 drop(alloc);
1015 self.kv_free_pool.push(caches);
1016 eprintln!(
1017 "[ferrum] paged KV pool exhausted on ensure_kv for \
1018 cache_id={cache_id:?}: {e}. Increase \
1019 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1020 throttle concurrent requests.",
1021 );
1022 return;
1023 }
1024 };
1025 let mut padded = block_indices.clone();
1030 padded.resize(max_blocks_per_seq, 0);
1031 let mut ctx_tmp = B::new_context();
1032 for c in caches.iter_mut() {
1033 if let Some(bt) = K::block_table_mut(c) {
1034 B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1035 }
1036 *K::paged_block_indices_mut(c) = block_indices.clone();
1037 }
1038 B::sync(&mut ctx_tmp);
1039 }
1040
1041 for c in caches.iter_mut() {
1045 K::set_len(c, 0);
1046 if let Some(cl) = K::context_lens_mut(c) {
1047 let mut ctx_tmp = B::new_context();
1048 B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1049 B::sync(&mut ctx_tmp);
1050 }
1051 }
1052 self.kv_caches.insert(cache_id.to_string(), caches);
1053 }
1054
1055 #[allow(clippy::too_many_arguments)]
1060 pub(crate) fn forward_layer(
1061 &mut self,
1062 ctx: &mut B::Context,
1063 li: usize,
1064 cache_id: &str,
1065 residual: &mut B::Buffer,
1066 pos_offset: usize,
1067 tokens: usize,
1068 ) {
1069 let layer = &self.layers[li];
1070 let cfg = &self.cfg;
1071 let h = cfg.hidden_size;
1072 let nh = cfg.num_heads;
1073 let nkv = cfg.num_kv_heads;
1074 let hd = cfg.head_dim;
1075 let im = cfg.intermediate_size;
1076 let eps = cfg.rms_norm_eps;
1077 let q_dim = nh * hd;
1078 let kv_dim = nkv * hd;
1079
1080 let _t0 = if llama_family_runtime_env().decode_op_profile {
1082 B::sync(ctx);
1083 Some(std::time::Instant::now())
1084 } else {
1085 None
1086 };
1087 B::rms_norm(
1088 ctx,
1089 residual,
1090 &layer.input_ln_w,
1091 eps,
1092 &mut self.scratch.norm_out,
1093 tokens,
1094 h,
1095 );
1096 if let Some(t0) = _t0 {
1097 B::sync(ctx);
1098 NORM_TIME_US.fetch_add(
1099 t0.elapsed().as_micros() as u64,
1100 std::sync::atomic::Ordering::Relaxed,
1101 );
1102 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1103 }
1104
1105 let _t0 = if llama_family_runtime_env().decode_op_profile {
1107 B::sync(ctx);
1108 Some(std::time::Instant::now())
1109 } else {
1110 None
1111 };
1112 layer.qkv_proj.forward(
1113 ctx,
1114 &self.scratch.norm_out,
1115 &mut self.scratch.qkv_out,
1116 tokens,
1117 );
1118 if let Some(t0) = _t0 {
1119 B::sync(ctx);
1120 MATMUL_TIME_US.fetch_add(
1121 t0.elapsed().as_micros() as u64,
1122 std::sync::atomic::Ordering::Relaxed,
1123 );
1124 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1125 }
1126
1127 let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
1140 let dummy = &layer.input_ln_w;
1141 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1142 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1143
1144 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1155 if let Some(pools) = self.paged_pools.as_mut() {
1156 let pool = &mut pools[li];
1157 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1158 } else {
1159 None
1160 };
1161 let caches = self
1162 .kv_caches
1163 .get_mut(cache_id)
1164 .expect("ensure_kv must be called before forward_layer");
1165 let cache_len_before = K::len(&caches[li]);
1168 let cache_capacity = K::capacity(&caches[li]);
1169 let cache_block_size = K::block_size(&caches[li]);
1170
1171 if cache_len_before + tokens > cache_capacity {
1177 panic!(
1178 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1179 cache_len_before + tokens
1180 );
1181 }
1182
1183 if cache_block_size > 0 {
1188 let (pool_k_ptr, pool_v_ptr) =
1189 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1190 let pool_k = unsafe { &mut *pool_k_ptr };
1193 let pool_v = unsafe { &mut *pool_v_ptr };
1194
1195 K::paged_write(
1196 ctx,
1197 &mut caches[li],
1198 &self.scratch.qkv_out,
1199 q_norm_w,
1200 k_norm_w,
1201 &self.rope.cos,
1202 &self.rope.sin,
1203 &mut self.scratch.q_head_major,
1204 &mut self.scratch.k_head_major,
1205 &mut self.scratch.v_head_major,
1206 pool_k,
1207 pool_v,
1208 tokens,
1209 nh,
1210 nkv,
1211 hd,
1212 pos_offset,
1213 eps,
1214 qk_mode,
1215 )
1216 .expect("K::paged_write");
1217
1218 let new_len = cache_len_before + tokens;
1219 K::set_len(&mut caches[li], new_len);
1220
1221 let pool_k_imm = unsafe { &*pool_k_ptr };
1222 let pool_v_imm = unsafe { &*pool_v_ptr };
1223 K::paged_decode_attention(
1224 ctx,
1225 &mut caches[li],
1226 &self.scratch.q_head_major,
1227 pool_k_imm,
1228 pool_v_imm,
1229 &mut self.scratch.attn_head_major_out,
1230 nh,
1231 nkv,
1232 hd,
1233 new_len,
1234 tokens,
1235 )
1236 .expect("K::paged_decode_attention");
1237
1238 return self.forward_layer_post_attn(ctx, li, residual, tokens);
1239 }
1240
1241 let _qkr_t0 = if llama_family_runtime_env().decode_op_profile {
1244 B::sync(ctx);
1245 Some(std::time::Instant::now())
1246 } else {
1247 None
1248 };
1249 K::contig_write(
1250 ctx,
1251 &mut caches[li],
1252 &self.scratch.qkv_out,
1253 q_norm_w,
1254 k_norm_w,
1255 &self.rope.cos,
1256 &self.rope.sin,
1257 &mut self.scratch.q_head_major,
1258 &mut self.scratch.k_head_major,
1259 &mut self.scratch.v_head_major,
1260 &mut self.scratch.q_buf,
1261 &mut self.scratch.k_buf,
1262 &mut self.scratch.v_buf,
1263 tokens,
1264 nh,
1265 nkv,
1266 hd,
1267 pos_offset,
1268 eps,
1269 qk_mode,
1270 )
1271 .expect("K::contig_write");
1272 if let Some(t0) = _qkr_t0 {
1273 B::sync(ctx);
1274 QKR_TIME_US.fetch_add(
1275 t0.elapsed().as_micros() as u64,
1276 std::sync::atomic::Ordering::Relaxed,
1277 );
1278 QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1279 }
1280 let new_len = cache_len_before + tokens;
1281 K::set_len(&mut caches[li], new_len);
1282 let kv_stride = cache_capacity;
1283
1284 let _attn_t0 = if llama_family_runtime_env().decode_op_profile {
1285 B::sync(ctx);
1286 Some(std::time::Instant::now())
1287 } else {
1288 None
1289 };
1290 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1291 num_heads: nh,
1292 num_kv_heads: nkv,
1293 head_dim: hd,
1294 causal: true,
1295 scale: 1.0 / (hd as f32).sqrt(),
1296 kv_seq_stride: kv_stride,
1297 sliding_window: cfg.sliding_window,
1298 };
1299 K::contig_decode_attention(
1300 ctx,
1301 &caches[li],
1302 &self.scratch.q_head_major,
1303 &mut self.scratch.attn_head_major_out,
1304 attn_cfg,
1305 tokens,
1306 pos_offset,
1307 )
1308 .expect("K::contig_decode_attention");
1309 let _ = q_dim;
1310 let _ = kv_dim;
1311 let _ = dummy;
1312 if let Some(t0) = _attn_t0 {
1313 B::sync(ctx);
1314 ATTN_TIME_US.fetch_add(
1315 t0.elapsed().as_micros() as u64,
1316 std::sync::atomic::Ordering::Relaxed,
1317 );
1318 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1319 }
1320
1321 self.forward_layer_post_attn(ctx, li, residual, tokens);
1322 }
1323
1324 pub(crate) fn forward_layer_post_attn(
1329 &mut self,
1330 ctx: &mut B::Context,
1331 li: usize,
1332 residual: &mut B::Buffer,
1333 tokens: usize,
1334 ) {
1335 let layer = &self.layers[li];
1336 let cfg = &self.cfg;
1337 let h = cfg.hidden_size;
1338 let nh = cfg.num_heads;
1339 let hd = cfg.head_dim;
1340 let im = cfg.intermediate_size;
1341 let eps = cfg.rms_norm_eps;
1342
1343 let attn_token_major = if tokens == 1 {
1345 &self.scratch.attn_head_major_out
1346 } else {
1347 B::transpose_head_to_token(
1348 ctx,
1349 &self.scratch.attn_head_major_out,
1350 &mut self.scratch.attn_flat,
1351 tokens,
1352 nh,
1353 hd,
1354 );
1355 &self.scratch.attn_flat
1356 };
1357
1358 layer
1360 .o_proj
1361 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1362
1363 B::fused_add_rms_norm(
1365 ctx,
1366 residual,
1367 &self.scratch.o_proj_out,
1368 &layer.post_ln_w,
1369 eps,
1370 &mut self.scratch.norm_out,
1371 tokens,
1372 h,
1373 );
1374
1375 layer.gate_up_proj.forward(
1377 ctx,
1378 &self.scratch.norm_out,
1379 &mut self.scratch.gate_up_out,
1380 tokens,
1381 );
1382
1383 B::fused_silu_mul_split(
1385 ctx,
1386 &self.scratch.gate_up_out,
1387 &mut self.scratch.silu_out,
1388 tokens,
1389 im,
1390 );
1391
1392 layer.down_proj.forward(
1394 ctx,
1395 &self.scratch.silu_out,
1396 &mut self.scratch.mlp_out,
1397 tokens,
1398 );
1399
1400 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1402 }
1403
1404 pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1416 let seq_len = tokens.len();
1417 assert!(seq_len > 0, "forward_verify called with empty tokens");
1418 self.ensure_scratch(seq_len);
1419 self.ensure_kv(cache_id);
1420
1421 let h = self.cfg.hidden_size;
1422 let vocab = self.cfg.vocab_size;
1423
1424 let pos_offset = self
1425 .kv_caches
1426 .get(cache_id)
1427 .and_then(|layers| layers.first())
1428 .map(|c| K::len(c))
1429 .unwrap_or(0);
1430
1431 let mut ctx = B::new_context();
1432 let mut residual = self
1433 .scratch
1434 .residual
1435 .take()
1436 .expect("scratch residual missing (previous call didn't restore)");
1437
1438 let embed = self
1439 .embed
1440 .as_ref()
1441 .expect("forward_verify called on backbone-only model (no embed)");
1442 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1443
1444 for li in 0..self.cfg.num_layers {
1445 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1446 }
1447
1448 B::rms_norm(
1451 &mut ctx,
1452 &residual,
1453 &self.final_norm_w,
1454 self.cfg.rms_norm_eps,
1455 &mut self.scratch.norm_out,
1456 seq_len,
1457 h,
1458 );
1459
1460 let lm_head = self
1464 .lm_head
1465 .as_ref()
1466 .expect("forward_verify called on backbone-only model (no lm_head)");
1467 lm_head.forward(
1468 &mut ctx,
1469 &self.scratch.norm_out,
1470 &mut self.scratch.batch_logits,
1471 seq_len,
1472 );
1473
1474 B::sync(&mut ctx);
1475 self.scratch.residual = Some(residual);
1476 B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1477 }
1478
1479 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1487 let seq_len = tokens.len();
1488 assert!(seq_len > 0, "prefill called with empty token list");
1489 self.ensure_scratch(seq_len);
1490 self.ensure_kv(cache_id);
1491
1492 let pos_offset = self
1495 .kv_caches
1496 .get(cache_id)
1497 .and_then(|layers| layers.first())
1498 .map(|c| K::len(c))
1499 .unwrap_or(0);
1500
1501 let h = self.cfg.hidden_size;
1502 let vocab = self.cfg.vocab_size;
1503 let mut ctx = B::new_context();
1504
1505 let mut residual = self
1512 .scratch
1513 .residual
1514 .take()
1515 .expect("scratch residual missing (previous call didn't restore)");
1516 let embed = self
1517 .embed
1518 .as_ref()
1519 .expect("prefill_internal called on backbone-only model (no embed)");
1520 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1521
1522 let prefill_profile = llama_family_runtime_env().prefill_op_profile;
1523 let prefill_t0 = if prefill_profile {
1524 B::sync(&mut ctx);
1525 Some(std::time::Instant::now())
1526 } else {
1527 None
1528 };
1529
1530 for li in 0..self.cfg.num_layers {
1531 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1532 }
1533
1534 if let Some(t0) = prefill_t0 {
1535 B::sync(&mut ctx);
1536 let total_us = t0.elapsed().as_micros() as u64;
1537 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1538 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1539 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1540 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1541 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1542 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1543 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1544 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1545 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1546 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1547 eprintln!(
1548 "[prefill-profile] tokens={} layers total={} ms",
1549 seq_len,
1550 total_us / 1000
1551 );
1552 let bucket = |label: &str, n: u64, us: u64| {
1553 if n > 0 {
1554 eprintln!(
1555 "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1556 n,
1557 us / 1000,
1558 us / n
1559 );
1560 }
1561 };
1562 bucket("flash_attn", attn_n, attn_us);
1563 bucket("qk_norm_rope", qkr_n, qkr_us);
1564 bucket("matmuls", mm_n, mm_us);
1565 bucket("norms", norm_n, norm_us);
1566 bucket("other", other_n, other_us);
1567 }
1568
1569 B::copy_slice(
1571 &mut ctx,
1572 &residual,
1573 (seq_len - 1) * h,
1574 &mut self.scratch.last_hidden,
1575 0,
1576 h,
1577 );
1578
1579 B::rms_norm(
1581 &mut ctx,
1582 &self.scratch.last_hidden,
1583 &self.final_norm_w,
1584 self.cfg.rms_norm_eps,
1585 &mut self.scratch.last_normed,
1586 1,
1587 h,
1588 );
1589
1590 let lm_head = self
1592 .lm_head
1593 .as_ref()
1594 .expect("prefill_internal called on backbone-only model (no lm_head)");
1595 lm_head.forward(
1596 &mut ctx,
1597 &self.scratch.last_normed,
1598 &mut self.scratch.logits,
1599 1,
1600 );
1601
1602 B::sync(&mut ctx);
1609
1610 self.scratch.residual = Some(residual);
1612
1613 B::to_vec(&self.scratch.logits, vocab)
1614 }
1615
1616 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1618 self.ensure_scratch(1);
1619 self.ensure_kv(cache_id);
1620
1621 let h = self.cfg.hidden_size;
1622 let vocab = self.cfg.vocab_size;
1623
1624 let mut ctx = B::new_context();
1627
1628 const GRAPH_WARMUP: usize = 3;
1633 let graph_enabled = llama_family_runtime_env().cuda_graph;
1634
1635 if graph_enabled {
1636 B::set_decode_state(&mut ctx, token, pos);
1639
1640 match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
1644 Ok(true) => {
1645 B::sync(&mut ctx);
1646 return B::to_vec(&self.scratch.logits, vocab);
1647 }
1648 Ok(false) => { }
1649 Err(_) => { }
1650 }
1651 }
1652
1653 let should_capture =
1654 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1655
1656 if should_capture {
1657 B::set_dev_state_mode(&mut ctx, true);
1658 if B::begin_graph_capture(&mut ctx).is_err() {
1659 self.graph_capture_failed = true;
1660 B::set_dev_state_mode(&mut ctx, false);
1661 }
1662 }
1663
1664 let mut residual = self
1670 .scratch
1671 .residual
1672 .take()
1673 .expect("scratch residual missing (previous call didn't restore)");
1674 let embed = self
1675 .embed
1676 .as_ref()
1677 .expect("decode_internal called on backbone-only model (no embed)");
1678 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1679
1680 let layer_profile = llama_family_runtime_env().decode_layer_profile;
1684 let mut layer_times = if layer_profile {
1685 Some(Vec::with_capacity(self.cfg.num_layers))
1686 } else {
1687 None
1688 };
1689
1690 for li in 0..self.cfg.num_layers {
1691 if layer_profile {
1692 let t0 = std::time::Instant::now();
1693 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1694 B::sync(&mut ctx);
1695 let elapsed_us = t0.elapsed().as_micros() as u64;
1696 if let Some(v) = layer_times.as_mut() {
1697 v.push(elapsed_us);
1698 }
1699 } else {
1700 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1701 }
1702 }
1703 if let Some(times) = layer_times.take() {
1704 let sum: u64 = times.iter().sum();
1705 let avg = sum / times.len() as u64;
1706 let mn = *times.iter().min().unwrap_or(&0);
1707 let mx = *times.iter().max().unwrap_or(&0);
1708 eprintln!(
1709 "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1710 times.len(),
1711 sum / 1000,
1712 avg,
1713 mn,
1714 mx
1715 );
1716 for (i, t) in times.iter().enumerate() {
1717 eprint!("L{i}={}ms ", t / 1000);
1718 if (i + 1) % 6 == 0 {
1719 eprintln!();
1720 }
1721 }
1722 eprintln!();
1723 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1724 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1725 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1726 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1727 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1728 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1729 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1730 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1731 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1732 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1733 eprintln!(
1734 "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1735 attn_n,
1736 attn_us / 1000,
1737 if attn_n > 0 { attn_us / attn_n } else { 0 }
1738 );
1739 eprintln!(
1740 "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1741 qkr_n,
1742 qkr_us / 1000,
1743 if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1744 );
1745 eprintln!(
1746 "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1747 mm_n,
1748 mm_us / 1000,
1749 if mm_n > 0 { mm_us / mm_n } else { 0 }
1750 );
1751 eprintln!(
1752 "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1753 norm_n,
1754 norm_us / 1000,
1755 if norm_n > 0 { norm_us / norm_n } else { 0 }
1756 );
1757 eprintln!(
1758 "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1759 other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1760 );
1761 }
1762
1763 B::rms_norm(
1764 &mut ctx,
1765 &residual,
1766 &self.final_norm_w,
1767 self.cfg.rms_norm_eps,
1768 &mut self.scratch.last_normed,
1769 1,
1770 h,
1771 );
1772
1773 let lm_head = self
1774 .lm_head
1775 .as_ref()
1776 .expect("decode_internal called on backbone-only model (no lm_head)");
1777 lm_head.forward(
1778 &mut ctx,
1779 &self.scratch.last_normed,
1780 &mut self.scratch.logits,
1781 1,
1782 );
1783
1784 if should_capture && !self.graph_capture_failed {
1785 if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1786 self.graph_capture_failed = true;
1787 } else {
1788 if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1795 self.graph_capture_failed = true;
1796 }
1797 }
1798 B::set_dev_state_mode(&mut ctx, false);
1799 } else {
1800 self.graph_warmup += 1;
1801 }
1802
1803 B::sync(&mut ctx);
1810 self.scratch.residual = Some(residual);
1811
1812 B::to_vec(&self.scratch.logits, vocab)
1813 }
1814
1815 pub fn prefill_from_embeds(
1824 &mut self,
1825 cache_id: &str,
1826 embeds: &[f32],
1827 seq_len: usize,
1828 ) -> Vec<f32> {
1829 let h = self.cfg.hidden_size;
1830 assert_eq!(
1831 embeds.len(),
1832 seq_len * h,
1833 "embeds length {} != seq_len * hidden_size {}",
1834 embeds.len(),
1835 seq_len * h
1836 );
1837 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1838
1839 self.ensure_scratch(seq_len);
1840 self.ensure_kv(cache_id);
1841
1842 let mut ctx = B::new_context();
1843 let mut residual = self
1844 .scratch
1845 .residual
1846 .take()
1847 .expect("scratch residual missing (previous call didn't restore)");
1848
1849 let embed_buf = B::from_slice(embeds);
1851 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1852
1853 for li in 0..self.cfg.num_layers {
1854 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1855 }
1856
1857 B::copy_slice(
1858 &mut ctx,
1859 &residual,
1860 (seq_len - 1) * h,
1861 &mut self.scratch.last_hidden,
1862 0,
1863 h,
1864 );
1865 B::sync(&mut ctx);
1866 self.scratch.residual = Some(residual);
1867 B::to_vec(&self.scratch.last_hidden, h)
1868 }
1869
1870 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1874 let h = self.cfg.hidden_size;
1875 assert_eq!(
1876 embed.len(),
1877 h,
1878 "embed length {} != hidden_size {}",
1879 embed.len(),
1880 h
1881 );
1882
1883 self.ensure_scratch(1);
1884 self.ensure_kv(cache_id);
1885
1886 let mut ctx = B::new_context();
1887 let mut residual = self
1888 .scratch
1889 .residual
1890 .take()
1891 .expect("scratch residual missing (previous call didn't restore)");
1892
1893 let embed_buf = B::from_slice(embed);
1894 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1895
1896 for li in 0..self.cfg.num_layers {
1897 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1898 }
1899
1900 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1901 B::sync(&mut ctx);
1902 self.scratch.residual = Some(residual);
1903 B::to_vec(&self.scratch.last_hidden, h)
1904 }
1905
1906 pub fn prefill_all_post_norm(
1917 &mut self,
1918 cache_id: &str,
1919 embeds: &[f32],
1920 seq_len: usize,
1921 pos_offset: usize,
1922 ) -> Vec<f32> {
1923 let h = self.cfg.hidden_size;
1924 assert_eq!(
1925 embeds.len(),
1926 seq_len * h,
1927 "embeds length {} != seq_len * hidden_size {}",
1928 embeds.len(),
1929 seq_len * h
1930 );
1931 assert!(seq_len > 0);
1932
1933 self.ensure_scratch(seq_len);
1934 self.ensure_kv(cache_id);
1935
1936 let mut ctx = B::new_context();
1937 let mut residual = self
1938 .scratch
1939 .residual
1940 .take()
1941 .expect("scratch residual missing (previous call didn't restore)");
1942
1943 let embed_buf = B::from_slice(embeds);
1944 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1945
1946 for li in 0..self.cfg.num_layers {
1947 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1948 }
1949
1950 B::rms_norm(
1952 &mut ctx,
1953 &residual,
1954 &self.final_norm_w,
1955 self.cfg.rms_norm_eps,
1956 &mut self.scratch.norm_out,
1957 seq_len,
1958 h,
1959 );
1960 B::sync(&mut ctx);
1961 self.scratch.residual = Some(residual);
1962 B::to_vec(&self.scratch.norm_out, seq_len * h)
1963 }
1964
1965 pub fn decode_post_norm_from_embed(
1969 &mut self,
1970 cache_id: &str,
1971 embed: &[f32],
1972 pos: u32,
1973 ) -> Vec<f32> {
1974 let h = self.cfg.hidden_size;
1975 assert_eq!(embed.len(), h);
1976
1977 self.ensure_scratch(1);
1978 self.ensure_kv(cache_id);
1979
1980 let mut ctx = B::new_context();
1981 let mut residual = self
1982 .scratch
1983 .residual
1984 .take()
1985 .expect("scratch residual missing (previous call didn't restore)");
1986
1987 let embed_buf = B::from_slice(embed);
1988 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1989
1990 for li in 0..self.cfg.num_layers {
1991 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1992 }
1993
1994 B::rms_norm(
1995 &mut ctx,
1996 &residual,
1997 &self.final_norm_w,
1998 self.cfg.rms_norm_eps,
1999 &mut self.scratch.last_normed,
2000 1,
2001 h,
2002 );
2003 B::sync(&mut ctx);
2004 self.scratch.residual = Some(residual);
2005 B::to_vec(&self.scratch.last_normed, h)
2006 }
2007}
2008
2009impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2011 fn config(&self) -> &LlmRuntimeConfig {
2012 &self.runtime_cfg
2013 }
2014
2015 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2016 self.ensure_scratch(max_tokens);
2017 self.ensure_kv(cache_id);
2018
2019 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2020 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2021 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2022 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2023 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2024 if let Some(c0) = caches.first() {
2025 if !c0.paged_block_indices.is_empty() {
2026 alloc.free(&c0.paged_block_indices);
2027 }
2028 }
2029 for c in caches.iter_mut() {
2030 c.paged_block_indices.clear();
2031 }
2032 }
2033 self.kv_free_pool.push(caches);
2034 }
2035 }
2036
2037 fn kv_capacity(&self) -> usize {
2038 llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2039 }
2040
2041 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2042 self.prefill_internal(cache_id, tokens)
2043 }
2044
2045 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2046 self.decode_internal(cache_id, token, pos)
2047 }
2048
2049 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2050 self.decode_batch_internal(batch)
2051 }
2052
2053 fn unified_forward(
2054 &mut self,
2055 items: &[(String, Vec<u32>, usize, bool)],
2056 ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2057 if items.is_empty() {
2058 return Ok(Vec::new());
2059 }
2060 if !B::supports_varlen_qkv() {
2061 return Err(ferrum_types::FerrumError::unsupported(
2062 "LlamaFamilyModel::unified_forward: backend lacks varlen \
2063 QKV kernels. Engine will fall back to per-item dispatch.",
2064 ));
2065 }
2066 self.ensure_kv(&items[0].0);
2067 if self.paged_pools.is_none() {
2068 return Err(ferrum_types::FerrumError::unsupported(
2069 "LlamaFamilyModel::unified_forward: paged KV required; \
2070 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2071 Engine will fall back to per-item dispatch.",
2072 ));
2073 }
2074 for (cid, _, _, _) in items {
2075 self.ensure_kv(cid);
2076 if !self.kv_caches.contains_key(cid) {
2077 return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2078 "paged KV pool exhausted for cache_id={cid:?}; back off"
2079 )));
2080 }
2081 }
2082 Ok(self.unified_forward_internal(items))
2083 }
2084
2085 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2086 LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2087 }
2088
2089 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2090 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2091 for c in caches.iter_mut() {
2092 if new_len < c.len {
2093 c.len = new_len;
2094 }
2095 }
2096 }
2097 let mut ctx = B::new_context();
2098 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2099 self.graph_warmup = 0;
2100 self.graph_capture_failed = false;
2101 }
2102
2103 fn release(&mut self, cache_id: &str) {
2104 let mut ctx = B::new_context();
2105 B::sync(&mut ctx);
2106 self.graph_warmup = 0;
2107 self.graph_capture_failed = false;
2108 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2109 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2110 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2111 if let Some(c0) = caches.first() {
2112 if !c0.paged_block_indices.is_empty() {
2113 alloc.free(&c0.paged_block_indices);
2114 }
2115 }
2116 for c in caches.iter_mut() {
2117 c.paged_block_indices.clear();
2118 }
2119 }
2120 self.kv_free_pool.push(caches);
2121 }
2122 }
2123
2124 fn reset(&mut self) {
2125 let mut ctx = B::new_context();
2126 B::sync(&mut ctx);
2127 B::reset_all_graphs(&mut ctx);
2128 B::sync(&mut ctx);
2129 self.graph_warmup = 0;
2130 self.graph_capture_failed = false;
2131 self.batched_graph_keys_seen.clear();
2132 self.batched_graph_warmup = 0;
2133 self.batched_graph_failed = false;
2134 self.kv_caches.clear();
2135 self.kv_free_pool.clear();
2136 }
2137}
2138
2139impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2143 fn config(&self) -> &LlmRuntimeConfig {
2144 &self.runtime_cfg
2145 }
2146
2147 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2148 self.ensure_scratch(max_tokens);
2149 self.ensure_kv(cache_id);
2150
2151 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2152 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2153 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2154 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2155 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2156 if let Some(c0) = caches.first() {
2157 if !c0.paged_block_indices.is_empty() {
2158 alloc.free(&c0.paged_block_indices);
2159 }
2160 }
2161 for c in caches.iter_mut() {
2162 c.paged_block_indices.clear();
2163 }
2164 }
2165 self.kv_free_pool.push(caches);
2166 }
2167 }
2168
2169 fn kv_capacity(&self) -> usize {
2170 llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2171 }
2172
2173 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2174 self.prefill_internal(cache_id, tokens)
2175 }
2176
2177 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2178 self.decode_internal(cache_id, token, pos)
2179 }
2180
2181 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2184 LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
2185 }
2186
2187 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2188 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2189 for c in caches.iter_mut() {
2190 if new_len < c.len {
2191 c.len = new_len;
2192 }
2193 }
2194 }
2195 let mut ctx = B::new_context();
2196 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2197 self.graph_warmup = 0;
2198 self.graph_capture_failed = false;
2199 }
2200
2201 fn release(&mut self, cache_id: &str) {
2202 let mut ctx = B::new_context();
2203 B::sync(&mut ctx);
2204 self.graph_warmup = 0;
2205 self.graph_capture_failed = false;
2206 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2207 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2208 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2209 if let Some(c0) = caches.first() {
2210 if !c0.paged_block_indices.is_empty() {
2211 alloc.free(&c0.paged_block_indices);
2212 }
2213 }
2214 for c in caches.iter_mut() {
2215 c.paged_block_indices.clear();
2216 }
2217 }
2218 self.kv_free_pool.push(caches);
2219 }
2220 }
2221
2222 fn reset(&mut self) {
2223 let mut ctx = B::new_context();
2224 B::sync(&mut ctx);
2225 B::reset_all_graphs(&mut ctx);
2226 B::sync(&mut ctx);
2227 self.graph_warmup = 0;
2228 self.graph_capture_failed = false;
2229 self.kv_caches.clear();
2230 self.kv_free_pool.clear();
2231 }
2232}
2233
2234fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2235 let hd = cfg.head_dim;
2236 let half = hd / 2;
2237 let max = cfg.max_seq_len;
2238 let mut cos = vec![0.0f32; max * half];
2239 let mut sin = vec![0.0f32; max * half];
2240 for pos in 0..max {
2241 for i in 0..half {
2242 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
2243 let angle = pos as f64 * freq;
2244 cos[pos * half + i] = angle.cos() as f32;
2245 sin[pos * half + i] = angle.sin() as f32;
2246 }
2247 }
2248 RopeCache {
2249 cos: B::from_slice(&cos),
2250 sin: B::from_slice(&sin),
2251 }
2252}
2253
2254#[cfg(test)]
2255mod tests {
2256 use super::{LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY};
2257
2258 #[test]
2259 fn llama_family_runtime_env_parses_startup_knobs() {
2260 let env = LlamaFamilyRuntimeEnv::from_env_vars([
2261 ("FERRUM_KV_CAPACITY", "4096"),
2262 ("FERRUM_METAL_PAGED_KV", "0"),
2263 ("FERRUM_PAGED_MAX_SEQS", "64"),
2264 ("FERRUM_DECODE_OP_PROFILE", "0"),
2265 ("FERRUM_PREFILL_OP_PROFILE", ""),
2266 ("FERRUM_CUDA_GRAPH", ""),
2267 ("FERRUM_DECODE_LAYER_PROFILE", "false"),
2268 ]);
2269
2270 assert_eq!(env.kv_capacity, Some(4096));
2271 assert_eq!(env.metal_paged_kv, Some(false));
2272 assert_eq!(env.paged_max_seqs, 64);
2273 assert!(env.decode_op_profile);
2274 assert!(env.prefill_op_profile);
2275 assert!(env.cuda_graph);
2276 assert!(env.decode_layer_profile);
2277 assert_eq!(env.kv_capacity_for_model(2048), 2048);
2278 }
2279
2280 #[test]
2281 fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
2282 let env = LlamaFamilyRuntimeEnv::from_env_vars([
2283 ("FERRUM_KV_CAPACITY", "bad"),
2284 ("FERRUM_PAGED_MAX_SEQS", "bad"),
2285 ("FERRUM_METAL_PAGED_KV", "1"),
2286 ]);
2287
2288 assert_eq!(env.kv_capacity, None);
2289 assert_eq!(env.metal_paged_kv, Some(true));
2290 assert_eq!(env.paged_max_seqs, 32);
2291 assert_eq!(
2292 env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
2293 DEFAULT_KV_CAPACITY
2294 );
2295 }
2296}