1use std::collections::HashMap;
20use std::sync::{atomic::AtomicU64, OnceLock};
21
22use ferrum_interfaces::kv_dtype::{KvFp16, KvInt8};
23use ferrum_kernels::backend::{
24 Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
25 BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
26 MAX_LAYERS_FOR_GRAPH,
27};
28
29pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
32
33pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
35pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
36
37pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
38pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
39pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
41pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
43pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
45pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
47use ferrum_quantization::{Linear, WeightLoader};
48use ferrum_types::Result;
49
50use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
51
52const DEFAULT_KV_CAPACITY: usize = 512;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55struct LlamaFamilyRuntimeEnv {
56 kv_capacity: Option<usize>,
57 metal_paged_kv: Option<bool>,
58 paged_max_seqs: usize,
59 decode_op_profile: bool,
60 prefill_op_profile: bool,
61 cuda_graph: bool,
62 decode_layer_profile: bool,
63}
64
65impl LlamaFamilyRuntimeEnv {
66 fn from_env() -> Self {
67 Self::from_env_vars(std::env::vars())
68 }
69
70 fn from_env_vars<I, K, V>(vars: I) -> Self
71 where
72 I: IntoIterator<Item = (K, V)>,
73 K: AsRef<str>,
74 V: AsRef<str>,
75 {
76 let mut config = Self {
77 kv_capacity: None,
78 metal_paged_kv: None,
79 paged_max_seqs: 32,
80 decode_op_profile: false,
81 prefill_op_profile: false,
82 cuda_graph: false,
83 decode_layer_profile: false,
84 };
85 for (name, value) in vars {
86 let value = value.as_ref();
87 match name.as_ref() {
88 "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
89 "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
90 "FERRUM_PAGED_MAX_SEQS" => {
91 if let Ok(max_seqs) = value.parse::<usize>() {
92 config.paged_max_seqs = max_seqs;
93 }
94 }
95 "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
96 "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
97 "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
98 "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
99 _ => {}
100 }
101 }
102 config
103 }
104
105 fn kv_capacity_for_model(&self, model_max: usize) -> usize {
106 self.kv_capacity
107 .map(|cap| cap.min(model_max))
108 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
109 }
110
111 fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
112 self.metal_paged_kv
113 .unwrap_or_else(|| B::supports_paged_kv())
114 }
115}
116
117fn llama_family_runtime_env() -> &'static LlamaFamilyRuntimeEnv {
118 static CONFIG: OnceLock<LlamaFamilyRuntimeEnv> = OnceLock::new();
119 CONFIG.get_or_init(LlamaFamilyRuntimeEnv::from_env)
120}
121
122#[derive(Clone, Debug, PartialEq)]
123pub enum RopeScalingConfig {
124 Llama3 {
126 factor: f64,
127 low_freq_factor: f64,
128 high_freq_factor: f64,
129 original_max_position_embeddings: f64,
130 },
131}
132
133impl RopeScalingConfig {
134 pub fn llama3_default() -> Self {
135 Self::Llama3 {
136 factor: 8.0,
137 low_freq_factor: 1.0,
138 high_freq_factor: 4.0,
139 original_max_position_embeddings: 8192.0,
140 }
141 }
142}
143
144#[derive(Clone, Debug, PartialEq)]
147pub struct LlamaFamilyConfig {
148 pub hidden_size: usize,
149 pub intermediate_size: usize,
150 pub num_heads: usize,
151 pub num_kv_heads: usize,
152 pub head_dim: usize,
153 pub num_layers: usize,
154 pub vocab_size: usize,
155 pub max_seq_len: usize,
156 pub rms_norm_eps: f32,
157 pub rope_theta: f64,
158 pub rope_scaling: Option<RopeScalingConfig>,
159 pub rope_interleaved: bool,
163 pub has_qk_norm: bool,
166 pub sliding_window: usize,
169}
170
171impl LlamaFamilyConfig {
172 pub fn to_runtime(&self) -> LlmRuntimeConfig {
173 LlmRuntimeConfig {
174 hidden_size: self.hidden_size,
175 num_layers: self.num_layers,
176 num_kv_heads: self.num_kv_heads,
177 head_dim: self.head_dim,
178 vocab_size: self.vocab_size,
179 max_seq_len: self.max_seq_len,
180 }
181 }
182
183 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
187 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
188 let head_dim = def
189 .extra_params
190 .get("head_dim")
191 .and_then(|v| v.as_u64())
192 .map(|v| v as usize)
193 .unwrap_or(def.hidden_size / def.num_attention_heads);
194 let sliding_window = def
197 .extra_params
198 .get("sliding_window")
199 .and_then(|v| v.as_u64())
200 .map(|v| v as usize)
201 .unwrap_or(0);
202
203 LlamaFamilyConfigBase {
204 hidden_size: def.hidden_size,
205 intermediate_size: def.intermediate_size,
206 num_heads: def.num_attention_heads,
207 num_kv_heads,
208 head_dim,
209 num_layers: def.num_hidden_layers,
210 vocab_size: def.vocab_size,
211 max_seq_len: def.max_position_embeddings,
212 rms_norm_eps: def.norm_eps as f32,
213 rope_theta_opt: def.rope_theta,
214 rope_scaling: rope_scaling_from_model_def(def),
215 sliding_window,
216 }
217 }
218
219 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
220 Self {
221 hidden_size: b.hidden_size,
222 intermediate_size: b.intermediate_size,
223 num_heads: b.num_heads,
224 num_kv_heads: b.num_kv_heads,
225 head_dim: b.head_dim,
226 num_layers: b.num_layers,
227 vocab_size: b.vocab_size,
228 max_seq_len: b.max_seq_len,
229 rms_norm_eps: b.rms_norm_eps,
230 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
231 rope_scaling: b.rope_scaling,
232 rope_interleaved: false,
233 has_qk_norm,
234 sliding_window: b.sliding_window,
235 }
236 }
237
238 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
240 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
241 }
242
243 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
247 Self::from_base(Self::from_def_base(def), 500_000.0, false)
248 }
249
250 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
252 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
253 }
254
255 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
259 Self::from_base(Self::from_def_base(def), 10_000.0, false)
260 }
261}
262
263struct LlamaFamilyConfigBase {
264 hidden_size: usize,
265 intermediate_size: usize,
266 num_heads: usize,
267 num_kv_heads: usize,
268 head_dim: usize,
269 num_layers: usize,
270 vocab_size: usize,
271 max_seq_len: usize,
272 rms_norm_eps: f32,
273 rope_theta_opt: Option<f64>,
274 rope_scaling: Option<RopeScalingConfig>,
275 sliding_window: usize,
276}
277
278fn rope_scaling_from_model_def(
279 def: &crate::definition::ModelDefinition,
280) -> Option<RopeScalingConfig> {
281 let value = def.extra_params.get("rope_scaling")?;
282 let obj = value.as_object()?;
283 let rope_type = obj
284 .get("rope_type")
285 .or_else(|| obj.get("type"))
286 .and_then(|v| v.as_str())?;
287 if rope_type != "llama3" {
288 return None;
289 }
290 let factor = json_f64(obj.get("factor"))?;
291 let low_freq_factor = json_f64(obj.get("low_freq_factor"))?;
292 let high_freq_factor = json_f64(obj.get("high_freq_factor"))?;
293 let original_max_position_embeddings = json_f64(obj.get("original_max_position_embeddings"))
294 .or_else(|| {
295 def.extra_params
296 .get("original_max_position_embeddings")
297 .and_then(|v| json_f64(Some(v)))
298 })
299 .unwrap_or(8192.0);
300 if factor <= 0.0
301 || low_freq_factor <= 0.0
302 || high_freq_factor <= low_freq_factor
303 || original_max_position_embeddings <= 0.0
304 {
305 return None;
306 }
307 Some(RopeScalingConfig::Llama3 {
308 factor,
309 low_freq_factor,
310 high_freq_factor,
311 original_max_position_embeddings,
312 })
313}
314
315fn json_f64(value: Option<&serde_json::Value>) -> Option<f64> {
316 match value? {
317 serde_json::Value::Number(n) => n.as_f64(),
318 _ => None,
319 }
320}
321
322pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
325 pub input_ln_w: B::Buffer,
326 pub qkv_proj: Box<dyn Linear<B>>,
327 pub q_norm_w: Option<B::Buffer>,
329 pub k_norm_w: Option<B::Buffer>,
330 pub o_proj: Box<dyn Linear<B>>,
331 pub post_ln_w: B::Buffer,
332 pub gate_up_proj: Box<dyn Linear<B>>,
333 pub down_proj: Box<dyn Linear<B>>,
334}
335
336pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
338 pub cos: B::Buffer,
339 pub sin: B::Buffer,
340}
341
342pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
348 pub residual: Option<B::Buffer>,
359 pub norm_out: B::Buffer,
360 pub qkv_out: B::Buffer,
361 pub q_single: B::Buffer,
369 pub k_single: B::Buffer,
370 pub v_single: B::Buffer,
371 pub q_head_major_single: B::Buffer,
372 pub k_head_major_single: B::Buffer,
373 pub v_head_major_single: B::Buffer,
374 pub attn_head_major_single: B::Buffer,
375 pub attn_flat_single: B::Buffer,
376 pub batch_logits: B::Buffer,
379 pub q_buf: B::Buffer,
381 pub k_buf: B::Buffer,
382 pub v_buf: B::Buffer,
383 pub q_head_major: B::Buffer,
385 pub k_head_major: B::Buffer,
388 pub v_head_major: B::Buffer,
389 pub attn_head_major_out: B::Buffer,
391 pub attn_flat: B::Buffer,
393 pub o_proj_out: B::Buffer,
394 pub gate_up_out: B::Buffer,
395 pub silu_out: B::Buffer,
396 pub mlp_out: B::Buffer,
397 pub paged_batch_q: Option<B::Buffer>,
403 pub paged_batch_o: Option<B::Buffer>,
404 pub paged_batch_block_tables: Option<B::Buffer>,
408 pub paged_batch_context_lens: Option<B::Buffer>,
411 pub paged_max_blocks_per_seq: usize,
414 pub paged_max_seqs: usize,
420 pub batch_positions: B::Buffer,
425 pub batch_tokens: B::Buffer,
429 pub batch_kv_lens_pre: B::Buffer,
433 pub batch_kv_lens_post: B::Buffer,
438 pub q_normed_batched: B::Buffer,
442 pub k_normed_batched: B::Buffer,
443 pub v_normed_batched: B::Buffer,
444
445 pub unified_capacity: usize, pub unified_residual: Option<B::Buffer>,
452 pub unified_norm_out: Option<B::Buffer>,
453 pub unified_qkv_out: Option<B::Buffer>,
454 pub unified_packed_q: Option<B::Buffer>,
455 pub unified_attn_out: Option<B::Buffer>,
456 pub unified_o_proj_out: Option<B::Buffer>,
457 pub unified_gate_up_out: Option<B::Buffer>,
458 pub unified_silu_out: Option<B::Buffer>,
459 pub unified_mlp_out: Option<B::Buffer>,
460 pub unified_cu_seqlens_q: Option<B::Buffer>,
465 pub unified_pos_offsets: Option<B::Buffer>,
466 pub unified_block_tables: Option<B::Buffer>,
467 pub unified_packed_normed: Option<B::Buffer>,
470 pub unified_packed_logits: Option<B::Buffer>,
472 pub last_hidden: B::Buffer,
476 pub last_normed: B::Buffer,
478 pub logits: B::Buffer,
480 pub max_tokens: usize,
482}
483
484impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
485 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
486 let h = cfg.hidden_size;
487 let im = cfg.intermediate_size;
488 let q_dim = cfg.num_heads * cfg.head_dim;
489 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
490 let qkv_dim = q_dim + 2 * kv_dim;
491 let t = max_tokens;
492 Self {
493 residual: Some(B::alloc(t * h)),
494 norm_out: B::alloc(t * h),
495 qkv_out: B::alloc(t * qkv_dim),
496 q_buf: B::alloc(t * q_dim),
497 k_buf: B::alloc(t * kv_dim),
498 v_buf: B::alloc(t * kv_dim),
499 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
500 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
501 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
502 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
503 attn_flat: B::alloc(t * q_dim),
504 o_proj_out: B::alloc(t * h),
505 gate_up_out: B::alloc(t * 2 * im),
506 silu_out: B::alloc(t * im),
507 mlp_out: B::alloc(t * h),
508 last_hidden: B::alloc(h),
509 last_normed: B::alloc(h),
510 logits: B::alloc(cfg.vocab_size),
511 q_single: B::alloc(q_dim),
512 k_single: B::alloc(kv_dim),
513 v_single: B::alloc(kv_dim),
514 q_head_major_single: B::alloc(q_dim),
515 k_head_major_single: B::alloc(kv_dim),
516 v_head_major_single: B::alloc(kv_dim),
517 attn_head_major_single: B::alloc(q_dim),
518 attn_flat_single: B::alloc(q_dim),
519 batch_logits: B::alloc(t * cfg.vocab_size),
520 paged_batch_q: None,
525 paged_batch_o: None,
526 paged_batch_block_tables: None,
527 paged_batch_context_lens: None,
528 paged_max_blocks_per_seq: 0,
529 paged_max_seqs: 0,
530 batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
531 batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
532 batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
533 batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
534 q_normed_batched: B::alloc(t * q_dim),
535 k_normed_batched: B::alloc(t * kv_dim),
536 v_normed_batched: B::alloc(t * kv_dim),
537 unified_capacity: 0,
538 unified_residual: None,
539 unified_norm_out: None,
540 unified_qkv_out: None,
541 unified_packed_q: None,
542 unified_attn_out: None,
543 unified_o_proj_out: None,
544 unified_gate_up_out: None,
545 unified_silu_out: None,
546 unified_mlp_out: None,
547 unified_cu_seqlens_q: None,
548 unified_pos_offsets: None,
549 unified_block_tables: None,
550 unified_packed_normed: None,
551 unified_packed_logits: None,
552 max_tokens: t,
553 }
554 }
555
556 pub(crate) fn ensure_unified_scratch(
560 &mut self,
561 cfg: &LlamaFamilyConfig,
562 m_total: usize,
563 max_seqs: usize,
564 max_blocks_per_seq: usize,
565 ) {
566 if m_total <= self.unified_capacity
567 && self.unified_residual.is_some()
568 && self.unified_cu_seqlens_q.is_some()
569 {
570 return;
571 }
572 let cap = m_total.max(self.unified_capacity).max(1);
573 let h = cfg.hidden_size;
574 let q_dim = cfg.num_heads * cfg.head_dim;
575 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
576 let qkv_dim = q_dim + 2 * kv_dim;
577 let im = cfg.intermediate_size;
578 let v = cfg.vocab_size;
579 self.unified_residual = Some(B::alloc(cap * h));
580 self.unified_norm_out = Some(B::alloc(cap * h));
581 self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
582 self.unified_packed_q = Some(B::alloc(cap * q_dim));
583 self.unified_attn_out = Some(B::alloc(cap * q_dim));
584 self.unified_o_proj_out = Some(B::alloc(cap * h));
585 self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
586 self.unified_silu_out = Some(B::alloc(cap * im));
587 self.unified_mlp_out = Some(B::alloc(cap * h));
588 if self.unified_cu_seqlens_q.is_none() {
589 self.unified_cu_seqlens_q = Some(B::alloc_typed(
590 ferrum_kernels::backend::Dtype::U32,
591 max_seqs + 1,
592 ));
593 self.unified_pos_offsets = Some(B::alloc_typed(
594 ferrum_kernels::backend::Dtype::U32,
595 max_seqs,
596 ));
597 self.unified_block_tables = Some(B::alloc_typed(
598 ferrum_kernels::backend::Dtype::U32,
599 max_seqs * max_blocks_per_seq,
600 ));
601 self.unified_packed_normed = Some(B::alloc(max_seqs * h));
602 self.unified_packed_logits = Some(B::alloc(max_seqs * v));
603 }
604 self.unified_capacity = cap;
605 }
606
607 fn enable_paged_batch(
611 &mut self,
612 cfg: &LlamaFamilyConfig,
613 max_seqs: usize,
614 max_blocks_per_seq: usize,
615 ) {
616 if self.paged_batch_q.is_some() {
617 return;
618 }
619 let q_dim = cfg.num_heads * cfg.head_dim;
620 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
621 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
622 self.paged_batch_block_tables = Some(B::alloc_typed(
623 ferrum_kernels::backend::Dtype::U32,
624 max_seqs * max_blocks_per_seq,
625 ));
626 self.paged_batch_context_lens = Some(B::alloc_typed(
627 ferrum_kernels::backend::Dtype::U32,
628 max_seqs,
629 ));
630 self.paged_max_blocks_per_seq = max_blocks_per_seq;
631 self.paged_max_seqs = max_seqs;
632 }
633}
634
635pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
650 pub cfg: LlamaFamilyConfig,
651 pub runtime_cfg: LlmRuntimeConfig,
652
653 pub embed: Option<B::Buffer>,
657 pub layers: Vec<LlamaFamilyLayer<B>>,
658 pub final_norm_w: B::Buffer,
659 pub lm_head: Option<Box<dyn Linear<B>>>,
661
662 pub rope: RopeCache<B>,
663 pub scratch: LlamaFamilyScratch<B>,
664
665 pub kv_caches: HashMap<String, Vec<K::Layer>>,
674 kv_free_pool: Vec<Vec<K::Layer>>,
679
680 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
692 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
696 pub paged_dims: Option<(usize, usize)>,
702
703 pub(crate) graph_warmup: usize,
707 pub(crate) graph_capture_failed: bool,
710 pub(crate) batched_graph_warmup: usize,
712 pub(crate) batched_graph_failed: bool,
714 pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
718 pub(crate) batched_pointers_for: Option<Vec<String>>,
723 pub(crate) unified_graph_warmup: usize,
728 pub(crate) unified_graph_failed: bool,
729 pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
730}
731
732impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
733 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
738 {
743 let mut ctx = B::new_context();
744 B::reset_all_graphs(&mut ctx);
745 }
746 let rope = build_rope_cache::<B>(&cfg);
747 let scratch = LlamaFamilyScratch::alloc(&cfg, 1); let embed = loader.load_tensor("model.embed_tokens.weight")?;
751
752 let mut layers = Vec::with_capacity(cfg.num_layers);
754 for li in 0..cfg.num_layers {
755 let prefix = format!("model.layers.{li}");
756 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
757 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
758 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
759 let post_ln_w =
760 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
761 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
762 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
763
764 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
765 let q = loader
766 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
767 .ok();
768 let k = loader
769 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
770 .ok();
771 (q, k)
772 } else {
773 (None, None)
774 };
775
776 layers.push(LlamaFamilyLayer {
777 input_ln_w,
778 qkv_proj,
779 q_norm_w,
780 k_norm_w,
781 o_proj,
782 post_ln_w,
783 gate_up_proj,
784 down_proj,
785 });
786 }
787
788 let final_norm_w = loader.load_tensor("model.norm.weight")?;
789
790 let lm_head = if loader.has_tensor("lm_head.weight") {
798 loader.load_linear("lm_head")?
799 } else {
800 tracing::info!(
801 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
802 );
803 let as_linear = loader.load_linear("model.embed_tokens")?;
804 if as_linear.out_features() != cfg.vocab_size
806 || as_linear.in_features() != cfg.hidden_size
807 {
808 return Err(ferrum_types::FerrumError::model(format!(
809 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
810 as_linear.out_features(),
811 as_linear.in_features(),
812 cfg.vocab_size,
813 cfg.hidden_size
814 )));
815 }
816 as_linear
817 };
818
819 let runtime_cfg = cfg.to_runtime();
820 Ok(Self {
821 cfg,
822 runtime_cfg,
823 embed: Some(embed),
824 layers,
825 final_norm_w,
826 lm_head: Some(lm_head),
827 rope,
828 scratch,
829 kv_caches: HashMap::new(),
830 kv_free_pool: Vec::new(),
831 paged_pools: None,
832 paged_block_alloc: None,
833 paged_dims: None,
834 graph_warmup: 0,
835 graph_capture_failed: false,
836 batched_graph_warmup: 0,
837 batched_graph_failed: false,
838 batched_graph_keys_seen: std::collections::HashSet::new(),
839 batched_pointers_for: None,
840 unified_graph_warmup: 0,
841 unified_graph_failed: false,
842 unified_graph_keys_seen: std::collections::HashSet::new(),
843 })
844 }
845
846 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
858 {
860 let mut ctx = B::new_context();
861 B::reset_all_graphs(&mut ctx);
862 }
863 let rope = build_rope_cache::<B>(&cfg);
864 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
865
866 let mut layers = Vec::with_capacity(cfg.num_layers);
867 for li in 0..cfg.num_layers {
868 let prefix = format!("model.layers.{li}");
869 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
870 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
871 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
872 let post_ln_w =
873 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
874 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
875 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
876
877 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
878 let q = loader
879 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
880 .ok();
881 let k = loader
882 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
883 .ok();
884 (q, k)
885 } else {
886 (None, None)
887 };
888
889 layers.push(LlamaFamilyLayer {
890 input_ln_w,
891 qkv_proj,
892 q_norm_w,
893 k_norm_w,
894 o_proj,
895 post_ln_w,
896 gate_up_proj,
897 down_proj,
898 });
899 }
900
901 let final_norm_w = loader.load_tensor("model.norm.weight")?;
902
903 let runtime_cfg = cfg.to_runtime();
904 Ok(Self {
905 cfg,
906 runtime_cfg,
907 embed: None,
908 layers,
909 final_norm_w,
910 lm_head: None,
911 rope,
912 scratch,
913 kv_caches: HashMap::new(),
914 kv_free_pool: Vec::new(),
915 paged_pools: None,
916 paged_block_alloc: None,
917 paged_dims: None,
918 graph_warmup: 0,
919 graph_capture_failed: false,
920 batched_graph_warmup: 0,
921 batched_graph_failed: false,
922 batched_graph_keys_seen: std::collections::HashSet::new(),
923 batched_pointers_for: None,
924 unified_graph_warmup: 0,
925 unified_graph_failed: false,
926 unified_graph_keys_seen: std::collections::HashSet::new(),
927 })
928 }
929
930 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
932 if self.scratch.max_tokens < tokens {
933 {
938 let mut ctx = B::new_context();
939 B::reset_all_graphs(&mut ctx);
940 }
941 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
942 self.graph_warmup = 0;
943 self.graph_capture_failed = false;
944 self.batched_graph_keys_seen.clear();
945 self.batched_graph_warmup = 0;
946 self.batched_graph_failed = false;
947 self.unified_graph_keys_seen.clear();
948 self.unified_graph_warmup = 0;
949 self.unified_graph_failed = false;
950 if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
955 self.scratch
956 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
957 }
958 }
959 }
960
961 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
965 if self.kv_caches.contains_key(cache_id) {
966 return;
967 }
968 let nkv = self.cfg.num_kv_heads;
969 let hd = self.cfg.head_dim;
970 let model_max = self.cfg.max_seq_len;
977 let runtime_env = llama_family_runtime_env();
986 let max = runtime_env.kv_capacity_for_model(model_max);
987
988 let paged = runtime_env.paged_kv_enabled::<B>();
1004 const PAGED_BLOCK_SIZE: usize = 16;
1005
1006 let max_seqs = runtime_env.paged_max_seqs;
1014 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
1015 let total_pool_blocks = max_seqs * max_blocks_per_seq;
1016
1017 if paged && self.paged_pools.is_none() {
1024 let mut pools = Vec::with_capacity(self.cfg.num_layers);
1025 for _ in 0..self.cfg.num_layers {
1026 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
1027 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
1028 }
1029 self.paged_pools = Some(pools);
1030 self.paged_block_alloc = Some(std::sync::Mutex::new(
1031 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
1032 ));
1033 }
1034 if paged {
1040 self.scratch
1041 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1042 self.paged_dims = Some((max_seqs, max_blocks_per_seq));
1045 }
1046
1047 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
1055 (0..self.cfg.num_layers)
1056 .map(|_| {
1057 if paged {
1058 K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
1059 } else {
1060 K::alloc_contig(max, nkv, hd)
1061 }
1062 })
1063 .collect()
1064 });
1065
1066 if paged {
1072 let alloc_arc = self
1073 .paged_block_alloc
1074 .as_ref()
1075 .expect("paged_block_alloc must be initialised when paged=true");
1076 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1080 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1081 Ok(idx) => idx,
1082 Err(e) => {
1083 drop(alloc);
1090 self.kv_free_pool.push(caches);
1091 eprintln!(
1092 "[ferrum] paged KV pool exhausted on ensure_kv for \
1093 cache_id={cache_id:?}: {e}. Increase \
1094 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1095 throttle concurrent requests.",
1096 );
1097 return;
1098 }
1099 };
1100 let mut padded = block_indices.clone();
1105 padded.resize(max_blocks_per_seq, 0);
1106 let mut ctx_tmp = B::new_context();
1107 for c in caches.iter_mut() {
1108 if let Some(bt) = K::block_table_mut(c) {
1109 B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1110 }
1111 *K::paged_block_indices_mut(c) = block_indices.clone();
1112 }
1113 B::sync(&mut ctx_tmp);
1114 }
1115
1116 for c in caches.iter_mut() {
1120 K::set_len(c, 0);
1121 if let Some(cl) = K::context_lens_mut(c) {
1122 let mut ctx_tmp = B::new_context();
1123 B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1124 B::sync(&mut ctx_tmp);
1125 }
1126 }
1127 self.kv_caches.insert(cache_id.to_string(), caches);
1128 }
1129
1130 #[allow(clippy::too_many_arguments)]
1135 pub(crate) fn forward_layer(
1136 &mut self,
1137 ctx: &mut B::Context,
1138 li: usize,
1139 cache_id: &str,
1140 residual: &mut B::Buffer,
1141 pos_offset: usize,
1142 tokens: usize,
1143 ) {
1144 let layer = &self.layers[li];
1145 let cfg = &self.cfg;
1146 let h = cfg.hidden_size;
1147 let nh = cfg.num_heads;
1148 let nkv = cfg.num_kv_heads;
1149 let hd = cfg.head_dim;
1150 let im = cfg.intermediate_size;
1151 let eps = cfg.rms_norm_eps;
1152 let q_dim = nh * hd;
1153 let kv_dim = nkv * hd;
1154
1155 let _t0 = if llama_family_runtime_env().decode_op_profile {
1157 B::sync(ctx);
1158 Some(std::time::Instant::now())
1159 } else {
1160 None
1161 };
1162 B::rms_norm(
1163 ctx,
1164 residual,
1165 &layer.input_ln_w,
1166 eps,
1167 &mut self.scratch.norm_out,
1168 tokens,
1169 h,
1170 );
1171 if let Some(t0) = _t0 {
1172 B::sync(ctx);
1173 NORM_TIME_US.fetch_add(
1174 t0.elapsed().as_micros() as u64,
1175 std::sync::atomic::Ordering::Relaxed,
1176 );
1177 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1178 }
1179
1180 let _t0 = if llama_family_runtime_env().decode_op_profile {
1182 B::sync(ctx);
1183 Some(std::time::Instant::now())
1184 } else {
1185 None
1186 };
1187 layer.qkv_proj.forward(
1188 ctx,
1189 &self.scratch.norm_out,
1190 &mut self.scratch.qkv_out,
1191 tokens,
1192 );
1193 if let Some(t0) = _t0 {
1194 B::sync(ctx);
1195 MATMUL_TIME_US.fetch_add(
1196 t0.elapsed().as_micros() as u64,
1197 std::sync::atomic::Ordering::Relaxed,
1198 );
1199 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1200 }
1201
1202 let qk_mode: i32 = if cfg.has_qk_norm {
1217 1
1218 } else if cfg.rope_interleaved {
1219 3
1220 } else {
1221 2
1222 };
1223 let dummy = &layer.input_ln_w;
1224 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1225 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1226
1227 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1238 if let Some(pools) = self.paged_pools.as_mut() {
1239 let pool = &mut pools[li];
1240 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1241 } else {
1242 None
1243 };
1244 let caches = self
1245 .kv_caches
1246 .get_mut(cache_id)
1247 .expect("ensure_kv must be called before forward_layer");
1248 let cache_len_before = K::len(&caches[li]);
1251 let cache_capacity = K::capacity(&caches[li]);
1252 let cache_block_size = K::block_size(&caches[li]);
1253
1254 if cache_len_before + tokens > cache_capacity {
1260 panic!(
1261 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1262 cache_len_before + tokens
1263 );
1264 }
1265
1266 if cache_block_size > 0 {
1271 let (pool_k_ptr, pool_v_ptr) =
1272 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1273 let pool_k = unsafe { &mut *pool_k_ptr };
1276 let pool_v = unsafe { &mut *pool_v_ptr };
1277
1278 K::paged_write(
1279 ctx,
1280 &mut caches[li],
1281 &self.scratch.qkv_out,
1282 q_norm_w,
1283 k_norm_w,
1284 &self.rope.cos,
1285 &self.rope.sin,
1286 &mut self.scratch.q_head_major,
1287 &mut self.scratch.k_head_major,
1288 &mut self.scratch.v_head_major,
1289 pool_k,
1290 pool_v,
1291 tokens,
1292 nh,
1293 nkv,
1294 hd,
1295 pos_offset,
1296 eps,
1297 qk_mode,
1298 )
1299 .expect("K::paged_write");
1300
1301 let new_len = cache_len_before + tokens;
1302 K::set_len(&mut caches[li], new_len);
1303
1304 let pool_k_imm = unsafe { &*pool_k_ptr };
1305 let pool_v_imm = unsafe { &*pool_v_ptr };
1306 K::paged_decode_attention(
1307 ctx,
1308 &mut caches[li],
1309 &self.scratch.q_head_major,
1310 pool_k_imm,
1311 pool_v_imm,
1312 &mut self.scratch.attn_head_major_out,
1313 nh,
1314 nkv,
1315 hd,
1316 new_len,
1317 tokens,
1318 )
1319 .expect("K::paged_decode_attention");
1320
1321 return self.forward_layer_post_attn(ctx, li, residual, tokens);
1322 }
1323
1324 let _qkr_t0 = if llama_family_runtime_env().decode_op_profile {
1327 B::sync(ctx);
1328 Some(std::time::Instant::now())
1329 } else {
1330 None
1331 };
1332 K::contig_write(
1333 ctx,
1334 &mut caches[li],
1335 &self.scratch.qkv_out,
1336 q_norm_w,
1337 k_norm_w,
1338 &self.rope.cos,
1339 &self.rope.sin,
1340 &mut self.scratch.q_head_major,
1341 &mut self.scratch.k_head_major,
1342 &mut self.scratch.v_head_major,
1343 &mut self.scratch.q_buf,
1344 &mut self.scratch.k_buf,
1345 &mut self.scratch.v_buf,
1346 tokens,
1347 nh,
1348 nkv,
1349 hd,
1350 pos_offset,
1351 eps,
1352 qk_mode,
1353 )
1354 .expect("K::contig_write");
1355 if let Some(t0) = _qkr_t0 {
1356 B::sync(ctx);
1357 QKR_TIME_US.fetch_add(
1358 t0.elapsed().as_micros() as u64,
1359 std::sync::atomic::Ordering::Relaxed,
1360 );
1361 QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1362 }
1363 let new_len = cache_len_before + tokens;
1364 K::set_len(&mut caches[li], new_len);
1365 let kv_stride = cache_capacity;
1366
1367 let _attn_t0 = if llama_family_runtime_env().decode_op_profile {
1368 B::sync(ctx);
1369 Some(std::time::Instant::now())
1370 } else {
1371 None
1372 };
1373 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1374 num_heads: nh,
1375 num_kv_heads: nkv,
1376 head_dim: hd,
1377 causal: true,
1378 scale: 1.0 / (hd as f32).sqrt(),
1379 kv_seq_stride: kv_stride,
1380 sliding_window: cfg.sliding_window,
1381 };
1382 K::contig_decode_attention(
1383 ctx,
1384 &caches[li],
1385 &self.scratch.q_head_major,
1386 &mut self.scratch.attn_head_major_out,
1387 attn_cfg,
1388 tokens,
1389 pos_offset,
1390 )
1391 .expect("K::contig_decode_attention");
1392 let _ = q_dim;
1393 let _ = kv_dim;
1394 let _ = dummy;
1395 if let Some(t0) = _attn_t0 {
1396 B::sync(ctx);
1397 ATTN_TIME_US.fetch_add(
1398 t0.elapsed().as_micros() as u64,
1399 std::sync::atomic::Ordering::Relaxed,
1400 );
1401 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1402 }
1403
1404 self.forward_layer_post_attn(ctx, li, residual, tokens);
1405 }
1406
1407 pub(crate) fn forward_layer_post_attn(
1412 &mut self,
1413 ctx: &mut B::Context,
1414 li: usize,
1415 residual: &mut B::Buffer,
1416 tokens: usize,
1417 ) {
1418 let layer = &self.layers[li];
1419 let cfg = &self.cfg;
1420 let h = cfg.hidden_size;
1421 let nh = cfg.num_heads;
1422 let hd = cfg.head_dim;
1423 let im = cfg.intermediate_size;
1424 let eps = cfg.rms_norm_eps;
1425
1426 let attn_token_major = if tokens == 1 {
1428 &self.scratch.attn_head_major_out
1429 } else {
1430 B::transpose_head_to_token(
1431 ctx,
1432 &self.scratch.attn_head_major_out,
1433 &mut self.scratch.attn_flat,
1434 tokens,
1435 nh,
1436 hd,
1437 );
1438 &self.scratch.attn_flat
1439 };
1440
1441 layer
1443 .o_proj
1444 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1445
1446 B::fused_add_rms_norm(
1448 ctx,
1449 residual,
1450 &self.scratch.o_proj_out,
1451 &layer.post_ln_w,
1452 eps,
1453 &mut self.scratch.norm_out,
1454 tokens,
1455 h,
1456 );
1457
1458 layer.gate_up_proj.forward(
1460 ctx,
1461 &self.scratch.norm_out,
1462 &mut self.scratch.gate_up_out,
1463 tokens,
1464 );
1465
1466 B::fused_silu_mul_split(
1468 ctx,
1469 &self.scratch.gate_up_out,
1470 &mut self.scratch.silu_out,
1471 tokens,
1472 im,
1473 );
1474
1475 layer.down_proj.forward(
1477 ctx,
1478 &self.scratch.silu_out,
1479 &mut self.scratch.mlp_out,
1480 tokens,
1481 );
1482
1483 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1485 }
1486
1487 pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1499 let seq_len = tokens.len();
1500 assert!(seq_len > 0, "forward_verify called with empty tokens");
1501 self.ensure_scratch(seq_len);
1502 self.ensure_kv(cache_id);
1503
1504 let h = self.cfg.hidden_size;
1505 let vocab = self.cfg.vocab_size;
1506
1507 let pos_offset = self
1508 .kv_caches
1509 .get(cache_id)
1510 .and_then(|layers| layers.first())
1511 .map(|c| K::len(c))
1512 .unwrap_or(0);
1513
1514 let mut ctx = B::new_context();
1515 let mut residual = self
1516 .scratch
1517 .residual
1518 .take()
1519 .expect("scratch residual missing (previous call didn't restore)");
1520
1521 let embed = self
1522 .embed
1523 .as_ref()
1524 .expect("forward_verify called on backbone-only model (no embed)");
1525 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1526
1527 for li in 0..self.cfg.num_layers {
1528 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1529 }
1530
1531 B::rms_norm(
1534 &mut ctx,
1535 &residual,
1536 &self.final_norm_w,
1537 self.cfg.rms_norm_eps,
1538 &mut self.scratch.norm_out,
1539 seq_len,
1540 h,
1541 );
1542
1543 let lm_head = self
1547 .lm_head
1548 .as_ref()
1549 .expect("forward_verify called on backbone-only model (no lm_head)");
1550 lm_head.forward(
1551 &mut ctx,
1552 &self.scratch.norm_out,
1553 &mut self.scratch.batch_logits,
1554 seq_len,
1555 );
1556
1557 B::sync(&mut ctx);
1558 self.scratch.residual = Some(residual);
1559 B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1560 }
1561
1562 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1570 let seq_len = tokens.len();
1571 assert!(seq_len > 0, "prefill called with empty token list");
1572 self.ensure_scratch(seq_len);
1573 self.ensure_kv(cache_id);
1574
1575 let pos_offset = self
1578 .kv_caches
1579 .get(cache_id)
1580 .and_then(|layers| layers.first())
1581 .map(|c| K::len(c))
1582 .unwrap_or(0);
1583
1584 let h = self.cfg.hidden_size;
1585 let vocab = self.cfg.vocab_size;
1586 let mut ctx = B::new_context();
1587
1588 let mut residual = self
1595 .scratch
1596 .residual
1597 .take()
1598 .expect("scratch residual missing (previous call didn't restore)");
1599 let embed = self
1600 .embed
1601 .as_ref()
1602 .expect("prefill_internal called on backbone-only model (no embed)");
1603 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1604
1605 let prefill_profile = llama_family_runtime_env().prefill_op_profile;
1606 let prefill_t0 = if prefill_profile {
1607 B::sync(&mut ctx);
1608 Some(std::time::Instant::now())
1609 } else {
1610 None
1611 };
1612
1613 for li in 0..self.cfg.num_layers {
1614 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1615 }
1616
1617 if let Some(t0) = prefill_t0 {
1618 B::sync(&mut ctx);
1619 let total_us = t0.elapsed().as_micros() as u64;
1620 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1621 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1622 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1623 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1624 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1625 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1626 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1627 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1628 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1629 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1630 eprintln!(
1631 "[prefill-profile] tokens={} layers total={} ms",
1632 seq_len,
1633 total_us / 1000
1634 );
1635 let bucket = |label: &str, n: u64, us: u64| {
1636 if n > 0 {
1637 eprintln!(
1638 "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1639 n,
1640 us / 1000,
1641 us / n
1642 );
1643 }
1644 };
1645 bucket("flash_attn", attn_n, attn_us);
1646 bucket("qk_norm_rope", qkr_n, qkr_us);
1647 bucket("matmuls", mm_n, mm_us);
1648 bucket("norms", norm_n, norm_us);
1649 bucket("other", other_n, other_us);
1650 }
1651
1652 B::copy_slice(
1654 &mut ctx,
1655 &residual,
1656 (seq_len - 1) * h,
1657 &mut self.scratch.last_hidden,
1658 0,
1659 h,
1660 );
1661
1662 B::rms_norm(
1664 &mut ctx,
1665 &self.scratch.last_hidden,
1666 &self.final_norm_w,
1667 self.cfg.rms_norm_eps,
1668 &mut self.scratch.last_normed,
1669 1,
1670 h,
1671 );
1672
1673 let lm_head = self
1675 .lm_head
1676 .as_ref()
1677 .expect("prefill_internal called on backbone-only model (no lm_head)");
1678 lm_head.forward(
1679 &mut ctx,
1680 &self.scratch.last_normed,
1681 &mut self.scratch.logits,
1682 1,
1683 );
1684
1685 B::sync(&mut ctx);
1692
1693 self.scratch.residual = Some(residual);
1695
1696 B::to_vec(&self.scratch.logits, vocab)
1697 }
1698
1699 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1701 self.ensure_scratch(1);
1702 self.ensure_kv(cache_id);
1703
1704 let h = self.cfg.hidden_size;
1705 let vocab = self.cfg.vocab_size;
1706
1707 let mut ctx = B::new_context();
1710
1711 const GRAPH_WARMUP: usize = 3;
1716 let graph_enabled = llama_family_runtime_env().cuda_graph;
1717
1718 if graph_enabled {
1719 B::set_decode_state(&mut ctx, token, pos);
1722
1723 match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
1727 Ok(true) => {
1728 B::sync(&mut ctx);
1729 return B::to_vec(&self.scratch.logits, vocab);
1730 }
1731 Ok(false) => { }
1732 Err(_) => { }
1733 }
1734 }
1735
1736 let should_capture =
1737 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1738
1739 if should_capture {
1740 B::set_dev_state_mode(&mut ctx, true);
1741 if B::begin_graph_capture(&mut ctx).is_err() {
1742 self.graph_capture_failed = true;
1743 B::set_dev_state_mode(&mut ctx, false);
1744 }
1745 }
1746
1747 let mut residual = self
1753 .scratch
1754 .residual
1755 .take()
1756 .expect("scratch residual missing (previous call didn't restore)");
1757 let embed = self
1758 .embed
1759 .as_ref()
1760 .expect("decode_internal called on backbone-only model (no embed)");
1761 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1762
1763 let layer_profile = llama_family_runtime_env().decode_layer_profile;
1767 let mut layer_times = if layer_profile {
1768 Some(Vec::with_capacity(self.cfg.num_layers))
1769 } else {
1770 None
1771 };
1772
1773 for li in 0..self.cfg.num_layers {
1774 if layer_profile {
1775 let t0 = std::time::Instant::now();
1776 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1777 B::sync(&mut ctx);
1778 let elapsed_us = t0.elapsed().as_micros() as u64;
1779 if let Some(v) = layer_times.as_mut() {
1780 v.push(elapsed_us);
1781 }
1782 } else {
1783 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1784 }
1785 }
1786 if let Some(times) = layer_times.take() {
1787 let sum: u64 = times.iter().sum();
1788 let avg = sum / times.len() as u64;
1789 let mn = *times.iter().min().unwrap_or(&0);
1790 let mx = *times.iter().max().unwrap_or(&0);
1791 eprintln!(
1792 "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1793 times.len(),
1794 sum / 1000,
1795 avg,
1796 mn,
1797 mx
1798 );
1799 for (i, t) in times.iter().enumerate() {
1800 eprint!("L{i}={}ms ", t / 1000);
1801 if (i + 1) % 6 == 0 {
1802 eprintln!();
1803 }
1804 }
1805 eprintln!();
1806 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1807 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1808 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1809 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1810 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1811 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1812 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1813 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1814 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1815 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1816 eprintln!(
1817 "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1818 attn_n,
1819 attn_us / 1000,
1820 if attn_n > 0 { attn_us / attn_n } else { 0 }
1821 );
1822 eprintln!(
1823 "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1824 qkr_n,
1825 qkr_us / 1000,
1826 if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1827 );
1828 eprintln!(
1829 "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1830 mm_n,
1831 mm_us / 1000,
1832 if mm_n > 0 { mm_us / mm_n } else { 0 }
1833 );
1834 eprintln!(
1835 "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1836 norm_n,
1837 norm_us / 1000,
1838 if norm_n > 0 { norm_us / norm_n } else { 0 }
1839 );
1840 eprintln!(
1841 "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1842 other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1843 );
1844 }
1845
1846 B::rms_norm(
1847 &mut ctx,
1848 &residual,
1849 &self.final_norm_w,
1850 self.cfg.rms_norm_eps,
1851 &mut self.scratch.last_normed,
1852 1,
1853 h,
1854 );
1855
1856 let lm_head = self
1857 .lm_head
1858 .as_ref()
1859 .expect("decode_internal called on backbone-only model (no lm_head)");
1860 lm_head.forward(
1861 &mut ctx,
1862 &self.scratch.last_normed,
1863 &mut self.scratch.logits,
1864 1,
1865 );
1866
1867 if should_capture && !self.graph_capture_failed {
1868 if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1869 self.graph_capture_failed = true;
1870 } else {
1871 if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1878 self.graph_capture_failed = true;
1879 }
1880 }
1881 B::set_dev_state_mode(&mut ctx, false);
1882 } else {
1883 self.graph_warmup += 1;
1884 }
1885
1886 B::sync(&mut ctx);
1893 self.scratch.residual = Some(residual);
1894
1895 B::to_vec(&self.scratch.logits, vocab)
1896 }
1897
1898 pub fn prefill_from_embeds(
1907 &mut self,
1908 cache_id: &str,
1909 embeds: &[f32],
1910 seq_len: usize,
1911 ) -> Vec<f32> {
1912 let h = self.cfg.hidden_size;
1913 assert_eq!(
1914 embeds.len(),
1915 seq_len * h,
1916 "embeds length {} != seq_len * hidden_size {}",
1917 embeds.len(),
1918 seq_len * h
1919 );
1920 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1921
1922 self.ensure_scratch(seq_len);
1923 self.ensure_kv(cache_id);
1924
1925 let mut ctx = B::new_context();
1926 let mut residual = self
1927 .scratch
1928 .residual
1929 .take()
1930 .expect("scratch residual missing (previous call didn't restore)");
1931
1932 let embed_buf = B::from_slice(embeds);
1934 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1935
1936 for li in 0..self.cfg.num_layers {
1937 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1938 }
1939
1940 B::copy_slice(
1941 &mut ctx,
1942 &residual,
1943 (seq_len - 1) * h,
1944 &mut self.scratch.last_hidden,
1945 0,
1946 h,
1947 );
1948 B::sync(&mut ctx);
1949 self.scratch.residual = Some(residual);
1950 B::to_vec(&self.scratch.last_hidden, h)
1951 }
1952
1953 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1957 let h = self.cfg.hidden_size;
1958 assert_eq!(
1959 embed.len(),
1960 h,
1961 "embed length {} != hidden_size {}",
1962 embed.len(),
1963 h
1964 );
1965
1966 self.ensure_scratch(1);
1967 self.ensure_kv(cache_id);
1968
1969 let mut ctx = B::new_context();
1970 let mut residual = self
1971 .scratch
1972 .residual
1973 .take()
1974 .expect("scratch residual missing (previous call didn't restore)");
1975
1976 let embed_buf = B::from_slice(embed);
1977 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1978
1979 for li in 0..self.cfg.num_layers {
1980 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1981 }
1982
1983 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1984 B::sync(&mut ctx);
1985 self.scratch.residual = Some(residual);
1986 B::to_vec(&self.scratch.last_hidden, h)
1987 }
1988
1989 pub fn prefill_all_post_norm(
2000 &mut self,
2001 cache_id: &str,
2002 embeds: &[f32],
2003 seq_len: usize,
2004 pos_offset: usize,
2005 ) -> Vec<f32> {
2006 let h = self.cfg.hidden_size;
2007 assert_eq!(
2008 embeds.len(),
2009 seq_len * h,
2010 "embeds length {} != seq_len * hidden_size {}",
2011 embeds.len(),
2012 seq_len * h
2013 );
2014 assert!(seq_len > 0);
2015
2016 self.ensure_scratch(seq_len);
2017 self.ensure_kv(cache_id);
2018
2019 let mut ctx = B::new_context();
2020 let mut residual = self
2021 .scratch
2022 .residual
2023 .take()
2024 .expect("scratch residual missing (previous call didn't restore)");
2025
2026 let embed_buf = B::from_slice(embeds);
2027 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2028
2029 for li in 0..self.cfg.num_layers {
2030 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2031 }
2032
2033 B::rms_norm(
2035 &mut ctx,
2036 &residual,
2037 &self.final_norm_w,
2038 self.cfg.rms_norm_eps,
2039 &mut self.scratch.norm_out,
2040 seq_len,
2041 h,
2042 );
2043 B::sync(&mut ctx);
2044 self.scratch.residual = Some(residual);
2045 B::to_vec(&self.scratch.norm_out, seq_len * h)
2046 }
2047
2048 pub fn decode_post_norm_from_embed(
2052 &mut self,
2053 cache_id: &str,
2054 embed: &[f32],
2055 pos: u32,
2056 ) -> Vec<f32> {
2057 let h = self.cfg.hidden_size;
2058 assert_eq!(embed.len(), h);
2059
2060 self.ensure_scratch(1);
2061 self.ensure_kv(cache_id);
2062
2063 let mut ctx = B::new_context();
2064 let mut residual = self
2065 .scratch
2066 .residual
2067 .take()
2068 .expect("scratch residual missing (previous call didn't restore)");
2069
2070 let embed_buf = B::from_slice(embed);
2071 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2072
2073 for li in 0..self.cfg.num_layers {
2074 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2075 }
2076
2077 B::rms_norm(
2078 &mut ctx,
2079 &residual,
2080 &self.final_norm_w,
2081 self.cfg.rms_norm_eps,
2082 &mut self.scratch.last_normed,
2083 1,
2084 h,
2085 );
2086 B::sync(&mut ctx);
2087 self.scratch.residual = Some(residual);
2088 B::to_vec(&self.scratch.last_normed, h)
2089 }
2090}
2091
2092impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2094 fn config(&self) -> &LlmRuntimeConfig {
2095 &self.runtime_cfg
2096 }
2097
2098 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2099 self.ensure_scratch(max_tokens);
2100 self.ensure_kv(cache_id);
2101
2102 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2103 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2104 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2105 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2106 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2107 if let Some(c0) = caches.first() {
2108 if !c0.paged_block_indices.is_empty() {
2109 alloc.free(&c0.paged_block_indices);
2110 }
2111 }
2112 for c in caches.iter_mut() {
2113 c.paged_block_indices.clear();
2114 }
2115 }
2116 self.kv_free_pool.push(caches);
2117 }
2118 }
2119
2120 fn kv_capacity(&self) -> usize {
2121 llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2122 }
2123
2124 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2125 self.prefill_internal(cache_id, tokens)
2126 }
2127
2128 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2129 self.decode_internal(cache_id, token, pos)
2130 }
2131
2132 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2133 self.decode_batch_internal(batch)
2134 }
2135
2136 fn unified_forward(
2137 &mut self,
2138 items: &[(String, Vec<u32>, usize, bool)],
2139 ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2140 if items.is_empty() {
2141 return Ok(Vec::new());
2142 }
2143 if !B::supports_varlen_qkv() {
2144 return Err(ferrum_types::FerrumError::unsupported(
2145 "LlamaFamilyModel::unified_forward: backend lacks varlen \
2146 QKV kernels. Engine will fall back to per-item dispatch.",
2147 ));
2148 }
2149 self.ensure_kv(&items[0].0);
2150 if self.paged_pools.is_none() {
2151 return Err(ferrum_types::FerrumError::unsupported(
2152 "LlamaFamilyModel::unified_forward: paged KV required; \
2153 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2154 Engine will fall back to per-item dispatch.",
2155 ));
2156 }
2157 for (cid, _, _, _) in items {
2158 self.ensure_kv(cid);
2159 if !self.kv_caches.contains_key(cid) {
2160 return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2161 "paged KV pool exhausted for cache_id={cid:?}; back off"
2162 )));
2163 }
2164 }
2165 Ok(self.unified_forward_internal(items))
2166 }
2167
2168 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2169 LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2170 }
2171
2172 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2173 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2174 for c in caches.iter_mut() {
2175 if new_len < c.len {
2176 c.len = new_len;
2177 }
2178 }
2179 }
2180 let mut ctx = B::new_context();
2181 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2182 self.graph_warmup = 0;
2183 self.graph_capture_failed = false;
2184 }
2185
2186 fn release(&mut self, cache_id: &str) {
2187 let mut ctx = B::new_context();
2188 B::sync(&mut ctx);
2189 self.graph_warmup = 0;
2190 self.graph_capture_failed = false;
2191 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2192 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2193 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2194 if let Some(c0) = caches.first() {
2195 if !c0.paged_block_indices.is_empty() {
2196 alloc.free(&c0.paged_block_indices);
2197 }
2198 }
2199 for c in caches.iter_mut() {
2200 c.paged_block_indices.clear();
2201 }
2202 }
2203 self.kv_free_pool.push(caches);
2204 }
2205 }
2206
2207 fn reset(&mut self) {
2208 let mut ctx = B::new_context();
2209 B::sync(&mut ctx);
2210 B::reset_all_graphs(&mut ctx);
2211 B::sync(&mut ctx);
2212 self.graph_warmup = 0;
2213 self.graph_capture_failed = false;
2214 self.batched_graph_keys_seen.clear();
2215 self.batched_graph_warmup = 0;
2216 self.batched_graph_failed = false;
2217 self.kv_caches.clear();
2218 self.kv_free_pool.clear();
2219 }
2220}
2221
2222impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2226 fn config(&self) -> &LlmRuntimeConfig {
2227 &self.runtime_cfg
2228 }
2229
2230 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2231 self.ensure_scratch(max_tokens);
2232 self.ensure_kv(cache_id);
2233
2234 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2235 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2236 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2237 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2238 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2239 if let Some(c0) = caches.first() {
2240 if !c0.paged_block_indices.is_empty() {
2241 alloc.free(&c0.paged_block_indices);
2242 }
2243 }
2244 for c in caches.iter_mut() {
2245 c.paged_block_indices.clear();
2246 }
2247 }
2248 self.kv_free_pool.push(caches);
2249 }
2250 }
2251
2252 fn kv_capacity(&self) -> usize {
2253 llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2254 }
2255
2256 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2257 self.prefill_internal(cache_id, tokens)
2258 }
2259
2260 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2261 self.decode_internal(cache_id, token, pos)
2262 }
2263
2264 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2267 LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
2268 }
2269
2270 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2271 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2272 for c in caches.iter_mut() {
2273 if new_len < c.len {
2274 c.len = new_len;
2275 }
2276 }
2277 }
2278 let mut ctx = B::new_context();
2279 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2280 self.graph_warmup = 0;
2281 self.graph_capture_failed = false;
2282 }
2283
2284 fn release(&mut self, cache_id: &str) {
2285 let mut ctx = B::new_context();
2286 B::sync(&mut ctx);
2287 self.graph_warmup = 0;
2288 self.graph_capture_failed = false;
2289 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2290 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2291 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2292 if let Some(c0) = caches.first() {
2293 if !c0.paged_block_indices.is_empty() {
2294 alloc.free(&c0.paged_block_indices);
2295 }
2296 }
2297 for c in caches.iter_mut() {
2298 c.paged_block_indices.clear();
2299 }
2300 }
2301 self.kv_free_pool.push(caches);
2302 }
2303 }
2304
2305 fn reset(&mut self) {
2306 let mut ctx = B::new_context();
2307 B::sync(&mut ctx);
2308 B::reset_all_graphs(&mut ctx);
2309 B::sync(&mut ctx);
2310 self.graph_warmup = 0;
2311 self.graph_capture_failed = false;
2312 self.kv_caches.clear();
2313 self.kv_free_pool.clear();
2314 }
2315}
2316
2317fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2318 let hd = cfg.head_dim;
2319 let half = hd / 2;
2320 let max = cfg.max_seq_len;
2321 let mut cos = vec![0.0f32; max * half];
2322 let mut sin = vec![0.0f32; max * half];
2323 for pos in 0..max {
2324 for i in 0..half {
2325 let freq = rope_freq(cfg, i);
2326 let angle = pos as f64 * freq;
2327 cos[pos * half + i] = angle.cos() as f32;
2328 sin[pos * half + i] = angle.sin() as f32;
2329 }
2330 }
2331 RopeCache {
2332 cos: B::from_slice(&cos),
2333 sin: B::from_slice(&sin),
2334 }
2335}
2336
2337fn rope_freq(cfg: &LlamaFamilyConfig, pair_idx: usize) -> f64 {
2338 let base_freq = 1.0f64
2339 / cfg
2340 .rope_theta
2341 .powf((2 * pair_idx) as f64 / cfg.head_dim as f64);
2342 match &cfg.rope_scaling {
2343 Some(RopeScalingConfig::Llama3 {
2344 factor,
2345 low_freq_factor,
2346 high_freq_factor,
2347 original_max_position_embeddings,
2348 }) => scale_llama3_rope_freq(
2349 base_freq,
2350 *factor,
2351 *low_freq_factor,
2352 *high_freq_factor,
2353 *original_max_position_embeddings,
2354 ),
2355 None => base_freq,
2356 }
2357}
2358
2359fn scale_llama3_rope_freq(
2360 freq: f64,
2361 factor: f64,
2362 low_freq_factor: f64,
2363 high_freq_factor: f64,
2364 original_max_position_embeddings: f64,
2365) -> f64 {
2366 let wavelen = 2.0 * std::f64::consts::PI / freq;
2367 let low_freq_wavelen = original_max_position_embeddings / low_freq_factor;
2368 let high_freq_wavelen = original_max_position_embeddings / high_freq_factor;
2369 if wavelen < high_freq_wavelen {
2370 freq
2371 } else if wavelen > low_freq_wavelen {
2372 freq / factor
2373 } else {
2374 let smooth = (original_max_position_embeddings / wavelen - low_freq_factor)
2375 / (high_freq_factor - low_freq_factor);
2376 (1.0 - smooth) * freq / factor + smooth * freq
2377 }
2378}
2379
2380#[cfg(test)]
2381mod tests {
2382 use super::{LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY};
2383
2384 #[test]
2385 fn llama_family_runtime_env_parses_startup_knobs() {
2386 let env = LlamaFamilyRuntimeEnv::from_env_vars([
2387 ("FERRUM_KV_CAPACITY", "4096"),
2388 ("FERRUM_METAL_PAGED_KV", "0"),
2389 ("FERRUM_PAGED_MAX_SEQS", "64"),
2390 ("FERRUM_DECODE_OP_PROFILE", "0"),
2391 ("FERRUM_PREFILL_OP_PROFILE", ""),
2392 ("FERRUM_CUDA_GRAPH", ""),
2393 ("FERRUM_DECODE_LAYER_PROFILE", "false"),
2394 ]);
2395
2396 assert_eq!(env.kv_capacity, Some(4096));
2397 assert_eq!(env.metal_paged_kv, Some(false));
2398 assert_eq!(env.paged_max_seqs, 64);
2399 assert!(env.decode_op_profile);
2400 assert!(env.prefill_op_profile);
2401 assert!(env.cuda_graph);
2402 assert!(env.decode_layer_profile);
2403 assert_eq!(env.kv_capacity_for_model(2048), 2048);
2404 }
2405
2406 #[test]
2407 fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
2408 let env = LlamaFamilyRuntimeEnv::from_env_vars([
2409 ("FERRUM_KV_CAPACITY", "bad"),
2410 ("FERRUM_PAGED_MAX_SEQS", "bad"),
2411 ("FERRUM_METAL_PAGED_KV", "1"),
2412 ]);
2413
2414 assert_eq!(env.kv_capacity, None);
2415 assert_eq!(env.metal_paged_kv, Some(true));
2416 assert_eq!(env.paged_max_seqs, 32);
2417 assert_eq!(
2418 env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
2419 DEFAULT_KV_CAPACITY
2420 );
2421 }
2422}