Skip to main content

ferrum_models/executor/
llm_executor.rs

1//! `LlmExecutor<M>` — adapts a `DecoderOnlyLLM` to the `ModelExecutor` trait
2//! the engine scheduler calls.
3//!
4//! This is the Model-as-Code equivalent of `GenericModelExecutor`: where
5//! `GenericModelExecutor` wraps a `Box<dyn RunnerInterface>` (legacy
6//! `ModelRunner<B>`), `LlmExecutor` wraps a `Box<dyn DecoderOnlyLLM>`
7//! (new-style per-model code such as `Qwen3Model<B>`).
8//!
9//! Tokens/logits are currently bridged through candle Tensor for
10//! `TensorRef` — Phase C will likely replace that with `SmallTensor` to
11//! drop candle from the hot path.
12
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, OnceLock};
15
16use parking_lot::{Mutex, MutexGuard};
17use tracing::debug;
18
19use ferrum_interfaces::{
20    model_executor::{
21        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22        MemoryRequirements, PrefillInput, PrefillOutput, UnifiedBatch,
23    },
24    ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29use crate::lora::ActiveLoraAdapter;
30
31use super::common::{self, GenericKvCacheHandle};
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct LlmExecutorRuntimeEnv {
35    batch_prefill_prof: bool,
36    batch_decode_prof: bool,
37}
38
39impl LlmExecutorRuntimeEnv {
40    fn from_env() -> Self {
41        Self::from_env_vars(std::env::vars())
42    }
43
44    fn from_env_vars<I, K, V>(vars: I) -> Self
45    where
46        I: IntoIterator<Item = (K, V)>,
47        K: AsRef<str>,
48    {
49        let mut batch_prefill_prof = false;
50        let mut batch_decode_prof = false;
51
52        for (key, _) in vars {
53            match key.as_ref() {
54                "FERRUM_BATCH_PREFILL_PROF" => batch_prefill_prof = true,
55                "FERRUM_BATCH_DECODE_PROF" => batch_decode_prof = true,
56                _ => {}
57            }
58        }
59
60        Self {
61            batch_prefill_prof,
62            batch_decode_prof,
63        }
64    }
65}
66
67fn llm_executor_runtime_env() -> &'static LlmExecutorRuntimeEnv {
68    static CONFIG: OnceLock<LlmExecutorRuntimeEnv> = OnceLock::new();
69    CONFIG.get_or_init(LlmExecutorRuntimeEnv::from_env)
70}
71
72fn active_lora_from_metadata(
73    metadata: &std::collections::HashMap<String, serde_json::Value>,
74) -> Result<Option<ActiveLoraAdapter>> {
75    let name = metadata
76        .get("ferrum_lora_adapter")
77        .and_then(|value| value.as_str());
78    let path = metadata
79        .get("ferrum_lora_path")
80        .and_then(|value| value.as_str());
81    match (name, path) {
82        (Some(name), Some(path)) => Ok(Some(ActiveLoraAdapter {
83            name: name.to_string(),
84            path: std::path::PathBuf::from(path),
85        })),
86        (None, None) => Ok(None),
87        _ => Err(FerrumError::model(
88            "incomplete LoRA metadata: expected ferrum_lora_adapter and ferrum_lora_path",
89        )),
90    }
91}
92
93fn metadata_requires_full_logits(
94    metadata: &std::collections::HashMap<String, serde_json::Value>,
95) -> bool {
96    metadata
97        .get("ferrum_require_full_logits")
98        .and_then(|value| value.as_bool())
99        .unwrap_or(false)
100}
101
102/// Map a `ferrum_types::Device` to the matching `candle_core::Device`.
103/// Used when materialising KV cache handles so downstream readers see
104/// the real backend the model runs on (Metal / CUDA / CPU) rather than
105/// a hard-coded CPU placeholder.
106fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
107    match d {
108        ferrum_types::Device::CPU => candle_core::Device::Cpu,
109        #[cfg(feature = "cuda")]
110        ferrum_types::Device::CUDA(i) => {
111            candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
112        }
113        #[cfg(not(feature = "cuda"))]
114        ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
115        #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
116        ferrum_types::Device::Metal => {
117            candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
118        }
119        _ => candle_core::Device::Cpu,
120    }
121}
122
123pub struct LlmExecutor {
124    model: Mutex<Box<dyn DecoderOnlyLLM>>,
125    info: ModelInfo,
126    next_cache_id: AtomicU64,
127    total_model_lock_wait_us: AtomicU64,
128    model_lock_wait_samples: AtomicU64,
129}
130
131impl LlmExecutor {
132    pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
133        Self {
134            model: Mutex::new(model),
135            info,
136            next_cache_id: AtomicU64::new(0),
137            total_model_lock_wait_us: AtomicU64::new(0),
138            model_lock_wait_samples: AtomicU64::new(0),
139        }
140    }
141
142    fn lock_model(&self) -> MutexGuard<'_, Box<dyn DecoderOnlyLLM>> {
143        let start = std::time::Instant::now();
144        let guard = self.model.lock();
145        self.record_model_lock_wait(start.elapsed());
146        guard
147    }
148
149    fn record_model_lock_wait(&self, duration: std::time::Duration) {
150        self.total_model_lock_wait_us.fetch_add(
151            duration.as_micros().min(u64::MAX as u128) as u64,
152            Ordering::Relaxed,
153        );
154        self.model_lock_wait_samples.fetch_add(1, Ordering::Relaxed);
155    }
156
157    fn model_lock_metrics_json(&self) -> serde_json::Value {
158        let samples = self.model_lock_wait_samples.load(Ordering::Relaxed);
159        let total_us = self.total_model_lock_wait_us.load(Ordering::Relaxed);
160        serde_json::json!({
161            "schema_version": 1,
162            "samples": samples,
163            "total_wait_time_us": total_us,
164            "avg_wait_time_ms": if samples == 0 {
165                0.0
166            } else {
167                total_us as f64 / samples as f64 / 1000.0
168            },
169        })
170    }
171
172    fn attach_model_lock_metrics(&self, mut snapshot: serde_json::Value) -> serde_json::Value {
173        let lock_metrics = self.model_lock_metrics_json();
174        if let Some(obj) = snapshot.as_object_mut() {
175            obj.insert("executor_model_lock".to_string(), lock_metrics);
176            snapshot
177        } else {
178            serde_json::json!({
179                "cache_metrics": snapshot,
180                "executor_model_lock": lock_metrics,
181            })
182        }
183    }
184
185    fn gen_cache_id(&self) -> String {
186        format!(
187            "llm-cache-{}",
188            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
189        )
190    }
191
192    /// Roll the KV cache for `cache_id` back to `new_len` positions.
193    /// Used by speculative decoding on partial rejection. The caller must
194    /// supply a `GenericKvCacheHandle` whose seq_len is also updated.
195    pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
196        let mut model = self.lock_model();
197        model.truncate_kv(cache_id, new_len);
198    }
199}
200
201#[async_trait::async_trait]
202impl ModelExecutor for LlmExecutor {
203    fn info(&self) -> &ModelInfo {
204        &self.info
205    }
206
207    fn supports_native_unified_decode(&self) -> bool {
208        // CUDA has a native unified mixed prefill+decode forward; CPU and Metal
209        // use the legacy split path. The device→capability mapping lives here
210        // (the executor is backend-aware) so the engine needs no platform cfg.
211        matches!(self.info.device, ferrum_types::Device::CUDA(_))
212    }
213
214    fn kv_capacity(&self) -> Option<usize> {
215        Some(self.lock_model().kv_capacity())
216    }
217
218    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
219        let tokens = common::tensor_to_tokens(&input.input_ids)?;
220
221        // Reuse an existing cache_id when the caller supplies a KV handle
222        // (chunked prefill) — fresh id only on the very first call for a
223        // request. Without this, every chunk would create a new KV cache
224        // at position 0 and subsequent chunks wouldn't see prior tokens.
225        let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
226            h.as_any()
227                .downcast_ref::<GenericKvCacheHandle>()
228                .map(|g| g.request_cache_id().to_string())
229        });
230        let cache_id = supplied_handle_id
231            .clone()
232            .unwrap_or_else(|| self.gen_cache_id());
233
234        // For chunked-prefill continuation, the prior KV length is the seq
235        // length already in the supplied handle; for fresh prefill it's 0.
236        let prior_seq_len = input
237            .kv_cache
238            .as_ref()
239            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
240            .map(|g| {
241                use ferrum_interfaces::KvCacheHandle;
242                g.block_table().sequence_length
243            })
244            .unwrap_or(0);
245
246        // Try the unified_forward path first when the caller can accept the
247        // model's fast readback path. Requests that need logits processors or
248        // token masks require full logits so the engine sampler can enforce
249        // them; unified_forward currently has no per-item metadata channel.
250        let force_full_logits = metadata_requires_full_logits(&input.metadata);
251        let logits = {
252            let mut model = self.lock_model();
253            model.set_lora_adapter_for_cache(
254                &cache_id,
255                active_lora_from_metadata(&input.metadata)?,
256            )?;
257            if force_full_logits {
258                model.prefill(&cache_id, &tokens)
259            } else {
260                let unified_item = vec![(cache_id.clone(), tokens.clone(), prior_seq_len, true)];
261                match model.unified_forward(&unified_item) {
262                    Ok(mut per_item) => per_item
263                        .pop()
264                        .flatten()
265                        .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
266                    Err(FerrumError::Unsupported { .. }) => model.prefill(&cache_id, &tokens),
267                    Err(e) => return Err(e),
268                }
269            }
270        };
271
272        // Wrap logits as TensorRef: [1, 1, vocab_size]
273        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
274            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
275            .unsqueeze(0)
276            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
277            .unsqueeze(0)
278            .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
279        let logits_ref = common::wrap_tensor(logits_tensor);
280
281        let cfg = self.lock_model().config().clone();
282        // Sequence-length tracking across chunks: if the caller supplied a
283        // GenericKvCacheHandle (chunked prefill continuation), add this
284        // chunk's tokens to the prior length. Otherwise this is a fresh
285        // prefill so seq_len == this call's token count. Without this the
286        // handle would claim only the last chunk's length, misleading
287        // decode() into rewriting the KV at an earlier position.
288        let seq_len = input
289            .kv_cache
290            .as_ref()
291            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
292            .map(|g| {
293                use ferrum_interfaces::KvCacheHandle;
294                g.block_table().sequence_length + tokens.len()
295            })
296            .unwrap_or(tokens.len());
297
298        let kv_handle = Arc::new(GenericKvCacheHandle::new(
299            cfg.num_layers,
300            cfg.num_kv_heads,
301            cfg.head_dim,
302            candle_core::Device::Cpu,
303            seq_len,
304            cache_id,
305        ));
306
307        Ok(PrefillOutput::new(logits_ref, kv_handle))
308    }
309
310    /// Batched prefill: combine all prompts into ONE `model.unified_forward`
311    /// call so launch / kernel-overhead is amortized across the cohort.
312    ///
313    /// Falls back to the trait default (serial per-item) when the model
314    /// returns `Err(unsupported)` from `unified_forward` — e.g. Qwen3MoeModel
315    /// today, until Phase 2 adds its native unified path.
316    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
317        if inputs.is_empty() {
318            return Ok(Vec::new());
319        }
320        let force_full_logits = inputs
321            .iter()
322            .any(|input| metadata_requires_full_logits(&input.metadata));
323
324        // Per-input: derive cache_id (reuse supplied handle's id or generate
325        // fresh) + prior_seq_len. Mirrors the single-prefill path so chunked
326        // prefill continuations route correctly when batched.
327        let mut cache_ids = Vec::with_capacity(inputs.len());
328        let mut prior_seq_lens = Vec::with_capacity(inputs.len());
329        let mut tokens_per_input = Vec::with_capacity(inputs.len());
330        let mut lora_per_input = Vec::with_capacity(inputs.len());
331        for input in inputs {
332            let tokens = common::tensor_to_tokens(&input.input_ids)?;
333            let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
334                h.as_any()
335                    .downcast_ref::<GenericKvCacheHandle>()
336                    .map(|g| g.request_cache_id().to_string())
337            });
338            let cache_id = supplied_handle_id
339                .clone()
340                .unwrap_or_else(|| self.gen_cache_id());
341            let prior_seq_len = input
342                .kv_cache
343                .as_ref()
344                .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
345                .map(|g| {
346                    use ferrum_interfaces::KvCacheHandle;
347                    g.block_table().sequence_length
348                })
349                .unwrap_or(0);
350            cache_ids.push(cache_id);
351            prior_seq_lens.push(prior_seq_len);
352            tokens_per_input.push(tokens);
353            lora_per_input.push(active_lora_from_metadata(&input.metadata)?);
354        }
355
356        // Build unified items and ONE `unified_forward` call. If the model
357        // doesn't support it, fall back to the trait-default serial path.
358        let unified_items: Vec<(String, Vec<u32>, usize, bool)> = cache_ids
359            .iter()
360            .zip(tokens_per_input.iter())
361            .zip(prior_seq_lens.iter())
362            .map(|((cid, toks), &prior)| (cid.clone(), toks.clone(), prior, true))
363            .collect();
364
365        let nb_prof = llm_executor_runtime_env().batch_prefill_prof;
366        let bp_t0 = if nb_prof {
367            Some(std::time::Instant::now())
368        } else {
369            None
370        };
371        let mut took_fallback = false;
372        let per_item_logits: Vec<Vec<f32>> = {
373            let mut model = self.lock_model();
374            for (cache_id, adapter) in cache_ids.iter().zip(lora_per_input.iter()) {
375                model.set_lora_adapter_for_cache(cache_id, adapter.clone())?;
376            }
377            if force_full_logits {
378                took_fallback = true;
379                let mut out = Vec::with_capacity(inputs.len());
380                for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
381                    out.push(model.prefill(cid, toks));
382                }
383                out
384            } else {
385                match model.unified_forward(&unified_items) {
386                    Ok(per_item) => per_item
387                        .into_iter()
388                        .map(|opt| opt.expect("is_final_chunk=true must yield logits"))
389                        .collect(),
390                    Err(FerrumError::Unsupported { .. }) => {
391                        took_fallback = true;
392                        let mut out = Vec::with_capacity(inputs.len());
393                        for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
394                            out.push(model.prefill(cid, toks));
395                        }
396                        out
397                    }
398                    Err(e) => return Err(e),
399                }
400            }
401        };
402        if let Some(t0) = bp_t0 {
403            let total_q: usize = unified_items.iter().map(|it| it.1.len()).sum();
404            eprintln!(
405                "[batch-prefill] n_items={} total_q={} fallback={} elapsed={}us",
406                inputs.len(),
407                total_q,
408                took_fallback,
409                t0.elapsed().as_micros()
410            );
411        }
412
413        let cfg = self.lock_model().config().clone();
414        let mut outputs = Vec::with_capacity(inputs.len());
415        for (i, logits) in per_item_logits.into_iter().enumerate() {
416            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
417                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
418                .unsqueeze(0)
419                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
420                .unsqueeze(0)
421                .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
422            let logits_ref = common::wrap_tensor(logits_tensor);
423            let seq_len = inputs[i]
424                .kv_cache
425                .as_ref()
426                .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
427                .map(|g| {
428                    use ferrum_interfaces::KvCacheHandle;
429                    g.block_table().sequence_length + tokens_per_input[i].len()
430                })
431                .unwrap_or(tokens_per_input[i].len());
432            let kv_handle = Arc::new(GenericKvCacheHandle::new(
433                cfg.num_layers,
434                cfg.num_kv_heads,
435                cfg.head_dim,
436                candle_core::Device::Cpu,
437                seq_len,
438                cache_ids[i].clone(),
439            ));
440            outputs.push(PrefillOutput::new(logits_ref, kv_handle));
441        }
442        Ok(outputs)
443    }
444
445    async fn truncate_kv(
446        &self,
447        kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
448        new_len: usize,
449    ) -> Result<()> {
450        if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
451            let cache_id = g.request_cache_id();
452            self.lock_model().truncate_kv(cache_id, new_len);
453        }
454        Ok(())
455    }
456
457    async fn forward_verify(
458        &self,
459        inputs: &[ferrum_interfaces::model_executor::DecodeInput],
460    ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
461        if inputs.is_empty() {
462            return Ok(Vec::new());
463        }
464
465        // All inputs must share the same KV handle (speculative decoding
466        // contract). Extract cache_id + starting seq_len once.
467        let first_handle = inputs[0].kv_cache.clone();
468        let cache_id = first_handle
469            .as_any()
470            .downcast_ref::<GenericKvCacheHandle>()
471            .ok_or_else(|| {
472                FerrumError::model("forward_verify requires GenericKvCacheHandle input")
473            })?
474            .request_cache_id()
475            .to_string();
476        let start_seq = {
477            use ferrum_interfaces::KvCacheHandle;
478            first_handle.block_table().sequence_length
479        };
480
481        // Collect the N+1 token ids.
482        let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
483        for input in inputs {
484            let toks = common::tensor_to_tokens(&input.input_ids)?;
485            if toks.is_empty() {
486                return Err(FerrumError::model("forward_verify input token empty"));
487            }
488            token_ids.push(toks[0]);
489        }
490
491        // One model forward for all N+1 positions → flat seq_len*vocab.
492        let flat = {
493            let mut model = self.lock_model();
494            model.set_lora_adapter_for_cache(
495                &cache_id,
496                active_lora_from_metadata(&inputs[0].metadata)?,
497            )?;
498            model.forward_verify(&cache_id, &token_ids)
499        };
500
501        let cfg = self.lock_model().config().clone();
502        let vocab = cfg.vocab_size;
503
504        // Record the actual backend device so downstream code that reads
505        // `KvCacheHandle::device()` sees Metal/CUDA/CPU matching the
506        // model's real location. The logits `Tensor` still wraps CPU data
507        // because `B::to_vec` already moved it off-device.
508        let candle_device = ferrum_device_to_candle(&self.info.device);
509
510        // Split the flat logits into per-position tensors, each wrapped
511        // with a handle whose seq_len reflects the positions written so
512        // far. Matches what the spec runner expects from sequential
513        // decode() calls.
514        let mut outputs = Vec::with_capacity(inputs.len());
515        for (i, _) in inputs.iter().enumerate() {
516            let row = &flat[i * vocab..(i + 1) * vocab];
517            let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
518                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
519                .unsqueeze(0)
520                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
521            let logits_ref = common::wrap_tensor(logits_tensor);
522            let handle = Arc::new(GenericKvCacheHandle::new(
523                cfg.num_layers,
524                cfg.num_kv_heads,
525                cfg.head_dim,
526                candle_device.clone(),
527                start_seq + i + 1,
528                cache_id.clone(),
529            ));
530            outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
531                logits_ref, handle,
532            ));
533        }
534        Ok(outputs)
535    }
536
537    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
538        let input_handle = input
539            .kv_cache
540            .as_any()
541            .downcast_ref::<GenericKvCacheHandle>()
542            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
543
544        let cache_id = input_handle.request_cache_id().to_string();
545        let seq_len = {
546            use ferrum_interfaces::KvCacheHandle;
547            input_handle.block_table().sequence_length
548        };
549
550        let tokens = common::tensor_to_tokens(&input.input_ids)?;
551        if tokens.is_empty() {
552            return Err(FerrumError::model("Decode input is empty"));
553        }
554        let token = tokens[0];
555
556        debug!("LlmExecutor decode: token={token}, pos={seq_len}");
557
558        // Try unified_forward first unless the engine needs full logits for
559        // masks/processors. The direct decode path returns vocabulary logits.
560        let force_full_logits = metadata_requires_full_logits(&input.metadata);
561        let logits = {
562            let mut model = self.lock_model();
563            model.set_lora_adapter_for_cache(
564                &cache_id,
565                active_lora_from_metadata(&input.metadata)?,
566            )?;
567            if force_full_logits {
568                model.decode(&cache_id, token, seq_len as u32)
569            } else {
570                let unified_item = vec![(cache_id.clone(), vec![token], seq_len, true)];
571                match model.unified_forward(&unified_item) {
572                    Ok(mut per_item) => per_item
573                        .pop()
574                        .flatten()
575                        .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
576                    Err(FerrumError::Unsupported { .. }) => {
577                        model.decode(&cache_id, token, seq_len as u32)
578                    }
579                    Err(e) => return Err(e),
580                }
581            }
582        };
583
584        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
585            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
586            .unsqueeze(0)
587            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
588        let logits_ref = common::wrap_tensor(logits_tensor);
589
590        let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
591        Ok(DecodeOutput::new(logits_ref, kv_handle))
592    }
593
594    /// Override default fallback to acquire the model lock ONCE for the whole
595    /// batch, avoiding N round-trips through parking_lot. Does not yet do
596    /// true attention batching (each cache has its own kv_len), but removes
597    /// mutex churn that was serialising concurrent requests at async level.
598    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
599        if inputs.is_empty() {
600            return Ok(Vec::new());
601        }
602        let prof = llm_executor_runtime_env().batch_decode_prof;
603        let t0 = if prof {
604            Some(std::time::Instant::now())
605        } else {
606            None
607        };
608        // Pre-extract all per-input metadata OUTSIDE the lock — this is pure
609        // borrow/downcast work that doesn't touch the model.
610        struct Prep {
611            cache_id: String,
612            token: u32,
613            seq_len: u32,
614            lora: Option<ActiveLoraAdapter>,
615            requires_full_logits: bool,
616            handle: Arc<GenericKvCacheHandle>,
617        }
618        let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
619        for input in inputs {
620            let input_handle = input
621                .kv_cache
622                .as_any()
623                .downcast_ref::<GenericKvCacheHandle>()
624                .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
625            use ferrum_interfaces::KvCacheHandle;
626            let seq_len = input_handle.block_table().sequence_length as u32;
627            let tokens = common::tensor_to_tokens(&input.input_ids)?;
628            if tokens.is_empty() {
629                return Err(FerrumError::model("Decode input is empty"));
630            }
631            prepped.push(Prep {
632                cache_id: input_handle.request_cache_id().to_string(),
633                token: tokens[0],
634                seq_len,
635                lora: active_lora_from_metadata(&input.metadata)?,
636                requires_full_logits: metadata_requires_full_logits(&input.metadata),
637                handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
638            });
639        }
640        let t_prep = if prof {
641            Some(std::time::Instant::now())
642        } else {
643            None
644        };
645
646        // One lock for the whole batch. Try unified_forward first: paged
647        // configs route through the varlen kernel (single mixed dispatch
648        // for the whole batch); contig configs fall back to model's
649        // legacy decode_batch (separate paged_decode_attention call per
650        // item, batched matmul for QKV/MLP).
651        let (all_logits, t_lock_acq, t_model_call): (Vec<Vec<f32>>, _, _) = {
652            let lock_t0 = if prof {
653                Some(std::time::Instant::now())
654            } else {
655                None
656            };
657            let mut model = self.lock_model();
658            let lock_acq = lock_t0.map(|t| t.elapsed());
659            let model_t0 = if prof {
660                Some(std::time::Instant::now())
661            } else {
662                None
663            };
664            for p in &prepped {
665                model.set_lora_adapter_for_cache(&p.cache_id, p.lora.clone())?;
666            }
667            let unified_items: Vec<(String, Vec<u32>, usize, bool)> = prepped
668                .iter()
669                .map(|p| (p.cache_id.clone(), vec![p.token], p.seq_len as usize, true))
670                .collect();
671            let tuples: Vec<(String, u32, u32)> = prepped
672                .iter()
673                .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
674                .collect();
675            let force_full_logits = prepped.iter().any(|p| p.requires_full_logits);
676            let logits = if force_full_logits {
677                model.decode_batch_with_full_logits(&tuples, true)
678            } else {
679                match model.unified_forward(&unified_items) {
680                    Ok(per_item) => {
681                        if per_item.len() != prepped.len() {
682                            return Err(FerrumError::model(format!(
683                                "unified_forward returned {} entries for {} items",
684                                per_item.len(),
685                                prepped.len(),
686                            )));
687                        }
688                        let mut out = Vec::with_capacity(prepped.len());
689                        for (i, opt) in per_item.into_iter().enumerate() {
690                            out.push(opt.ok_or_else(|| {
691                                FerrumError::model(format!(
692                                    "unified_forward returned None for decode item {i}"
693                                ))
694                            })?);
695                        }
696                        out
697                    }
698                    Err(FerrumError::Unsupported { .. }) => {
699                        model.decode_batch_with_full_logits(&tuples, false)
700                    }
701                    Err(e) => return Err(e),
702                }
703            };
704            let model_call = model_t0.map(|t| t.elapsed());
705            (logits, lock_acq, model_call)
706        };
707        let t_model_done = if prof {
708            Some(std::time::Instant::now())
709        } else {
710            None
711        };
712
713        let m_count = prepped.len();
714        let mut outputs = Vec::with_capacity(m_count);
715        for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
716            debug!(
717                "LlmExecutor batch_decode: token={}, pos={}",
718                p.token, p.seq_len
719            );
720            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
721                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
722                .unsqueeze(0)
723                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
724            let logits_ref = common::wrap_tensor(logits_tensor);
725            outputs.push(DecodeOutput::new(logits_ref, p.handle));
726        }
727        if let (Some(t0), Some(tp), Some(tm)) = (t0, t_prep, t_model_done) {
728            static EX_PROF_CALLS: std::sync::atomic::AtomicU64 =
729                std::sync::atomic::AtomicU64::new(0);
730            let n = EX_PROF_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
731            if n.is_multiple_of(8) {
732                let total = t0.elapsed().as_micros();
733                let prep = tp.duration_since(t0).as_micros();
734                let lock_acq = t_lock_acq.map(|d| d.as_micros()).unwrap_or(0);
735                let model_call = t_model_call.map(|d| d.as_micros()).unwrap_or(0);
736                let model_block = tm.duration_since(tp).as_micros();
737                let wrap = tm.elapsed().as_micros();
738                eprintln!(
739                    "[exec-batch-decode-prof] call#{} m={} total={}us prep={}us model_block={}us(lock_acq={}us model_call={}us) wrap={}us",
740                    n, m_count, total, prep, model_block, lock_acq, model_call, wrap,
741                );
742            }
743        }
744        Ok(outputs)
745    }
746
747    /// Unified mixed-batch dispatch (chunked-prefill API).
748    ///
749    /// This impl is a behavior-preserving FALLBACK over the existing
750    /// trait methods on `DecoderOnlyLLM`: prefill items go through
751    /// `model.prefill(seq_id, &q_tokens)` (one at a time, mirroring the
752    /// engine's current sequential prefill loop), decode items
753    /// (`q_len == 1 && is_final_chunk`) are grouped into a single
754    /// `model.decode_batch(...)` call. Net behavior is identical to the
755    /// engine's pre-Phase-13 path; this just changes WHO orchestrates
756    /// the prefill/decode split (caller → unified_decode) so the engine
757    /// can converge on a single call.
758    ///
759    /// The real performance unlock comes in Step 5 when models override
760    /// this with a true unified-forward (one [M_total, hidden] forward
761    /// + varlen attention) — at that point the kernel-level mix replaces
762    /// the host-side serial dispatch here.
763    async fn unified_decode(&self, batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
764        let mut results: Vec<Option<Vec<f32>>> = vec![None; batch.items.len()];
765        if batch.items.is_empty() {
766            return Ok(results);
767        }
768
769        // ── Real unified path (Step 5b+): if the model implements
770        // `DecoderOnlyLLM::unified_forward`, route the entire batch
771        // through one model forward (mixed prefill chunks + decode
772        // tokens in a single [M_total, hidden] pass). The model returns
773        // `Err(unsupported)` if it hasn't been wired yet — fall through
774        // to the behaviour-preserving fallback below.
775        let unified_items: Vec<(String, Vec<u32>, usize, bool)> = batch
776            .items
777            .iter()
778            .map(|it| {
779                (
780                    it.seq_id.clone(),
781                    it.q_tokens.clone(),
782                    it.pos_offset,
783                    it.is_final_chunk,
784                )
785            })
786            .collect();
787        let force_full_logits = batch
788            .items
789            .iter()
790            .any(|item| metadata_requires_full_logits(&item.metadata));
791        if !force_full_logits {
792            let model_result = {
793                let mut model = self.lock_model();
794                for item in &batch.items {
795                    model.set_lora_adapter_for_cache(
796                        &item.seq_id,
797                        active_lora_from_metadata(&item.metadata)?,
798                    )?;
799                }
800                model.unified_forward(&unified_items)
801            };
802            match model_result {
803                Ok(per_item) => {
804                    if per_item.len() != batch.items.len() {
805                        return Err(FerrumError::model(format!(
806                            "unified_forward returned {} entries for {} items",
807                            per_item.len(),
808                            batch.items.len(),
809                        )));
810                    }
811                    return Ok(per_item);
812                }
813                Err(FerrumError::Unsupported { .. }) => {
814                    // Fall through to the dispatch fallback below.
815                }
816                Err(e) => return Err(e),
817            }
818        }
819
820        // Partition: pure decode items vs prefill chunks.
821        // A "decode" item has q_len == 1 AND is_final_chunk == true.
822        // Anything else (chunked prefill mid-stream OR a single-token
823        // prefill that returns logits) goes through the per-item prefill
824        // path so the model receives the right pos_offset behaviour.
825        let mut prefill_indices: Vec<usize> = Vec::new();
826        let mut decode_indices: Vec<usize> = Vec::new();
827        for (i, item) in batch.items.iter().enumerate() {
828            if item.q_tokens.len() == 1 && item.is_final_chunk {
829                decode_indices.push(i);
830            } else {
831                prefill_indices.push(i);
832            }
833        }
834
835        // Prefill items — sequential, mirrors current engine behaviour.
836        // Held under a single model lock to amortise lock acquire across
837        // all prefills in this batch (we may revisit per-call locking
838        // when chunked-prefill becomes the perf-critical path).
839        if !prefill_indices.is_empty() {
840            let mut model = self.lock_model();
841            for &i in &prefill_indices {
842                let item = &batch.items[i];
843                model.set_lora_adapter_for_cache(
844                    &item.seq_id,
845                    active_lora_from_metadata(&item.metadata)?,
846                )?;
847                let logits = model.prefill(&item.seq_id, &item.q_tokens);
848                if item.is_final_chunk {
849                    results[i] = Some(logits);
850                }
851            }
852        }
853
854        // Decode items — single batched dispatch.
855        if !decode_indices.is_empty() {
856            let tuples: Vec<(String, u32, u32)> = decode_indices
857                .iter()
858                .map(|&i| {
859                    let it = &batch.items[i];
860                    (it.seq_id.clone(), it.q_tokens[0], it.pos_offset as u32)
861                })
862                .collect();
863            let logits_vec = {
864                let mut model = self.lock_model();
865                for &i in &decode_indices {
866                    let item = &batch.items[i];
867                    model.set_lora_adapter_for_cache(
868                        &item.seq_id,
869                        active_lora_from_metadata(&item.metadata)?,
870                    )?;
871                }
872                let force_full_logits = decode_indices
873                    .iter()
874                    .any(|&i| metadata_requires_full_logits(&batch.items[i].metadata));
875                model.decode_batch_with_full_logits(&tuples, force_full_logits)
876            };
877            for (j, &i) in decode_indices.iter().enumerate() {
878                results[i] = Some(logits_vec[j].clone());
879            }
880        }
881
882        Ok(results)
883    }
884
885    fn release_cache(&self, cache_id: &str) {
886        self.lock_model().release(cache_id);
887    }
888
889    fn capabilities(&self) -> ExecutorCapabilities {
890        let cfg = self.lock_model().config().clone();
891        ExecutorCapabilities {
892            max_batch_size: 256,
893            max_sequence_length: cfg.max_seq_len,
894            attention_mechanisms: vec![AttentionType::GroupedQuery],
895            supports_dynamic_batching: true,
896            supports_continuous_batching: true,
897            supports_speculative_decoding: false,
898            supports_tensor_parallelism: false,
899            supports_pipeline_parallelism: false,
900            supported_dtypes: vec![DataType::FP32],
901            supported_devices: vec![self.info.device.clone()],
902            memory_requirements: MemoryRequirements {
903                parameter_memory: (self.info.num_parameters * 4) as u64,
904                activation_memory_per_token: cfg.hidden_size * 4,
905                kv_cache_memory_per_token: cfg.hidden_size * 2,
906                overhead_memory: 256 * 1024 * 1024,
907            },
908        }
909    }
910
911    fn status(&self) -> ExecutorStatus {
912        common::default_executor_status()
913    }
914
915    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
916        let snapshot = self.lock_model().cache_metrics_snapshot()?;
917        Some(self.attach_model_lock_metrics(snapshot))
918    }
919
920    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
921        self.lock_model().lora_metrics_snapshot()
922    }
923}
924
925#[cfg(test)]
926mod tests {
927    use super::*;
928    use std::collections::HashMap;
929
930    use ferrum_interfaces::model_executor::{DecodeInput, PrefillInput, UnifiedBatchItem};
931    use ferrum_interfaces::KvCacheHandle;
932    use ferrum_testkit::MockTensor;
933    use ferrum_types::{Device, ModelId, ModelType};
934
935    #[derive(Default)]
936    struct RecordingCalls {
937        unified_forward: usize,
938        prefill: usize,
939        decode: usize,
940        decode_batch_force_full_logits: Vec<bool>,
941    }
942
943    struct RecordingLlm {
944        calls: Arc<Mutex<RecordingCalls>>,
945        config: crate::common::LlmRuntimeConfig,
946    }
947
948    impl RecordingLlm {
949        fn new(calls: Arc<Mutex<RecordingCalls>>) -> Self {
950            Self {
951                calls,
952                config: crate::common::LlmRuntimeConfig {
953                    hidden_size: 4,
954                    num_layers: 1,
955                    num_kv_heads: 1,
956                    head_dim: 4,
957                    vocab_size: 4,
958                    max_seq_len: 16,
959                },
960            }
961        }
962    }
963
964    impl DecoderOnlyLLM for RecordingLlm {
965        fn config(&self) -> &crate::common::LlmRuntimeConfig {
966            &self.config
967        }
968
969        fn prefill(&mut self, _cache_id: &str, _tokens: &[u32]) -> Vec<f32> {
970            self.calls.lock().prefill += 1;
971            vec![0.0, 1.0, 2.0, 3.0]
972        }
973
974        fn decode(&mut self, _cache_id: &str, _token: u32, _pos: u32) -> Vec<f32> {
975            self.calls.lock().decode += 1;
976            vec![3.0, 2.0, 1.0, 0.0]
977        }
978
979        fn decode_batch_with_full_logits(
980            &mut self,
981            batch: &[(String, u32, u32)],
982            force_full_logits: bool,
983        ) -> Vec<Vec<f32>> {
984            self.calls
985                .lock()
986                .decode_batch_force_full_logits
987                .push(force_full_logits);
988            batch.iter().map(|_| vec![3.0, 2.0, 1.0, 0.0]).collect()
989        }
990
991        fn unified_forward(
992            &mut self,
993            items: &[(String, Vec<u32>, usize, bool)],
994        ) -> std::result::Result<Vec<Option<Vec<f32>>>, FerrumError> {
995            self.calls.lock().unified_forward += 1;
996            Ok(items
997                .iter()
998                .map(|(_, _, _, is_final_chunk)| is_final_chunk.then_some(vec![99.0]))
999                .collect())
1000        }
1001
1002        fn release(&mut self, _cache_id: &str) {}
1003
1004        fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
1005            Some(serde_json::json!({
1006                "position": "recording-test-cache",
1007            }))
1008        }
1009    }
1010
1011    fn test_model_info() -> ModelInfo {
1012        ModelInfo {
1013            model_id: ModelId("recording".to_string()),
1014            model_type: ModelType::Custom("recording".to_string()),
1015            num_parameters: 0,
1016            hidden_size: 4,
1017            num_layers: 1,
1018            num_heads: 1,
1019            num_kv_heads: 1,
1020            vocab_size: 4,
1021            max_sequence_length: 16,
1022            dtype: DataType::FP32,
1023            device: Device::CPU,
1024            version: None,
1025            license: None,
1026            metadata: HashMap::new(),
1027        }
1028    }
1029
1030    fn recording_executor(calls: Arc<Mutex<RecordingCalls>>) -> LlmExecutor {
1031        LlmExecutor::new(Box::new(RecordingLlm::new(calls)), test_model_info())
1032    }
1033
1034    fn full_logits_metadata() -> HashMap<String, serde_json::Value> {
1035        HashMap::from([(
1036            "ferrum_require_full_logits".to_string(),
1037            serde_json::json!(true),
1038        )])
1039    }
1040
1041    fn test_kv_handle(cache_id: &str, seq_len: usize) -> Arc<dyn KvCacheHandle> {
1042        Arc::new(GenericKvCacheHandle::new(
1043            1,
1044            1,
1045            4,
1046            candle_core::Device::Cpu,
1047            seq_len,
1048            cache_id.to_string(),
1049        ))
1050    }
1051
1052    #[test]
1053    fn llm_executor_runtime_env_parses_profile_flags_by_presence() {
1054        let env = LlmExecutorRuntimeEnv::from_env_vars([
1055            ("FERRUM_BATCH_PREFILL_PROF", ""),
1056            ("FERRUM_BATCH_DECODE_PROF", "0"),
1057        ]);
1058
1059        assert!(env.batch_prefill_prof);
1060        assert!(env.batch_decode_prof);
1061    }
1062
1063    #[test]
1064    fn llm_executor_runtime_env_defaults_profile_flags_off() {
1065        let env = LlmExecutorRuntimeEnv::from_env_vars([("UNRELATED", "1")]);
1066
1067        assert!(!env.batch_prefill_prof);
1068        assert!(!env.batch_decode_prof);
1069    }
1070
1071    #[test]
1072    fn prefill_skips_unified_forward_when_full_logits_required() {
1073        let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1074        let executor = recording_executor(calls.clone());
1075        let input = PrefillInput::new(MockTensor::from_u32(&[1, 2], &[2]).into_ref())
1076            .with_metadata(full_logits_metadata());
1077
1078        let output = tokio_test::block_on(executor.prefill(&input)).unwrap();
1079
1080        assert_eq!(
1081            output
1082                .last_token_logits()
1083                .unwrap()
1084                .to_vec_f32()
1085                .unwrap()
1086                .len(),
1087            4
1088        );
1089        let calls = calls.lock();
1090        assert_eq!(calls.unified_forward, 0);
1091        assert_eq!(calls.prefill, 1);
1092    }
1093
1094    #[test]
1095    fn decode_skips_unified_forward_when_full_logits_required() {
1096        let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1097        let executor = recording_executor(calls.clone());
1098        let input = DecodeInput::new(
1099            MockTensor::from_u32(&[7], &[1]).into_ref(),
1100            test_kv_handle("decode-cache", 3),
1101        )
1102        .with_metadata(full_logits_metadata());
1103
1104        let output = tokio_test::block_on(executor.decode(&input)).unwrap();
1105
1106        assert_eq!(output.logits.to_vec_f32().unwrap().len(), 4);
1107        let calls = calls.lock();
1108        assert_eq!(calls.unified_forward, 0);
1109        assert_eq!(calls.decode, 1);
1110    }
1111
1112    #[test]
1113    fn unified_decode_skips_unified_forward_when_full_logits_required() {
1114        let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1115        let executor = recording_executor(calls.clone());
1116        let mut batch = UnifiedBatch::new();
1117        batch.items.push(UnifiedBatchItem {
1118            seq_id: "decode-cache".to_string(),
1119            q_tokens: vec![7],
1120            kv_cache: test_kv_handle("decode-cache", 3),
1121            pos_offset: 3,
1122            is_final_chunk: true,
1123            metadata: full_logits_metadata(),
1124        });
1125
1126        let output = tokio_test::block_on(executor.unified_decode(&batch)).unwrap();
1127
1128        assert_eq!(output[0].as_ref().unwrap().len(), 4);
1129        let calls = calls.lock();
1130        assert_eq!(calls.unified_forward, 0);
1131        assert_eq!(calls.decode_batch_force_full_logits, vec![true]);
1132    }
1133
1134    #[test]
1135    fn cache_metrics_snapshot_includes_model_lock_wait_metrics() {
1136        let calls = Arc::new(Mutex::new(RecordingCalls::default()));
1137        let executor = recording_executor(calls);
1138
1139        assert_eq!(executor.kv_capacity(), Some(16));
1140        let metrics = executor.cache_metrics_snapshot().unwrap();
1141
1142        assert_eq!(metrics["position"], "recording-test-cache");
1143        assert_eq!(metrics["executor_model_lock"]["schema_version"], 1);
1144        assert!(
1145            metrics["executor_model_lock"]["samples"].as_u64().unwrap() >= 2,
1146            "metrics: {metrics}"
1147        );
1148        assert!(
1149            metrics["executor_model_lock"]["total_wait_time_us"]
1150                .as_u64()
1151                .is_some(),
1152            "metrics: {metrics}"
1153        );
1154        assert!(
1155            metrics["executor_model_lock"]["avg_wait_time_ms"].is_number(),
1156            "metrics: {metrics}"
1157        );
1158    }
1159}