1use std::collections::HashMap;
20use std::ops::Range;
21use std::sync::atomic::AtomicU64;
22
23use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16, KvInt8};
24use ferrum_kernels::backend::{
25 Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
26 BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
27 MAX_LAYERS_FOR_GRAPH,
28};
29
30pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
33
34pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
36pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
37
38pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
39pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
41pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
43pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
45pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
47pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
48use ferrum_quantization::{Linear, WeightLoader};
49use ferrum_types::Result;
50
51use crate::common::paged_pool::block_hash_chain;
52use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
53use crate::lora::{load_runtime_lora_adapter, ActiveLoraAdapter, RuntimeLoraAdapter};
54
55const DEFAULT_KV_CAPACITY: usize = 512;
56
57pub(crate) fn elapsed_micros_u64_floor1(t0: std::time::Instant) -> u64 {
58 t0.elapsed().as_micros().min(u64::MAX as u128).max(1) as u64
59}
60
61#[derive(Debug, Clone, Copy, Default)]
62pub(crate) struct LlamaStageHiddenBridgeTiming {
63 pub(crate) bridge_us: u64,
64 pub(crate) host_copy_us: u64,
65 pub(crate) device_copy_us: u64,
66}
67
68impl LlamaStageHiddenBridgeTiming {
69 pub(crate) fn add(self, other: Self) -> Self {
70 Self {
71 bridge_us: self.bridge_us.saturating_add(other.bridge_us),
72 host_copy_us: self.host_copy_us.saturating_add(other.host_copy_us),
73 device_copy_us: self.device_copy_us.saturating_add(other.device_copy_us),
74 }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub(crate) struct LlamaFamilyRuntimeEnv {
80 pub(crate) kv_capacity: Option<usize>,
81 pub(crate) metal_paged_kv: Option<bool>,
82 pub(crate) paged_max_seqs: usize,
83 pub(crate) decode_op_profile: bool,
84 pub(crate) prefill_op_profile: bool,
85 pub(crate) prefix_cache: bool,
86 pub(crate) cuda_graph: bool,
87 pub(crate) decode_layer_profile: bool,
88}
89
90impl LlamaFamilyRuntimeEnv {
91 fn from_env() -> Self {
95 Self::from_runtime_config_snapshot(&ferrum_types::active_runtime_snapshot())
96 }
97
98 fn from_runtime_config_snapshot(snapshot: &ferrum_types::RuntimeConfigSnapshot) -> Self {
99 Self::from_env_vars(
100 snapshot
101 .entries
102 .iter()
103 .map(|e| (e.key.as_str(), e.effective_value.as_str())),
104 )
105 }
106
107 fn from_env_vars<I, K, V>(vars: I) -> Self
108 where
109 I: IntoIterator<Item = (K, V)>,
110 K: AsRef<str>,
111 V: AsRef<str>,
112 {
113 let mut config = Self {
114 kv_capacity: None,
115 metal_paged_kv: None,
116 paged_max_seqs: 32,
117 decode_op_profile: false,
118 prefill_op_profile: false,
119 prefix_cache: false,
120 cuda_graph: false,
121 decode_layer_profile: false,
122 };
123 for (name, value) in vars {
124 let value = value.as_ref();
125 match name.as_ref() {
126 "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
127 "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
128 "FERRUM_PAGED_MAX_SEQS" => {
129 if let Ok(max_seqs) = value.parse::<usize>() {
130 config.paged_max_seqs = max_seqs;
131 }
132 }
133 "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
134 "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
135 "FERRUM_PREFIX_CACHE" => config.prefix_cache = value == "1",
136 "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
137 "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
138 _ => {}
139 }
140 }
141 config
142 }
143
144 fn kv_capacity_for_model(&self, model_max: usize) -> usize {
145 self.kv_capacity
146 .map(|cap| cap.min(model_max))
147 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
148 }
149
150 fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
151 self.metal_paged_kv
152 .unwrap_or_else(|| B::supports_paged_kv())
153 }
154}
155
156pub(crate) fn llama_family_decode_op_profile_enabled() -> bool {
161 LlamaFamilyRuntimeEnv::from_env().decode_op_profile
162}
163
164#[derive(Clone, Debug, PartialEq)]
165pub enum RopeScalingConfig {
166 Llama3 {
168 factor: f64,
169 low_freq_factor: f64,
170 high_freq_factor: f64,
171 original_max_position_embeddings: f64,
172 },
173}
174
175impl RopeScalingConfig {
176 pub fn llama3_default() -> Self {
177 Self::Llama3 {
178 factor: 8.0,
179 low_freq_factor: 1.0,
180 high_freq_factor: 4.0,
181 original_max_position_embeddings: 8192.0,
182 }
183 }
184}
185
186#[derive(Clone, Debug, PartialEq)]
189pub struct LlamaFamilyConfig {
190 pub hidden_size: usize,
191 pub intermediate_size: usize,
192 pub num_heads: usize,
193 pub num_kv_heads: usize,
194 pub head_dim: usize,
195 pub num_layers: usize,
196 pub vocab_size: usize,
197 pub max_seq_len: usize,
198 pub rms_norm_eps: f32,
199 pub rope_theta: f64,
200 pub rope_scaling: Option<RopeScalingConfig>,
201 pub rope_interleaved: bool,
205 pub has_qk_norm: bool,
208 pub sliding_window: usize,
211}
212
213impl LlamaFamilyConfig {
214 pub fn to_runtime(&self) -> LlmRuntimeConfig {
215 LlmRuntimeConfig {
216 hidden_size: self.hidden_size,
217 num_layers: self.num_layers,
218 num_kv_heads: self.num_kv_heads,
219 head_dim: self.head_dim,
220 vocab_size: self.vocab_size,
221 max_seq_len: self.max_seq_len,
222 }
223 }
224
225 fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
229 let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
230 let head_dim = def
231 .extra_params
232 .get("head_dim")
233 .and_then(|v| v.as_u64())
234 .map(|v| v as usize)
235 .unwrap_or(def.hidden_size / def.num_attention_heads);
236 let sliding_window = def
239 .extra_params
240 .get("sliding_window")
241 .and_then(|v| v.as_u64())
242 .map(|v| v as usize)
243 .unwrap_or(0);
244
245 LlamaFamilyConfigBase {
246 hidden_size: def.hidden_size,
247 intermediate_size: def.intermediate_size,
248 num_heads: def.num_attention_heads,
249 num_kv_heads,
250 head_dim,
251 num_layers: def.num_hidden_layers,
252 vocab_size: def.vocab_size,
253 max_seq_len: def.max_position_embeddings,
254 rms_norm_eps: def.norm_eps as f32,
255 rope_theta_opt: def.rope_theta,
256 rope_scaling: rope_scaling_from_model_def(def),
257 sliding_window,
258 }
259 }
260
261 fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
262 Self {
263 hidden_size: b.hidden_size,
264 intermediate_size: b.intermediate_size,
265 num_heads: b.num_heads,
266 num_kv_heads: b.num_kv_heads,
267 head_dim: b.head_dim,
268 num_layers: b.num_layers,
269 vocab_size: b.vocab_size,
270 max_seq_len: b.max_seq_len,
271 rms_norm_eps: b.rms_norm_eps,
272 rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
273 rope_scaling: b.rope_scaling,
274 rope_interleaved: false,
275 has_qk_norm,
276 sliding_window: b.sliding_window,
277 }
278 }
279
280 pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
282 Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
283 }
284
285 pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
289 Self::from_base(Self::from_def_base(def), 500_000.0, false)
290 }
291
292 pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
294 Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
295 }
296
297 pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
301 Self::from_base(Self::from_def_base(def), 10_000.0, false)
302 }
303}
304
305struct LlamaFamilyConfigBase {
306 hidden_size: usize,
307 intermediate_size: usize,
308 num_heads: usize,
309 num_kv_heads: usize,
310 head_dim: usize,
311 num_layers: usize,
312 vocab_size: usize,
313 max_seq_len: usize,
314 rms_norm_eps: f32,
315 rope_theta_opt: Option<f64>,
316 rope_scaling: Option<RopeScalingConfig>,
317 sliding_window: usize,
318}
319
320fn rope_scaling_from_model_def(
321 def: &crate::definition::ModelDefinition,
322) -> Option<RopeScalingConfig> {
323 let value = def.extra_params.get("rope_scaling")?;
324 let obj = value.as_object()?;
325 let rope_type = obj
326 .get("rope_type")
327 .or_else(|| obj.get("type"))
328 .and_then(|v| v.as_str())?;
329 if rope_type != "llama3" {
330 return None;
331 }
332 let factor = json_f64(obj.get("factor"))?;
333 let low_freq_factor = json_f64(obj.get("low_freq_factor"))?;
334 let high_freq_factor = json_f64(obj.get("high_freq_factor"))?;
335 let original_max_position_embeddings = json_f64(obj.get("original_max_position_embeddings"))
336 .or_else(|| {
337 def.extra_params
338 .get("original_max_position_embeddings")
339 .and_then(|v| json_f64(Some(v)))
340 })
341 .unwrap_or(8192.0);
342 if factor <= 0.0
343 || low_freq_factor <= 0.0
344 || high_freq_factor <= low_freq_factor
345 || original_max_position_embeddings <= 0.0
346 {
347 return None;
348 }
349 Some(RopeScalingConfig::Llama3 {
350 factor,
351 low_freq_factor,
352 high_freq_factor,
353 original_max_position_embeddings,
354 })
355}
356
357fn json_f64(value: Option<&serde_json::Value>) -> Option<f64> {
358 match value? {
359 serde_json::Value::Number(n) => n.as_f64(),
360 _ => None,
361 }
362}
363
364pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
367 pub input_ln_w: B::Buffer,
368 pub qkv_proj: Box<dyn Linear<B>>,
369 pub q_norm_w: Option<B::Buffer>,
371 pub k_norm_w: Option<B::Buffer>,
372 pub o_proj: Box<dyn Linear<B>>,
373 pub post_ln_w: B::Buffer,
374 pub gate_up_proj: Box<dyn Linear<B>>,
375 pub down_proj: Box<dyn Linear<B>>,
376}
377
378#[derive(Debug, Clone, PartialEq, Eq)]
379pub struct LlamaFamilyLayerStageConfig {
380 pub source_layers: Range<usize>,
381 pub load_embedding: bool,
382 pub load_lm_head: bool,
383}
384
385impl LlamaFamilyLayerStageConfig {
386 pub fn full(num_layers: usize) -> Self {
387 Self {
388 source_layers: 0..num_layers,
389 load_embedding: true,
390 load_lm_head: true,
391 }
392 }
393
394 pub fn backbone(num_layers: usize) -> Self {
395 Self {
396 source_layers: 0..num_layers,
397 load_embedding: false,
398 load_lm_head: false,
399 }
400 }
401
402 pub fn pipeline_stage(
403 source_layers: Range<usize>,
404 is_first_stage: bool,
405 is_last_stage: bool,
406 ) -> Self {
407 Self {
408 source_layers,
409 load_embedding: is_first_stage,
410 load_lm_head: is_last_stage,
411 }
412 }
413}
414
415fn load_llama_family_layers<B: MoeLlmBackend>(
416 cfg: &LlamaFamilyConfig,
417 loader: &dyn WeightLoader<B>,
418 source_layers: Range<usize>,
419) -> Result<Vec<LlamaFamilyLayer<B>>> {
420 if source_layers.start > source_layers.end || source_layers.end > cfg.num_layers {
421 return Err(ferrum_types::FerrumError::model(format!(
422 "llama layer range {}..{} is outside model layer count {}",
423 source_layers.start, source_layers.end, cfg.num_layers
424 )));
425 }
426
427 let mut layers = Vec::with_capacity(source_layers.end.saturating_sub(source_layers.start));
428 for li in source_layers {
429 let prefix = format!("model.layers.{li}");
430 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
431 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
432 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
433 let post_ln_w = loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
434 let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
435 let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
436
437 let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
438 let q = loader
439 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
440 .ok();
441 let k = loader
442 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
443 .ok();
444 (q, k)
445 } else {
446 (None, None)
447 };
448
449 layers.push(LlamaFamilyLayer {
450 input_ln_w,
451 qkv_proj,
452 q_norm_w,
453 k_norm_w,
454 o_proj,
455 post_ln_w,
456 gate_up_proj,
457 down_proj,
458 });
459 }
460 Ok(layers)
461}
462
463fn load_llama_family_lm_head<B: MoeLlmBackend>(
464 cfg: &LlamaFamilyConfig,
465 loader: &dyn WeightLoader<B>,
466) -> Result<Box<dyn Linear<B>>> {
467 let lm_head = if loader.has_tensor("lm_head.weight") {
475 loader.load_linear("lm_head")?
476 } else {
477 tracing::info!(
478 "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
479 );
480 let as_linear = loader.load_linear("model.embed_tokens")?;
481 if as_linear.out_features() != cfg.vocab_size || as_linear.in_features() != cfg.hidden_size
483 {
484 return Err(ferrum_types::FerrumError::model(format!(
485 "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
486 as_linear.out_features(),
487 as_linear.in_features(),
488 cfg.vocab_size,
489 cfg.hidden_size
490 )));
491 }
492 as_linear
493 };
494 Ok(lm_head)
495}
496
497pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
499 pub cos: B::Buffer,
500 pub sin: B::Buffer,
501}
502
503pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
509 pub residual: Option<B::Buffer>,
520 pub norm_out: B::Buffer,
521 pub qkv_out: B::Buffer,
522 pub q_single: B::Buffer,
530 pub k_single: B::Buffer,
531 pub v_single: B::Buffer,
532 pub q_head_major_single: B::Buffer,
533 pub k_head_major_single: B::Buffer,
534 pub v_head_major_single: B::Buffer,
535 pub attn_head_major_single: B::Buffer,
536 pub attn_flat_single: B::Buffer,
537 pub batch_logits: B::Buffer,
540 pub q_buf: B::Buffer,
542 pub k_buf: B::Buffer,
543 pub v_buf: B::Buffer,
544 pub q_head_major: B::Buffer,
546 pub k_head_major: B::Buffer,
549 pub v_head_major: B::Buffer,
550 pub attn_head_major_out: B::Buffer,
552 pub attn_flat: B::Buffer,
554 pub o_proj_out: B::Buffer,
555 pub gate_up_out: B::Buffer,
556 pub silu_out: B::Buffer,
557 pub mlp_out: B::Buffer,
558 pub paged_batch_q: Option<B::Buffer>,
564 pub paged_batch_o: Option<B::Buffer>,
565 pub paged_batch_block_tables: Option<B::Buffer>,
569 pub paged_batch_context_lens: Option<B::Buffer>,
572 pub paged_max_blocks_per_seq: usize,
575 pub paged_max_seqs: usize,
581 pub batch_positions: B::Buffer,
586 pub batch_tokens: B::Buffer,
590 pub batch_kv_lens_pre: B::Buffer,
594 pub batch_kv_lens_post: B::Buffer,
599 pub q_normed_batched: B::Buffer,
603 pub k_normed_batched: B::Buffer,
604 pub v_normed_batched: B::Buffer,
605
606 pub unified_capacity: usize, pub unified_residual: Option<B::Buffer>,
613 pub unified_norm_out: Option<B::Buffer>,
614 pub unified_qkv_out: Option<B::Buffer>,
615 pub unified_packed_q: Option<B::Buffer>,
616 pub unified_attn_out: Option<B::Buffer>,
617 pub unified_o_proj_out: Option<B::Buffer>,
618 pub unified_gate_up_out: Option<B::Buffer>,
619 pub unified_silu_out: Option<B::Buffer>,
620 pub unified_mlp_out: Option<B::Buffer>,
621 pub unified_cu_seqlens_q: Option<B::Buffer>,
626 pub unified_pos_offsets: Option<B::Buffer>,
627 pub unified_block_tables: Option<B::Buffer>,
628 pub unified_packed_normed: Option<B::Buffer>,
631 pub unified_packed_logits: Option<B::Buffer>,
633 pub last_hidden: B::Buffer,
637 pub last_normed: B::Buffer,
639 pub logits: B::Buffer,
641 pub max_tokens: usize,
643}
644
645impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
646 fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
647 let h = cfg.hidden_size;
648 let im = cfg.intermediate_size;
649 let q_dim = cfg.num_heads * cfg.head_dim;
650 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
651 let qkv_dim = q_dim + 2 * kv_dim;
652 let t = max_tokens;
653 Self {
654 residual: Some(B::alloc(t * h)),
655 norm_out: B::alloc(t * h),
656 qkv_out: B::alloc(t * qkv_dim),
657 q_buf: B::alloc(t * q_dim),
658 k_buf: B::alloc(t * kv_dim),
659 v_buf: B::alloc(t * kv_dim),
660 q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
661 k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
662 v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
663 attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
664 attn_flat: B::alloc(t * q_dim),
665 o_proj_out: B::alloc(t * h),
666 gate_up_out: B::alloc(t * 2 * im),
667 silu_out: B::alloc(t * im),
668 mlp_out: B::alloc(t * h),
669 last_hidden: B::alloc(h),
670 last_normed: B::alloc(h),
671 logits: B::alloc(cfg.vocab_size),
672 q_single: B::alloc(q_dim),
673 k_single: B::alloc(kv_dim),
674 v_single: B::alloc(kv_dim),
675 q_head_major_single: B::alloc(q_dim),
676 k_head_major_single: B::alloc(kv_dim),
677 v_head_major_single: B::alloc(kv_dim),
678 attn_head_major_single: B::alloc(q_dim),
679 attn_flat_single: B::alloc(q_dim),
680 batch_logits: B::alloc(t * cfg.vocab_size),
681 paged_batch_q: None,
686 paged_batch_o: None,
687 paged_batch_block_tables: None,
688 paged_batch_context_lens: None,
689 paged_max_blocks_per_seq: 0,
690 paged_max_seqs: 0,
691 batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
692 batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
693 batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
694 batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
695 q_normed_batched: B::alloc(t * q_dim),
696 k_normed_batched: B::alloc(t * kv_dim),
697 v_normed_batched: B::alloc(t * kv_dim),
698 unified_capacity: 0,
699 unified_residual: None,
700 unified_norm_out: None,
701 unified_qkv_out: None,
702 unified_packed_q: None,
703 unified_attn_out: None,
704 unified_o_proj_out: None,
705 unified_gate_up_out: None,
706 unified_silu_out: None,
707 unified_mlp_out: None,
708 unified_cu_seqlens_q: None,
709 unified_pos_offsets: None,
710 unified_block_tables: None,
711 unified_packed_normed: None,
712 unified_packed_logits: None,
713 max_tokens: t,
714 }
715 }
716
717 pub(crate) fn ensure_unified_scratch(
721 &mut self,
722 cfg: &LlamaFamilyConfig,
723 m_total: usize,
724 max_seqs: usize,
725 max_blocks_per_seq: usize,
726 ) {
727 if m_total <= self.unified_capacity
728 && self.unified_residual.is_some()
729 && self.unified_cu_seqlens_q.is_some()
730 {
731 return;
732 }
733 let cap = m_total.max(self.unified_capacity).max(1);
734 let h = cfg.hidden_size;
735 let q_dim = cfg.num_heads * cfg.head_dim;
736 let kv_dim = cfg.num_kv_heads * cfg.head_dim;
737 let qkv_dim = q_dim + 2 * kv_dim;
738 let im = cfg.intermediate_size;
739 let v = cfg.vocab_size;
740 self.unified_residual = Some(B::alloc(cap * h));
741 self.unified_norm_out = Some(B::alloc(cap * h));
742 self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
743 self.unified_packed_q = Some(B::alloc(cap * q_dim));
744 self.unified_attn_out = Some(B::alloc(cap * q_dim));
745 self.unified_o_proj_out = Some(B::alloc(cap * h));
746 self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
747 self.unified_silu_out = Some(B::alloc(cap * im));
748 self.unified_mlp_out = Some(B::alloc(cap * h));
749 if self.unified_cu_seqlens_q.is_none() {
750 self.unified_cu_seqlens_q = Some(B::alloc_typed(
751 ferrum_kernels::backend::Dtype::U32,
752 max_seqs + 1,
753 ));
754 self.unified_pos_offsets = Some(B::alloc_typed(
755 ferrum_kernels::backend::Dtype::U32,
756 max_seqs,
757 ));
758 self.unified_block_tables = Some(B::alloc_typed(
759 ferrum_kernels::backend::Dtype::U32,
760 max_seqs * max_blocks_per_seq,
761 ));
762 self.unified_packed_normed = Some(B::alloc(max_seqs * h));
763 self.unified_packed_logits = Some(B::alloc(max_seqs * v));
764 }
765 self.unified_capacity = cap;
766 }
767
768 fn enable_paged_batch(
772 &mut self,
773 cfg: &LlamaFamilyConfig,
774 max_seqs: usize,
775 max_blocks_per_seq: usize,
776 ) {
777 if self.paged_batch_q.is_some() {
778 return;
779 }
780 let q_dim = cfg.num_heads * cfg.head_dim;
781 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
782 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
783 self.paged_batch_block_tables = Some(B::alloc_typed(
784 ferrum_kernels::backend::Dtype::U32,
785 max_seqs * max_blocks_per_seq,
786 ));
787 self.paged_batch_context_lens = Some(B::alloc_typed(
788 ferrum_kernels::backend::Dtype::U32,
789 max_seqs,
790 ));
791 self.paged_max_blocks_per_seq = max_blocks_per_seq;
792 self.paged_max_seqs = max_seqs;
793 }
794}
795
796pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
811 pub cfg: LlamaFamilyConfig,
812 pub runtime_cfg: LlmRuntimeConfig,
813
814 pub(crate) supports_varlen_qkv: bool,
819 pub(crate) supports_batched_decode: bool,
820 pub(crate) runtime_env: LlamaFamilyRuntimeEnv,
824 pub(crate) batched_cfg: super::llama_family_forward_batched::LlamaBatchedRuntimeConfig,
825
826 pub embed: Option<B::Buffer>,
830 pub layer_source_start: usize,
833 pub layer_source_end: usize,
834 pub layers: Vec<LlamaFamilyLayer<B>>,
835 pub final_norm_w: B::Buffer,
836 pub lm_head: Option<Box<dyn Linear<B>>>,
838
839 pub rope: RopeCache<B>,
840 pub scratch: LlamaFamilyScratch<B>,
841
842 pub kv_caches: HashMap<String, Vec<K::Layer>>,
851 kv_free_pool: Vec<Vec<K::Layer>>,
856
857 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
869 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
873 pub paged_dims: Option<(usize, usize)>,
879
880 pub(crate) graph_warmup: usize,
884 pub(crate) graph_capture_failed: bool,
887 pub(crate) batched_graph_warmup: usize,
889 pub(crate) batched_graph_failed: bool,
891 pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
895 pub(crate) batched_pointers_for: Option<Vec<String>>,
900 pub(crate) unified_graph_warmup: usize,
905 pub(crate) unified_graph_failed: bool,
906 pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
907
908 prefix_cache_hits: u64,
910 prefix_cache_misses: u64,
911 prefix_cache_saved_prefill_tokens: u64,
912
913 lora_adapters: HashMap<String, RuntimeLoraAdapter<B>>,
915 lora_cache_adapters: HashMap<String, String>,
916 lora_projection_applications: u64,
917}
918
919impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
920 pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
925 let num_layers = cfg.num_layers;
926 Self::new_with_stage_config(cfg, loader, LlamaFamilyLayerStageConfig::full(num_layers))
927 }
928
929 pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
941 let num_layers = cfg.num_layers;
942 Self::new_with_stage_config(
943 cfg,
944 loader,
945 LlamaFamilyLayerStageConfig::backbone(num_layers),
946 )
947 }
948
949 pub fn new_layer_stage(
953 cfg: LlamaFamilyConfig,
954 loader: &dyn WeightLoader<B>,
955 stage: LlamaFamilyLayerStageConfig,
956 ) -> Result<Self> {
957 Self::new_with_stage_config(cfg, loader, stage)
958 }
959
960 fn new_with_stage_config(
961 cfg: LlamaFamilyConfig,
962 loader: &dyn WeightLoader<B>,
963 stage: LlamaFamilyLayerStageConfig,
964 ) -> Result<Self> {
965 if stage.source_layers.is_empty() {
966 return Err(ferrum_types::FerrumError::model(
967 "llama layer stage must include at least one source layer",
968 ));
969 }
970 if stage.source_layers.end > cfg.num_layers {
971 return Err(ferrum_types::FerrumError::model(format!(
972 "llama layer stage range {}..{} is outside model layer count {}",
973 stage.source_layers.start, stage.source_layers.end, cfg.num_layers
974 )));
975 }
976
977 {
982 let mut ctx = B::new_context();
983 B::reset_all_graphs(&mut ctx);
984 }
985 let rope = build_rope_cache::<B>(&cfg);
986 let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
987 let embed = if stage.load_embedding {
988 Some(loader.load_tensor("model.embed_tokens.weight")?)
989 } else {
990 None
991 };
992 let layers = load_llama_family_layers(&cfg, loader, stage.source_layers.clone())?;
993
994 let lm_head = if stage.load_lm_head {
995 Some(load_llama_family_lm_head(&cfg, loader)?)
996 } else {
997 None
998 };
999 let final_norm_w = loader.load_tensor("model.norm.weight")?;
1000
1001 let layer_source_start = stage.source_layers.start;
1002 let layer_source_end = stage.source_layers.end;
1003 let runtime_cfg = cfg.to_runtime();
1004 let supports_varlen_qkv = B::supports_varlen_qkv();
1007 let supports_batched_decode = B::supports_llama_family_batched_decode();
1008 let runtime_env = LlamaFamilyRuntimeEnv::from_env();
1009 let batched_cfg =
1010 super::llama_family_forward_batched::LlamaBatchedRuntimeConfig::from_env();
1011 Ok(Self {
1012 cfg,
1013 runtime_cfg,
1014 supports_varlen_qkv,
1015 supports_batched_decode,
1016 runtime_env,
1017 batched_cfg,
1018 embed,
1019 layer_source_start,
1020 layer_source_end,
1021 layers,
1022 final_norm_w,
1023 lm_head,
1024 rope,
1025 scratch,
1026 kv_caches: HashMap::new(),
1027 kv_free_pool: Vec::new(),
1028 paged_pools: None,
1029 paged_block_alloc: None,
1030 paged_dims: None,
1031 graph_warmup: 0,
1032 graph_capture_failed: false,
1033 batched_graph_warmup: 0,
1034 batched_graph_failed: false,
1035 batched_graph_keys_seen: std::collections::HashSet::new(),
1036 batched_pointers_for: None,
1037 unified_graph_warmup: 0,
1038 unified_graph_failed: false,
1039 unified_graph_keys_seen: std::collections::HashSet::new(),
1040 prefix_cache_hits: 0,
1041 prefix_cache_misses: 0,
1042 prefix_cache_saved_prefill_tokens: 0,
1043 lora_adapters: HashMap::new(),
1044 lora_cache_adapters: HashMap::new(),
1045 lora_projection_applications: 0,
1046 })
1047 }
1048
1049 pub fn source_layer_range(&self) -> Range<usize> {
1050 self.layer_source_start..self.layer_source_end
1051 }
1052
1053 pub fn local_layer_count(&self) -> usize {
1054 self.layers.len()
1055 }
1056
1057 pub fn cache_len(&self, cache_id: &str) -> usize {
1058 self.kv_caches
1059 .get(cache_id)
1060 .and_then(|layers| layers.first())
1061 .map(K::len)
1062 .unwrap_or(0)
1063 }
1064
1065 fn local_layer_indices(&self) -> Range<usize> {
1066 0..self.local_layer_count()
1067 }
1068
1069 fn source_layer_index(&self, local_layer_index: usize) -> usize {
1070 self.layer_source_start + local_layer_index
1071 }
1072
1073 fn local_layer_index_for_source(&self, source_layer_index: usize) -> Option<usize> {
1074 if source_layer_index < self.layer_source_start
1075 || source_layer_index >= self.layer_source_end
1076 {
1077 return None;
1078 }
1079 Some(source_layer_index - self.layer_source_start)
1080 }
1081
1082 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
1084 if self.scratch.max_tokens < tokens {
1085 {
1090 let mut ctx = B::new_context();
1091 B::reset_all_graphs(&mut ctx);
1092 }
1093 self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
1094 self.graph_warmup = 0;
1095 self.graph_capture_failed = false;
1096 self.batched_graph_keys_seen.clear();
1097 self.batched_graph_warmup = 0;
1098 self.batched_graph_failed = false;
1099 self.unified_graph_keys_seen.clear();
1100 self.unified_graph_warmup = 0;
1101 self.unified_graph_failed = false;
1102 if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
1107 self.scratch
1108 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1109 }
1110 }
1111 }
1112
1113 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
1117 if self.kv_caches.contains_key(cache_id) {
1118 return;
1119 }
1120 let nkv = self.cfg.num_kv_heads;
1121 let hd = self.cfg.head_dim;
1122 let model_max = self.cfg.max_seq_len;
1129 let runtime_env = &self.runtime_env;
1138 let max = runtime_env.kv_capacity_for_model(model_max);
1139
1140 let paged = runtime_env.paged_kv_enabled::<B>();
1156 const PAGED_BLOCK_SIZE: usize = 16;
1157
1158 let max_seqs = runtime_env.paged_max_seqs;
1166 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
1167 let total_pool_blocks = max_seqs * max_blocks_per_seq;
1168
1169 if paged && self.paged_pools.is_none() {
1176 let mut pools = Vec::with_capacity(self.local_layer_count());
1177 for _ in self.local_layer_indices() {
1178 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
1179 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
1180 }
1181 self.paged_pools = Some(pools);
1182 self.paged_block_alloc = Some(std::sync::Mutex::new(
1183 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
1184 ));
1185 }
1186 if paged {
1192 self.scratch
1193 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1194 self.paged_dims = Some((max_seqs, max_blocks_per_seq));
1197 }
1198
1199 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
1207 self.local_layer_indices()
1208 .map(|_| {
1209 if paged {
1210 K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
1211 } else {
1212 K::alloc_contig(max, nkv, hd)
1213 }
1214 })
1215 .collect()
1216 });
1217
1218 if paged {
1224 let alloc_arc = self
1225 .paged_block_alloc
1226 .as_ref()
1227 .expect("paged_block_alloc must be initialised when paged=true");
1228 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1232 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1233 Ok(idx) => idx,
1234 Err(e) => {
1235 drop(alloc);
1242 self.kv_free_pool.push(caches);
1243 eprintln!(
1244 "[ferrum] paged KV pool exhausted on ensure_kv for \
1245 cache_id={cache_id:?}: {e}. Increase \
1246 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1247 throttle concurrent requests.",
1248 );
1249 return;
1250 }
1251 };
1252 let mut padded = block_indices.clone();
1257 padded.resize(max_blocks_per_seq, 0);
1258 let mut ctx_tmp = B::new_context();
1259 for c in caches.iter_mut() {
1260 if let Some(bt) = K::block_table_mut(c) {
1261 B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1262 }
1263 *K::paged_block_indices_mut(c) = block_indices.clone();
1264 }
1265 B::sync(&mut ctx_tmp);
1266 }
1267
1268 for c in caches.iter_mut() {
1272 K::set_len(c, 0);
1273 if let Some(cl) = K::context_lens_mut(c) {
1274 let mut ctx_tmp = B::new_context();
1275 B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1276 B::sync(&mut ctx_tmp);
1277 }
1278 }
1279 self.kv_caches.insert(cache_id.to_string(), caches);
1280 }
1281
1282 fn record_prefix_cache_probe(&mut self, saved_tokens: usize) {
1283 if saved_tokens > 0 {
1284 self.prefix_cache_hits += 1;
1285 self.prefix_cache_saved_prefill_tokens += saved_tokens as u64;
1286 } else {
1287 self.prefix_cache_misses += 1;
1288 }
1289 }
1290
1291 fn try_acquire_prefix_cache(&mut self, cache_id: &str, tokens: &[u32]) -> usize {
1292 let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1293 return 0;
1294 };
1295 let caches = match self.kv_caches.get(cache_id) {
1296 Some(caches) => caches,
1297 None => return 0,
1298 };
1299 let block_size = caches.first().map(K::block_size).unwrap_or(0);
1300 if block_size == 0 {
1301 return 0;
1302 }
1303
1304 let token_ids: Vec<ferrum_types::TokenId> = tokens
1305 .iter()
1306 .map(|&token| ferrum_types::TokenId::new(token))
1307 .collect();
1308 let hashes = block_hash_chain(&token_ids, block_size);
1309 if hashes.is_empty() {
1310 return 0;
1311 }
1312
1313 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1314 let mut matched = Vec::with_capacity(hashes.len());
1315 for hash in hashes {
1316 match alloc.try_acquire_by_hash(hash) {
1317 Some(block) => matched.push(block),
1318 None => break,
1319 }
1320 }
1321 if matched.is_empty() {
1322 return 0;
1323 }
1324 let n_matched = matched.len();
1325
1326 let displaced = caches
1327 .first()
1328 .map(|cache| K::paged_block_indices(cache)[..n_matched].to_vec())
1329 .unwrap_or_default();
1330 alloc.free(&displaced);
1331 drop(alloc);
1332
1333 let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
1334 let max_blocks = caches_mut
1335 .first()
1336 .map(|cache| K::paged_block_indices(cache).len())
1337 .unwrap_or(0);
1338 let new_len = n_matched * block_size;
1339 let mut ctx = B::new_context();
1340 for cache in caches_mut.iter_mut() {
1341 {
1342 let indices = K::paged_block_indices_mut(cache);
1343 for (idx, &block) in matched.iter().enumerate() {
1344 indices[idx] = block;
1345 }
1346 }
1347 K::set_len(cache, new_len);
1348 let padded = {
1349 let mut padded = K::paged_block_indices(cache).to_vec();
1350 padded.resize(max_blocks, 0);
1351 padded
1352 };
1353 if let Some(block_table) = K::block_table_mut(cache) {
1354 B::write_typed::<u32>(&mut ctx, block_table, &padded);
1355 }
1356 if let Some(context_lens) = K::context_lens_mut(cache) {
1357 B::write_typed::<u32>(&mut ctx, context_lens, &[new_len as u32]);
1358 }
1359 }
1360 B::sync(&mut ctx);
1361
1362 new_len
1363 }
1364
1365 fn register_prefix_cache(
1366 &mut self,
1367 cache_id: &str,
1368 all_tokens: &[u32],
1369 prior_cached_tokens: usize,
1370 ) {
1371 let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1372 return;
1373 };
1374 let caches = match self.kv_caches.get(cache_id) {
1375 Some(caches) => caches,
1376 None => return,
1377 };
1378 let cache0 = match caches.first() {
1379 Some(cache) => cache,
1380 None => return,
1381 };
1382 let block_size = K::block_size(cache0);
1383 if block_size == 0 {
1384 return;
1385 }
1386
1387 let token_ids: Vec<ferrum_types::TokenId> = all_tokens
1388 .iter()
1389 .map(|&token| ferrum_types::TokenId::new(token))
1390 .collect();
1391 let hashes = block_hash_chain(&token_ids, block_size);
1392 if hashes.is_empty() {
1393 return;
1394 }
1395
1396 let start_block = prior_cached_tokens / block_size;
1397 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1398 for i in start_block..hashes.len().min(K::paged_block_indices(cache0).len()) {
1399 let block_end_token = (i + 1) * block_size;
1400 if block_end_token > K::len(cache0) {
1401 break;
1402 }
1403 alloc.register_block_hash(K::paged_block_indices(cache0)[i], hashes[i]);
1404 }
1405 }
1406
1407 fn prefix_cache_snapshot_json(&self) -> serde_json::Value {
1408 let (entries, block_size) = self
1409 .paged_block_alloc
1410 .as_ref()
1411 .and_then(|alloc| {
1412 let alloc = alloc.lock().ok()?;
1413 let block_size = self
1414 .kv_caches
1415 .values()
1416 .find_map(|layers| layers.first().map(K::block_size))
1417 .unwrap_or(16);
1418 Some((alloc.hash_table_size() as u64, block_size))
1419 })
1420 .unwrap_or((0, 16));
1421 let bytes_per_entry = (block_size
1422 * self.local_layer_count()
1423 * self.cfg.num_kv_heads
1424 * self.cfg.head_dim
1425 * K::BYTES_PER_ELEM
1426 * 2) as u64;
1427 serde_json::json!({
1428 "position": "real-kv-reuse",
1429 "source": "llama-family-paged-block-prefix-cache",
1430 "enabled": self.runtime_env.prefix_cache,
1431 "hits": self.prefix_cache_hits,
1432 "misses": self.prefix_cache_misses,
1433 "evictions": 0u64,
1434 "saved_prefill_tokens": self.prefix_cache_saved_prefill_tokens,
1435 "entries": entries,
1436 "bytes": entries.saturating_mul(bytes_per_entry),
1437 "block_size": block_size,
1438 "kv_dtype": K::NAME,
1439 })
1440 }
1441
1442 fn lora_projection_shape(
1443 &self,
1444 layer_index: usize,
1445 target_module: &str,
1446 ) -> Option<(usize, usize)> {
1447 let local_layer_index = self.local_layer_index_for_source(layer_index)?;
1448 let layer = self.layers.get(local_layer_index)?;
1449 match target_module {
1450 "qkv_proj" => Some((layer.qkv_proj.in_features(), layer.qkv_proj.out_features())),
1451 "o_proj" => Some((layer.o_proj.in_features(), layer.o_proj.out_features())),
1452 "gate_up_proj" => Some((
1453 layer.gate_up_proj.in_features(),
1454 layer.gate_up_proj.out_features(),
1455 )),
1456 "down_proj" => Some((
1457 layer.down_proj.in_features(),
1458 layer.down_proj.out_features(),
1459 )),
1460 _ => None,
1461 }
1462 }
1463
1464 fn validate_lora_adapter(&self, adapter: &RuntimeLoraAdapter<B>) -> Result<()> {
1465 if adapter.linears.is_empty() {
1466 return Err(ferrum_types::FerrumError::config(format!(
1467 "LoRA adapter {} has no runtime tensors",
1468 adapter.name
1469 )));
1470 }
1471 for linear in &adapter.linears {
1472 let layer_index = linear.layer_index.ok_or_else(|| {
1473 ferrum_types::FerrumError::config(format!(
1474 "LoRA tensor for target {} must include model.layers.<N> in its tensor name",
1475 linear.target_module
1476 ))
1477 })?;
1478 if layer_index < self.cfg.num_layers
1479 && !self.source_layer_range().contains(&layer_index)
1480 {
1481 continue;
1482 }
1483 let Some((expected_in, expected_out)) =
1484 self.lora_projection_shape(layer_index, &linear.target_module)
1485 else {
1486 return Err(ferrum_types::FerrumError::unsupported(format!(
1487 "LoRA target {} is not supported by Llama-family runtime; supported targets: qkv_proj, o_proj, gate_up_proj, down_proj",
1488 linear.target_module
1489 )));
1490 };
1491 if linear.in_features != expected_in || linear.out_features != expected_out {
1492 return Err(ferrum_types::FerrumError::config(format!(
1493 "LoRA tensor shape mismatch for layer {} target {}: got out={} in={}, expected out={} in={}",
1494 layer_index,
1495 linear.target_module,
1496 linear.out_features,
1497 linear.in_features,
1498 expected_out,
1499 expected_in
1500 )));
1501 }
1502 }
1503 Ok(())
1504 }
1505
1506 fn ensure_lora_adapter_loaded(&mut self, adapter: ActiveLoraAdapter) -> Result<()> {
1507 if self.lora_adapters.contains_key(&adapter.name) {
1508 return Ok(());
1509 }
1510 let runtime = load_runtime_lora_adapter::<B>(&adapter)?;
1511 self.validate_lora_adapter(&runtime)?;
1512 self.lora_adapters.insert(adapter.name.clone(), runtime);
1513 Ok(())
1514 }
1515
1516 fn active_lora_adapter_for_cache(&self, cache_id: &str) -> Option<&RuntimeLoraAdapter<B>> {
1517 let adapter_name = self.lora_cache_adapters.get(cache_id)?;
1518 self.lora_adapters.get(adapter_name)
1519 }
1520
1521 fn active_lora_adapter_ptr_for_cache(
1522 &self,
1523 cache_id: &str,
1524 ) -> Option<*const RuntimeLoraAdapter<B>> {
1525 self.active_lora_adapter_for_cache(cache_id)
1526 .map(|adapter| adapter as *const RuntimeLoraAdapter<B>)
1527 }
1528
1529 fn lora_metrics_snapshot_json(&self) -> serde_json::Value {
1530 serde_json::json!({
1531 "enabled": !self.lora_adapters.is_empty(),
1532 "adapter_count": self.lora_adapters.len() as u64,
1533 "active_cache_bindings": self.lora_cache_adapters.len() as u64,
1534 "projection_applications": self.lora_projection_applications,
1535 "position": "real-inference",
1536 "source": "llama-family-runtime-lora",
1537 })
1538 }
1539
1540 #[allow(clippy::too_many_arguments)]
1545 pub(crate) fn forward_layer(
1546 &mut self,
1547 ctx: &mut B::Context,
1548 li: usize,
1549 cache_id: &str,
1550 residual: &mut B::Buffer,
1551 pos_offset: usize,
1552 tokens: usize,
1553 ) {
1554 let source_li = self.source_layer_index(li);
1555 let layer = &self.layers[li];
1556 let cfg = &self.cfg;
1557 let h = cfg.hidden_size;
1558 let nh = cfg.num_heads;
1559 let nkv = cfg.num_kv_heads;
1560 let hd = cfg.head_dim;
1561 let im = cfg.intermediate_size;
1562 let eps = cfg.rms_norm_eps;
1563 let q_dim = nh * hd;
1564 let kv_dim = nkv * hd;
1565
1566 let _t0 = if self.runtime_env.decode_op_profile {
1568 B::sync(ctx);
1569 Some(std::time::Instant::now())
1570 } else {
1571 None
1572 };
1573 B::rms_norm(
1574 ctx,
1575 residual,
1576 &layer.input_ln_w,
1577 eps,
1578 &mut self.scratch.norm_out,
1579 tokens,
1580 h,
1581 );
1582 if let Some(t0) = _t0 {
1583 B::sync(ctx);
1584 NORM_TIME_US.fetch_add(
1585 t0.elapsed().as_micros() as u64,
1586 std::sync::atomic::Ordering::Relaxed,
1587 );
1588 NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1589 }
1590
1591 let _t0 = if self.runtime_env.decode_op_profile {
1593 B::sync(ctx);
1594 Some(std::time::Instant::now())
1595 } else {
1596 None
1597 };
1598 layer.qkv_proj.forward(
1599 ctx,
1600 &self.scratch.norm_out,
1601 &mut self.scratch.qkv_out,
1602 tokens,
1603 );
1604 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1605 let applied = unsafe { &*adapter }
1608 .apply_projection(
1609 ctx,
1610 source_li,
1611 "qkv_proj",
1612 &self.scratch.norm_out,
1613 &mut self.scratch.qkv_out,
1614 tokens,
1615 )
1616 .expect("validated LoRA qkv_proj");
1617 self.lora_projection_applications += applied as u64;
1618 }
1619 if let Some(t0) = _t0 {
1620 B::sync(ctx);
1621 MATMUL_TIME_US.fetch_add(
1622 t0.elapsed().as_micros() as u64,
1623 std::sync::atomic::Ordering::Relaxed,
1624 );
1625 MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1626 }
1627
1628 let qk_mode: i32 = if cfg.has_qk_norm {
1643 1
1644 } else if cfg.rope_interleaved {
1645 3
1646 } else {
1647 2
1648 };
1649 let dummy = &layer.input_ln_w;
1650 let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1651 let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1652
1653 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1664 if let Some(pools) = self.paged_pools.as_mut() {
1665 let pool = &mut pools[li];
1666 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1667 } else {
1668 None
1669 };
1670 let caches = self
1671 .kv_caches
1672 .get_mut(cache_id)
1673 .expect("ensure_kv must be called before forward_layer");
1674 let cache_len_before = K::len(&caches[li]);
1677 let cache_capacity = K::capacity(&caches[li]);
1678 let cache_block_size = K::block_size(&caches[li]);
1679
1680 if cache_len_before + tokens > cache_capacity {
1686 panic!(
1687 "KV cache overflow on source layer {source_li} (local layer {li}): would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1688 cache_len_before + tokens
1689 );
1690 }
1691
1692 if cache_block_size > 0 {
1697 let (pool_k_ptr, pool_v_ptr) =
1698 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1699 let pool_k = unsafe { &mut *pool_k_ptr };
1702 let pool_v = unsafe { &mut *pool_v_ptr };
1703
1704 K::paged_write(
1705 ctx,
1706 &mut caches[li],
1707 &self.scratch.qkv_out,
1708 q_norm_w,
1709 k_norm_w,
1710 &self.rope.cos,
1711 &self.rope.sin,
1712 &mut self.scratch.q_head_major,
1713 &mut self.scratch.k_head_major,
1714 &mut self.scratch.v_head_major,
1715 pool_k,
1716 pool_v,
1717 tokens,
1718 nh,
1719 nkv,
1720 hd,
1721 pos_offset,
1722 eps,
1723 qk_mode,
1724 )
1725 .expect("K::paged_write");
1726
1727 let new_len = cache_len_before + tokens;
1728 K::set_len(&mut caches[li], new_len);
1729
1730 let pool_k_imm = unsafe { &*pool_k_ptr };
1731 let pool_v_imm = unsafe { &*pool_v_ptr };
1732 K::paged_decode_attention(
1733 ctx,
1734 &mut caches[li],
1735 &self.scratch.q_head_major,
1736 pool_k_imm,
1737 pool_v_imm,
1738 &mut self.scratch.attn_head_major_out,
1739 nh,
1740 nkv,
1741 hd,
1742 new_len,
1743 tokens,
1744 )
1745 .expect("K::paged_decode_attention");
1746
1747 return self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1748 }
1749
1750 let _qkr_t0 = if self.runtime_env.decode_op_profile {
1753 B::sync(ctx);
1754 Some(std::time::Instant::now())
1755 } else {
1756 None
1757 };
1758 K::contig_write(
1759 ctx,
1760 &mut caches[li],
1761 &self.scratch.qkv_out,
1762 q_norm_w,
1763 k_norm_w,
1764 &self.rope.cos,
1765 &self.rope.sin,
1766 &mut self.scratch.q_head_major,
1767 &mut self.scratch.k_head_major,
1768 &mut self.scratch.v_head_major,
1769 &mut self.scratch.q_buf,
1770 &mut self.scratch.k_buf,
1771 &mut self.scratch.v_buf,
1772 tokens,
1773 nh,
1774 nkv,
1775 hd,
1776 pos_offset,
1777 eps,
1778 qk_mode,
1779 )
1780 .expect("K::contig_write");
1781 if let Some(t0) = _qkr_t0 {
1782 B::sync(ctx);
1783 QKR_TIME_US.fetch_add(
1784 t0.elapsed().as_micros() as u64,
1785 std::sync::atomic::Ordering::Relaxed,
1786 );
1787 QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1788 }
1789 let new_len = cache_len_before + tokens;
1790 K::set_len(&mut caches[li], new_len);
1791 let kv_stride = cache_capacity;
1792
1793 let _attn_t0 = if self.runtime_env.decode_op_profile {
1794 B::sync(ctx);
1795 Some(std::time::Instant::now())
1796 } else {
1797 None
1798 };
1799 let attn_cfg = ferrum_kernels::backend::AttnConfig {
1800 num_heads: nh,
1801 num_kv_heads: nkv,
1802 head_dim: hd,
1803 causal: true,
1804 scale: 1.0 / (hd as f32).sqrt(),
1805 kv_seq_stride: kv_stride,
1806 sliding_window: cfg.sliding_window,
1807 };
1808 K::contig_decode_attention(
1809 ctx,
1810 &caches[li],
1811 &self.scratch.q_head_major,
1812 &mut self.scratch.attn_head_major_out,
1813 attn_cfg,
1814 tokens,
1815 pos_offset,
1816 )
1817 .expect("K::contig_decode_attention");
1818 let _ = q_dim;
1819 let _ = kv_dim;
1820 let _ = dummy;
1821 if let Some(t0) = _attn_t0 {
1822 B::sync(ctx);
1823 ATTN_TIME_US.fetch_add(
1824 t0.elapsed().as_micros() as u64,
1825 std::sync::atomic::Ordering::Relaxed,
1826 );
1827 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1828 }
1829
1830 self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1831 }
1832
1833 pub(crate) fn forward_layer_post_attn(
1838 &mut self,
1839 ctx: &mut B::Context,
1840 li: usize,
1841 cache_id: &str,
1842 residual: &mut B::Buffer,
1843 tokens: usize,
1844 ) {
1845 let source_li = self.source_layer_index(li);
1846 let layer = &self.layers[li];
1847 let cfg = &self.cfg;
1848 let h = cfg.hidden_size;
1849 let nh = cfg.num_heads;
1850 let hd = cfg.head_dim;
1851 let im = cfg.intermediate_size;
1852 let eps = cfg.rms_norm_eps;
1853
1854 let attn_token_major = if tokens == 1 {
1856 &self.scratch.attn_head_major_out
1857 } else {
1858 B::transpose_head_to_token(
1859 ctx,
1860 &self.scratch.attn_head_major_out,
1861 &mut self.scratch.attn_flat,
1862 tokens,
1863 nh,
1864 hd,
1865 );
1866 &self.scratch.attn_flat
1867 };
1868
1869 layer
1871 .o_proj
1872 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1873 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1874 let applied = unsafe { &*adapter }
1876 .apply_projection(
1877 ctx,
1878 source_li,
1879 "o_proj",
1880 attn_token_major,
1881 &mut self.scratch.o_proj_out,
1882 tokens,
1883 )
1884 .expect("validated LoRA o_proj");
1885 self.lora_projection_applications += applied as u64;
1886 }
1887
1888 B::fused_add_rms_norm(
1890 ctx,
1891 residual,
1892 &self.scratch.o_proj_out,
1893 &layer.post_ln_w,
1894 eps,
1895 &mut self.scratch.norm_out,
1896 tokens,
1897 h,
1898 );
1899
1900 layer.gate_up_proj.forward(
1902 ctx,
1903 &self.scratch.norm_out,
1904 &mut self.scratch.gate_up_out,
1905 tokens,
1906 );
1907 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1908 let applied = unsafe { &*adapter }
1910 .apply_projection(
1911 ctx,
1912 source_li,
1913 "gate_up_proj",
1914 &self.scratch.norm_out,
1915 &mut self.scratch.gate_up_out,
1916 tokens,
1917 )
1918 .expect("validated LoRA gate_up_proj");
1919 self.lora_projection_applications += applied as u64;
1920 }
1921
1922 B::fused_silu_mul_split(
1924 ctx,
1925 &self.scratch.gate_up_out,
1926 &mut self.scratch.silu_out,
1927 tokens,
1928 im,
1929 );
1930
1931 layer.down_proj.forward(
1933 ctx,
1934 &self.scratch.silu_out,
1935 &mut self.scratch.mlp_out,
1936 tokens,
1937 );
1938 if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1939 let applied = unsafe { &*adapter }
1941 .apply_projection(
1942 ctx,
1943 source_li,
1944 "down_proj",
1945 &self.scratch.silu_out,
1946 &mut self.scratch.mlp_out,
1947 tokens,
1948 )
1949 .expect("validated LoRA down_proj");
1950 self.lora_projection_applications += applied as u64;
1951 }
1952
1953 B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1955 }
1956
1957 pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1969 let seq_len = tokens.len();
1970 assert!(seq_len > 0, "forward_verify called with empty tokens");
1971 self.ensure_scratch(seq_len);
1972 self.ensure_kv(cache_id);
1973
1974 let h = self.cfg.hidden_size;
1975 let vocab = self.cfg.vocab_size;
1976
1977 let pos_offset = self
1978 .kv_caches
1979 .get(cache_id)
1980 .and_then(|layers| layers.first())
1981 .map(|c| K::len(c))
1982 .unwrap_or(0);
1983
1984 let mut ctx = B::new_context();
1985 let mut residual = self
1986 .scratch
1987 .residual
1988 .take()
1989 .expect("scratch residual missing (previous call didn't restore)");
1990
1991 let embed = self
1992 .embed
1993 .as_ref()
1994 .expect("forward_verify called on backbone-only model (no embed)");
1995 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1996
1997 for li in self.local_layer_indices() {
1998 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1999 }
2000
2001 B::rms_norm(
2004 &mut ctx,
2005 &residual,
2006 &self.final_norm_w,
2007 self.cfg.rms_norm_eps,
2008 &mut self.scratch.norm_out,
2009 seq_len,
2010 h,
2011 );
2012
2013 let lm_head = self
2017 .lm_head
2018 .as_ref()
2019 .expect("forward_verify called on backbone-only model (no lm_head)");
2020 lm_head.forward(
2021 &mut ctx,
2022 &self.scratch.norm_out,
2023 &mut self.scratch.batch_logits,
2024 seq_len,
2025 );
2026
2027 B::sync(&mut ctx);
2028 self.scratch.residual = Some(residual);
2029 B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
2030 }
2031
2032 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2040 assert!(!tokens.is_empty(), "prefill called with empty token list");
2041 self.ensure_kv(cache_id);
2042
2043 let cache_len_before = self
2044 .kv_caches
2045 .get(cache_id)
2046 .and_then(|layers| layers.first())
2047 .map(K::len)
2048 .unwrap_or(0);
2049 let mut cached_prefix_tokens = if self.runtime_env.prefix_cache && cache_len_before == 0 {
2050 self.try_acquire_prefix_cache(cache_id, tokens)
2051 } else {
2052 0
2053 };
2054 if cached_prefix_tokens >= tokens.len() {
2055 let block_size = self
2056 .kv_caches
2057 .get(cache_id)
2058 .and_then(|layers| layers.first())
2059 .map(K::block_size)
2060 .unwrap_or(16);
2061 cached_prefix_tokens = cached_prefix_tokens
2062 .saturating_sub(block_size)
2063 .min(tokens.len() - 1);
2064 }
2065 if self.runtime_env.prefix_cache && cache_len_before == 0 {
2066 self.record_prefix_cache_probe(cached_prefix_tokens);
2067 }
2068
2069 if cached_prefix_tokens > 0 {
2070 let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
2071 let mut ctx_tmp = B::new_context();
2072 for cache in caches_mut.iter_mut() {
2073 if K::len(cache) != cached_prefix_tokens {
2074 K::set_len(cache, cached_prefix_tokens);
2075 if let Some(context_lens) = K::context_lens_mut(cache) {
2076 B::write_typed::<u32>(
2077 &mut ctx_tmp,
2078 context_lens,
2079 &[cached_prefix_tokens as u32],
2080 );
2081 }
2082 }
2083 }
2084 B::sync(&mut ctx_tmp);
2085 }
2086
2087 let suffix_tokens = &tokens[cached_prefix_tokens..];
2088 let seq_len = suffix_tokens.len();
2089 assert!(
2090 seq_len > 0,
2091 "prefix cache must leave at least one suffix token"
2092 );
2093 self.ensure_scratch(seq_len);
2094
2095 let pos_offset = self
2098 .kv_caches
2099 .get(cache_id)
2100 .and_then(|layers| layers.first())
2101 .map(|c| K::len(c))
2102 .unwrap_or(0);
2103
2104 let h = self.cfg.hidden_size;
2105 let vocab = self.cfg.vocab_size;
2106 let mut ctx = B::new_context();
2107
2108 let mut residual = self
2115 .scratch
2116 .residual
2117 .take()
2118 .expect("scratch residual missing (previous call didn't restore)");
2119 let embed = self
2120 .embed
2121 .as_ref()
2122 .expect("prefill_internal called on backbone-only model (no embed)");
2123 B::embedding_lookup(&mut ctx, embed, suffix_tokens, &mut residual, h);
2124
2125 let prefill_profile = self.runtime_env.prefill_op_profile;
2126 let prefill_t0 = if prefill_profile {
2127 B::sync(&mut ctx);
2128 Some(std::time::Instant::now())
2129 } else {
2130 None
2131 };
2132
2133 for li in self.local_layer_indices() {
2134 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2135 }
2136
2137 if let Some(t0) = prefill_t0 {
2138 B::sync(&mut ctx);
2139 let total_us = t0.elapsed().as_micros() as u64;
2140 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2141 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2142 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2143 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2144 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2145 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2146 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2147 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2148 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2149 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2150 eprintln!(
2151 "[prefill-profile] tokens={} layers total={} ms",
2152 seq_len,
2153 total_us / 1000
2154 );
2155 let bucket = |label: &str, n: u64, us: u64| {
2156 if n > 0 {
2157 eprintln!(
2158 "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
2159 n,
2160 us / 1000,
2161 us / n
2162 );
2163 }
2164 };
2165 bucket("flash_attn", attn_n, attn_us);
2166 bucket("qk_norm_rope", qkr_n, qkr_us);
2167 bucket("matmuls", mm_n, mm_us);
2168 bucket("norms", norm_n, norm_us);
2169 bucket("other", other_n, other_us);
2170 }
2171
2172 B::copy_slice(
2174 &mut ctx,
2175 &residual,
2176 (seq_len - 1) * h,
2177 &mut self.scratch.last_hidden,
2178 0,
2179 h,
2180 );
2181
2182 B::rms_norm(
2184 &mut ctx,
2185 &self.scratch.last_hidden,
2186 &self.final_norm_w,
2187 self.cfg.rms_norm_eps,
2188 &mut self.scratch.last_normed,
2189 1,
2190 h,
2191 );
2192
2193 let lm_head = self
2195 .lm_head
2196 .as_ref()
2197 .expect("prefill_internal called on backbone-only model (no lm_head)");
2198 lm_head.forward(
2199 &mut ctx,
2200 &self.scratch.last_normed,
2201 &mut self.scratch.logits,
2202 1,
2203 );
2204
2205 B::sync(&mut ctx);
2212
2213 self.scratch.residual = Some(residual);
2215 if self.runtime_env.prefix_cache && cache_len_before == 0 {
2216 self.register_prefix_cache(cache_id, tokens, cached_prefix_tokens);
2217 }
2218
2219 B::to_vec(&self.scratch.logits, vocab)
2220 }
2221
2222 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2224 self.ensure_scratch(1);
2225 self.ensure_kv(cache_id);
2226
2227 let h = self.cfg.hidden_size;
2228 let vocab = self.cfg.vocab_size;
2229
2230 let mut ctx = B::new_context();
2233
2234 const GRAPH_WARMUP: usize = 3;
2239 let graph_enabled = self.runtime_env.cuda_graph;
2240
2241 if graph_enabled {
2242 B::set_decode_state(&mut ctx, token, pos);
2245
2246 match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
2250 Ok(true) => {
2251 B::sync(&mut ctx);
2252 return B::to_vec(&self.scratch.logits, vocab);
2253 }
2254 Ok(false) => { }
2255 Err(_) => { }
2256 }
2257 }
2258
2259 let should_capture =
2260 graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
2261
2262 if should_capture {
2263 B::set_dev_state_mode(&mut ctx, true);
2264 if B::begin_graph_capture(&mut ctx).is_err() {
2265 self.graph_capture_failed = true;
2266 B::set_dev_state_mode(&mut ctx, false);
2267 }
2268 }
2269
2270 let mut residual = self
2276 .scratch
2277 .residual
2278 .take()
2279 .expect("scratch residual missing (previous call didn't restore)");
2280 let embed = self
2281 .embed
2282 .as_ref()
2283 .expect("decode_internal called on backbone-only model (no embed)");
2284 B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
2285
2286 let layer_profile = self.runtime_env.decode_layer_profile;
2290 let mut layer_times = if layer_profile {
2291 Some(Vec::with_capacity(self.local_layer_count()))
2292 } else {
2293 None
2294 };
2295
2296 for li in self.local_layer_indices() {
2297 if layer_profile {
2298 let t0 = std::time::Instant::now();
2299 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2300 B::sync(&mut ctx);
2301 let elapsed_us = t0.elapsed().as_micros() as u64;
2302 if let Some(v) = layer_times.as_mut() {
2303 v.push(elapsed_us);
2304 }
2305 } else {
2306 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2307 }
2308 }
2309 if let Some(times) = layer_times.take() {
2310 let sum: u64 = times.iter().sum();
2311 let avg = sum / times.len() as u64;
2312 let mn = *times.iter().min().unwrap_or(&0);
2313 let mx = *times.iter().max().unwrap_or(&0);
2314 eprintln!(
2315 "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
2316 times.len(),
2317 sum / 1000,
2318 avg,
2319 mn,
2320 mx
2321 );
2322 for (i, t) in times.iter().enumerate() {
2323 eprint!("L{i}={}ms ", t / 1000);
2324 if (i + 1) % 6 == 0 {
2325 eprintln!();
2326 }
2327 }
2328 eprintln!();
2329 let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2330 let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2331 let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2332 let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2333 let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2334 let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2335 let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2336 let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2337 let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2338 let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2339 eprintln!(
2340 "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
2341 attn_n,
2342 attn_us / 1000,
2343 if attn_n > 0 { attn_us / attn_n } else { 0 }
2344 );
2345 eprintln!(
2346 "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
2347 qkr_n,
2348 qkr_us / 1000,
2349 if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
2350 );
2351 eprintln!(
2352 "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
2353 mm_n,
2354 mm_us / 1000,
2355 if mm_n > 0 { mm_us / mm_n } else { 0 }
2356 );
2357 eprintln!(
2358 "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
2359 norm_n,
2360 norm_us / 1000,
2361 if norm_n > 0 { norm_us / norm_n } else { 0 }
2362 );
2363 eprintln!(
2364 "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
2365 other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
2366 );
2367 }
2368
2369 B::rms_norm(
2370 &mut ctx,
2371 &residual,
2372 &self.final_norm_w,
2373 self.cfg.rms_norm_eps,
2374 &mut self.scratch.last_normed,
2375 1,
2376 h,
2377 );
2378
2379 let lm_head = self
2380 .lm_head
2381 .as_ref()
2382 .expect("decode_internal called on backbone-only model (no lm_head)");
2383 lm_head.forward(
2384 &mut ctx,
2385 &self.scratch.last_normed,
2386 &mut self.scratch.logits,
2387 1,
2388 );
2389
2390 if should_capture && !self.graph_capture_failed {
2391 if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2392 self.graph_capture_failed = true;
2393 } else {
2394 if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2401 self.graph_capture_failed = true;
2402 }
2403 }
2404 B::set_dev_state_mode(&mut ctx, false);
2405 } else {
2406 self.graph_warmup += 1;
2407 }
2408
2409 B::sync(&mut ctx);
2416 self.scratch.residual = Some(residual);
2417
2418 B::to_vec(&self.scratch.logits, vocab)
2419 }
2420
2421 fn stage_tokens_to_hidden(
2422 &mut self,
2423 cache_id: &str,
2424 tokens: &[u32],
2425 pos_offset: usize,
2426 ) -> Vec<f32> {
2427 let seq_len = tokens.len();
2428 assert!(seq_len > 0, "stage token forward called with zero length");
2429 self.ensure_scratch(seq_len);
2430 self.ensure_kv(cache_id);
2431
2432 let h = self.cfg.hidden_size;
2433 let mut ctx = B::new_context();
2434 let mut residual = self
2435 .scratch
2436 .residual
2437 .take()
2438 .expect("scratch residual missing (previous call didn't restore)");
2439 let embed = self
2440 .embed
2441 .as_ref()
2442 .expect("stage token forward called on stage without embedding");
2443 B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
2444
2445 for li in self.local_layer_indices() {
2446 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2447 }
2448
2449 B::sync(&mut ctx);
2450 let out = B::to_vec(&residual, seq_len * h);
2451 self.scratch.residual = Some(residual);
2452 out
2453 }
2454
2455 pub fn prefill_stage_tokens_to_hidden(
2458 &mut self,
2459 cache_id: &str,
2460 tokens: &[u32],
2461 pos_offset: usize,
2462 ) -> Vec<f32> {
2463 self.stage_tokens_to_hidden(cache_id, tokens, pos_offset)
2464 }
2465
2466 pub fn decode_stage_token_to_hidden(
2469 &mut self,
2470 cache_id: &str,
2471 token: u32,
2472 pos: u32,
2473 ) -> Vec<f32> {
2474 self.stage_tokens_to_hidden(cache_id, &[token], pos as usize)
2475 }
2476
2477 fn stage_hidden_from_host(
2478 &mut self,
2479 cache_id: &str,
2480 hidden: &[f32],
2481 seq_len: usize,
2482 pos_offset: usize,
2483 ) -> Vec<f32> {
2484 self.stage_hidden_from_host_with_timing(cache_id, hidden, seq_len, pos_offset)
2485 .0
2486 }
2487
2488 fn stage_hidden_from_host_with_timing(
2489 &mut self,
2490 cache_id: &str,
2491 hidden: &[f32],
2492 seq_len: usize,
2493 pos_offset: usize,
2494 ) -> (Vec<f32>, LlamaStageHiddenBridgeTiming) {
2495 let h = self.cfg.hidden_size;
2496 assert_eq!(
2497 hidden.len(),
2498 seq_len * h,
2499 "hidden length {} != seq_len * hidden_size {}",
2500 hidden.len(),
2501 seq_len * h
2502 );
2503 assert!(seq_len > 0, "stage hidden forward called with zero length");
2504
2505 self.ensure_scratch(seq_len);
2506 self.ensure_kv(cache_id);
2507
2508 let mut ctx = B::new_context();
2509 let mut residual = self
2510 .scratch
2511 .residual
2512 .take()
2513 .expect("scratch residual missing (previous call didn't restore)");
2514
2515 let bridge_t0 = std::time::Instant::now();
2516 let host_copy_t0 = std::time::Instant::now();
2517 let hidden_buf = B::from_slice(hidden);
2518 let host_copy_us = elapsed_micros_u64_floor1(host_copy_t0);
2519 let device_copy_t0 = std::time::Instant::now();
2520 B::copy_slice(&mut ctx, &hidden_buf, 0, &mut residual, 0, seq_len * h);
2521 let device_copy_us = elapsed_micros_u64_floor1(device_copy_t0);
2522 let bridge_timing = LlamaStageHiddenBridgeTiming {
2523 bridge_us: elapsed_micros_u64_floor1(bridge_t0),
2524 host_copy_us,
2525 device_copy_us,
2526 };
2527
2528 for li in self.local_layer_indices() {
2529 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2530 }
2531
2532 B::sync(&mut ctx);
2533 let out = B::to_vec(&residual, seq_len * h);
2534 self.scratch.residual = Some(residual);
2535 (out, bridge_timing)
2536 }
2537
2538 pub fn prefill_stage_hidden_from_host(
2542 &mut self,
2543 cache_id: &str,
2544 hidden: &[f32],
2545 seq_len: usize,
2546 pos_offset: usize,
2547 ) -> Vec<f32> {
2548 self.stage_hidden_from_host(cache_id, hidden, seq_len, pos_offset)
2549 }
2550
2551 pub fn decode_stage_hidden_from_host(
2554 &mut self,
2555 cache_id: &str,
2556 hidden: &[f32],
2557 pos: u32,
2558 ) -> Vec<f32> {
2559 self.stage_hidden_from_host(cache_id, hidden, 1, pos as usize)
2560 }
2561
2562 pub(crate) fn decode_stage_hidden_from_host_with_timing(
2563 &mut self,
2564 cache_id: &str,
2565 hidden: &[f32],
2566 pos: u32,
2567 ) -> (Vec<f32>, LlamaStageHiddenBridgeTiming) {
2568 self.stage_hidden_from_host_with_timing(cache_id, hidden, 1, pos as usize)
2569 }
2570
2571 pub fn logits_from_hidden(&mut self, hidden: &[f32]) -> Vec<f32> {
2575 let h = self.cfg.hidden_size;
2576 let vocab = self.cfg.vocab_size;
2577 assert_eq!(
2578 hidden.len(),
2579 h,
2580 "hidden length {} != hidden_size {}",
2581 hidden.len(),
2582 h
2583 );
2584 self.ensure_scratch(1);
2585
2586 let mut ctx = B::new_context();
2587 let hidden_buf = B::from_slice(hidden);
2588 B::rms_norm(
2589 &mut ctx,
2590 &hidden_buf,
2591 &self.final_norm_w,
2592 self.cfg.rms_norm_eps,
2593 &mut self.scratch.last_normed,
2594 1,
2595 h,
2596 );
2597 let lm_head = self
2598 .lm_head
2599 .as_ref()
2600 .expect("logits_from_hidden called on stage without lm_head");
2601 lm_head.forward(
2602 &mut ctx,
2603 &self.scratch.last_normed,
2604 &mut self.scratch.logits,
2605 1,
2606 );
2607 B::sync(&mut ctx);
2608 B::to_vec(&self.scratch.logits, vocab)
2609 }
2610
2611 pub fn prefill_from_embeds(
2620 &mut self,
2621 cache_id: &str,
2622 embeds: &[f32],
2623 seq_len: usize,
2624 ) -> Vec<f32> {
2625 let h = self.cfg.hidden_size;
2626 assert_eq!(
2627 embeds.len(),
2628 seq_len * h,
2629 "embeds length {} != seq_len * hidden_size {}",
2630 embeds.len(),
2631 seq_len * h
2632 );
2633 assert!(seq_len > 0, "prefill_from_embeds called with zero length");
2634
2635 self.ensure_scratch(seq_len);
2636 self.ensure_kv(cache_id);
2637
2638 let mut ctx = B::new_context();
2639 let mut residual = self
2640 .scratch
2641 .residual
2642 .take()
2643 .expect("scratch residual missing (previous call didn't restore)");
2644
2645 let embed_buf = B::from_slice(embeds);
2647 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2648
2649 for li in self.local_layer_indices() {
2650 self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
2651 }
2652
2653 B::copy_slice(
2654 &mut ctx,
2655 &residual,
2656 (seq_len - 1) * h,
2657 &mut self.scratch.last_hidden,
2658 0,
2659 h,
2660 );
2661 B::sync(&mut ctx);
2662 self.scratch.residual = Some(residual);
2663 B::to_vec(&self.scratch.last_hidden, h)
2664 }
2665
2666 pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
2670 let h = self.cfg.hidden_size;
2671 assert_eq!(
2672 embed.len(),
2673 h,
2674 "embed length {} != hidden_size {}",
2675 embed.len(),
2676 h
2677 );
2678
2679 self.ensure_scratch(1);
2680 self.ensure_kv(cache_id);
2681
2682 let mut ctx = B::new_context();
2683 let mut residual = self
2684 .scratch
2685 .residual
2686 .take()
2687 .expect("scratch residual missing (previous call didn't restore)");
2688
2689 let embed_buf = B::from_slice(embed);
2690 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2691
2692 for li in self.local_layer_indices() {
2693 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2694 }
2695
2696 B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
2697 B::sync(&mut ctx);
2698 self.scratch.residual = Some(residual);
2699 B::to_vec(&self.scratch.last_hidden, h)
2700 }
2701
2702 pub fn prefill_all_post_norm(
2713 &mut self,
2714 cache_id: &str,
2715 embeds: &[f32],
2716 seq_len: usize,
2717 pos_offset: usize,
2718 ) -> Vec<f32> {
2719 let h = self.cfg.hidden_size;
2720 assert_eq!(
2721 embeds.len(),
2722 seq_len * h,
2723 "embeds length {} != seq_len * hidden_size {}",
2724 embeds.len(),
2725 seq_len * h
2726 );
2727 assert!(seq_len > 0);
2728
2729 self.ensure_scratch(seq_len);
2730 self.ensure_kv(cache_id);
2731
2732 let mut ctx = B::new_context();
2733 let mut residual = self
2734 .scratch
2735 .residual
2736 .take()
2737 .expect("scratch residual missing (previous call didn't restore)");
2738
2739 let embed_buf = B::from_slice(embeds);
2740 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2741
2742 for li in self.local_layer_indices() {
2743 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2744 }
2745
2746 B::rms_norm(
2748 &mut ctx,
2749 &residual,
2750 &self.final_norm_w,
2751 self.cfg.rms_norm_eps,
2752 &mut self.scratch.norm_out,
2753 seq_len,
2754 h,
2755 );
2756 B::sync(&mut ctx);
2757 self.scratch.residual = Some(residual);
2758 B::to_vec(&self.scratch.norm_out, seq_len * h)
2759 }
2760
2761 pub fn decode_post_norm_from_embed(
2765 &mut self,
2766 cache_id: &str,
2767 embed: &[f32],
2768 pos: u32,
2769 ) -> Vec<f32> {
2770 let h = self.cfg.hidden_size;
2771 assert_eq!(embed.len(), h);
2772
2773 self.ensure_scratch(1);
2774 self.ensure_kv(cache_id);
2775
2776 let mut ctx = B::new_context();
2777 let mut residual = self
2778 .scratch
2779 .residual
2780 .take()
2781 .expect("scratch residual missing (previous call didn't restore)");
2782
2783 let embed_buf = B::from_slice(embed);
2784 B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2785
2786 for li in self.local_layer_indices() {
2787 self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2788 }
2789
2790 B::rms_norm(
2791 &mut ctx,
2792 &residual,
2793 &self.final_norm_w,
2794 self.cfg.rms_norm_eps,
2795 &mut self.scratch.last_normed,
2796 1,
2797 h,
2798 );
2799 B::sync(&mut ctx);
2800 self.scratch.residual = Some(residual);
2801 B::to_vec(&self.scratch.last_normed, h)
2802 }
2803}
2804
2805impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2807 fn config(&self) -> &LlmRuntimeConfig {
2808 &self.runtime_cfg
2809 }
2810
2811 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2812 Some(self.prefix_cache_snapshot_json())
2813 }
2814
2815 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
2816 Some(self.lora_metrics_snapshot_json())
2817 }
2818
2819 fn set_lora_adapter_for_cache(
2820 &mut self,
2821 cache_id: &str,
2822 adapter: Option<ActiveLoraAdapter>,
2823 ) -> std::result::Result<(), ferrum_types::FerrumError> {
2824 if let Some(adapter) = adapter {
2825 self.ensure_lora_adapter_loaded(adapter.clone())?;
2826 self.lora_cache_adapters
2827 .insert(cache_id.to_string(), adapter.name);
2828 } else {
2829 self.lora_cache_adapters.remove(cache_id);
2830 }
2831 Ok(())
2832 }
2833
2834 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2835 self.ensure_scratch(max_tokens);
2836 self.ensure_kv(cache_id);
2837
2838 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2839 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2840 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2841 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2842 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2843 if let Some(c0) = caches.first() {
2844 if !c0.paged_block_indices.is_empty() {
2845 alloc.free(&c0.paged_block_indices);
2846 }
2847 }
2848 for c in caches.iter_mut() {
2849 c.paged_block_indices.clear();
2850 }
2851 }
2852 self.kv_free_pool.push(caches);
2853 }
2854 }
2855
2856 fn kv_capacity(&self) -> usize {
2857 self.runtime_env.kv_capacity_for_model(self.cfg.max_seq_len)
2858 }
2859
2860 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2861 self.prefill_internal(cache_id, tokens)
2862 }
2863
2864 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2865 self.decode_internal(cache_id, token, pos)
2866 }
2867
2868 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2869 self.decode_batch_with_full_logits(batch, false)
2870 }
2871
2872 fn decode_batch_with_full_logits(
2873 &mut self,
2874 batch: &[(String, u32, u32)],
2875 force_full_logits: bool,
2876 ) -> Vec<Vec<f32>> {
2877 self.decode_batch_internal_with_full_logits(batch, force_full_logits)
2878 }
2879
2880 fn unified_forward(
2881 &mut self,
2882 items: &[(String, Vec<u32>, usize, bool)],
2883 ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2884 if items.is_empty() {
2885 return Ok(Vec::new());
2886 }
2887 if self.runtime_env.prefix_cache
2888 && items
2889 .iter()
2890 .any(|(_, tokens, pos_offset, _)| *pos_offset == 0 && tokens.len() > 1)
2891 {
2892 return Err(ferrum_types::FerrumError::unsupported(
2893 "LlamaFamilyModel::unified_forward: fresh prefill with prefix cache enabled \
2894 routes through prefill_internal so real paged-block KV reuse can probe and \
2895 register block hashes",
2896 ));
2897 }
2898 if !self.supports_varlen_qkv {
2899 return Err(ferrum_types::FerrumError::unsupported(
2900 "LlamaFamilyModel::unified_forward: backend lacks varlen \
2901 QKV kernels. Engine will fall back to per-item dispatch.",
2902 ));
2903 }
2904 if items
2905 .iter()
2906 .any(|(cache_id, _, _, _)| self.active_lora_adapter_for_cache(cache_id).is_some())
2907 {
2908 return Err(ferrum_types::FerrumError::unsupported(
2909 "LlamaFamilyModel::unified_forward: active LoRA adapter routes through \
2910 per-item dispatch until unified LoRA supports row-selective adapters.",
2911 ));
2912 }
2913 self.ensure_kv(&items[0].0);
2914 if self.paged_pools.is_none() {
2915 return Err(ferrum_types::FerrumError::unsupported(
2916 "LlamaFamilyModel::unified_forward: paged KV required; \
2917 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2918 Engine will fall back to per-item dispatch.",
2919 ));
2920 }
2921 for (cid, _, _, _) in items {
2922 self.ensure_kv(cid);
2923 if !self.kv_caches.contains_key(cid) {
2924 return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2925 "paged KV pool exhausted for cache_id={cid:?}; back off"
2926 )));
2927 }
2928 }
2929 Ok(self.unified_forward_internal(items))
2930 }
2931
2932 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2933 LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2934 }
2935
2936 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2937 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2938 for c in caches.iter_mut() {
2939 if new_len < c.len {
2940 c.len = new_len;
2941 }
2942 }
2943 }
2944 let mut ctx = B::new_context();
2945 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2946 self.graph_warmup = 0;
2947 self.graph_capture_failed = false;
2948 }
2949
2950 fn release(&mut self, cache_id: &str) {
2951 let mut ctx = B::new_context();
2952 B::sync(&mut ctx);
2953 self.graph_warmup = 0;
2954 self.graph_capture_failed = false;
2955 self.lora_cache_adapters.remove(cache_id);
2956 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2957 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2958 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2959 if let Some(c0) = caches.first() {
2960 if !c0.paged_block_indices.is_empty() {
2961 alloc.free(&c0.paged_block_indices);
2962 }
2963 }
2964 for c in caches.iter_mut() {
2965 c.paged_block_indices.clear();
2966 }
2967 }
2968 self.kv_free_pool.push(caches);
2969 }
2970 }
2971
2972 fn reset(&mut self) {
2973 let mut ctx = B::new_context();
2974 B::sync(&mut ctx);
2975 B::reset_all_graphs(&mut ctx);
2976 B::sync(&mut ctx);
2977 self.graph_warmup = 0;
2978 self.graph_capture_failed = false;
2979 self.batched_graph_keys_seen.clear();
2980 self.batched_graph_warmup = 0;
2981 self.batched_graph_failed = false;
2982 self.kv_caches.clear();
2983 self.kv_free_pool.clear();
2984 self.lora_cache_adapters.clear();
2985 }
2986}
2987
2988impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2992 fn config(&self) -> &LlmRuntimeConfig {
2993 &self.runtime_cfg
2994 }
2995
2996 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2997 Some(self.prefix_cache_snapshot_json())
2998 }
2999
3000 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
3001 Some(self.lora_metrics_snapshot_json())
3002 }
3003
3004 fn set_lora_adapter_for_cache(
3005 &mut self,
3006 cache_id: &str,
3007 adapter: Option<ActiveLoraAdapter>,
3008 ) -> std::result::Result<(), ferrum_types::FerrumError> {
3009 if let Some(adapter) = adapter {
3010 self.ensure_lora_adapter_loaded(adapter.clone())?;
3011 self.lora_cache_adapters
3012 .insert(cache_id.to_string(), adapter.name);
3013 } else {
3014 self.lora_cache_adapters.remove(cache_id);
3015 }
3016 Ok(())
3017 }
3018
3019 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
3020 self.ensure_scratch(max_tokens);
3021 self.ensure_kv(cache_id);
3022
3023 const WARMUP_CACHE: &str = "__ferrum_warmup__";
3024 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
3025 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
3026 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
3027 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
3028 if let Some(c0) = caches.first() {
3029 if !c0.paged_block_indices.is_empty() {
3030 alloc.free(&c0.paged_block_indices);
3031 }
3032 }
3033 for c in caches.iter_mut() {
3034 c.paged_block_indices.clear();
3035 }
3036 }
3037 self.kv_free_pool.push(caches);
3038 }
3039 }
3040
3041 fn kv_capacity(&self) -> usize {
3042 self.runtime_env.kv_capacity_for_model(self.cfg.max_seq_len)
3043 }
3044
3045 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
3046 self.prefill_internal(cache_id, tokens)
3047 }
3048
3049 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
3050 self.decode_internal(cache_id, token, pos)
3051 }
3052
3053 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
3056 LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
3057 }
3058
3059 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
3060 if let Some(caches) = self.kv_caches.get_mut(cache_id) {
3061 for c in caches.iter_mut() {
3062 if new_len < c.len {
3063 c.len = new_len;
3064 }
3065 }
3066 }
3067 let mut ctx = B::new_context();
3068 B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
3069 self.graph_warmup = 0;
3070 self.graph_capture_failed = false;
3071 }
3072
3073 fn release(&mut self, cache_id: &str) {
3074 let mut ctx = B::new_context();
3075 B::sync(&mut ctx);
3076 self.graph_warmup = 0;
3077 self.graph_capture_failed = false;
3078 self.lora_cache_adapters.remove(cache_id);
3079 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
3080 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
3081 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
3082 if let Some(c0) = caches.first() {
3083 if !c0.paged_block_indices.is_empty() {
3084 alloc.free(&c0.paged_block_indices);
3085 }
3086 }
3087 for c in caches.iter_mut() {
3088 c.paged_block_indices.clear();
3089 }
3090 }
3091 self.kv_free_pool.push(caches);
3092 }
3093 }
3094
3095 fn reset(&mut self) {
3096 let mut ctx = B::new_context();
3097 B::sync(&mut ctx);
3098 B::reset_all_graphs(&mut ctx);
3099 B::sync(&mut ctx);
3100 self.graph_warmup = 0;
3101 self.graph_capture_failed = false;
3102 self.kv_caches.clear();
3103 self.kv_free_pool.clear();
3104 self.lora_cache_adapters.clear();
3105 }
3106}
3107
3108fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
3109 let hd = cfg.head_dim;
3110 let half = hd / 2;
3111 let max = cfg.max_seq_len;
3112 let mut cos = vec![0.0f32; max * half];
3113 let mut sin = vec![0.0f32; max * half];
3114 for pos in 0..max {
3115 for i in 0..half {
3116 let freq = rope_freq(cfg, i);
3117 let angle = pos as f64 * freq;
3118 cos[pos * half + i] = angle.cos() as f32;
3119 sin[pos * half + i] = angle.sin() as f32;
3120 }
3121 }
3122 RopeCache {
3123 cos: B::from_slice(&cos),
3124 sin: B::from_slice(&sin),
3125 }
3126}
3127
3128fn rope_freq(cfg: &LlamaFamilyConfig, pair_idx: usize) -> f64 {
3129 let base_freq = 1.0f64
3130 / cfg
3131 .rope_theta
3132 .powf((2 * pair_idx) as f64 / cfg.head_dim as f64);
3133 match &cfg.rope_scaling {
3134 Some(RopeScalingConfig::Llama3 {
3135 factor,
3136 low_freq_factor,
3137 high_freq_factor,
3138 original_max_position_embeddings,
3139 }) => scale_llama3_rope_freq(
3140 base_freq,
3141 *factor,
3142 *low_freq_factor,
3143 *high_freq_factor,
3144 *original_max_position_embeddings,
3145 ),
3146 None => base_freq,
3147 }
3148}
3149
3150fn scale_llama3_rope_freq(
3151 freq: f64,
3152 factor: f64,
3153 low_freq_factor: f64,
3154 high_freq_factor: f64,
3155 original_max_position_embeddings: f64,
3156) -> f64 {
3157 let wavelen = 2.0 * std::f64::consts::PI / freq;
3158 let low_freq_wavelen = original_max_position_embeddings / low_freq_factor;
3159 let high_freq_wavelen = original_max_position_embeddings / high_freq_factor;
3160 if wavelen < high_freq_wavelen {
3161 freq
3162 } else if wavelen > low_freq_wavelen {
3163 freq / factor
3164 } else {
3165 let smooth = (original_max_position_embeddings / wavelen - low_freq_factor)
3166 / (high_freq_factor - low_freq_factor);
3167 (1.0 - smooth) * freq / factor + smooth * freq
3168 }
3169}
3170
3171#[cfg(test)]
3172mod tests {
3173 use std::sync::Mutex;
3174
3175 use ferrum_kernels::backend::cpu::CpuBackend;
3176 use ferrum_quantization::{DenseLinear, QuantConfig, WeightLoader};
3177 use ferrum_types::Result;
3178
3179 use super::{
3180 load_llama_family_layers, LlamaFamilyConfig, LlamaFamilyLayerStageConfig, LlamaFamilyModel,
3181 LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY,
3182 };
3183
3184 #[derive(Default)]
3185 struct RecordingLoader {
3186 tensors: Mutex<Vec<String>>,
3187 linears: Mutex<Vec<String>>,
3188 }
3189
3190 impl WeightLoader<CpuBackend> for RecordingLoader {
3191 fn load_tensor(&self, name: &str) -> Result<Vec<f32>> {
3192 self.tensors.lock().unwrap().push(name.to_string());
3193 Ok(vec![1.0])
3194 }
3195
3196 fn load_linear(
3197 &self,
3198 name: &str,
3199 ) -> Result<Box<dyn ferrum_quantization::Linear<CpuBackend>>> {
3200 self.linears.lock().unwrap().push(name.to_string());
3201 Ok(Box::new(DenseLinear::<CpuBackend>::from_rows(&[1.0], 1, 1)))
3202 }
3203
3204 fn has_tensor(&self, _name: &str) -> bool {
3205 false
3206 }
3207
3208 fn quant_config(&self) -> Option<&QuantConfig> {
3209 None
3210 }
3211 }
3212
3213 fn test_llama_config(num_layers: usize, has_qk_norm: bool) -> LlamaFamilyConfig {
3214 LlamaFamilyConfig {
3215 hidden_size: 1,
3216 intermediate_size: 1,
3217 num_heads: 1,
3218 num_kv_heads: 1,
3219 head_dim: 1,
3220 num_layers,
3221 vocab_size: 1,
3222 max_seq_len: 1,
3223 rms_norm_eps: 1e-5,
3224 rope_theta: 10_000.0,
3225 rope_scaling: None,
3226 rope_interleaved: false,
3227 has_qk_norm,
3228 sliding_window: 0,
3229 }
3230 }
3231
3232 #[test]
3233 fn llama_family_layer_loader_uses_source_layer_range() {
3234 let cfg = test_llama_config(5, true);
3235 let loader = RecordingLoader::default();
3236
3237 let layers = load_llama_family_layers(&cfg, &loader, 2..4).unwrap();
3238
3239 assert_eq!(layers.len(), 2);
3240 assert!(layers.iter().all(|layer| layer.q_norm_w.is_some()));
3241 assert!(layers.iter().all(|layer| layer.k_norm_w.is_some()));
3242 assert_eq!(
3243 loader.tensors.into_inner().unwrap(),
3244 vec![
3245 "model.layers.2.input_layernorm.weight",
3246 "model.layers.2.post_attention_layernorm.weight",
3247 "model.layers.2.self_attn.q_norm.weight",
3248 "model.layers.2.self_attn.k_norm.weight",
3249 "model.layers.3.input_layernorm.weight",
3250 "model.layers.3.post_attention_layernorm.weight",
3251 "model.layers.3.self_attn.q_norm.weight",
3252 "model.layers.3.self_attn.k_norm.weight",
3253 ]
3254 );
3255 assert_eq!(
3256 loader.linears.into_inner().unwrap(),
3257 vec![
3258 "model.layers.2.self_attn.qkv_proj",
3259 "model.layers.2.self_attn.o_proj",
3260 "model.layers.2.mlp.gate_up_proj",
3261 "model.layers.2.mlp.down_proj",
3262 "model.layers.3.self_attn.qkv_proj",
3263 "model.layers.3.self_attn.o_proj",
3264 "model.layers.3.mlp.gate_up_proj",
3265 "model.layers.3.mlp.down_proj",
3266 ]
3267 );
3268 }
3269
3270 #[test]
3271 fn llama_family_layer_loader_rejects_out_of_bounds_range() {
3272 let cfg = test_llama_config(2, false);
3273 let loader = RecordingLoader::default();
3274
3275 let err = match load_llama_family_layers(&cfg, &loader, 1..3) {
3276 Ok(_) => panic!("expected out-of-bounds layer range to fail"),
3277 Err(err) => err,
3278 };
3279
3280 assert!(
3281 err.to_string().contains("outside model layer count"),
3282 "{err}"
3283 );
3284 }
3285
3286 #[test]
3287 fn llama_family_full_model_records_source_layer_range() {
3288 let cfg = test_llama_config(3, false);
3289 let loader = RecordingLoader::default();
3290 let mut model = LlamaFamilyModel::<CpuBackend>::new(cfg, &loader).unwrap();
3291
3292 assert_eq!(model.source_layer_range(), 0..3);
3293 assert_eq!(model.local_layer_count(), 3);
3294 assert_eq!(model.source_layer_index(2), 2);
3295
3296 model.ensure_kv("test-cache");
3297 assert_eq!(
3298 model.kv_caches["test-cache"].len(),
3299 model.local_layer_count()
3300 );
3301 }
3302
3303 #[test]
3304 fn llama_family_layer_stage_loads_only_requested_weights() {
3305 let cfg = test_llama_config(5, false);
3306 let loader = RecordingLoader::default();
3307
3308 let model = LlamaFamilyModel::<CpuBackend>::new_layer_stage(
3309 cfg,
3310 &loader,
3311 LlamaFamilyLayerStageConfig::pipeline_stage(2..5, false, true),
3312 )
3313 .unwrap();
3314
3315 assert_eq!(model.source_layer_range(), 2..5);
3316 assert_eq!(model.local_layer_count(), 3);
3317 assert!(model.embed.is_none());
3318 assert!(model.lm_head.is_some());
3319 assert_eq!(
3320 loader.tensors.into_inner().unwrap(),
3321 vec![
3322 "model.layers.2.input_layernorm.weight",
3323 "model.layers.2.post_attention_layernorm.weight",
3324 "model.layers.3.input_layernorm.weight",
3325 "model.layers.3.post_attention_layernorm.weight",
3326 "model.layers.4.input_layernorm.weight",
3327 "model.layers.4.post_attention_layernorm.weight",
3328 "model.norm.weight",
3329 ]
3330 );
3331 assert_eq!(
3332 loader.linears.into_inner().unwrap(),
3333 vec![
3334 "model.layers.2.self_attn.qkv_proj",
3335 "model.layers.2.self_attn.o_proj",
3336 "model.layers.2.mlp.gate_up_proj",
3337 "model.layers.2.mlp.down_proj",
3338 "model.layers.3.self_attn.qkv_proj",
3339 "model.layers.3.self_attn.o_proj",
3340 "model.layers.3.mlp.gate_up_proj",
3341 "model.layers.3.mlp.down_proj",
3342 "model.layers.4.self_attn.qkv_proj",
3343 "model.layers.4.self_attn.o_proj",
3344 "model.layers.4.mlp.gate_up_proj",
3345 "model.layers.4.mlp.down_proj",
3346 "model.embed_tokens",
3347 ]
3348 );
3349 }
3350
3351 #[test]
3352 fn llama_family_layer_stage_runs_hidden_forward_bridge() {
3353 let cfg = test_llama_config(1, false);
3354 let loader = RecordingLoader::default();
3355 let mut model = LlamaFamilyModel::<CpuBackend>::new_layer_stage(
3356 cfg,
3357 &loader,
3358 LlamaFamilyLayerStageConfig::pipeline_stage(0..1, false, false),
3359 )
3360 .unwrap();
3361
3362 let hidden = model.prefill_stage_hidden_from_host("stage-cache", &[1.0], 1, 0);
3363
3364 assert_eq!(hidden.len(), 1);
3365 assert!(hidden[0].is_finite());
3366 assert_eq!(
3367 model.kv_caches["stage-cache"].len(),
3368 model.local_layer_count()
3369 );
3370 }
3371
3372 #[test]
3373 fn llama_family_last_stage_projects_hidden_to_logits() {
3374 let cfg = test_llama_config(1, false);
3375 let loader = RecordingLoader::default();
3376 let mut model = LlamaFamilyModel::<CpuBackend>::new_layer_stage(
3377 cfg,
3378 &loader,
3379 LlamaFamilyLayerStageConfig::pipeline_stage(0..1, false, true),
3380 )
3381 .unwrap();
3382
3383 let logits = model.logits_from_hidden(&[2.0]);
3384
3385 assert_eq!(logits.len(), 1);
3386 assert!(logits[0].is_finite());
3387 }
3388
3389 #[test]
3390 fn llama_family_runtime_env_parses_startup_knobs() {
3391 let env = LlamaFamilyRuntimeEnv::from_env_vars([
3392 ("FERRUM_KV_CAPACITY", "4096"),
3393 ("FERRUM_METAL_PAGED_KV", "0"),
3394 ("FERRUM_PAGED_MAX_SEQS", "64"),
3395 ("FERRUM_DECODE_OP_PROFILE", "0"),
3396 ("FERRUM_PREFILL_OP_PROFILE", ""),
3397 ("FERRUM_CUDA_GRAPH", ""),
3398 ("FERRUM_DECODE_LAYER_PROFILE", "false"),
3399 ]);
3400
3401 assert_eq!(env.kv_capacity, Some(4096));
3402 assert_eq!(env.metal_paged_kv, Some(false));
3403 assert_eq!(env.paged_max_seqs, 64);
3404 assert!(env.decode_op_profile);
3405 assert!(env.prefill_op_profile);
3406 assert!(env.cuda_graph);
3407 assert!(env.decode_layer_profile);
3408 assert_eq!(env.kv_capacity_for_model(2048), 2048);
3409 }
3410
3411 #[test]
3412 fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
3413 let env = LlamaFamilyRuntimeEnv::from_env_vars([
3414 ("FERRUM_KV_CAPACITY", "bad"),
3415 ("FERRUM_PAGED_MAX_SEQS", "bad"),
3416 ("FERRUM_METAL_PAGED_KV", "1"),
3417 ]);
3418
3419 assert_eq!(env.kv_capacity, None);
3420 assert_eq!(env.metal_paged_kv, Some(true));
3421 assert_eq!(env.paged_max_seqs, 32);
3422 assert_eq!(
3423 env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
3424 DEFAULT_KV_CAPACITY
3425 );
3426 }
3427}