1use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, OnceLock};
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20 model_executor::{
21 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22 MemoryRequirements, PrefillInput, PrefillOutput, UnifiedBatch,
23 },
24 ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29use crate::lora::ActiveLoraAdapter;
30
31use super::common::{self, GenericKvCacheHandle};
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct LlmExecutorRuntimeEnv {
35 batch_prefill_prof: bool,
36 batch_decode_prof: bool,
37}
38
39impl LlmExecutorRuntimeEnv {
40 fn from_env() -> Self {
41 Self::from_env_vars(std::env::vars())
42 }
43
44 fn from_env_vars<I, K, V>(vars: I) -> Self
45 where
46 I: IntoIterator<Item = (K, V)>,
47 K: AsRef<str>,
48 {
49 let mut batch_prefill_prof = false;
50 let mut batch_decode_prof = false;
51
52 for (key, _) in vars {
53 match key.as_ref() {
54 "FERRUM_BATCH_PREFILL_PROF" => batch_prefill_prof = true,
55 "FERRUM_BATCH_DECODE_PROF" => batch_decode_prof = true,
56 _ => {}
57 }
58 }
59
60 Self {
61 batch_prefill_prof,
62 batch_decode_prof,
63 }
64 }
65}
66
67fn llm_executor_runtime_env() -> &'static LlmExecutorRuntimeEnv {
68 static CONFIG: OnceLock<LlmExecutorRuntimeEnv> = OnceLock::new();
69 CONFIG.get_or_init(LlmExecutorRuntimeEnv::from_env)
70}
71
72fn active_lora_from_metadata(
73 metadata: &std::collections::HashMap<String, serde_json::Value>,
74) -> Result<Option<ActiveLoraAdapter>> {
75 let name = metadata
76 .get("ferrum_lora_adapter")
77 .and_then(|value| value.as_str());
78 let path = metadata
79 .get("ferrum_lora_path")
80 .and_then(|value| value.as_str());
81 match (name, path) {
82 (Some(name), Some(path)) => Ok(Some(ActiveLoraAdapter {
83 name: name.to_string(),
84 path: std::path::PathBuf::from(path),
85 })),
86 (None, None) => Ok(None),
87 _ => Err(FerrumError::model(
88 "incomplete LoRA metadata: expected ferrum_lora_adapter and ferrum_lora_path",
89 )),
90 }
91}
92
93fn metadata_requires_full_logits(
94 metadata: &std::collections::HashMap<String, serde_json::Value>,
95) -> bool {
96 metadata
97 .get("ferrum_require_full_logits")
98 .and_then(|value| value.as_bool())
99 .unwrap_or(false)
100}
101
102fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
107 match d {
108 ferrum_types::Device::CPU => candle_core::Device::Cpu,
109 #[cfg(feature = "cuda")]
110 ferrum_types::Device::CUDA(i) => {
111 candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
112 }
113 #[cfg(not(feature = "cuda"))]
114 ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
115 #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
116 ferrum_types::Device::Metal => {
117 candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
118 }
119 _ => candle_core::Device::Cpu,
120 }
121}
122
123pub struct LlmExecutor {
124 model: Mutex<Box<dyn DecoderOnlyLLM>>,
125 info: ModelInfo,
126 next_cache_id: AtomicU64,
127}
128
129impl LlmExecutor {
130 pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
131 Self {
132 model: Mutex::new(model),
133 info,
134 next_cache_id: AtomicU64::new(0),
135 }
136 }
137
138 fn gen_cache_id(&self) -> String {
139 format!(
140 "llm-cache-{}",
141 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
142 )
143 }
144
145 pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
149 let mut model = self.model.lock();
150 model.truncate_kv(cache_id, new_len);
151 }
152}
153
154#[async_trait::async_trait]
155impl ModelExecutor for LlmExecutor {
156 fn info(&self) -> &ModelInfo {
157 &self.info
158 }
159
160 fn kv_capacity(&self) -> Option<usize> {
161 Some(self.model.lock().kv_capacity())
162 }
163
164 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
165 let tokens = common::tensor_to_tokens(&input.input_ids)?;
166
167 let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
172 h.as_any()
173 .downcast_ref::<GenericKvCacheHandle>()
174 .map(|g| g.request_cache_id().to_string())
175 });
176 let cache_id = supplied_handle_id
177 .clone()
178 .unwrap_or_else(|| self.gen_cache_id());
179
180 let prior_seq_len = input
183 .kv_cache
184 .as_ref()
185 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
186 .map(|g| {
187 use ferrum_interfaces::KvCacheHandle;
188 g.block_table().sequence_length
189 })
190 .unwrap_or(0);
191
192 let logits = {
198 let mut model = self.model.lock();
199 model.set_lora_adapter_for_cache(
200 &cache_id,
201 active_lora_from_metadata(&input.metadata)?,
202 )?;
203 let unified_item = vec![(cache_id.clone(), tokens.clone(), prior_seq_len, true)];
204 match model.unified_forward(&unified_item) {
205 Ok(mut per_item) => per_item
206 .pop()
207 .flatten()
208 .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
209 Err(FerrumError::Unsupported { .. }) => model.prefill(&cache_id, &tokens),
210 Err(e) => return Err(e),
211 }
212 };
213
214 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
216 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
217 .unsqueeze(0)
218 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
219 .unsqueeze(0)
220 .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
221 let logits_ref = common::wrap_tensor(logits_tensor);
222
223 let cfg = self.model.lock().config().clone();
224 let seq_len = input
231 .kv_cache
232 .as_ref()
233 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
234 .map(|g| {
235 use ferrum_interfaces::KvCacheHandle;
236 g.block_table().sequence_length + tokens.len()
237 })
238 .unwrap_or(tokens.len());
239
240 let kv_handle = Arc::new(GenericKvCacheHandle::new(
241 cfg.num_layers,
242 cfg.num_kv_heads,
243 cfg.head_dim,
244 candle_core::Device::Cpu,
245 seq_len,
246 cache_id,
247 ));
248
249 Ok(PrefillOutput::new(logits_ref, kv_handle))
250 }
251
252 async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
259 if inputs.is_empty() {
260 return Ok(Vec::new());
261 }
262
263 let mut cache_ids = Vec::with_capacity(inputs.len());
267 let mut prior_seq_lens = Vec::with_capacity(inputs.len());
268 let mut tokens_per_input = Vec::with_capacity(inputs.len());
269 let mut lora_per_input = Vec::with_capacity(inputs.len());
270 for input in inputs {
271 let tokens = common::tensor_to_tokens(&input.input_ids)?;
272 let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
273 h.as_any()
274 .downcast_ref::<GenericKvCacheHandle>()
275 .map(|g| g.request_cache_id().to_string())
276 });
277 let cache_id = supplied_handle_id
278 .clone()
279 .unwrap_or_else(|| self.gen_cache_id());
280 let prior_seq_len = input
281 .kv_cache
282 .as_ref()
283 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
284 .map(|g| {
285 use ferrum_interfaces::KvCacheHandle;
286 g.block_table().sequence_length
287 })
288 .unwrap_or(0);
289 cache_ids.push(cache_id);
290 prior_seq_lens.push(prior_seq_len);
291 tokens_per_input.push(tokens);
292 lora_per_input.push(active_lora_from_metadata(&input.metadata)?);
293 }
294
295 let unified_items: Vec<(String, Vec<u32>, usize, bool)> = cache_ids
298 .iter()
299 .zip(tokens_per_input.iter())
300 .zip(prior_seq_lens.iter())
301 .map(|((cid, toks), &prior)| (cid.clone(), toks.clone(), prior, true))
302 .collect();
303
304 let nb_prof = llm_executor_runtime_env().batch_prefill_prof;
305 let bp_t0 = if nb_prof {
306 Some(std::time::Instant::now())
307 } else {
308 None
309 };
310 let mut took_fallback = false;
311 let per_item_logits: Vec<Vec<f32>> = {
312 let mut model = self.model.lock();
313 for (cache_id, adapter) in cache_ids.iter().zip(lora_per_input.iter()) {
314 model.set_lora_adapter_for_cache(cache_id, adapter.clone())?;
315 }
316 match model.unified_forward(&unified_items) {
317 Ok(per_item) => per_item
318 .into_iter()
319 .map(|opt| opt.expect("is_final_chunk=true must yield logits"))
320 .collect(),
321 Err(FerrumError::Unsupported { .. }) => {
322 took_fallback = true;
323 let mut out = Vec::with_capacity(inputs.len());
324 for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
325 out.push(model.prefill(cid, toks));
326 }
327 out
328 }
329 Err(e) => return Err(e),
330 }
331 };
332 if let Some(t0) = bp_t0 {
333 let total_q: usize = unified_items.iter().map(|it| it.1.len()).sum();
334 eprintln!(
335 "[batch-prefill] n_items={} total_q={} fallback={} elapsed={}us",
336 inputs.len(),
337 total_q,
338 took_fallback,
339 t0.elapsed().as_micros()
340 );
341 }
342
343 let cfg = self.model.lock().config().clone();
344 let mut outputs = Vec::with_capacity(inputs.len());
345 for (i, logits) in per_item_logits.into_iter().enumerate() {
346 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
347 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
348 .unsqueeze(0)
349 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
350 .unsqueeze(0)
351 .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
352 let logits_ref = common::wrap_tensor(logits_tensor);
353 let seq_len = inputs[i]
354 .kv_cache
355 .as_ref()
356 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
357 .map(|g| {
358 use ferrum_interfaces::KvCacheHandle;
359 g.block_table().sequence_length + tokens_per_input[i].len()
360 })
361 .unwrap_or(tokens_per_input[i].len());
362 let kv_handle = Arc::new(GenericKvCacheHandle::new(
363 cfg.num_layers,
364 cfg.num_kv_heads,
365 cfg.head_dim,
366 candle_core::Device::Cpu,
367 seq_len,
368 cache_ids[i].clone(),
369 ));
370 outputs.push(PrefillOutput::new(logits_ref, kv_handle));
371 }
372 Ok(outputs)
373 }
374
375 async fn truncate_kv(
376 &self,
377 kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
378 new_len: usize,
379 ) -> Result<()> {
380 if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
381 let cache_id = g.request_cache_id();
382 self.model.lock().truncate_kv(cache_id, new_len);
383 }
384 Ok(())
385 }
386
387 async fn forward_verify(
388 &self,
389 inputs: &[ferrum_interfaces::model_executor::DecodeInput],
390 ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
391 if inputs.is_empty() {
392 return Ok(Vec::new());
393 }
394
395 let first_handle = inputs[0].kv_cache.clone();
398 let cache_id = first_handle
399 .as_any()
400 .downcast_ref::<GenericKvCacheHandle>()
401 .ok_or_else(|| {
402 FerrumError::model("forward_verify requires GenericKvCacheHandle input")
403 })?
404 .request_cache_id()
405 .to_string();
406 let start_seq = {
407 use ferrum_interfaces::KvCacheHandle;
408 first_handle.block_table().sequence_length
409 };
410
411 let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
413 for input in inputs {
414 let toks = common::tensor_to_tokens(&input.input_ids)?;
415 if toks.is_empty() {
416 return Err(FerrumError::model("forward_verify input token empty"));
417 }
418 token_ids.push(toks[0]);
419 }
420
421 let flat = {
423 let mut model = self.model.lock();
424 model.set_lora_adapter_for_cache(
425 &cache_id,
426 active_lora_from_metadata(&inputs[0].metadata)?,
427 )?;
428 model.forward_verify(&cache_id, &token_ids)
429 };
430
431 let cfg = self.model.lock().config().clone();
432 let vocab = cfg.vocab_size;
433
434 let candle_device = ferrum_device_to_candle(&self.info.device);
439
440 let mut outputs = Vec::with_capacity(inputs.len());
445 for (i, _) in inputs.iter().enumerate() {
446 let row = &flat[i * vocab..(i + 1) * vocab];
447 let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
448 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
449 .unsqueeze(0)
450 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
451 let logits_ref = common::wrap_tensor(logits_tensor);
452 let handle = Arc::new(GenericKvCacheHandle::new(
453 cfg.num_layers,
454 cfg.num_kv_heads,
455 cfg.head_dim,
456 candle_device.clone(),
457 start_seq + i + 1,
458 cache_id.clone(),
459 ));
460 outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
461 logits_ref, handle,
462 ));
463 }
464 Ok(outputs)
465 }
466
467 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
468 let input_handle = input
469 .kv_cache
470 .as_any()
471 .downcast_ref::<GenericKvCacheHandle>()
472 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
473
474 let cache_id = input_handle.request_cache_id().to_string();
475 let seq_len = {
476 use ferrum_interfaces::KvCacheHandle;
477 input_handle.block_table().sequence_length
478 };
479
480 let tokens = common::tensor_to_tokens(&input.input_ids)?;
481 if tokens.is_empty() {
482 return Err(FerrumError::model("Decode input is empty"));
483 }
484 let token = tokens[0];
485
486 debug!("LlmExecutor decode: token={token}, pos={seq_len}");
487
488 let logits = {
493 let mut model = self.model.lock();
494 model.set_lora_adapter_for_cache(
495 &cache_id,
496 active_lora_from_metadata(&input.metadata)?,
497 )?;
498 let unified_item = vec![(cache_id.clone(), vec![token], seq_len, true)];
499 match model.unified_forward(&unified_item) {
500 Ok(mut per_item) => per_item
501 .pop()
502 .flatten()
503 .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
504 Err(FerrumError::Unsupported { .. }) => {
505 model.decode(&cache_id, token, seq_len as u32)
506 }
507 Err(e) => return Err(e),
508 }
509 };
510
511 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
512 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
513 .unsqueeze(0)
514 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
515 let logits_ref = common::wrap_tensor(logits_tensor);
516
517 let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
518 Ok(DecodeOutput::new(logits_ref, kv_handle))
519 }
520
521 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
526 if inputs.is_empty() {
527 return Ok(Vec::new());
528 }
529 let prof = llm_executor_runtime_env().batch_decode_prof;
530 let t0 = if prof {
531 Some(std::time::Instant::now())
532 } else {
533 None
534 };
535 struct Prep {
538 cache_id: String,
539 token: u32,
540 seq_len: u32,
541 lora: Option<ActiveLoraAdapter>,
542 requires_full_logits: bool,
543 handle: Arc<GenericKvCacheHandle>,
544 }
545 let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
546 for input in inputs {
547 let input_handle = input
548 .kv_cache
549 .as_any()
550 .downcast_ref::<GenericKvCacheHandle>()
551 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
552 use ferrum_interfaces::KvCacheHandle;
553 let seq_len = input_handle.block_table().sequence_length as u32;
554 let tokens = common::tensor_to_tokens(&input.input_ids)?;
555 if tokens.is_empty() {
556 return Err(FerrumError::model("Decode input is empty"));
557 }
558 prepped.push(Prep {
559 cache_id: input_handle.request_cache_id().to_string(),
560 token: tokens[0],
561 seq_len,
562 lora: active_lora_from_metadata(&input.metadata)?,
563 requires_full_logits: metadata_requires_full_logits(&input.metadata),
564 handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
565 });
566 }
567 let t_prep = if prof {
568 Some(std::time::Instant::now())
569 } else {
570 None
571 };
572
573 let (all_logits, t_lock_acq, t_model_call): (Vec<Vec<f32>>, _, _) = {
579 let lock_t0 = if prof {
580 Some(std::time::Instant::now())
581 } else {
582 None
583 };
584 let mut model = self.model.lock();
585 let lock_acq = lock_t0.map(|t| t.elapsed());
586 let model_t0 = if prof {
587 Some(std::time::Instant::now())
588 } else {
589 None
590 };
591 for p in &prepped {
592 model.set_lora_adapter_for_cache(&p.cache_id, p.lora.clone())?;
593 }
594 let unified_items: Vec<(String, Vec<u32>, usize, bool)> = prepped
595 .iter()
596 .map(|p| (p.cache_id.clone(), vec![p.token], p.seq_len as usize, true))
597 .collect();
598 let logits = match model.unified_forward(&unified_items) {
599 Ok(per_item) => {
600 if per_item.len() != prepped.len() {
601 return Err(FerrumError::model(format!(
602 "unified_forward returned {} entries for {} items",
603 per_item.len(),
604 prepped.len(),
605 )));
606 }
607 let mut out = Vec::with_capacity(prepped.len());
608 for (i, opt) in per_item.into_iter().enumerate() {
609 out.push(opt.ok_or_else(|| {
610 FerrumError::model(format!(
611 "unified_forward returned None for decode item {i}"
612 ))
613 })?);
614 }
615 out
616 }
617 Err(FerrumError::Unsupported { .. }) => {
618 let tuples: Vec<(String, u32, u32)> = prepped
619 .iter()
620 .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
621 .collect();
622 let force_full_logits = prepped.iter().any(|p| p.requires_full_logits);
623 model.decode_batch_with_full_logits(&tuples, force_full_logits)
624 }
625 Err(e) => return Err(e),
626 };
627 let model_call = model_t0.map(|t| t.elapsed());
628 (logits, lock_acq, model_call)
629 };
630 let t_model_done = if prof {
631 Some(std::time::Instant::now())
632 } else {
633 None
634 };
635
636 let m_count = prepped.len();
637 let mut outputs = Vec::with_capacity(m_count);
638 for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
639 debug!(
640 "LlmExecutor batch_decode: token={}, pos={}",
641 p.token, p.seq_len
642 );
643 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
644 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
645 .unsqueeze(0)
646 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
647 let logits_ref = common::wrap_tensor(logits_tensor);
648 outputs.push(DecodeOutput::new(logits_ref, p.handle));
649 }
650 if let (Some(t0), Some(tp), Some(tm)) = (t0, t_prep, t_model_done) {
651 static EX_PROF_CALLS: std::sync::atomic::AtomicU64 =
652 std::sync::atomic::AtomicU64::new(0);
653 let n = EX_PROF_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
654 if n.is_multiple_of(8) {
655 let total = t0.elapsed().as_micros();
656 let prep = tp.duration_since(t0).as_micros();
657 let lock_acq = t_lock_acq.map(|d| d.as_micros()).unwrap_or(0);
658 let model_call = t_model_call.map(|d| d.as_micros()).unwrap_or(0);
659 let model_block = tm.duration_since(tp).as_micros();
660 let wrap = tm.elapsed().as_micros();
661 eprintln!(
662 "[exec-batch-decode-prof] call#{} m={} total={}us prep={}us model_block={}us(lock_acq={}us model_call={}us) wrap={}us",
663 n, m_count, total, prep, model_block, lock_acq, model_call, wrap,
664 );
665 }
666 }
667 Ok(outputs)
668 }
669
670 async fn unified_decode(&self, batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
687 let mut results: Vec<Option<Vec<f32>>> = vec![None; batch.items.len()];
688 if batch.items.is_empty() {
689 return Ok(results);
690 }
691
692 let unified_items: Vec<(String, Vec<u32>, usize, bool)> = batch
699 .items
700 .iter()
701 .map(|it| {
702 (
703 it.seq_id.clone(),
704 it.q_tokens.clone(),
705 it.pos_offset,
706 it.is_final_chunk,
707 )
708 })
709 .collect();
710 let model_result = {
711 let mut model = self.model.lock();
712 for item in &batch.items {
713 model.set_lora_adapter_for_cache(
714 &item.seq_id,
715 active_lora_from_metadata(&item.metadata)?,
716 )?;
717 }
718 model.unified_forward(&unified_items)
719 };
720 match model_result {
721 Ok(per_item) => {
722 if per_item.len() != batch.items.len() {
723 return Err(FerrumError::model(format!(
724 "unified_forward returned {} entries for {} items",
725 per_item.len(),
726 batch.items.len(),
727 )));
728 }
729 return Ok(per_item);
730 }
731 Err(FerrumError::Unsupported { .. }) => {
732 }
734 Err(e) => return Err(e),
735 }
736
737 let mut prefill_indices: Vec<usize> = Vec::new();
743 let mut decode_indices: Vec<usize> = Vec::new();
744 for (i, item) in batch.items.iter().enumerate() {
745 if item.q_tokens.len() == 1 && item.is_final_chunk {
746 decode_indices.push(i);
747 } else {
748 prefill_indices.push(i);
749 }
750 }
751
752 if !prefill_indices.is_empty() {
757 let mut model = self.model.lock();
758 for &i in &prefill_indices {
759 let item = &batch.items[i];
760 model.set_lora_adapter_for_cache(
761 &item.seq_id,
762 active_lora_from_metadata(&item.metadata)?,
763 )?;
764 let logits = model.prefill(&item.seq_id, &item.q_tokens);
765 if item.is_final_chunk {
766 results[i] = Some(logits);
767 }
768 }
769 }
770
771 if !decode_indices.is_empty() {
773 let tuples: Vec<(String, u32, u32)> = decode_indices
774 .iter()
775 .map(|&i| {
776 let it = &batch.items[i];
777 (it.seq_id.clone(), it.q_tokens[0], it.pos_offset as u32)
778 })
779 .collect();
780 let logits_vec = {
781 let mut model = self.model.lock();
782 for &i in &decode_indices {
783 let item = &batch.items[i];
784 model.set_lora_adapter_for_cache(
785 &item.seq_id,
786 active_lora_from_metadata(&item.metadata)?,
787 )?;
788 }
789 let force_full_logits = decode_indices
790 .iter()
791 .any(|&i| metadata_requires_full_logits(&batch.items[i].metadata));
792 model.decode_batch_with_full_logits(&tuples, force_full_logits)
793 };
794 for (j, &i) in decode_indices.iter().enumerate() {
795 results[i] = Some(logits_vec[j].clone());
796 }
797 }
798
799 Ok(results)
800 }
801
802 fn release_cache(&self, cache_id: &str) {
803 self.model.lock().release(cache_id);
804 }
805
806 fn capabilities(&self) -> ExecutorCapabilities {
807 let cfg = self.model.lock().config().clone();
808 ExecutorCapabilities {
809 max_batch_size: 256,
810 max_sequence_length: cfg.max_seq_len,
811 attention_mechanisms: vec![AttentionType::GroupedQuery],
812 supports_dynamic_batching: true,
813 supports_continuous_batching: true,
814 supports_speculative_decoding: false,
815 supports_tensor_parallelism: false,
816 supports_pipeline_parallelism: false,
817 supported_dtypes: vec![DataType::FP32],
818 supported_devices: vec![self.info.device.clone()],
819 memory_requirements: MemoryRequirements {
820 parameter_memory: (self.info.num_parameters * 4) as u64,
821 activation_memory_per_token: cfg.hidden_size * 4,
822 kv_cache_memory_per_token: cfg.hidden_size * 2,
823 overhead_memory: 256 * 1024 * 1024,
824 },
825 }
826 }
827
828 fn status(&self) -> ExecutorStatus {
829 common::default_executor_status()
830 }
831
832 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
833 self.model.lock().cache_metrics_snapshot()
834 }
835
836 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
837 self.model.lock().lora_metrics_snapshot()
838 }
839}
840
841#[cfg(test)]
842mod tests {
843 use super::*;
844
845 #[test]
846 fn llm_executor_runtime_env_parses_profile_flags_by_presence() {
847 let env = LlmExecutorRuntimeEnv::from_env_vars([
848 ("FERRUM_BATCH_PREFILL_PROF", ""),
849 ("FERRUM_BATCH_DECODE_PROF", "0"),
850 ]);
851
852 assert!(env.batch_prefill_prof);
853 assert!(env.batch_decode_prof);
854 }
855
856 #[test]
857 fn llm_executor_runtime_env_defaults_profile_flags_off() {
858 let env = LlmExecutorRuntimeEnv::from_env_vars([("UNRELATED", "1")]);
859
860 assert!(!env.batch_prefill_prof);
861 assert!(!env.batch_decode_prof);
862 }
863}