1use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, OnceLock};
15
16use parking_lot::{Mutex, MutexGuard};
17use tracing::debug;
18
19use ferrum_interfaces::{
20 model_executor::{
21 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22 MemoryRequirements, PrefillInput, PrefillOutput, UnifiedBatch,
23 },
24 ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29use crate::lora::ActiveLoraAdapter;
30
31use super::common::{self, GenericKvCacheHandle};
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct LlmExecutorRuntimeEnv {
35 batch_prefill_prof: bool,
36 batch_decode_prof: bool,
37}
38
39impl LlmExecutorRuntimeEnv {
40 fn from_env() -> Self {
41 Self::from_env_vars(std::env::vars())
42 }
43
44 fn from_env_vars<I, K, V>(vars: I) -> Self
45 where
46 I: IntoIterator<Item = (K, V)>,
47 K: AsRef<str>,
48 {
49 let mut batch_prefill_prof = false;
50 let mut batch_decode_prof = false;
51
52 for (key, _) in vars {
53 match key.as_ref() {
54 "FERRUM_BATCH_PREFILL_PROF" => batch_prefill_prof = true,
55 "FERRUM_BATCH_DECODE_PROF" => batch_decode_prof = true,
56 _ => {}
57 }
58 }
59
60 Self {
61 batch_prefill_prof,
62 batch_decode_prof,
63 }
64 }
65}
66
67fn llm_executor_runtime_env() -> &'static LlmExecutorRuntimeEnv {
68 static CONFIG: OnceLock<LlmExecutorRuntimeEnv> = OnceLock::new();
69 CONFIG.get_or_init(LlmExecutorRuntimeEnv::from_env)
70}
71
72fn active_lora_from_metadata(
73 metadata: &std::collections::HashMap<String, serde_json::Value>,
74) -> Result<Option<ActiveLoraAdapter>> {
75 let name = metadata
76 .get("ferrum_lora_adapter")
77 .and_then(|value| value.as_str());
78 let path = metadata
79 .get("ferrum_lora_path")
80 .and_then(|value| value.as_str());
81 match (name, path) {
82 (Some(name), Some(path)) => Ok(Some(ActiveLoraAdapter {
83 name: name.to_string(),
84 path: std::path::PathBuf::from(path),
85 })),
86 (None, None) => Ok(None),
87 _ => Err(FerrumError::model(
88 "incomplete LoRA metadata: expected ferrum_lora_adapter and ferrum_lora_path",
89 )),
90 }
91}
92
93fn metadata_requires_full_logits(
94 metadata: &std::collections::HashMap<String, serde_json::Value>,
95) -> bool {
96 metadata
97 .get("ferrum_require_full_logits")
98 .and_then(|value| value.as_bool())
99 .unwrap_or(false)
100}
101
102fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
107 match d {
108 ferrum_types::Device::CPU => candle_core::Device::Cpu,
109 #[cfg(feature = "cuda")]
110 ferrum_types::Device::CUDA(i) => {
111 candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
112 }
113 #[cfg(not(feature = "cuda"))]
114 ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
115 #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
116 ferrum_types::Device::Metal => {
117 candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
118 }
119 _ => candle_core::Device::Cpu,
120 }
121}
122
123pub struct LlmExecutor {
124 model: Mutex<Box<dyn DecoderOnlyLLM>>,
125 info: ModelInfo,
126 next_cache_id: AtomicU64,
127 total_model_lock_wait_us: AtomicU64,
128 model_lock_wait_samples: AtomicU64,
129}
130
131impl LlmExecutor {
132 pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
133 Self {
134 model: Mutex::new(model),
135 info,
136 next_cache_id: AtomicU64::new(0),
137 total_model_lock_wait_us: AtomicU64::new(0),
138 model_lock_wait_samples: AtomicU64::new(0),
139 }
140 }
141
142 fn lock_model(&self) -> MutexGuard<'_, Box<dyn DecoderOnlyLLM>> {
143 let start = std::time::Instant::now();
144 let guard = self.model.lock();
145 self.record_model_lock_wait(start.elapsed());
146 guard
147 }
148
149 fn record_model_lock_wait(&self, duration: std::time::Duration) {
150 self.total_model_lock_wait_us.fetch_add(
151 duration.as_micros().min(u64::MAX as u128) as u64,
152 Ordering::Relaxed,
153 );
154 self.model_lock_wait_samples.fetch_add(1, Ordering::Relaxed);
155 }
156
157 fn model_lock_metrics_json(&self) -> serde_json::Value {
158 let samples = self.model_lock_wait_samples.load(Ordering::Relaxed);
159 let total_us = self.total_model_lock_wait_us.load(Ordering::Relaxed);
160 serde_json::json!({
161 "schema_version": 1,
162 "samples": samples,
163 "total_wait_time_us": total_us,
164 "avg_wait_time_ms": if samples == 0 {
165 0.0
166 } else {
167 total_us as f64 / samples as f64 / 1000.0
168 },
169 })
170 }
171
172 fn attach_model_lock_metrics(&self, mut snapshot: serde_json::Value) -> serde_json::Value {
173 let lock_metrics = self.model_lock_metrics_json();
174 if let Some(obj) = snapshot.as_object_mut() {
175 obj.insert("executor_model_lock".to_string(), lock_metrics);
176 snapshot
177 } else {
178 serde_json::json!({
179 "cache_metrics": snapshot,
180 "executor_model_lock": lock_metrics,
181 })
182 }
183 }
184
185 fn gen_cache_id(&self) -> String {
186 format!(
187 "llm-cache-{}",
188 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
189 )
190 }
191
192 pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
196 let mut model = self.lock_model();
197 model.truncate_kv(cache_id, new_len);
198 }
199}
200
201#[async_trait::async_trait]
202impl ModelExecutor for LlmExecutor {
203 fn info(&self) -> &ModelInfo {
204 &self.info
205 }
206
207 fn supports_native_unified_decode(&self) -> bool {
208 matches!(self.info.device, ferrum_types::Device::CUDA(_))
212 }
213
214 fn kv_capacity(&self) -> Option<usize> {
215 Some(self.lock_model().kv_capacity())
216 }
217
218 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
219 let tokens = common::tensor_to_tokens(&input.input_ids)?;
220
221 let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
226 h.as_any()
227 .downcast_ref::<GenericKvCacheHandle>()
228 .map(|g| g.request_cache_id().to_string())
229 });
230 let cache_id = supplied_handle_id
231 .clone()
232 .unwrap_or_else(|| self.gen_cache_id());
233
234 let prior_seq_len = input
237 .kv_cache
238 .as_ref()
239 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
240 .map(|g| {
241 use ferrum_interfaces::KvCacheHandle;
242 g.block_table().sequence_length
243 })
244 .unwrap_or(0);
245
246 let force_full_logits = metadata_requires_full_logits(&input.metadata);
251 let logits = {
252 let mut model = self.lock_model();
253 model.set_lora_adapter_for_cache(
254 &cache_id,
255 active_lora_from_metadata(&input.metadata)?,
256 )?;
257 if force_full_logits {
258 model.prefill(&cache_id, &tokens)
259 } else {
260 let unified_item = vec![(cache_id.clone(), tokens.clone(), prior_seq_len, true)];
261 match model.unified_forward(&unified_item) {
262 Ok(mut per_item) => per_item
263 .pop()
264 .flatten()
265 .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
266 Err(FerrumError::Unsupported { .. }) => model.prefill(&cache_id, &tokens),
267 Err(e) => return Err(e),
268 }
269 }
270 };
271
272 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
274 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
275 .unsqueeze(0)
276 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
277 .unsqueeze(0)
278 .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
279 let logits_ref = common::wrap_tensor(logits_tensor);
280
281 let cfg = self.lock_model().config().clone();
282 let seq_len = input
289 .kv_cache
290 .as_ref()
291 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
292 .map(|g| {
293 use ferrum_interfaces::KvCacheHandle;
294 g.block_table().sequence_length + tokens.len()
295 })
296 .unwrap_or(tokens.len());
297
298 let kv_handle = Arc::new(GenericKvCacheHandle::new(
299 cfg.num_layers,
300 cfg.num_kv_heads,
301 cfg.head_dim,
302 candle_core::Device::Cpu,
303 seq_len,
304 cache_id,
305 ));
306
307 Ok(PrefillOutput::new(logits_ref, kv_handle))
308 }
309
310 async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
317 if inputs.is_empty() {
318 return Ok(Vec::new());
319 }
320 let force_full_logits = inputs
321 .iter()
322 .any(|input| metadata_requires_full_logits(&input.metadata));
323
324 let mut cache_ids = Vec::with_capacity(inputs.len());
328 let mut prior_seq_lens = Vec::with_capacity(inputs.len());
329 let mut tokens_per_input = Vec::with_capacity(inputs.len());
330 let mut lora_per_input = Vec::with_capacity(inputs.len());
331 for input in inputs {
332 let tokens = common::tensor_to_tokens(&input.input_ids)?;
333 let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
334 h.as_any()
335 .downcast_ref::<GenericKvCacheHandle>()
336 .map(|g| g.request_cache_id().to_string())
337 });
338 let cache_id = supplied_handle_id
339 .clone()
340 .unwrap_or_else(|| self.gen_cache_id());
341 let prior_seq_len = input
342 .kv_cache
343 .as_ref()
344 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
345 .map(|g| {
346 use ferrum_interfaces::KvCacheHandle;
347 g.block_table().sequence_length
348 })
349 .unwrap_or(0);
350 cache_ids.push(cache_id);
351 prior_seq_lens.push(prior_seq_len);
352 tokens_per_input.push(tokens);
353 lora_per_input.push(active_lora_from_metadata(&input.metadata)?);
354 }
355
356 let unified_items: Vec<(String, Vec<u32>, usize, bool)> = cache_ids
359 .iter()
360 .zip(tokens_per_input.iter())
361 .zip(prior_seq_lens.iter())
362 .map(|((cid, toks), &prior)| (cid.clone(), toks.clone(), prior, true))
363 .collect();
364
365 let nb_prof = llm_executor_runtime_env().batch_prefill_prof;
366 let bp_t0 = if nb_prof {
367 Some(std::time::Instant::now())
368 } else {
369 None
370 };
371 let mut took_fallback = false;
372 let per_item_logits: Vec<Vec<f32>> = {
373 let mut model = self.lock_model();
374 for (cache_id, adapter) in cache_ids.iter().zip(lora_per_input.iter()) {
375 model.set_lora_adapter_for_cache(cache_id, adapter.clone())?;
376 }
377 if force_full_logits {
378 took_fallback = true;
379 let mut out = Vec::with_capacity(inputs.len());
380 for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
381 out.push(model.prefill(cid, toks));
382 }
383 out
384 } else {
385 match model.unified_forward(&unified_items) {
386 Ok(per_item) => per_item
387 .into_iter()
388 .map(|opt| opt.expect("is_final_chunk=true must yield logits"))
389 .collect(),
390 Err(FerrumError::Unsupported { .. }) => {
391 took_fallback = true;
392 let mut out = Vec::with_capacity(inputs.len());
393 for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
394 out.push(model.prefill(cid, toks));
395 }
396 out
397 }
398 Err(e) => return Err(e),
399 }
400 }
401 };
402 if let Some(t0) = bp_t0 {
403 let total_q: usize = unified_items.iter().map(|it| it.1.len()).sum();
404 eprintln!(
405 "[batch-prefill] n_items={} total_q={} fallback={} elapsed={}us",
406 inputs.len(),
407 total_q,
408 took_fallback,
409 t0.elapsed().as_micros()
410 );
411 }
412
413 let cfg = self.lock_model().config().clone();
414 let mut outputs = Vec::with_capacity(inputs.len());
415 for (i, logits) in per_item_logits.into_iter().enumerate() {
416 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
417 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
418 .unsqueeze(0)
419 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
420 .unsqueeze(0)
421 .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
422 let logits_ref = common::wrap_tensor(logits_tensor);
423 let seq_len = inputs[i]
424 .kv_cache
425 .as_ref()
426 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
427 .map(|g| {
428 use ferrum_interfaces::KvCacheHandle;
429 g.block_table().sequence_length + tokens_per_input[i].len()
430 })
431 .unwrap_or(tokens_per_input[i].len());
432 let kv_handle = Arc::new(GenericKvCacheHandle::new(
433 cfg.num_layers,
434 cfg.num_kv_heads,
435 cfg.head_dim,
436 candle_core::Device::Cpu,
437 seq_len,
438 cache_ids[i].clone(),
439 ));
440 outputs.push(PrefillOutput::new(logits_ref, kv_handle));
441 }
442 Ok(outputs)
443 }
444
445 async fn truncate_kv(
446 &self,
447 kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
448 new_len: usize,
449 ) -> Result<()> {
450 if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
451 let cache_id = g.request_cache_id();
452 self.lock_model().truncate_kv(cache_id, new_len);
453 }
454 Ok(())
455 }
456
457 async fn forward_verify(
458 &self,
459 inputs: &[ferrum_interfaces::model_executor::DecodeInput],
460 ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
461 if inputs.is_empty() {
462 return Ok(Vec::new());
463 }
464
465 let first_handle = inputs[0].kv_cache.clone();
468 let cache_id = first_handle
469 .as_any()
470 .downcast_ref::<GenericKvCacheHandle>()
471 .ok_or_else(|| {
472 FerrumError::model("forward_verify requires GenericKvCacheHandle input")
473 })?
474 .request_cache_id()
475 .to_string();
476 let start_seq = {
477 use ferrum_interfaces::KvCacheHandle;
478 first_handle.block_table().sequence_length
479 };
480
481 let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
483 for input in inputs {
484 let toks = common::tensor_to_tokens(&input.input_ids)?;
485 if toks.is_empty() {
486 return Err(FerrumError::model("forward_verify input token empty"));
487 }
488 token_ids.push(toks[0]);
489 }
490
491 let flat = {
493 let mut model = self.lock_model();
494 model.set_lora_adapter_for_cache(
495 &cache_id,
496 active_lora_from_metadata(&inputs[0].metadata)?,
497 )?;
498 model.forward_verify(&cache_id, &token_ids)
499 };
500
501 let cfg = self.lock_model().config().clone();
502 let vocab = cfg.vocab_size;
503
504 let candle_device = ferrum_device_to_candle(&self.info.device);
509
510 let mut outputs = Vec::with_capacity(inputs.len());
515 for (i, _) in inputs.iter().enumerate() {
516 let row = &flat[i * vocab..(i + 1) * vocab];
517 let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
518 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
519 .unsqueeze(0)
520 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
521 let logits_ref = common::wrap_tensor(logits_tensor);
522 let handle = Arc::new(GenericKvCacheHandle::new(
523 cfg.num_layers,
524 cfg.num_kv_heads,
525 cfg.head_dim,
526 candle_device.clone(),
527 start_seq + i + 1,
528 cache_id.clone(),
529 ));
530 outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
531 logits_ref, handle,
532 ));
533 }
534 Ok(outputs)
535 }
536
537 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
538 let input_handle = input
539 .kv_cache
540 .as_any()
541 .downcast_ref::<GenericKvCacheHandle>()
542 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
543
544 let cache_id = input_handle.request_cache_id().to_string();
545 let seq_len = {
546 use ferrum_interfaces::KvCacheHandle;
547 input_handle.block_table().sequence_length
548 };
549
550 let tokens = common::tensor_to_tokens(&input.input_ids)?;
551 if tokens.is_empty() {
552 return Err(FerrumError::model("Decode input is empty"));
553 }
554 let token = tokens[0];
555
556 debug!("LlmExecutor decode: token={token}, pos={seq_len}");
557
558 let force_full_logits = metadata_requires_full_logits(&input.metadata);
561 let logits = {
562 let mut model = self.lock_model();
563 model.set_lora_adapter_for_cache(
564 &cache_id,
565 active_lora_from_metadata(&input.metadata)?,
566 )?;
567 if force_full_logits {
568 model.decode(&cache_id, token, seq_len as u32)
569 } else {
570 let unified_item = vec![(cache_id.clone(), vec![token], seq_len, true)];
571 match model.unified_forward(&unified_item) {
572 Ok(mut per_item) => per_item
573 .pop()
574 .flatten()
575 .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
576 Err(FerrumError::Unsupported { .. }) => {
577 model.decode(&cache_id, token, seq_len as u32)
578 }
579 Err(e) => return Err(e),
580 }
581 }
582 };
583
584 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
585 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
586 .unsqueeze(0)
587 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
588 let logits_ref = common::wrap_tensor(logits_tensor);
589
590 let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
591 Ok(DecodeOutput::new(logits_ref, kv_handle))
592 }
593
594 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
599 if inputs.is_empty() {
600 return Ok(Vec::new());
601 }
602 let prof = llm_executor_runtime_env().batch_decode_prof;
603 let t0 = if prof {
604 Some(std::time::Instant::now())
605 } else {
606 None
607 };
608 struct Prep {
611 cache_id: String,
612 token: u32,
613 seq_len: u32,
614 lora: Option<ActiveLoraAdapter>,
615 requires_full_logits: bool,
616 handle: Arc<GenericKvCacheHandle>,
617 }
618 let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
619 for input in inputs {
620 let input_handle = input
621 .kv_cache
622 .as_any()
623 .downcast_ref::<GenericKvCacheHandle>()
624 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
625 use ferrum_interfaces::KvCacheHandle;
626 let seq_len = input_handle.block_table().sequence_length as u32;
627 let tokens = common::tensor_to_tokens(&input.input_ids)?;
628 if tokens.is_empty() {
629 return Err(FerrumError::model("Decode input is empty"));
630 }
631 prepped.push(Prep {
632 cache_id: input_handle.request_cache_id().to_string(),
633 token: tokens[0],
634 seq_len,
635 lora: active_lora_from_metadata(&input.metadata)?,
636 requires_full_logits: metadata_requires_full_logits(&input.metadata),
637 handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
638 });
639 }
640 let t_prep = if prof {
641 Some(std::time::Instant::now())
642 } else {
643 None
644 };
645
646 let (all_logits, t_lock_acq, t_model_call): (Vec<Vec<f32>>, _, _) = {
652 let lock_t0 = if prof {
653 Some(std::time::Instant::now())
654 } else {
655 None
656 };
657 let mut model = self.lock_model();
658 let lock_acq = lock_t0.map(|t| t.elapsed());
659 let model_t0 = if prof {
660 Some(std::time::Instant::now())
661 } else {
662 None
663 };
664 for p in &prepped {
665 model.set_lora_adapter_for_cache(&p.cache_id, p.lora.clone())?;
666 }
667 let unified_items: Vec<(String, Vec<u32>, usize, bool)> = prepped
668 .iter()
669 .map(|p| (p.cache_id.clone(), vec![p.token], p.seq_len as usize, true))
670 .collect();
671 let tuples: Vec<(String, u32, u32)> = prepped
672 .iter()
673 .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
674 .collect();
675 let force_full_logits = prepped.iter().any(|p| p.requires_full_logits);
676 let logits = if force_full_logits {
677 model.decode_batch_with_full_logits(&tuples, true)
678 } else {
679 match model.unified_forward(&unified_items) {
680 Ok(per_item) => {
681 if per_item.len() != prepped.len() {
682 return Err(FerrumError::model(format!(
683 "unified_forward returned {} entries for {} items",
684 per_item.len(),
685 prepped.len(),
686 )));
687 }
688 let mut out = Vec::with_capacity(prepped.len());
689 for (i, opt) in per_item.into_iter().enumerate() {
690 out.push(opt.ok_or_else(|| {
691 FerrumError::model(format!(
692 "unified_forward returned None for decode item {i}"
693 ))
694 })?);
695 }
696 out
697 }
698 Err(FerrumError::Unsupported { .. }) => {
699 model.decode_batch_with_full_logits(&tuples, false)
700 }
701 Err(e) => return Err(e),
702 }
703 };
704 let model_call = model_t0.map(|t| t.elapsed());
705 (logits, lock_acq, model_call)
706 };
707 let t_model_done = if prof {
708 Some(std::time::Instant::now())
709 } else {
710 None
711 };
712
713 let m_count = prepped.len();
714 let mut outputs = Vec::with_capacity(m_count);
715 for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
716 debug!(
717 "LlmExecutor batch_decode: token={}, pos={}",
718 p.token, p.seq_len
719 );
720 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
721 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
722 .unsqueeze(0)
723 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
724 let logits_ref = common::wrap_tensor(logits_tensor);
725 outputs.push(DecodeOutput::new(logits_ref, p.handle));
726 }
727 if let (Some(t0), Some(tp), Some(tm)) = (t0, t_prep, t_model_done) {
728 static EX_PROF_CALLS: std::sync::atomic::AtomicU64 =
729 std::sync::atomic::AtomicU64::new(0);
730 let n = EX_PROF_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
731 if n.is_multiple_of(8) {
732 let total = t0.elapsed().as_micros();
733 let prep = tp.duration_since(t0).as_micros();
734 let lock_acq = t_lock_acq.map(|d| d.as_micros()).unwrap_or(0);
735 let model_call = t_model_call.map(|d| d.as_micros()).unwrap_or(0);
736 let model_block = tm.duration_since(tp).as_micros();
737 let wrap = tm.elapsed().as_micros();
738 eprintln!(
739 "[exec-batch-decode-prof] call#{} m={} total={}us prep={}us model_block={}us(lock_acq={}us model_call={}us) wrap={}us",
740 n, m_count, total, prep, model_block, lock_acq, model_call, wrap,
741 );
742 }
743 }
744 Ok(outputs)
745 }
746
747 async fn unified_decode(&self, batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
764 let mut results: Vec<Option<Vec<f32>>> = vec![None; batch.items.len()];
765 if batch.items.is_empty() {
766 return Ok(results);
767 }
768
769 let unified_items: Vec<(String, Vec<u32>, usize, bool)> = batch
776 .items
777 .iter()
778 .map(|it| {
779 (
780 it.seq_id.clone(),
781 it.q_tokens.clone(),
782 it.pos_offset,
783 it.is_final_chunk,
784 )
785 })
786 .collect();
787 let force_full_logits = batch
788 .items
789 .iter()
790 .any(|item| metadata_requires_full_logits(&item.metadata));
791 if !force_full_logits {
792 let model_result = {
793 let mut model = self.lock_model();
794 for item in &batch.items {
795 model.set_lora_adapter_for_cache(
796 &item.seq_id,
797 active_lora_from_metadata(&item.metadata)?,
798 )?;
799 }
800 model.unified_forward(&unified_items)
801 };
802 match model_result {
803 Ok(per_item) => {
804 if per_item.len() != batch.items.len() {
805 return Err(FerrumError::model(format!(
806 "unified_forward returned {} entries for {} items",
807 per_item.len(),
808 batch.items.len(),
809 )));
810 }
811 return Ok(per_item);
812 }
813 Err(FerrumError::Unsupported { .. }) => {
814 }
816 Err(e) => return Err(e),
817 }
818 }
819
820 let mut prefill_indices: Vec<usize> = Vec::new();
826 let mut decode_indices: Vec<usize> = Vec::new();
827 for (i, item) in batch.items.iter().enumerate() {
828 if item.q_tokens.len() == 1 && item.is_final_chunk {
829 decode_indices.push(i);
830 } else {
831 prefill_indices.push(i);
832 }
833 }
834
835 if !prefill_indices.is_empty() {
840 let mut model = self.lock_model();
841 for &i in &prefill_indices {
842 let item = &batch.items[i];
843 model.set_lora_adapter_for_cache(
844 &item.seq_id,
845 active_lora_from_metadata(&item.metadata)?,
846 )?;
847 let logits = model.prefill(&item.seq_id, &item.q_tokens);
848 if item.is_final_chunk {
849 results[i] = Some(logits);
850 }
851 }
852 }
853
854 if !decode_indices.is_empty() {
856 let tuples: Vec<(String, u32, u32)> = decode_indices
857 .iter()
858 .map(|&i| {
859 let it = &batch.items[i];
860 (it.seq_id.clone(), it.q_tokens[0], it.pos_offset as u32)
861 })
862 .collect();
863 let logits_vec = {
864 let mut model = self.lock_model();
865 for &i in &decode_indices {
866 let item = &batch.items[i];
867 model.set_lora_adapter_for_cache(
868 &item.seq_id,
869 active_lora_from_metadata(&item.metadata)?,
870 )?;
871 }
872 let force_full_logits = decode_indices
873 .iter()
874 .any(|&i| metadata_requires_full_logits(&batch.items[i].metadata));
875 model.decode_batch_with_full_logits(&tuples, force_full_logits)
876 };
877 for (j, &i) in decode_indices.iter().enumerate() {
878 results[i] = Some(logits_vec[j].clone());
879 }
880 }
881
882 Ok(results)
883 }
884
885 fn release_cache(&self, cache_id: &str) {
886 self.lock_model().release(cache_id);
887 }
888
889 fn capabilities(&self) -> ExecutorCapabilities {
890 let cfg = self.lock_model().config().clone();
891 ExecutorCapabilities {
892 max_batch_size: 256,
893 max_sequence_length: cfg.max_seq_len,
894 attention_mechanisms: vec![AttentionType::GroupedQuery],
895 supports_dynamic_batching: true,
896 supports_continuous_batching: true,
897 supports_speculative_decoding: false,
898 supports_tensor_parallelism: false,
899 supports_pipeline_parallelism: false,
900 supported_dtypes: vec![DataType::FP32],
901 supported_devices: vec![self.info.device.clone()],
902 memory_requirements: MemoryRequirements {
903 parameter_memory: (self.info.num_parameters * 4) as u64,
904 activation_memory_per_token: cfg.hidden_size * 4,
905 kv_cache_memory_per_token: cfg.hidden_size * 2,
906 overhead_memory: 256 * 1024 * 1024,
907 },
908 }
909 }
910
911 fn status(&self) -> ExecutorStatus {
912 common::default_executor_status()
913 }
914
915 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
916 let snapshot = self.lock_model().cache_metrics_snapshot()?;
917 Some(self.attach_model_lock_metrics(snapshot))
918 }
919
920 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
921 self.lock_model().lora_metrics_snapshot()
922 }
923}
924
925#[cfg(test)]
926mod tests {
927 use super::*;
928 use std::collections::HashMap;
929
930 use ferrum_interfaces::model_executor::{DecodeInput, PrefillInput, UnifiedBatchItem};
931 use ferrum_interfaces::KvCacheHandle;
932 use ferrum_testkit::MockTensor;
933 use ferrum_types::{Device, ModelId, ModelType};
934
935 #[derive(Default)]
936 struct RecordingCalls {
937 unified_forward: usize,
938 prefill: usize,
939 decode: usize,
940 decode_batch_force_full_logits: Vec<bool>,
941 }
942
943 struct RecordingLlm {
944 calls: Arc<Mutex<RecordingCalls>>,
945 config: crate::common::LlmRuntimeConfig,
946 }
947
948 impl RecordingLlm {
949 fn new(calls: Arc<Mutex<RecordingCalls>>) -> Self {
950 Self {
951 calls,
952 config: crate::common::LlmRuntimeConfig {
953 hidden_size: 4,
954 num_layers: 1,
955 num_kv_heads: 1,
956 head_dim: 4,
957 vocab_size: 4,
958 max_seq_len: 16,
959 },
960 }
961 }
962 }
963
964 impl DecoderOnlyLLM for RecordingLlm {
965 fn config(&self) -> &crate::common::LlmRuntimeConfig {
966 &self.config
967 }
968
969 fn prefill(&mut self, _cache_id: &str, _tokens: &[u32]) -> Vec<f32> {
970 self.calls.lock().prefill += 1;
971 vec![0.0, 1.0, 2.0, 3.0]
972 }
973
974 fn decode(&mut self, _cache_id: &str, _token: u32, _pos: u32) -> Vec<f32> {
975 self.calls.lock().decode += 1;
976 vec![3.0, 2.0, 1.0, 0.0]
977 }
978
979 fn decode_batch_with_full_logits(
980 &mut self,
981 batch: &[(String, u32, u32)],
982 force_full_logits: bool,
983 ) -> Vec<Vec<f32>> {
984 self.calls
985 .lock()
986 .decode_batch_force_full_logits
987 .push(force_full_logits);
988 batch.iter().map(|_| vec![3.0, 2.0, 1.0, 0.0]).collect()
989 }
990
991 fn unified_forward(
992 &mut self,
993 items: &[(String, Vec<u32>, usize, bool)],
994 ) -> std::result::Result<Vec<Option<Vec<f32>>>, FerrumError> {
995 self.calls.lock().unified_forward += 1;
996 Ok(items
997 .iter()
998 .map(|(_, _, _, is_final_chunk)| is_final_chunk.then_some(vec![99.0]))
999 .collect())
1000 }
1001
1002 fn release(&mut self, _cache_id: &str) {}
1003
1004 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
1005 Some(serde_json::json!({
1006 "position": "recording-test-cache",
1007 }))
1008 }
1009 }
1010
1011 fn test_model_info() -> ModelInfo {
1012 ModelInfo {
1013 model_id: ModelId("recording".to_string()),
1014 model_type: ModelType::Custom("recording".to_string()),
1015 num_parameters: 0,
1016 hidden_size: 4,
1017 num_layers: 1,
1018 num_heads: 1,
1019 num_kv_heads: 1,
1020 vocab_size: 4,
1021 max_sequence_length: 16,
1022 dtype: DataType::FP32,
1023 device: Device::CPU,
1024 version: None,
1025 license: None,
1026 metadata: HashMap::new(),
1027 }
1028 }
1029
1030 fn recording_executor(calls: Arc<Mutex<RecordingCalls>>) -> LlmExecutor {
1031 LlmExecutor::new(Box::new(RecordingLlm::new(calls)), test_model_info())
1032 }
1033
1034 fn full_logits_metadata() -> HashMap<String, serde_json::Value> {
1035 HashMap::from([(
1036 "ferrum_require_full_logits".to_string(),
1037 serde_json::json!(true),
1038 )])
1039 }
1040
1041 fn test_kv_handle(cache_id: &str, seq_len: usize) -> Arc<dyn KvCacheHandle> {
1042 Arc::new(GenericKvCacheHandle::new(
1043 1,
1044 1,
1045 4,
1046 candle_core::Device::Cpu,
1047 seq_len,
1048 cache_id.to_string(),
1049 ))
1050 }
1051
1052 #[test]
1053 fn llm_executor_runtime_env_parses_profile_flags_by_presence() {
1054 let env = LlmExecutorRuntimeEnv::from_env_vars([
1055 ("FERRUM_BATCH_PREFILL_PROF", ""),
1056 ("FERRUM_BATCH_DECODE_PROF", "0"),
1057 ]);
1058
1059 assert!(env.batch_prefill_prof);
1060 assert!(env.batch_decode_prof);
1061 }
1062
1063 #[test]
1064 fn llm_executor_runtime_env_defaults_profile_flags_off() {
1065 let env = LlmExecutorRuntimeEnv::from_env_vars([("UNRELATED", "1")]);
1066
1067 assert!(!env.batch_prefill_prof);
1068 assert!(!env.batch_decode_prof);
1069 }
1070
1071 #[test]
1072 fn prefill_skips_unified_forward_when_full_logits_required() {
1073 let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1074 let executor = recording_executor(calls.clone());
1075 let input = PrefillInput::new(MockTensor::from_u32(&[1, 2], &[2]).into_ref())
1076 .with_metadata(full_logits_metadata());
1077
1078 let output = tokio_test::block_on(executor.prefill(&input)).unwrap();
1079
1080 assert_eq!(
1081 output
1082 .last_token_logits()
1083 .unwrap()
1084 .to_vec_f32()
1085 .unwrap()
1086 .len(),
1087 4
1088 );
1089 let calls = calls.lock();
1090 assert_eq!(calls.unified_forward, 0);
1091 assert_eq!(calls.prefill, 1);
1092 }
1093
1094 #[test]
1095 fn decode_skips_unified_forward_when_full_logits_required() {
1096 let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1097 let executor = recording_executor(calls.clone());
1098 let input = DecodeInput::new(
1099 MockTensor::from_u32(&[7], &[1]).into_ref(),
1100 test_kv_handle("decode-cache", 3),
1101 )
1102 .with_metadata(full_logits_metadata());
1103
1104 let output = tokio_test::block_on(executor.decode(&input)).unwrap();
1105
1106 assert_eq!(output.logits.to_vec_f32().unwrap().len(), 4);
1107 let calls = calls.lock();
1108 assert_eq!(calls.unified_forward, 0);
1109 assert_eq!(calls.decode, 1);
1110 }
1111
1112 #[test]
1113 fn unified_decode_skips_unified_forward_when_full_logits_required() {
1114 let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1115 let executor = recording_executor(calls.clone());
1116 let mut batch = UnifiedBatch::new();
1117 batch.items.push(UnifiedBatchItem {
1118 seq_id: "decode-cache".to_string(),
1119 q_tokens: vec![7],
1120 kv_cache: test_kv_handle("decode-cache", 3),
1121 pos_offset: 3,
1122 is_final_chunk: true,
1123 metadata: full_logits_metadata(),
1124 });
1125
1126 let output = tokio_test::block_on(executor.unified_decode(&batch)).unwrap();
1127
1128 assert_eq!(output[0].as_ref().unwrap().len(), 4);
1129 let calls = calls.lock();
1130 assert_eq!(calls.unified_forward, 0);
1131 assert_eq!(calls.decode_batch_force_full_logits, vec![true]);
1132 }
1133
1134 #[test]
1135 fn cache_metrics_snapshot_includes_model_lock_wait_metrics() {
1136 let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1137 let executor = recording_executor(calls);
1138
1139 assert_eq!(executor.kv_capacity(), Some(16));
1140 let metrics = executor.cache_metrics_snapshot().unwrap();
1141
1142 assert_eq!(metrics["position"], "recording-test-cache");
1143 assert_eq!(metrics["executor_model_lock"]["schema_version"], 1);
1144 assert!(
1145 metrics["executor_model_lock"]["samples"].as_u64().unwrap() >= 2,
1146 "metrics: {metrics}"
1147 );
1148 assert!(
1149 metrics["executor_model_lock"]["total_wait_time_us"]
1150 .as_u64()
1151 .is_some(),
1152 "metrics: {metrics}"
1153 );
1154 assert!(
1155 metrics["executor_model_lock"]["avg_wait_time_ms"].is_number(),
1156 "metrics: {metrics}"
1157 );
1158 }
1159}