1use std::collections::HashMap;
20use std::sync::{atomic::AtomicU64, OnceLock};
21
22use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16, KvInt8};
23use ferrum_kernels::backend::{
24 Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
25 BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
26 MAX_LAYERS_FOR_GRAPH,
27};
28
29pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
32
33pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
35pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
36
37pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
38pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
39pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
41pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
43pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
45pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
47use ferrum_quantization::{Linear, WeightLoader};
48use ferrum_types::Result;
49
50use crate::common::paged_pool::block_hash_chain;
51use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
52use crate::lora::{load_runtime_lora_adapter, ActiveLoraAdapter, RuntimeLoraAdapter};
53
54const DEFAULT_KV_CAPACITY: usize = 512;
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57struct LlamaFamilyRuntimeEnv {
58 kv_capacity: Option<usize>,
59 metal_paged_kv: Option<bool>,
60 paged_max_seqs: usize,
61 decode_op_profile: bool,
62 prefill_op_profile: bool,
63 prefix_cache: bool,
64 cuda_graph: bool,
65 decode_layer_profile: bool,
66}
67
68impl LlamaFamilyRuntimeEnv {
69 fn from_env() -> Self {
70 Self::from_env_vars(std::env::vars())
71 }
72
73 fn from_env_vars<I, K, V>(vars: I) -> Self
74 where
75 I: IntoIterator<Item = (K, V)>,
76 K: AsRef<str>,
77 V: AsRef<str>,
78 {
79 let mut config = Self {
80 kv_capacity: None,
81 metal_paged_kv: None,
82 paged_max_seqs: 32,
83 decode_op_profile: false,
84 prefill_op_profile: false,
85 prefix_cache: false,
86 cuda_graph: false,
87 decode_layer_profile: false,
88 };
89 for (name, value) in vars {
90 let value = value.as_ref();
91 match name.as_ref() {
92 "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
93 "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
94 "FERRUM_PAGED_MAX_SEQS" => {
95 if let Ok(max_seqs) = value.parse::<usize>() {
96 config.paged_max_seqs = max_seqs;
97 }
98 }
99 "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
100 "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
101 "FERRUM_PREFIX_CACHE" => config.prefix_cache = value == "1",
102 "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
103 "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
104 _ => {}
105 }
106 }
107 config
108 }
109
110 fn kv_capacity_for_model(&self, model_max: usize) -> usize {
111 self.kv_capacity
112 .map(|cap| cap.min(model_max))
113 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
114 }
115
116 fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
117 self.metal_paged_kv
118 .unwrap_or_else(|| B::supports_paged_kv())
119 }
120}
121
122fn llama_family_runtime_env() -> &'static LlamaFamilyRuntimeEnv {
123 static CONFIG: OnceLock<LlamaFamilyRuntimeEnv> = OnceLock::new();
124 CONFIG.get_or_init(LlamaFamilyRuntimeEnv::from_env)
125}
126
127#[derive(Clone, Debug, PartialEq)]
128pub enum RopeScalingConfig {
129 Llama3 {
131 factor: f64,
132 low_freq_factor: f64,
133 high_freq_factor: f64,
134 original_max_position_embeddings: f64,
135 },
136}
137
138impl RopeScalingConfig {
139 pub fn llama3_default() -> Self {
140 Self::Llama3 {
141 factor: 8.0,
142 low_freq_factor: 1.0,
143 high_freq_factor: 4.0,
144 original_max_position_embeddings: 8192.0,
145 }
146 }
147}
148
149#[derive(Clone, Debug, PartialEq)]
152pub struct LlamaFamilyConfig {
153 pub hidden_size: usize,
154 pub intermediate_size: usize,
155 pub num_heads: usize,
156 pub num_kv_heads: usize,
157 pub head_dim: usize,
158 pub num_layers: usize,
159 pub vocab_size: usize,
160 pub max_seq_len: usize,
161 pub rms_norm_eps: f32,
162 pub rope_theta: f64,
163 pub rope_scaling: Option<RopeScalingConfig>,
164 pub rope_interleaved: bool,
168 pub has_qk_norm: bool,
171 pub sliding_window: usize,
174}
175
176impl LlamaFamilyConfig {
177 pub fn to_runtime(&self) -> LlmRuntimeConfig {
178 LlmRuntimeConfig {
179 hidden_size: self.hidden_size,
180 num_layers: self.num_layers,
181 num_kv_heads: self.num_kv_heads,
182 head_dim: self.head_dim,
183 vocab_size: self.vocab_size,
184 max_seq_len: self.max_seq_len,
185 }
186 }
187
188 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
192 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
193 let head_dim = def
194 .extra_params
195 .get("head_dim")
196 .and_then(|v| v.as_u64())
197 .map(|v| v as usize)
198 .unwrap_or(def.hidden_size / def.num_attention_heads);
199 let sliding_window = def
202 .extra_params
203 .get("sliding_window")
204 .and_then(|v| v.as_u64())
205 .map(|v| v as usize)
206 .unwrap_or(0);
207
208 LlamaFamilyConfigBase {
209 hidden_size: def.hidden_size,
210 intermediate_size: def.intermediate_size,
211 num_heads: def.num_attention_heads,
212 num_kv_heads,
213 head_dim,
214 num_layers: def.num_hidden_layers,
215 vocab_size: def.vocab_size,
216 max_seq_len: def.max_position_embeddings,
217 rms_norm_eps: def.norm_eps as f32,
218 rope_theta_opt: def.rope_theta,
219 rope_scaling: rope_scaling_from_model_def(def),
220 sliding_window,
221 }
222 }
223
224 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
225 Self {
226 hidden_size: b.hidden_size,
227 intermediate_size: b.intermediate_size,
228 num_heads: b.num_heads,
229 num_kv_heads: b.num_kv_heads,
230 head_dim: b.head_dim,
231 num_layers: b.num_layers,
232 vocab_size: b.vocab_size,
233 max_seq_len: b.max_seq_len,
234 rms_norm_eps: b.rms_norm_eps,
235 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
236 rope_scaling: b.rope_scaling,
237 rope_interleaved: false,
238 has_qk_norm,
239 sliding_window: b.sliding_window,
240 }
241 }
242
243 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
245 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
246 }
247
248 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
252 Self::from_base(Self::from_def_base(def), 500_000.0, false)
253 }
254
255 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
257 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
258 }
259
260 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
264 Self::from_base(Self::from_def_base(def), 10_000.0, false)
265 }
266}
267
268struct LlamaFamilyConfigBase {
269 hidden_size: usize,
270 intermediate_size: usize,
271 num_heads: usize,
272 num_kv_heads: usize,
273 head_dim: usize,
274 num_layers: usize,
275 vocab_size: usize,
276 max_seq_len: usize,
277 rms_norm_eps: f32,
278 rope_theta_opt: Option<f64>,
279 rope_scaling: Option<RopeScalingConfig>,
280 sliding_window: usize,
281}
282
283fn rope_scaling_from_model_def(
284 def: &crate::definition::ModelDefinition,
285) -> Option<RopeScalingConfig> {
286 let value = def.extra_params.get("rope_scaling")?;
287 let obj = value.as_object()?;
288 let rope_type = obj
289 .get("rope_type")
290 .or_else(|| obj.get("type"))
291 .and_then(|v| v.as_str())?;
292 if rope_type != "llama3" {
293 return None;
294 }
295 let factor = json_f64(obj.get("factor"))?;
296 let low_freq_factor = json_f64(obj.get("low_freq_factor"))?;
297 let high_freq_factor = json_f64(obj.get("high_freq_factor"))?;
298 let original_max_position_embeddings = json_f64(obj.get("original_max_position_embeddings"))
299 .or_else(|| {
300 def.extra_params
301 .get("original_max_position_embeddings")
302 .and_then(|v| json_f64(Some(v)))
303 })
304 .unwrap_or(8192.0);
305 if factor <= 0.0
306 || low_freq_factor <= 0.0
307 || high_freq_factor <= low_freq_factor
308 || original_max_position_embeddings <= 0.0
309 {
310 return None;
311 }
312 Some(RopeScalingConfig::Llama3 {
313 factor,
314 low_freq_factor,
315 high_freq_factor,
316 original_max_position_embeddings,
317 })
318}
319
320fn json_f64(value: Option<&serde_json::Value>) -> Option<f64> {
321 match value? {
322 serde_json::Value::Number(n) => n.as_f64(),
323 _ => None,
324 }
325}
326
327pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
330 pub input_ln_w: B::Buffer,
331 pub qkv_proj: Box<dyn Linear<B>>,
332 pub q_norm_w: Option<B::Buffer>,
334 pub k_norm_w: Option<B::Buffer>,
335 pub o_proj: Box<dyn Linear<B>>,
336 pub post_ln_w: B::Buffer,
337 pub gate_up_proj: Box<dyn Linear<B>>,
338 pub down_proj: Box<dyn Linear<B>>,
339}
340
341pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
343 pub cos: B::Buffer,
344 pub sin: B::Buffer,
345}
346
347pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
353 pub residual: Option<B::Buffer>,
364 pub norm_out: B::Buffer,
365 pub qkv_out: B::Buffer,
366 pub q_single: B::Buffer,
374 pub k_single: B::Buffer,
375 pub v_single: B::Buffer,
376 pub q_head_major_single: B::Buffer,
377 pub k_head_major_single: B::Buffer,
378 pub v_head_major_single: B::Buffer,
379 pub attn_head_major_single: B::Buffer,
380 pub attn_flat_single: B::Buffer,
381 pub batch_logits: B::Buffer,
384 pub q_buf: B::Buffer,
386 pub k_buf: B::Buffer,
387 pub v_buf: B::Buffer,
388 pub q_head_major: B::Buffer,
390 pub k_head_major: B::Buffer,
393 pub v_head_major: B::Buffer,
394 pub attn_head_major_out: B::Buffer,
396 pub attn_flat: B::Buffer,
398 pub o_proj_out: B::Buffer,
399 pub gate_up_out: B::Buffer,
400 pub silu_out: B::Buffer,
401 pub mlp_out: B::Buffer,
402 pub paged_batch_q: Option<B::Buffer>,
408 pub paged_batch_o: Option<B::Buffer>,
409 pub paged_batch_block_tables: Option<B::Buffer>,
413 pub paged_batch_context_lens: Option<B::Buffer>,
416 pub paged_max_blocks_per_seq: usize,
419 pub paged_max_seqs: usize,
425 pub batch_positions: B::Buffer,
430 pub batch_tokens: B::Buffer,
434 pub batch_kv_lens_pre: B::Buffer,
438 pub batch_kv_lens_post: B::Buffer,
443 pub q_normed_batched: B::Buffer,
447 pub k_normed_batched: B::Buffer,
448 pub v_normed_batched: B::Buffer,
449
450 pub unified_capacity: usize, pub unified_residual: Option<B::Buffer>,
457 pub unified_norm_out: Option<B::Buffer>,
458 pub unified_qkv_out: Option<B::Buffer>,
459 pub unified_packed_q: Option<B::Buffer>,
460 pub unified_attn_out: Option<B::Buffer>,
461 pub unified_o_proj_out: Option<B::Buffer>,
462 pub unified_gate_up_out: Option<B::Buffer>,
463 pub unified_silu_out: Option<B::Buffer>,
464 pub unified_mlp_out: Option<B::Buffer>,
465 pub unified_cu_seqlens_q: Option<B::Buffer>,
470 pub unified_pos_offsets: Option<B::Buffer>,
471 pub unified_block_tables: Option<B::Buffer>,
472 pub unified_packed_normed: Option<B::Buffer>,
475 pub unified_packed_logits: Option<B::Buffer>,
477 pub last_hidden: B::Buffer,
481 pub last_normed: B::Buffer,
483 pub logits: B::Buffer,
485 pub max_tokens: usize,
487}
488
489impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
490 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
491 let h = cfg.hidden_size;
492 let im = cfg.intermediate_size;
493 let q_dim = cfg.num_heads * cfg.head_dim;
494 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
495 let qkv_dim = q_dim + 2 * kv_dim;
496 let t = max_tokens;
497 Self {
498 residual: Some(B::alloc(t * h)),
499 norm_out: B::alloc(t * h),
500 qkv_out: B::alloc(t * qkv_dim),
501 q_buf: B::alloc(t * q_dim),
502 k_buf: B::alloc(t * kv_dim),
503 v_buf: B::alloc(t * kv_dim),
504 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
505 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
506 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
507 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
508 attn_flat: B::alloc(t * q_dim),
509 o_proj_out: B::alloc(t * h),
510 gate_up_out: B::alloc(t * 2 * im),
511 silu_out: B::alloc(t * im),
512 mlp_out: B::alloc(t * h),
513 last_hidden: B::alloc(h),
514 last_normed: B::alloc(h),
515 logits: B::alloc(cfg.vocab_size),
516 q_single: B::alloc(q_dim),
517 k_single: B::alloc(kv_dim),
518 v_single: B::alloc(kv_dim),
519 q_head_major_single: B::alloc(q_dim),
520 k_head_major_single: B::alloc(kv_dim),
521 v_head_major_single: B::alloc(kv_dim),
522 attn_head_major_single: B::alloc(q_dim),
523 attn_flat_single: B::alloc(q_dim),
524 batch_logits: B::alloc(t * cfg.vocab_size),
525 paged_batch_q: None,
530 paged_batch_o: None,
531 paged_batch_block_tables: None,
532 paged_batch_context_lens: None,
533 paged_max_blocks_per_seq: 0,
534 paged_max_seqs: 0,
535 batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
536 batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
537 batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
538 batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
539 q_normed_batched: B::alloc(t * q_dim),
540 k_normed_batched: B::alloc(t * kv_dim),
541 v_normed_batched: B::alloc(t * kv_dim),
542 unified_capacity: 0,
543 unified_residual: None,
544 unified_norm_out: None,
545 unified_qkv_out: None,
546 unified_packed_q: None,
547 unified_attn_out: None,
548 unified_o_proj_out: None,
549 unified_gate_up_out: None,
550 unified_silu_out: None,
551 unified_mlp_out: None,
552 unified_cu_seqlens_q: None,
553 unified_pos_offsets: None,
554 unified_block_tables: None,
555 unified_packed_normed: None,
556 unified_packed_logits: None,
557 max_tokens: t,
558 }
559 }
560
561 pub(crate) fn ensure_unified_scratch(
565 &mut self,
566 cfg: &LlamaFamilyConfig,
567 m_total: usize,
568 max_seqs: usize,
569 max_blocks_per_seq: usize,
570 ) {
571 if m_total <= self.unified_capacity
572 && self.unified_residual.is_some()
573 && self.unified_cu_seqlens_q.is_some()
574 {
575 return;
576 }
577 let cap = m_total.max(self.unified_capacity).max(1);
578 let h = cfg.hidden_size;
579 let q_dim = cfg.num_heads * cfg.head_dim;
580 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
581 let qkv_dim = q_dim + 2 * kv_dim;
582 let im = cfg.intermediate_size;
583 let v = cfg.vocab_size;
584 self.unified_residual = Some(B::alloc(cap * h));
585 self.unified_norm_out = Some(B::alloc(cap * h));
586 self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
587 self.unified_packed_q = Some(B::alloc(cap * q_dim));
588 self.unified_attn_out = Some(B::alloc(cap * q_dim));
589 self.unified_o_proj_out = Some(B::alloc(cap * h));
590 self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
591 self.unified_silu_out = Some(B::alloc(cap * im));
592 self.unified_mlp_out = Some(B::alloc(cap * h));
593 if self.unified_cu_seqlens_q.is_none() {
594 self.unified_cu_seqlens_q = Some(B::alloc_typed(
595 ferrum_kernels::backend::Dtype::U32,
596 max_seqs + 1,
597 ));
598 self.unified_pos_offsets = Some(B::alloc_typed(
599 ferrum_kernels::backend::Dtype::U32,
600 max_seqs,
601 ));
602 self.unified_block_tables = Some(B::alloc_typed(
603 ferrum_kernels::backend::Dtype::U32,
604 max_seqs * max_blocks_per_seq,
605 ));
606 self.unified_packed_normed = Some(B::alloc(max_seqs * h));
607 self.unified_packed_logits = Some(B::alloc(max_seqs * v));
608 }
609 self.unified_capacity = cap;
610 }
611
612 fn enable_paged_batch(
616 &mut self,
617 cfg: &LlamaFamilyConfig,
618 max_seqs: usize,
619 max_blocks_per_seq: usize,
620 ) {
621 if self.paged_batch_q.is_some() {
622 return;
623 }
624 let q_dim = cfg.num_heads * cfg.head_dim;
625 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
626 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
627 self.paged_batch_block_tables = Some(B::alloc_typed(
628 ferrum_kernels::backend::Dtype::U32,
629 max_seqs * max_blocks_per_seq,
630 ));
631 self.paged_batch_context_lens = Some(B::alloc_typed(
632 ferrum_kernels::backend::Dtype::U32,
633 max_seqs,
634 ));
635 self.paged_max_blocks_per_seq = max_blocks_per_seq;
636 self.paged_max_seqs = max_seqs;
637 }
638}
639
640pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
655 pub cfg: LlamaFamilyConfig,
656 pub runtime_cfg: LlmRuntimeConfig,
657
658 pub embed: Option<B::Buffer>,
662 pub layers: Vec<LlamaFamilyLayer<B>>,
663 pub final_norm_w: B::Buffer,
664 pub lm_head: Option<Box<dyn Linear<B>>>,
666
667 pub rope: RopeCache<B>,
668 pub scratch: LlamaFamilyScratch<B>,
669
670 pub kv_caches: HashMap<String, Vec<K::Layer>>,
679 kv_free_pool: Vec<Vec<K::Layer>>,
684
685 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
697 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
701 pub paged_dims: Option<(usize, usize)>,
707
708 pub(crate) graph_warmup: usize,
712 pub(crate) graph_capture_failed: bool,
715 pub(crate) batched_graph_warmup: usize,
717 pub(crate) batched_graph_failed: bool,
719 pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
723 pub(crate) batched_pointers_for: Option<Vec<String>>,
728 pub(crate) unified_graph_warmup: usize,
733 pub(crate) unified_graph_failed: bool,
734 pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
735
736 prefix_cache_hits: u64,
738 prefix_cache_misses: u64,
739 prefix_cache_saved_prefill_tokens: u64,
740
741 lora_adapters: HashMap<String, RuntimeLoraAdapter<B>>,
743 lora_cache_adapters: HashMap<String, String>,
744 lora_projection_applications: u64,
745}
746
747impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
748 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
753 {
758 let mut ctx = B::new_context();
759 B::reset_all_graphs(&mut ctx);
760 }
761 let rope = build_rope_cache::<B>(&cfg);
762 let scratch = LlamaFamilyScratch::alloc(&cfg, 1); let embed = loader.load_tensor("model.embed_tokens.weight")?;
766
767 let mut layers = Vec::with_capacity(cfg.num_layers);
769 for li in 0..cfg.num_layers {
770 let prefix = format!("model.layers.{li}");
771 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
772 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
773 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
774 let post_ln_w =
775 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
776 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
777 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
778
779 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
780 let q = loader
781 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
782 .ok();
783 let k = loader
784 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
785 .ok();
786 (q, k)
787 } else {
788 (None, None)
789 };
790
791 layers.push(LlamaFamilyLayer {
792 input_ln_w,
793 qkv_proj,
794 q_norm_w,
795 k_norm_w,
796 o_proj,
797 post_ln_w,
798 gate_up_proj,
799 down_proj,
800 });
801 }
802
803 let final_norm_w = loader.load_tensor("model.norm.weight")?;
804
805 let lm_head = if loader.has_tensor("lm_head.weight") {
813 loader.load_linear("lm_head")?
814 } else {
815 tracing::info!(
816 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
817 );
818 let as_linear = loader.load_linear("model.embed_tokens")?;
819 if as_linear.out_features() != cfg.vocab_size
821 || as_linear.in_features() != cfg.hidden_size
822 {
823 return Err(ferrum_types::FerrumError::model(format!(
824 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
825 as_linear.out_features(),
826 as_linear.in_features(),
827 cfg.vocab_size,
828 cfg.hidden_size
829 )));
830 }
831 as_linear
832 };
833
834 let runtime_cfg = cfg.to_runtime();
835 Ok(Self {
836 cfg,
837 runtime_cfg,
838 embed: Some(embed),
839 layers,
840 final_norm_w,
841 lm_head: Some(lm_head),
842 rope,
843 scratch,
844 kv_caches: HashMap::new(),
845 kv_free_pool: Vec::new(),
846 paged_pools: None,
847 paged_block_alloc: None,
848 paged_dims: None,
849 graph_warmup: 0,
850 graph_capture_failed: false,
851 batched_graph_warmup: 0,
852 batched_graph_failed: false,
853 batched_graph_keys_seen: std::collections::HashSet::new(),
854 batched_pointers_for: None,
855 unified_graph_warmup: 0,
856 unified_graph_failed: false,
857 unified_graph_keys_seen: std::collections::HashSet::new(),
858 prefix_cache_hits: 0,
859 prefix_cache_misses: 0,
860 prefix_cache_saved_prefill_tokens: 0,
861 lora_adapters: HashMap::new(),
862 lora_cache_adapters: HashMap::new(),
863 lora_projection_applications: 0,
864 })
865 }
866
867 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
879 {
881 let mut ctx = B::new_context();
882 B::reset_all_graphs(&mut ctx);
883 }
884 let rope = build_rope_cache::<B>(&cfg);
885 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
886
887 let mut layers = Vec::with_capacity(cfg.num_layers);
888 for li in 0..cfg.num_layers {
889 let prefix = format!("model.layers.{li}");
890 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
891 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
892 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
893 let post_ln_w =
894 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
895 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
896 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
897
898 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
899 let q = loader
900 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
901 .ok();
902 let k = loader
903 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
904 .ok();
905 (q, k)
906 } else {
907 (None, None)
908 };
909
910 layers.push(LlamaFamilyLayer {
911 input_ln_w,
912 qkv_proj,
913 q_norm_w,
914 k_norm_w,
915 o_proj,
916 post_ln_w,
917 gate_up_proj,
918 down_proj,
919 });
920 }
921
922 let final_norm_w = loader.load_tensor("model.norm.weight")?;
923
924 let runtime_cfg = cfg.to_runtime();
925 Ok(Self {
926 cfg,
927 runtime_cfg,
928 embed: None,
929 layers,
930 final_norm_w,
931 lm_head: None,
932 rope,
933 scratch,
934 kv_caches: HashMap::new(),
935 kv_free_pool: Vec::new(),
936 paged_pools: None,
937 paged_block_alloc: None,
938 paged_dims: None,
939 graph_warmup: 0,
940 graph_capture_failed: false,
941 batched_graph_warmup: 0,
942 batched_graph_failed: false,
943 batched_graph_keys_seen: std::collections::HashSet::new(),
944 batched_pointers_for: None,
945 unified_graph_warmup: 0,
946 unified_graph_failed: false,
947 unified_graph_keys_seen: std::collections::HashSet::new(),
948 prefix_cache_hits: 0,
949 prefix_cache_misses: 0,
950 prefix_cache_saved_prefill_tokens: 0,
951 lora_adapters: HashMap::new(),
952 lora_cache_adapters: HashMap::new(),
953 lora_projection_applications: 0,
954 })
955 }
956
957 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
959 if self.scratch.max_tokens < tokens {
960 {
965 let mut ctx = B::new_context();
966 B::reset_all_graphs(&mut ctx);
967 }
968 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
969 self.graph_warmup = 0;
970 self.graph_capture_failed = false;
971 self.batched_graph_keys_seen.clear();
972 self.batched_graph_warmup = 0;
973 self.batched_graph_failed = false;
974 self.unified_graph_keys_seen.clear();
975 self.unified_graph_warmup = 0;
976 self.unified_graph_failed = false;
977 if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
982 self.scratch
983 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
984 }
985 }
986 }
987
988 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
992 if self.kv_caches.contains_key(cache_id) {
993 return;
994 }
995 let nkv = self.cfg.num_kv_heads;
996 let hd = self.cfg.head_dim;
997 let model_max = self.cfg.max_seq_len;
1004 let runtime_env = llama_family_runtime_env();
1013 let max = runtime_env.kv_capacity_for_model(model_max);
1014
1015 let paged = runtime_env.paged_kv_enabled::<B>();
1031 const PAGED_BLOCK_SIZE: usize = 16;
1032
1033 let max_seqs = runtime_env.paged_max_seqs;
1041 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
1042 let total_pool_blocks = max_seqs * max_blocks_per_seq;
1043
1044 if paged && self.paged_pools.is_none() {
1051 let mut pools = Vec::with_capacity(self.cfg.num_layers);
1052 for _ in 0..self.cfg.num_layers {
1053 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
1054 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
1055 }
1056 self.paged_pools = Some(pools);
1057 self.paged_block_alloc = Some(std::sync::Mutex::new(
1058 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
1059 ));
1060 }
1061 if paged {
1067 self.scratch
1068 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1069 self.paged_dims = Some((max_seqs, max_blocks_per_seq));
1072 }
1073
1074 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
1082 (0..self.cfg.num_layers)
1083 .map(|_| {
1084 if paged {
1085 K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
1086 } else {
1087 K::alloc_contig(max, nkv, hd)
1088 }
1089 })
1090 .collect()
1091 });
1092
1093 if paged {
1099 let alloc_arc = self
1100 .paged_block_alloc
1101 .as_ref()
1102 .expect("paged_block_alloc must be initialised when paged=true");
1103 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1107 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1108 Ok(idx) => idx,
1109 Err(e) => {
1110 drop(alloc);
1117 self.kv_free_pool.push(caches);
1118 eprintln!(
1119 "[ferrum] paged KV pool exhausted on ensure_kv for \
1120 cache_id={cache_id:?}: {e}. Increase \
1121 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1122 throttle concurrent requests.",
1123 );
1124 return;
1125 }
1126 };
1127 let mut padded = block_indices.clone();
1132 padded.resize(max_blocks_per_seq, 0);
1133 let mut ctx_tmp = B::new_context();
1134 for c in caches.iter_mut() {
1135 if let Some(bt) = K::block_table_mut(c) {
1136 B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1137 }
1138 *K::paged_block_indices_mut(c) = block_indices.clone();
1139 }
1140 B::sync(&mut ctx_tmp);
1141 }
1142
1143 for c in caches.iter_mut() {
1147 K::set_len(c, 0);
1148 if let Some(cl) = K::context_lens_mut(c) {
1149 let mut ctx_tmp = B::new_context();
1150 B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1151 B::sync(&mut ctx_tmp);
1152 }
1153 }
1154 self.kv_caches.insert(cache_id.to_string(), caches);
1155 }
1156
1157 fn record_prefix_cache_probe(&mut self, saved_tokens: usize) {
1158 if saved_tokens > 0 {
1159 self.prefix_cache_hits += 1;
1160 self.prefix_cache_saved_prefill_tokens += saved_tokens as u64;
1161 } else {
1162 self.prefix_cache_misses += 1;
1163 }
1164 }
1165
1166 fn try_acquire_prefix_cache(&mut self, cache_id: &str, tokens: &[u32]) -> usize {
1167 let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1168 return 0;
1169 };
1170 let caches = match self.kv_caches.get(cache_id) {
1171 Some(caches) => caches,
1172 None => return 0,
1173 };
1174 let block_size = caches.first().map(K::block_size).unwrap_or(0);
1175 if block_size == 0 {
1176 return 0;
1177 }
1178
1179 let token_ids: Vec<ferrum_types::TokenId> = tokens
1180 .iter()
1181 .map(|&token| ferrum_types::TokenId::new(token))
1182 .collect();
1183 let hashes = block_hash_chain(&token_ids, block_size);
1184 if hashes.is_empty() {
1185 return 0;
1186 }
1187
1188 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1189 let mut matched = Vec::with_capacity(hashes.len());
1190 for hash in hashes {
1191 match alloc.try_acquire_by_hash(hash) {
1192 Some(block) => matched.push(block),
1193 None => break,
1194 }
1195 }
1196 if matched.is_empty() {
1197 return 0;
1198 }
1199 let n_matched = matched.len();
1200
1201 let displaced = caches
1202 .first()
1203 .map(|cache| K::paged_block_indices(cache)[..n_matched].to_vec())
1204 .unwrap_or_default();
1205 alloc.free(&displaced);
1206 drop(alloc);
1207
1208 let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
1209 let max_blocks = caches_mut
1210 .first()
1211 .map(|cache| K::paged_block_indices(cache).len())
1212 .unwrap_or(0);
1213 let new_len = n_matched * block_size;
1214 let mut ctx = B::new_context();
1215 for cache in caches_mut.iter_mut() {
1216 {
1217 let indices = K::paged_block_indices_mut(cache);
1218 for (idx, &block) in matched.iter().enumerate() {
1219 indices[idx] = block;
1220 }
1221 }
1222 K::set_len(cache, new_len);
1223 let padded = {
1224 let mut padded = K::paged_block_indices(cache).to_vec();
1225 padded.resize(max_blocks, 0);
1226 padded
1227 };
1228 if let Some(block_table) = K::block_table_mut(cache) {
1229 B::write_typed::<u32>(&mut ctx, block_table, &padded);
1230 }
1231 if let Some(context_lens) = K::context_lens_mut(cache) {
1232 B::write_typed::<u32>(&mut ctx, context_lens, &[new_len as u32]);
1233 }
1234 }
1235 B::sync(&mut ctx);
1236
1237 new_len
1238 }
1239
1240 fn register_prefix_cache(
1241 &mut self,
1242 cache_id: &str,
1243 all_tokens: &[u32],
1244 prior_cached_tokens: usize,
1245 ) {
1246 let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1247 return;
1248 };
1249 let caches = match self.kv_caches.get(cache_id) {
1250 Some(caches) => caches,
1251 None => return,
1252 };
1253 let cache0 = match caches.first() {
1254 Some(cache) => cache,
1255 None => return,
1256 };
1257 let block_size = K::block_size(cache0);
1258 if block_size == 0 {
1259 return;
1260 }
1261
1262 let token_ids: Vec<ferrum_types::TokenId> = all_tokens
1263 .iter()
1264 .map(|&token| ferrum_types::TokenId::new(token))
1265 .collect();
1266 let hashes = block_hash_chain(&token_ids, block_size);
1267 if hashes.is_empty() {
1268 return;
1269 }
1270
1271 let start_block = prior_cached_tokens / block_size;
1272 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1273 for i in start_block..hashes.len().min(K::paged_block_indices(cache0).len()) {
1274 let block_end_token = (i + 1) * block_size;
1275 if block_end_token > K::len(cache0) {
1276 break;
1277 }
1278 alloc.register_block_hash(K::paged_block_indices(cache0)[i], hashes[i]);
1279 }
1280 }
1281
1282 fn prefix_cache_snapshot_json(&self) -> serde_json::Value {
1283 let (entries, block_size) = self
1284 .paged_block_alloc
1285 .as_ref()
1286 .and_then(|alloc| {
1287 let alloc = alloc.lock().ok()?;
1288 let block_size = self
1289 .kv_caches
1290 .values()
1291 .find_map(|layers| layers.first().map(K::block_size))
1292 .unwrap_or(16);
1293 Some((alloc.hash_table_size() as u64, block_size))
1294 })
1295 .unwrap_or((0, 16));
1296 let bytes_per_entry = (block_size
1297 * self.cfg.num_layers
1298 * self.cfg.num_kv_heads
1299 * self.cfg.head_dim
1300 * K::BYTES_PER_ELEM
1301 * 2) as u64;
1302 serde_json::json!({
1303 "position": "real-kv-reuse",
1304 "source": "llama-family-paged-block-prefix-cache",
1305 "enabled": llama_family_runtime_env().prefix_cache,
1306 "hits": self.prefix_cache_hits,
1307 "misses": self.prefix_cache_misses,
1308 "evictions": 0u64,
1309 "saved_prefill_tokens": self.prefix_cache_saved_prefill_tokens,
1310 "entries": entries,
1311 "bytes": entries.saturating_mul(bytes_per_entry),
1312 "block_size": block_size,
1313 "kv_dtype": K::NAME,
1314 })
1315 }
1316
1317 fn lora_projection_shape(
1318 &self,
1319 layer_index: usize,
1320 target_module: &str,
1321 ) -> Option<(usize, usize)> {
1322 let layer = self.layers.get(layer_index)?;
1323 match target_module {
1324 "qkv_proj" => Some((layer.qkv_proj.in_features(), layer.qkv_proj.out_features())),
1325 "o_proj" => Some((layer.o_proj.in_features(), layer.o_proj.out_features())),
1326 "gate_up_proj" => Some((
1327 layer.gate_up_proj.in_features(),
1328 layer.gate_up_proj.out_features(),
1329 )),
1330 "down_proj" => Some((
1331 layer.down_proj.in_features(),
1332 layer.down_proj.out_features(),
1333 )),
1334 _ => None,
1335 }
1336 }
1337
1338 fn validate_lora_adapter(&self, adapter: &RuntimeLoraAdapter<B>) -> Result<()> {
1339 if adapter.linears.is_empty() {
1340 return Err(ferrum_types::FerrumError::config(format!(
1341 "LoRA adapter {} has no runtime tensors",
1342 adapter.name
1343 )));
1344 }
1345 for linear in &adapter.linears {
1346 let layer_index = linear.layer_index.ok_or_else(|| {
1347 ferrum_types::FerrumError::config(format!(
1348 "LoRA tensor for target {} must include model.layers.<N> in its tensor name",
1349 linear.target_module
1350 ))
1351 })?;
1352 let Some((expected_in, expected_out)) =
1353 self.lora_projection_shape(layer_index, &linear.target_module)
1354 else {
1355 return Err(ferrum_types::FerrumError::unsupported(format!(
1356 "LoRA target {} is not supported by Llama-family runtime; supported targets: qkv_proj, o_proj, gate_up_proj, down_proj",
1357 linear.target_module
1358 )));
1359 };
1360 if linear.in_features != expected_in || linear.out_features != expected_out {
1361 return Err(ferrum_types::FerrumError::config(format!(
1362 "LoRA tensor shape mismatch for layer {} target {}: got out={} in={}, expected out={} in={}",
1363 layer_index,
1364 linear.target_module,
1365 linear.out_features,
1366 linear.in_features,
1367 expected_out,
1368 expected_in
1369 )));
1370 }
1371 }
1372 Ok(())
1373 }
1374
1375 fn ensure_lora_adapter_loaded(&mut self, adapter: ActiveLoraAdapter) -> Result<()> {
1376 if self.lora_adapters.contains_key(&adapter.name) {
1377 return Ok(());
1378 }
1379 let runtime = load_runtime_lora_adapter::<B>(&adapter)?;
1380 self.validate_lora_adapter(&runtime)?;
1381 self.lora_adapters.insert(adapter.name.clone(), runtime);
1382 Ok(())
1383 }
1384
1385 fn active_lora_adapter_for_cache(&self, cache_id: &str) -> Option<&RuntimeLoraAdapter<B>> {
1386 let adapter_name = self.lora_cache_adapters.get(cache_id)?;
1387 self.lora_adapters.get(adapter_name)
1388 }
1389
1390 fn active_lora_adapter_ptr_for_cache(
1391 &self,
1392 cache_id: &str,
1393 ) -> Option<*const RuntimeLoraAdapter<B>> {
1394 self.active_lora_adapter_for_cache(cache_id)
1395 .map(|adapter| adapter as *const RuntimeLoraAdapter<B>)
1396 }
1397
1398 fn lora_metrics_snapshot_json(&self) -> serde_json::Value {
1399 serde_json::json!({
1400 "enabled": !self.lora_adapters.is_empty(),
1401 "adapter_count": self.lora_adapters.len() as u64,
1402 "active_cache_bindings": self.lora_cache_adapters.len() as u64,
1403 "projection_applications": self.lora_projection_applications,
1404 "position": "real-inference",
1405 "source": "llama-family-runtime-lora",
1406 })
1407 }
1408
1409 #[allow(clippy::too_many_arguments)]
1414 pub(crate) fn forward_layer(
1415 &mut self,
1416 ctx: &mut B::Context,
1417 li: usize,
1418 cache_id: &str,
1419 residual: &mut B::Buffer,
1420 pos_offset: usize,
1421 tokens: usize,
1422 ) {
1423 let layer = &self.layers[li];
1424 let cfg = &self.cfg;
1425 let h = cfg.hidden_size;
1426 let nh = cfg.num_heads;
1427 let nkv = cfg.num_kv_heads;
1428 let hd = cfg.head_dim;
1429 let im = cfg.intermediate_size;
1430 let eps = cfg.rms_norm_eps;
1431 let q_dim = nh * hd;
1432 let kv_dim = nkv * hd;
1433
1434 let _t0 = if llama_family_runtime_env().decode_op_profile {
1436 B::sync(ctx);
1437 Some(std::time::Instant::now())
1438 } else {
1439 None
1440 };
1441 B::rms_norm(
1442 ctx,
1443 residual,
1444 &layer.input_ln_w,
1445 eps,
1446 &mut self.scratch.norm_out,
1447 tokens,
1448 h,
1449 );
1450 if let Some(t0) = _t0 {
1451 B::sync(ctx);
1452 NORM_TIME_US.fetch_add(
1453 t0.elapsed().as_micros() as u64,
1454 std::sync::atomic::Ordering::Relaxed,
1455 );
1456 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1457 }
1458
1459 let _t0 = if llama_family_runtime_env().decode_op_profile {
1461 B::sync(ctx);
1462 Some(std::time::Instant::now())
1463 } else {
1464 None
1465 };
1466 layer.qkv_proj.forward(
1467 ctx,
1468 &self.scratch.norm_out,
1469 &mut self.scratch.qkv_out,
1470 tokens,
1471 );
1472 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1473 let applied = unsafe { &*adapter }
1476 .apply_projection(
1477 ctx,
1478 li,
1479 "qkv_proj",
1480 &self.scratch.norm_out,
1481 &mut self.scratch.qkv_out,
1482 tokens,
1483 )
1484 .expect("validated LoRA qkv_proj");
1485 self.lora_projection_applications += applied as u64;
1486 }
1487 if let Some(t0) = _t0 {
1488 B::sync(ctx);
1489 MATMUL_TIME_US.fetch_add(
1490 t0.elapsed().as_micros() as u64,
1491 std::sync::atomic::Ordering::Relaxed,
1492 );
1493 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1494 }
1495
1496 let qk_mode: i32 = if cfg.has_qk_norm {
1511 1
1512 } else if cfg.rope_interleaved {
1513 3
1514 } else {
1515 2
1516 };
1517 let dummy = &layer.input_ln_w;
1518 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1519 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1520
1521 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1532 if let Some(pools) = self.paged_pools.as_mut() {
1533 let pool = &mut pools[li];
1534 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1535 } else {
1536 None
1537 };
1538 let caches = self
1539 .kv_caches
1540 .get_mut(cache_id)
1541 .expect("ensure_kv must be called before forward_layer");
1542 let cache_len_before = K::len(&caches[li]);
1545 let cache_capacity = K::capacity(&caches[li]);
1546 let cache_block_size = K::block_size(&caches[li]);
1547
1548 if cache_len_before + tokens > cache_capacity {
1554 panic!(
1555 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1556 cache_len_before + tokens
1557 );
1558 }
1559
1560 if cache_block_size > 0 {
1565 let (pool_k_ptr, pool_v_ptr) =
1566 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1567 let pool_k = unsafe { &mut *pool_k_ptr };
1570 let pool_v = unsafe { &mut *pool_v_ptr };
1571
1572 K::paged_write(
1573 ctx,
1574 &mut caches[li],
1575 &self.scratch.qkv_out,
1576 q_norm_w,
1577 k_norm_w,
1578 &self.rope.cos,
1579 &self.rope.sin,
1580 &mut self.scratch.q_head_major,
1581 &mut self.scratch.k_head_major,
1582 &mut self.scratch.v_head_major,
1583 pool_k,
1584 pool_v,
1585 tokens,
1586 nh,
1587 nkv,
1588 hd,
1589 pos_offset,
1590 eps,
1591 qk_mode,
1592 )
1593 .expect("K::paged_write");
1594
1595 let new_len = cache_len_before + tokens;
1596 K::set_len(&mut caches[li], new_len);
1597
1598 let pool_k_imm = unsafe { &*pool_k_ptr };
1599 let pool_v_imm = unsafe { &*pool_v_ptr };
1600 K::paged_decode_attention(
1601 ctx,
1602 &mut caches[li],
1603 &self.scratch.q_head_major,
1604 pool_k_imm,
1605 pool_v_imm,
1606 &mut self.scratch.attn_head_major_out,
1607 nh,
1608 nkv,
1609 hd,
1610 new_len,
1611 tokens,
1612 )
1613 .expect("K::paged_decode_attention");
1614
1615 return self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1616 }
1617
1618 let _qkr_t0 = if llama_family_runtime_env().decode_op_profile {
1621 B::sync(ctx);
1622 Some(std::time::Instant::now())
1623 } else {
1624 None
1625 };
1626 K::contig_write(
1627 ctx,
1628 &mut caches[li],
1629 &self.scratch.qkv_out,
1630 q_norm_w,
1631 k_norm_w,
1632 &self.rope.cos,
1633 &self.rope.sin,
1634 &mut self.scratch.q_head_major,
1635 &mut self.scratch.k_head_major,
1636 &mut self.scratch.v_head_major,
1637 &mut self.scratch.q_buf,
1638 &mut self.scratch.k_buf,
1639 &mut self.scratch.v_buf,
1640 tokens,
1641 nh,
1642 nkv,
1643 hd,
1644 pos_offset,
1645 eps,
1646 qk_mode,
1647 )
1648 .expect("K::contig_write");
1649 if let Some(t0) = _qkr_t0 {
1650 B::sync(ctx);
1651 QKR_TIME_US.fetch_add(
1652 t0.elapsed().as_micros() as u64,
1653 std::sync::atomic::Ordering::Relaxed,
1654 );
1655 QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1656 }
1657 let new_len = cache_len_before + tokens;
1658 K::set_len(&mut caches[li], new_len);
1659 let kv_stride = cache_capacity;
1660
1661 let _attn_t0 = if llama_family_runtime_env().decode_op_profile {
1662 B::sync(ctx);
1663 Some(std::time::Instant::now())
1664 } else {
1665 None
1666 };
1667 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1668 num_heads: nh,
1669 num_kv_heads: nkv,
1670 head_dim: hd,
1671 causal: true,
1672 scale: 1.0 / (hd as f32).sqrt(),
1673 kv_seq_stride: kv_stride,
1674 sliding_window: cfg.sliding_window,
1675 };
1676 K::contig_decode_attention(
1677 ctx,
1678 &caches[li],
1679 &self.scratch.q_head_major,
1680 &mut self.scratch.attn_head_major_out,
1681 attn_cfg,
1682 tokens,
1683 pos_offset,
1684 )
1685 .expect("K::contig_decode_attention");
1686 let _ = q_dim;
1687 let _ = kv_dim;
1688 let _ = dummy;
1689 if let Some(t0) = _attn_t0 {
1690 B::sync(ctx);
1691 ATTN_TIME_US.fetch_add(
1692 t0.elapsed().as_micros() as u64,
1693 std::sync::atomic::Ordering::Relaxed,
1694 );
1695 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1696 }
1697
1698 self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1699 }
1700
1701 pub(crate) fn forward_layer_post_attn(
1706 &mut self,
1707 ctx: &mut B::Context,
1708 li: usize,
1709 cache_id: &str,
1710 residual: &mut B::Buffer,
1711 tokens: usize,
1712 ) {
1713 let layer = &self.layers[li];
1714 let cfg = &self.cfg;
1715 let h = cfg.hidden_size;
1716 let nh = cfg.num_heads;
1717 let hd = cfg.head_dim;
1718 let im = cfg.intermediate_size;
1719 let eps = cfg.rms_norm_eps;
1720
1721 let attn_token_major = if tokens == 1 {
1723 &self.scratch.attn_head_major_out
1724 } else {
1725 B::transpose_head_to_token(
1726 ctx,
1727 &self.scratch.attn_head_major_out,
1728 &mut self.scratch.attn_flat,
1729 tokens,
1730 nh,
1731 hd,
1732 );
1733 &self.scratch.attn_flat
1734 };
1735
1736 layer
1738 .o_proj
1739 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1740 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1741 let applied = unsafe { &*adapter }
1743 .apply_projection(
1744 ctx,
1745 li,
1746 "o_proj",
1747 attn_token_major,
1748 &mut self.scratch.o_proj_out,
1749 tokens,
1750 )
1751 .expect("validated LoRA o_proj");
1752 self.lora_projection_applications += applied as u64;
1753 }
1754
1755 B::fused_add_rms_norm(
1757 ctx,
1758 residual,
1759 &self.scratch.o_proj_out,
1760 &layer.post_ln_w,
1761 eps,
1762 &mut self.scratch.norm_out,
1763 tokens,
1764 h,
1765 );
1766
1767 layer.gate_up_proj.forward(
1769 ctx,
1770 &self.scratch.norm_out,
1771 &mut self.scratch.gate_up_out,
1772 tokens,
1773 );
1774 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1775 let applied = unsafe { &*adapter }
1777 .apply_projection(
1778 ctx,
1779 li,
1780 "gate_up_proj",
1781 &self.scratch.norm_out,
1782 &mut self.scratch.gate_up_out,
1783 tokens,
1784 )
1785 .expect("validated LoRA gate_up_proj");
1786 self.lora_projection_applications += applied as u64;
1787 }
1788
1789 B::fused_silu_mul_split(
1791 ctx,
1792 &self.scratch.gate_up_out,
1793 &mut self.scratch.silu_out,
1794 tokens,
1795 im,
1796 );
1797
1798 layer.down_proj.forward(
1800 ctx,
1801 &self.scratch.silu_out,
1802 &mut self.scratch.mlp_out,
1803 tokens,
1804 );
1805 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1806 let applied = unsafe { &*adapter }
1808 .apply_projection(
1809 ctx,
1810 li,
1811 "down_proj",
1812 &self.scratch.silu_out,
1813 &mut self.scratch.mlp_out,
1814 tokens,
1815 )
1816 .expect("validated LoRA down_proj");
1817 self.lora_projection_applications += applied as u64;
1818 }
1819
1820 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1822 }
1823
1824 pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1836 let seq_len = tokens.len();
1837 assert!(seq_len > 0, "forward_verify called with empty tokens");
1838 self.ensure_scratch(seq_len);
1839 self.ensure_kv(cache_id);
1840
1841 let h = self.cfg.hidden_size;
1842 let vocab = self.cfg.vocab_size;
1843
1844 let pos_offset = self
1845 .kv_caches
1846 .get(cache_id)
1847 .and_then(|layers| layers.first())
1848 .map(|c| K::len(c))
1849 .unwrap_or(0);
1850
1851 let mut ctx = B::new_context();
1852 let mut residual = self
1853 .scratch
1854 .residual
1855 .take()
1856 .expect("scratch residual missing (previous call didn't restore)");
1857
1858 let embed = self
1859 .embed
1860 .as_ref()
1861 .expect("forward_verify called on backbone-only model (no embed)");
1862 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1863
1864 for li in 0..self.cfg.num_layers {
1865 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1866 }
1867
1868 B::rms_norm(
1871 &mut ctx,
1872 &residual,
1873 &self.final_norm_w,
1874 self.cfg.rms_norm_eps,
1875 &mut self.scratch.norm_out,
1876 seq_len,
1877 h,
1878 );
1879
1880 let lm_head = self
1884 .lm_head
1885 .as_ref()
1886 .expect("forward_verify called on backbone-only model (no lm_head)");
1887 lm_head.forward(
1888 &mut ctx,
1889 &self.scratch.norm_out,
1890 &mut self.scratch.batch_logits,
1891 seq_len,
1892 );
1893
1894 B::sync(&mut ctx);
1895 self.scratch.residual = Some(residual);
1896 B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1897 }
1898
1899 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1907 assert!(!tokens.is_empty(), "prefill called with empty token list");
1908 self.ensure_kv(cache_id);
1909
1910 let cache_len_before = self
1911 .kv_caches
1912 .get(cache_id)
1913 .and_then(|layers| layers.first())
1914 .map(K::len)
1915 .unwrap_or(0);
1916 let mut cached_prefix_tokens =
1917 if llama_family_runtime_env().prefix_cache && cache_len_before == 0 {
1918 self.try_acquire_prefix_cache(cache_id, tokens)
1919 } else {
1920 0
1921 };
1922 if cached_prefix_tokens >= tokens.len() {
1923 let block_size = self
1924 .kv_caches
1925 .get(cache_id)
1926 .and_then(|layers| layers.first())
1927 .map(K::block_size)
1928 .unwrap_or(16);
1929 cached_prefix_tokens = cached_prefix_tokens
1930 .saturating_sub(block_size)
1931 .min(tokens.len() - 1);
1932 }
1933 if llama_family_runtime_env().prefix_cache && cache_len_before == 0 {
1934 self.record_prefix_cache_probe(cached_prefix_tokens);
1935 }
1936
1937 if cached_prefix_tokens > 0 {
1938 let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
1939 let mut ctx_tmp = B::new_context();
1940 for cache in caches_mut.iter_mut() {
1941 if K::len(cache) != cached_prefix_tokens {
1942 K::set_len(cache, cached_prefix_tokens);
1943 if let Some(context_lens) = K::context_lens_mut(cache) {
1944 B::write_typed::<u32>(
1945 &mut ctx_tmp,
1946 context_lens,
1947 &[cached_prefix_tokens as u32],
1948 );
1949 }
1950 }
1951 }
1952 B::sync(&mut ctx_tmp);
1953 }
1954
1955 let suffix_tokens = &tokens[cached_prefix_tokens..];
1956 let seq_len = suffix_tokens.len();
1957 assert!(
1958 seq_len > 0,
1959 "prefix cache must leave at least one suffix token"
1960 );
1961 self.ensure_scratch(seq_len);
1962
1963 let pos_offset = self
1966 .kv_caches
1967 .get(cache_id)
1968 .and_then(|layers| layers.first())
1969 .map(|c| K::len(c))
1970 .unwrap_or(0);
1971
1972 let h = self.cfg.hidden_size;
1973 let vocab = self.cfg.vocab_size;
1974 let mut ctx = B::new_context();
1975
1976 let mut residual = self
1983 .scratch
1984 .residual
1985 .take()
1986 .expect("scratch residual missing (previous call didn't restore)");
1987 let embed = self
1988 .embed
1989 .as_ref()
1990 .expect("prefill_internal called on backbone-only model (no embed)");
1991 B::embedding_lookup(&mut ctx, embed, suffix_tokens, &mut residual, h);
1992
1993 let prefill_profile = llama_family_runtime_env().prefill_op_profile;
1994 let prefill_t0 = if prefill_profile {
1995 B::sync(&mut ctx);
1996 Some(std::time::Instant::now())
1997 } else {
1998 None
1999 };
2000
2001 for li in 0..self.cfg.num_layers {
2002 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2003 }
2004
2005 if let Some(t0) = prefill_t0 {
2006 B::sync(&mut ctx);
2007 let total_us = t0.elapsed().as_micros() as u64;
2008 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2009 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2010 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2011 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2012 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2013 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2014 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2015 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2016 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2017 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2018 eprintln!(
2019 "[prefill-profile] tokens={} layers total={} ms",
2020 seq_len,
2021 total_us / 1000
2022 );
2023 let bucket = |label: &str, n: u64, us: u64| {
2024 if n > 0 {
2025 eprintln!(
2026 "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
2027 n,
2028 us / 1000,
2029 us / n
2030 );
2031 }
2032 };
2033 bucket("flash_attn", attn_n, attn_us);
2034 bucket("qk_norm_rope", qkr_n, qkr_us);
2035 bucket("matmuls", mm_n, mm_us);
2036 bucket("norms", norm_n, norm_us);
2037 bucket("other", other_n, other_us);
2038 }
2039
2040 B::copy_slice(
2042 &mut ctx,
2043 &residual,
2044 (seq_len - 1) * h,
2045 &mut self.scratch.last_hidden,
2046 0,
2047 h,
2048 );
2049
2050 B::rms_norm(
2052 &mut ctx,
2053 &self.scratch.last_hidden,
2054 &self.final_norm_w,
2055 self.cfg.rms_norm_eps,
2056 &mut self.scratch.last_normed,
2057 1,
2058 h,
2059 );
2060
2061 let lm_head = self
2063 .lm_head
2064 .as_ref()
2065 .expect("prefill_internal called on backbone-only model (no lm_head)");
2066 lm_head.forward(
2067 &mut ctx,
2068 &self.scratch.last_normed,
2069 &mut self.scratch.logits,
2070 1,
2071 );
2072
2073 B::sync(&mut ctx);
2080
2081 self.scratch.residual = Some(residual);
2083 if llama_family_runtime_env().prefix_cache && cache_len_before == 0 {
2084 self.register_prefix_cache(cache_id, tokens, cached_prefix_tokens);
2085 }
2086
2087 B::to_vec(&self.scratch.logits, vocab)
2088 }
2089
2090 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2092 self.ensure_scratch(1);
2093 self.ensure_kv(cache_id);
2094
2095 let h = self.cfg.hidden_size;
2096 let vocab = self.cfg.vocab_size;
2097
2098 let mut ctx = B::new_context();
2101
2102 const GRAPH_WARMUP: usize = 3;
2107 let graph_enabled = llama_family_runtime_env().cuda_graph;
2108
2109 if graph_enabled {
2110 B::set_decode_state(&mut ctx, token, pos);
2113
2114 match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
2118 Ok(true) => {
2119 B::sync(&mut ctx);
2120 return B::to_vec(&self.scratch.logits, vocab);
2121 }
2122 Ok(false) => { }
2123 Err(_) => { }
2124 }
2125 }
2126
2127 let should_capture =
2128 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
2129
2130 if should_capture {
2131 B::set_dev_state_mode(&mut ctx, true);
2132 if B::begin_graph_capture(&mut ctx).is_err() {
2133 self.graph_capture_failed = true;
2134 B::set_dev_state_mode(&mut ctx, false);
2135 }
2136 }
2137
2138 let mut residual = self
2144 .scratch
2145 .residual
2146 .take()
2147 .expect("scratch residual missing (previous call didn't restore)");
2148 let embed = self
2149 .embed
2150 .as_ref()
2151 .expect("decode_internal called on backbone-only model (no embed)");
2152 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
2153
2154 let layer_profile = llama_family_runtime_env().decode_layer_profile;
2158 let mut layer_times = if layer_profile {
2159 Some(Vec::with_capacity(self.cfg.num_layers))
2160 } else {
2161 None
2162 };
2163
2164 for li in 0..self.cfg.num_layers {
2165 if layer_profile {
2166 let t0 = std::time::Instant::now();
2167 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2168 B::sync(&mut ctx);
2169 let elapsed_us = t0.elapsed().as_micros() as u64;
2170 if let Some(v) = layer_times.as_mut() {
2171 v.push(elapsed_us);
2172 }
2173 } else {
2174 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2175 }
2176 }
2177 if let Some(times) = layer_times.take() {
2178 let sum: u64 = times.iter().sum();
2179 let avg = sum / times.len() as u64;
2180 let mn = *times.iter().min().unwrap_or(&0);
2181 let mx = *times.iter().max().unwrap_or(&0);
2182 eprintln!(
2183 "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
2184 times.len(),
2185 sum / 1000,
2186 avg,
2187 mn,
2188 mx
2189 );
2190 for (i, t) in times.iter().enumerate() {
2191 eprint!("L{i}={}ms ", t / 1000);
2192 if (i + 1) % 6 == 0 {
2193 eprintln!();
2194 }
2195 }
2196 eprintln!();
2197 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2198 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2199 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2200 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2201 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2202 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2203 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2204 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2205 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2206 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2207 eprintln!(
2208 "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
2209 attn_n,
2210 attn_us / 1000,
2211 if attn_n > 0 { attn_us / attn_n } else { 0 }
2212 );
2213 eprintln!(
2214 "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
2215 qkr_n,
2216 qkr_us / 1000,
2217 if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
2218 );
2219 eprintln!(
2220 "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
2221 mm_n,
2222 mm_us / 1000,
2223 if mm_n > 0 { mm_us / mm_n } else { 0 }
2224 );
2225 eprintln!(
2226 "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
2227 norm_n,
2228 norm_us / 1000,
2229 if norm_n > 0 { norm_us / norm_n } else { 0 }
2230 );
2231 eprintln!(
2232 "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
2233 other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
2234 );
2235 }
2236
2237 B::rms_norm(
2238 &mut ctx,
2239 &residual,
2240 &self.final_norm_w,
2241 self.cfg.rms_norm_eps,
2242 &mut self.scratch.last_normed,
2243 1,
2244 h,
2245 );
2246
2247 let lm_head = self
2248 .lm_head
2249 .as_ref()
2250 .expect("decode_internal called on backbone-only model (no lm_head)");
2251 lm_head.forward(
2252 &mut ctx,
2253 &self.scratch.last_normed,
2254 &mut self.scratch.logits,
2255 1,
2256 );
2257
2258 if should_capture && !self.graph_capture_failed {
2259 if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2260 self.graph_capture_failed = true;
2261 } else {
2262 if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2269 self.graph_capture_failed = true;
2270 }
2271 }
2272 B::set_dev_state_mode(&mut ctx, false);
2273 } else {
2274 self.graph_warmup += 1;
2275 }
2276
2277 B::sync(&mut ctx);
2284 self.scratch.residual = Some(residual);
2285
2286 B::to_vec(&self.scratch.logits, vocab)
2287 }
2288
2289 pub fn prefill_from_embeds(
2298 &mut self,
2299 cache_id: &str,
2300 embeds: &[f32],
2301 seq_len: usize,
2302 ) -> Vec<f32> {
2303 let h = self.cfg.hidden_size;
2304 assert_eq!(
2305 embeds.len(),
2306 seq_len * h,
2307 "embeds length {} != seq_len * hidden_size {}",
2308 embeds.len(),
2309 seq_len * h
2310 );
2311 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
2312
2313 self.ensure_scratch(seq_len);
2314 self.ensure_kv(cache_id);
2315
2316 let mut ctx = B::new_context();
2317 let mut residual = self
2318 .scratch
2319 .residual
2320 .take()
2321 .expect("scratch residual missing (previous call didn't restore)");
2322
2323 let embed_buf = B::from_slice(embeds);
2325 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2326
2327 for li in 0..self.cfg.num_layers {
2328 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
2329 }
2330
2331 B::copy_slice(
2332 &mut ctx,
2333 &residual,
2334 (seq_len - 1) * h,
2335 &mut self.scratch.last_hidden,
2336 0,
2337 h,
2338 );
2339 B::sync(&mut ctx);
2340 self.scratch.residual = Some(residual);
2341 B::to_vec(&self.scratch.last_hidden, h)
2342 }
2343
2344 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
2348 let h = self.cfg.hidden_size;
2349 assert_eq!(
2350 embed.len(),
2351 h,
2352 "embed length {} != hidden_size {}",
2353 embed.len(),
2354 h
2355 );
2356
2357 self.ensure_scratch(1);
2358 self.ensure_kv(cache_id);
2359
2360 let mut ctx = B::new_context();
2361 let mut residual = self
2362 .scratch
2363 .residual
2364 .take()
2365 .expect("scratch residual missing (previous call didn't restore)");
2366
2367 let embed_buf = B::from_slice(embed);
2368 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2369
2370 for li in 0..self.cfg.num_layers {
2371 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2372 }
2373
2374 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
2375 B::sync(&mut ctx);
2376 self.scratch.residual = Some(residual);
2377 B::to_vec(&self.scratch.last_hidden, h)
2378 }
2379
2380 pub fn prefill_all_post_norm(
2391 &mut self,
2392 cache_id: &str,
2393 embeds: &[f32],
2394 seq_len: usize,
2395 pos_offset: usize,
2396 ) -> Vec<f32> {
2397 let h = self.cfg.hidden_size;
2398 assert_eq!(
2399 embeds.len(),
2400 seq_len * h,
2401 "embeds length {} != seq_len * hidden_size {}",
2402 embeds.len(),
2403 seq_len * h
2404 );
2405 assert!(seq_len > 0);
2406
2407 self.ensure_scratch(seq_len);
2408 self.ensure_kv(cache_id);
2409
2410 let mut ctx = B::new_context();
2411 let mut residual = self
2412 .scratch
2413 .residual
2414 .take()
2415 .expect("scratch residual missing (previous call didn't restore)");
2416
2417 let embed_buf = B::from_slice(embeds);
2418 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2419
2420 for li in 0..self.cfg.num_layers {
2421 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2422 }
2423
2424 B::rms_norm(
2426 &mut ctx,
2427 &residual,
2428 &self.final_norm_w,
2429 self.cfg.rms_norm_eps,
2430 &mut self.scratch.norm_out,
2431 seq_len,
2432 h,
2433 );
2434 B::sync(&mut ctx);
2435 self.scratch.residual = Some(residual);
2436 B::to_vec(&self.scratch.norm_out, seq_len * h)
2437 }
2438
2439 pub fn decode_post_norm_from_embed(
2443 &mut self,
2444 cache_id: &str,
2445 embed: &[f32],
2446 pos: u32,
2447 ) -> Vec<f32> {
2448 let h = self.cfg.hidden_size;
2449 assert_eq!(embed.len(), h);
2450
2451 self.ensure_scratch(1);
2452 self.ensure_kv(cache_id);
2453
2454 let mut ctx = B::new_context();
2455 let mut residual = self
2456 .scratch
2457 .residual
2458 .take()
2459 .expect("scratch residual missing (previous call didn't restore)");
2460
2461 let embed_buf = B::from_slice(embed);
2462 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2463
2464 for li in 0..self.cfg.num_layers {
2465 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2466 }
2467
2468 B::rms_norm(
2469 &mut ctx,
2470 &residual,
2471 &self.final_norm_w,
2472 self.cfg.rms_norm_eps,
2473 &mut self.scratch.last_normed,
2474 1,
2475 h,
2476 );
2477 B::sync(&mut ctx);
2478 self.scratch.residual = Some(residual);
2479 B::to_vec(&self.scratch.last_normed, h)
2480 }
2481}
2482
2483impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2485 fn config(&self) -> &LlmRuntimeConfig {
2486 &self.runtime_cfg
2487 }
2488
2489 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2490 Some(self.prefix_cache_snapshot_json())
2491 }
2492
2493 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
2494 Some(self.lora_metrics_snapshot_json())
2495 }
2496
2497 fn set_lora_adapter_for_cache(
2498 &mut self,
2499 cache_id: &str,
2500 adapter: Option<ActiveLoraAdapter>,
2501 ) -> std::result::Result<(), ferrum_types::FerrumError> {
2502 if let Some(adapter) = adapter {
2503 self.ensure_lora_adapter_loaded(adapter.clone())?;
2504 self.lora_cache_adapters
2505 .insert(cache_id.to_string(), adapter.name);
2506 } else {
2507 self.lora_cache_adapters.remove(cache_id);
2508 }
2509 Ok(())
2510 }
2511
2512 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2513 self.ensure_scratch(max_tokens);
2514 self.ensure_kv(cache_id);
2515
2516 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2517 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2518 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2519 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2520 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2521 if let Some(c0) = caches.first() {
2522 if !c0.paged_block_indices.is_empty() {
2523 alloc.free(&c0.paged_block_indices);
2524 }
2525 }
2526 for c in caches.iter_mut() {
2527 c.paged_block_indices.clear();
2528 }
2529 }
2530 self.kv_free_pool.push(caches);
2531 }
2532 }
2533
2534 fn kv_capacity(&self) -> usize {
2535 llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2536 }
2537
2538 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2539 self.prefill_internal(cache_id, tokens)
2540 }
2541
2542 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2543 self.decode_internal(cache_id, token, pos)
2544 }
2545
2546 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2547 self.decode_batch_with_full_logits(batch, false)
2548 }
2549
2550 fn decode_batch_with_full_logits(
2551 &mut self,
2552 batch: &[(String, u32, u32)],
2553 force_full_logits: bool,
2554 ) -> Vec<Vec<f32>> {
2555 self.decode_batch_internal_with_full_logits(batch, force_full_logits)
2556 }
2557
2558 fn unified_forward(
2559 &mut self,
2560 items: &[(String, Vec<u32>, usize, bool)],
2561 ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2562 if items.is_empty() {
2563 return Ok(Vec::new());
2564 }
2565 if llama_family_runtime_env().prefix_cache
2566 && items
2567 .iter()
2568 .any(|(_, tokens, pos_offset, _)| *pos_offset == 0 && tokens.len() > 1)
2569 {
2570 return Err(ferrum_types::FerrumError::unsupported(
2571 "LlamaFamilyModel::unified_forward: fresh prefill with prefix cache enabled \
2572 routes through prefill_internal so real paged-block KV reuse can probe and \
2573 register block hashes",
2574 ));
2575 }
2576 if !B::supports_varlen_qkv() {
2577 return Err(ferrum_types::FerrumError::unsupported(
2578 "LlamaFamilyModel::unified_forward: backend lacks varlen \
2579 QKV kernels. Engine will fall back to per-item dispatch.",
2580 ));
2581 }
2582 if items
2583 .iter()
2584 .any(|(cache_id, _, _, _)| self.active_lora_adapter_for_cache(cache_id).is_some())
2585 {
2586 return Err(ferrum_types::FerrumError::unsupported(
2587 "LlamaFamilyModel::unified_forward: active LoRA adapter routes through \
2588 per-item dispatch until unified LoRA supports row-selective adapters.",
2589 ));
2590 }
2591 self.ensure_kv(&items[0].0);
2592 if self.paged_pools.is_none() {
2593 return Err(ferrum_types::FerrumError::unsupported(
2594 "LlamaFamilyModel::unified_forward: paged KV required; \
2595 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2596 Engine will fall back to per-item dispatch.",
2597 ));
2598 }
2599 for (cid, _, _, _) in items {
2600 self.ensure_kv(cid);
2601 if !self.kv_caches.contains_key(cid) {
2602 return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2603 "paged KV pool exhausted for cache_id={cid:?}; back off"
2604 )));
2605 }
2606 }
2607 Ok(self.unified_forward_internal(items))
2608 }
2609
2610 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2611 LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2612 }
2613
2614 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2615 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2616 for c in caches.iter_mut() {
2617 if new_len < c.len {
2618 c.len = new_len;
2619 }
2620 }
2621 }
2622 let mut ctx = B::new_context();
2623 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2624 self.graph_warmup = 0;
2625 self.graph_capture_failed = false;
2626 }
2627
2628 fn release(&mut self, cache_id: &str) {
2629 let mut ctx = B::new_context();
2630 B::sync(&mut ctx);
2631 self.graph_warmup = 0;
2632 self.graph_capture_failed = false;
2633 self.lora_cache_adapters.remove(cache_id);
2634 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2635 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2636 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2637 if let Some(c0) = caches.first() {
2638 if !c0.paged_block_indices.is_empty() {
2639 alloc.free(&c0.paged_block_indices);
2640 }
2641 }
2642 for c in caches.iter_mut() {
2643 c.paged_block_indices.clear();
2644 }
2645 }
2646 self.kv_free_pool.push(caches);
2647 }
2648 }
2649
2650 fn reset(&mut self) {
2651 let mut ctx = B::new_context();
2652 B::sync(&mut ctx);
2653 B::reset_all_graphs(&mut ctx);
2654 B::sync(&mut ctx);
2655 self.graph_warmup = 0;
2656 self.graph_capture_failed = false;
2657 self.batched_graph_keys_seen.clear();
2658 self.batched_graph_warmup = 0;
2659 self.batched_graph_failed = false;
2660 self.kv_caches.clear();
2661 self.kv_free_pool.clear();
2662 self.lora_cache_adapters.clear();
2663 }
2664}
2665
2666impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2670 fn config(&self) -> &LlmRuntimeConfig {
2671 &self.runtime_cfg
2672 }
2673
2674 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2675 Some(self.prefix_cache_snapshot_json())
2676 }
2677
2678 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
2679 Some(self.lora_metrics_snapshot_json())
2680 }
2681
2682 fn set_lora_adapter_for_cache(
2683 &mut self,
2684 cache_id: &str,
2685 adapter: Option<ActiveLoraAdapter>,
2686 ) -> std::result::Result<(), ferrum_types::FerrumError> {
2687 if let Some(adapter) = adapter {
2688 self.ensure_lora_adapter_loaded(adapter.clone())?;
2689 self.lora_cache_adapters
2690 .insert(cache_id.to_string(), adapter.name);
2691 } else {
2692 self.lora_cache_adapters.remove(cache_id);
2693 }
2694 Ok(())
2695 }
2696
2697 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2698 self.ensure_scratch(max_tokens);
2699 self.ensure_kv(cache_id);
2700
2701 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2702 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2703 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2704 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2705 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2706 if let Some(c0) = caches.first() {
2707 if !c0.paged_block_indices.is_empty() {
2708 alloc.free(&c0.paged_block_indices);
2709 }
2710 }
2711 for c in caches.iter_mut() {
2712 c.paged_block_indices.clear();
2713 }
2714 }
2715 self.kv_free_pool.push(caches);
2716 }
2717 }
2718
2719 fn kv_capacity(&self) -> usize {
2720 llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2721 }
2722
2723 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2724 self.prefill_internal(cache_id, tokens)
2725 }
2726
2727 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2728 self.decode_internal(cache_id, token, pos)
2729 }
2730
2731 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2734 LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
2735 }
2736
2737 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2738 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2739 for c in caches.iter_mut() {
2740 if new_len < c.len {
2741 c.len = new_len;
2742 }
2743 }
2744 }
2745 let mut ctx = B::new_context();
2746 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2747 self.graph_warmup = 0;
2748 self.graph_capture_failed = false;
2749 }
2750
2751 fn release(&mut self, cache_id: &str) {
2752 let mut ctx = B::new_context();
2753 B::sync(&mut ctx);
2754 self.graph_warmup = 0;
2755 self.graph_capture_failed = false;
2756 self.lora_cache_adapters.remove(cache_id);
2757 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2758 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2759 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2760 if let Some(c0) = caches.first() {
2761 if !c0.paged_block_indices.is_empty() {
2762 alloc.free(&c0.paged_block_indices);
2763 }
2764 }
2765 for c in caches.iter_mut() {
2766 c.paged_block_indices.clear();
2767 }
2768 }
2769 self.kv_free_pool.push(caches);
2770 }
2771 }
2772
2773 fn reset(&mut self) {
2774 let mut ctx = B::new_context();
2775 B::sync(&mut ctx);
2776 B::reset_all_graphs(&mut ctx);
2777 B::sync(&mut ctx);
2778 self.graph_warmup = 0;
2779 self.graph_capture_failed = false;
2780 self.kv_caches.clear();
2781 self.kv_free_pool.clear();
2782 self.lora_cache_adapters.clear();
2783 }
2784}
2785
2786fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2787 let hd = cfg.head_dim;
2788 let half = hd / 2;
2789 let max = cfg.max_seq_len;
2790 let mut cos = vec![0.0f32; max * half];
2791 let mut sin = vec![0.0f32; max * half];
2792 for pos in 0..max {
2793 for i in 0..half {
2794 let freq = rope_freq(cfg, i);
2795 let angle = pos as f64 * freq;
2796 cos[pos * half + i] = angle.cos() as f32;
2797 sin[pos * half + i] = angle.sin() as f32;
2798 }
2799 }
2800 RopeCache {
2801 cos: B::from_slice(&cos),
2802 sin: B::from_slice(&sin),
2803 }
2804}
2805
2806fn rope_freq(cfg: &LlamaFamilyConfig, pair_idx: usize) -> f64 {
2807 let base_freq = 1.0f64
2808 / cfg
2809 .rope_theta
2810 .powf((2 * pair_idx) as f64 / cfg.head_dim as f64);
2811 match &cfg.rope_scaling {
2812 Some(RopeScalingConfig::Llama3 {
2813 factor,
2814 low_freq_factor,
2815 high_freq_factor,
2816 original_max_position_embeddings,
2817 }) => scale_llama3_rope_freq(
2818 base_freq,
2819 *factor,
2820 *low_freq_factor,
2821 *high_freq_factor,
2822 *original_max_position_embeddings,
2823 ),
2824 None => base_freq,
2825 }
2826}
2827
2828fn scale_llama3_rope_freq(
2829 freq: f64,
2830 factor: f64,
2831 low_freq_factor: f64,
2832 high_freq_factor: f64,
2833 original_max_position_embeddings: f64,
2834) -> f64 {
2835 let wavelen = 2.0 * std::f64::consts::PI / freq;
2836 let low_freq_wavelen = original_max_position_embeddings / low_freq_factor;
2837 let high_freq_wavelen = original_max_position_embeddings / high_freq_factor;
2838 if wavelen < high_freq_wavelen {
2839 freq
2840 } else if wavelen > low_freq_wavelen {
2841 freq / factor
2842 } else {
2843 let smooth = (original_max_position_embeddings / wavelen - low_freq_factor)
2844 / (high_freq_factor - low_freq_factor);
2845 (1.0 - smooth) * freq / factor + smooth * freq
2846 }
2847}
2848
2849#[cfg(test)]
2850mod tests {
2851 use super::{LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY};
2852
2853 #[test]
2854 fn llama_family_runtime_env_parses_startup_knobs() {
2855 let env = LlamaFamilyRuntimeEnv::from_env_vars([
2856 ("FERRUM_KV_CAPACITY", "4096"),
2857 ("FERRUM_METAL_PAGED_KV", "0"),
2858 ("FERRUM_PAGED_MAX_SEQS", "64"),
2859 ("FERRUM_DECODE_OP_PROFILE", "0"),
2860 ("FERRUM_PREFILL_OP_PROFILE", ""),
2861 ("FERRUM_CUDA_GRAPH", ""),
2862 ("FERRUM_DECODE_LAYER_PROFILE", "false"),
2863 ]);
2864
2865 assert_eq!(env.kv_capacity, Some(4096));
2866 assert_eq!(env.metal_paged_kv, Some(false));
2867 assert_eq!(env.paged_max_seqs, 64);
2868 assert!(env.decode_op_profile);
2869 assert!(env.prefill_op_profile);
2870 assert!(env.cuda_graph);
2871 assert!(env.decode_layer_profile);
2872 assert_eq!(env.kv_capacity_for_model(2048), 2048);
2873 }
2874
2875 #[test]
2876 fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
2877 let env = LlamaFamilyRuntimeEnv::from_env_vars([
2878 ("FERRUM_KV_CAPACITY", "bad"),
2879 ("FERRUM_PAGED_MAX_SEQS", "bad"),
2880 ("FERRUM_METAL_PAGED_KV", "1"),
2881 ]);
2882
2883 assert_eq!(env.kv_capacity, None);
2884 assert_eq!(env.metal_paged_kv, Some(true));
2885 assert_eq!(env.paged_max_seqs, 32);
2886 assert_eq!(
2887 env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
2888 DEFAULT_KV_CAPACITY
2889 );
2890 }
2891}