Skip to main content

ferrum_models/executor/
llm_executor.rs

1//! `LlmExecutor<M>` — adapts a `DecoderOnlyLLM` to the `ModelExecutor` trait
2//! the engine scheduler calls.
3//!
4//! This is the Model-as-Code equivalent of `GenericModelExecutor`: where
5//! `GenericModelExecutor` wraps a `Box<dyn RunnerInterface>` (legacy
6//! `ModelRunner<B>`), `LlmExecutor` wraps a `Box<dyn DecoderOnlyLLM>`
7//! (new-style per-model code such as `Qwen3Model<B>`).
8//!
9//! Tokens/logits are currently bridged through candle Tensor for
10//! `TensorRef` — Phase C will likely replace that with `SmallTensor` to
11//! drop candle from the hot path.
12
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, OnceLock};
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20    model_executor::{
21        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22        MemoryRequirements, PrefillInput, PrefillOutput, UnifiedBatch,
23    },
24    ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29
30use super::common::{self, GenericKvCacheHandle};
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33struct LlmExecutorRuntimeEnv {
34    batch_prefill_prof: bool,
35    batch_decode_prof: bool,
36}
37
38impl LlmExecutorRuntimeEnv {
39    fn from_env() -> Self {
40        Self::from_env_vars(std::env::vars())
41    }
42
43    fn from_env_vars<I, K, V>(vars: I) -> Self
44    where
45        I: IntoIterator<Item = (K, V)>,
46        K: AsRef<str>,
47    {
48        let mut batch_prefill_prof = false;
49        let mut batch_decode_prof = false;
50
51        for (key, _) in vars {
52            match key.as_ref() {
53                "FERRUM_BATCH_PREFILL_PROF" => batch_prefill_prof = true,
54                "FERRUM_BATCH_DECODE_PROF" => batch_decode_prof = true,
55                _ => {}
56            }
57        }
58
59        Self {
60            batch_prefill_prof,
61            batch_decode_prof,
62        }
63    }
64}
65
66fn llm_executor_runtime_env() -> &'static LlmExecutorRuntimeEnv {
67    static CONFIG: OnceLock<LlmExecutorRuntimeEnv> = OnceLock::new();
68    CONFIG.get_or_init(LlmExecutorRuntimeEnv::from_env)
69}
70
71/// Map a `ferrum_types::Device` to the matching `candle_core::Device`.
72/// Used when materialising KV cache handles so downstream readers see
73/// the real backend the model runs on (Metal / CUDA / CPU) rather than
74/// a hard-coded CPU placeholder.
75fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
76    match d {
77        ferrum_types::Device::CPU => candle_core::Device::Cpu,
78        #[cfg(feature = "cuda")]
79        ferrum_types::Device::CUDA(i) => {
80            candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
81        }
82        #[cfg(not(feature = "cuda"))]
83        ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
84        #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
85        ferrum_types::Device::Metal => {
86            candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
87        }
88        _ => candle_core::Device::Cpu,
89    }
90}
91
92pub struct LlmExecutor {
93    model: Mutex<Box<dyn DecoderOnlyLLM>>,
94    info: ModelInfo,
95    next_cache_id: AtomicU64,
96}
97
98impl LlmExecutor {
99    pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
100        Self {
101            model: Mutex::new(model),
102            info,
103            next_cache_id: AtomicU64::new(0),
104        }
105    }
106
107    fn gen_cache_id(&self) -> String {
108        format!(
109            "llm-cache-{}",
110            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
111        )
112    }
113
114    /// Roll the KV cache for `cache_id` back to `new_len` positions.
115    /// Used by speculative decoding on partial rejection. The caller must
116    /// supply a `GenericKvCacheHandle` whose seq_len is also updated.
117    pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
118        let mut model = self.model.lock();
119        model.truncate_kv(cache_id, new_len);
120    }
121}
122
123#[async_trait::async_trait]
124impl ModelExecutor for LlmExecutor {
125    fn info(&self) -> &ModelInfo {
126        &self.info
127    }
128
129    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
130        let tokens = common::tensor_to_tokens(&input.input_ids)?;
131
132        // Reuse an existing cache_id when the caller supplies a KV handle
133        // (chunked prefill) — fresh id only on the very first call for a
134        // request. Without this, every chunk would create a new KV cache
135        // at position 0 and subsequent chunks wouldn't see prior tokens.
136        let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
137            h.as_any()
138                .downcast_ref::<GenericKvCacheHandle>()
139                .map(|g| g.request_cache_id().to_string())
140        });
141        let cache_id = supplied_handle_id
142            .clone()
143            .unwrap_or_else(|| self.gen_cache_id());
144
145        // For chunked-prefill continuation, the prior KV length is the seq
146        // length already in the supplied handle; for fresh prefill it's 0.
147        let prior_seq_len = input
148            .kv_cache
149            .as_ref()
150            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
151            .map(|g| {
152                use ferrum_interfaces::KvCacheHandle;
153                g.block_table().sequence_length
154            })
155            .unwrap_or(0);
156
157        // Try the unified_forward path first: when the model has it wired
158        // (paged KV pools allocated), this routes through the chunked-prefill
159        // varlen kernel — the same code that handles mixed prefill+decode
160        // batches. On Unsupported, fall back to the legacy single-item
161        // prefill path so contig-KV configs keep their existing behaviour.
162        let logits = {
163            let mut model = self.model.lock();
164            let unified_item = vec![(cache_id.clone(), tokens.clone(), prior_seq_len, true)];
165            match model.unified_forward(&unified_item) {
166                Ok(mut per_item) => per_item
167                    .pop()
168                    .flatten()
169                    .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
170                Err(FerrumError::Unsupported { .. }) => model.prefill(&cache_id, &tokens),
171                Err(e) => return Err(e),
172            }
173        };
174
175        // Wrap logits as TensorRef: [1, 1, vocab_size]
176        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
177            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
178            .unsqueeze(0)
179            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
180            .unsqueeze(0)
181            .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
182        let logits_ref = common::wrap_tensor(logits_tensor);
183
184        let cfg = self.model.lock().config().clone();
185        // Sequence-length tracking across chunks: if the caller supplied a
186        // GenericKvCacheHandle (chunked prefill continuation), add this
187        // chunk's tokens to the prior length. Otherwise this is a fresh
188        // prefill so seq_len == this call's token count. Without this the
189        // handle would claim only the last chunk's length, misleading
190        // decode() into rewriting the KV at an earlier position.
191        let seq_len = input
192            .kv_cache
193            .as_ref()
194            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
195            .map(|g| {
196                use ferrum_interfaces::KvCacheHandle;
197                g.block_table().sequence_length + tokens.len()
198            })
199            .unwrap_or(tokens.len());
200
201        let kv_handle = Arc::new(GenericKvCacheHandle::new(
202            cfg.num_layers,
203            cfg.num_kv_heads,
204            cfg.head_dim,
205            candle_core::Device::Cpu,
206            seq_len,
207            cache_id,
208        ));
209
210        Ok(PrefillOutput::new(logits_ref, kv_handle))
211    }
212
213    /// Batched prefill: combine all prompts into ONE `model.unified_forward`
214    /// call so launch / kernel-overhead is amortized across the cohort.
215    ///
216    /// Falls back to the trait default (serial per-item) when the model
217    /// returns `Err(unsupported)` from `unified_forward` — e.g. Qwen3MoeModel
218    /// today, until Phase 2 adds its native unified path.
219    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
220        if inputs.is_empty() {
221            return Ok(Vec::new());
222        }
223
224        // Per-input: derive cache_id (reuse supplied handle's id or generate
225        // fresh) + prior_seq_len. Mirrors the single-prefill path so chunked
226        // prefill continuations route correctly when batched.
227        let mut cache_ids = Vec::with_capacity(inputs.len());
228        let mut prior_seq_lens = Vec::with_capacity(inputs.len());
229        let mut tokens_per_input = Vec::with_capacity(inputs.len());
230        for input in inputs {
231            let tokens = common::tensor_to_tokens(&input.input_ids)?;
232            let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
233                h.as_any()
234                    .downcast_ref::<GenericKvCacheHandle>()
235                    .map(|g| g.request_cache_id().to_string())
236            });
237            let cache_id = supplied_handle_id
238                .clone()
239                .unwrap_or_else(|| self.gen_cache_id());
240            let prior_seq_len = input
241                .kv_cache
242                .as_ref()
243                .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
244                .map(|g| {
245                    use ferrum_interfaces::KvCacheHandle;
246                    g.block_table().sequence_length
247                })
248                .unwrap_or(0);
249            cache_ids.push(cache_id);
250            prior_seq_lens.push(prior_seq_len);
251            tokens_per_input.push(tokens);
252        }
253
254        // Build unified items and ONE `unified_forward` call. If the model
255        // doesn't support it, fall back to the trait-default serial path.
256        let unified_items: Vec<(String, Vec<u32>, usize, bool)> = cache_ids
257            .iter()
258            .zip(tokens_per_input.iter())
259            .zip(prior_seq_lens.iter())
260            .map(|((cid, toks), &prior)| (cid.clone(), toks.clone(), prior, true))
261            .collect();
262
263        let nb_prof = llm_executor_runtime_env().batch_prefill_prof;
264        let bp_t0 = if nb_prof {
265            Some(std::time::Instant::now())
266        } else {
267            None
268        };
269        let mut took_fallback = false;
270        let per_item_logits: Vec<Vec<f32>> = {
271            let mut model = self.model.lock();
272            match model.unified_forward(&unified_items) {
273                Ok(per_item) => per_item
274                    .into_iter()
275                    .map(|opt| opt.expect("is_final_chunk=true must yield logits"))
276                    .collect(),
277                Err(FerrumError::Unsupported { .. }) => {
278                    took_fallback = true;
279                    let mut out = Vec::with_capacity(inputs.len());
280                    for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
281                        out.push(model.prefill(cid, toks));
282                    }
283                    out
284                }
285                Err(e) => return Err(e),
286            }
287        };
288        if let Some(t0) = bp_t0 {
289            let total_q: usize = unified_items.iter().map(|it| it.1.len()).sum();
290            eprintln!(
291                "[batch-prefill] n_items={} total_q={} fallback={} elapsed={}us",
292                inputs.len(),
293                total_q,
294                took_fallback,
295                t0.elapsed().as_micros()
296            );
297        }
298
299        let cfg = self.model.lock().config().clone();
300        let mut outputs = Vec::with_capacity(inputs.len());
301        for (i, logits) in per_item_logits.into_iter().enumerate() {
302            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
303                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
304                .unsqueeze(0)
305                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
306                .unsqueeze(0)
307                .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
308            let logits_ref = common::wrap_tensor(logits_tensor);
309            let seq_len = inputs[i]
310                .kv_cache
311                .as_ref()
312                .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
313                .map(|g| {
314                    use ferrum_interfaces::KvCacheHandle;
315                    g.block_table().sequence_length + tokens_per_input[i].len()
316                })
317                .unwrap_or(tokens_per_input[i].len());
318            let kv_handle = Arc::new(GenericKvCacheHandle::new(
319                cfg.num_layers,
320                cfg.num_kv_heads,
321                cfg.head_dim,
322                candle_core::Device::Cpu,
323                seq_len,
324                cache_ids[i].clone(),
325            ));
326            outputs.push(PrefillOutput::new(logits_ref, kv_handle));
327        }
328        Ok(outputs)
329    }
330
331    async fn truncate_kv(
332        &self,
333        kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
334        new_len: usize,
335    ) -> Result<()> {
336        if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
337            let cache_id = g.request_cache_id();
338            self.model.lock().truncate_kv(cache_id, new_len);
339        }
340        Ok(())
341    }
342
343    async fn forward_verify(
344        &self,
345        inputs: &[ferrum_interfaces::model_executor::DecodeInput],
346    ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
347        if inputs.is_empty() {
348            return Ok(Vec::new());
349        }
350
351        // All inputs must share the same KV handle (speculative decoding
352        // contract). Extract cache_id + starting seq_len once.
353        let first_handle = inputs[0].kv_cache.clone();
354        let cache_id = first_handle
355            .as_any()
356            .downcast_ref::<GenericKvCacheHandle>()
357            .ok_or_else(|| {
358                FerrumError::model("forward_verify requires GenericKvCacheHandle input")
359            })?
360            .request_cache_id()
361            .to_string();
362        let start_seq = {
363            use ferrum_interfaces::KvCacheHandle;
364            first_handle.block_table().sequence_length
365        };
366
367        // Collect the N+1 token ids.
368        let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
369        for input in inputs {
370            let toks = common::tensor_to_tokens(&input.input_ids)?;
371            if toks.is_empty() {
372                return Err(FerrumError::model("forward_verify input token empty"));
373            }
374            token_ids.push(toks[0]);
375        }
376
377        // One model forward for all N+1 positions → flat seq_len*vocab.
378        let flat = {
379            let mut model = self.model.lock();
380            model.forward_verify(&cache_id, &token_ids)
381        };
382
383        let cfg = self.model.lock().config().clone();
384        let vocab = cfg.vocab_size;
385
386        // Record the actual backend device so downstream code that reads
387        // `KvCacheHandle::device()` sees Metal/CUDA/CPU matching the
388        // model's real location. The logits `Tensor` still wraps CPU data
389        // because `B::to_vec` already moved it off-device.
390        let candle_device = ferrum_device_to_candle(&self.info.device);
391
392        // Split the flat logits into per-position tensors, each wrapped
393        // with a handle whose seq_len reflects the positions written so
394        // far. Matches what the spec runner expects from sequential
395        // decode() calls.
396        let mut outputs = Vec::with_capacity(inputs.len());
397        for (i, _) in inputs.iter().enumerate() {
398            let row = &flat[i * vocab..(i + 1) * vocab];
399            let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
400                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
401                .unsqueeze(0)
402                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
403            let logits_ref = common::wrap_tensor(logits_tensor);
404            let handle = Arc::new(GenericKvCacheHandle::new(
405                cfg.num_layers,
406                cfg.num_kv_heads,
407                cfg.head_dim,
408                candle_device.clone(),
409                start_seq + i + 1,
410                cache_id.clone(),
411            ));
412            outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
413                logits_ref, handle,
414            ));
415        }
416        Ok(outputs)
417    }
418
419    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
420        let input_handle = input
421            .kv_cache
422            .as_any()
423            .downcast_ref::<GenericKvCacheHandle>()
424            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
425
426        let cache_id = input_handle.request_cache_id().to_string();
427        let seq_len = {
428            use ferrum_interfaces::KvCacheHandle;
429            input_handle.block_table().sequence_length
430        };
431
432        let tokens = common::tensor_to_tokens(&input.input_ids)?;
433        if tokens.is_empty() {
434            return Err(FerrumError::model("Decode input is empty"));
435        }
436        let token = tokens[0];
437
438        debug!("LlmExecutor decode: token={token}, pos={seq_len}");
439
440        // Try unified_forward first so paged-KV configs route the single
441        // decode through the same varlen kernel used by batched mixed
442        // batches. Falls back to legacy paged_decode_attention for contig
443        // configs that haven't wired unified_forward.
444        let logits = {
445            let mut model = self.model.lock();
446            let unified_item = vec![(cache_id.clone(), vec![token], seq_len, true)];
447            match model.unified_forward(&unified_item) {
448                Ok(mut per_item) => per_item
449                    .pop()
450                    .flatten()
451                    .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
452                Err(FerrumError::Unsupported { .. }) => {
453                    model.decode(&cache_id, token, seq_len as u32)
454                }
455                Err(e) => return Err(e),
456            }
457        };
458
459        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
460            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
461            .unsqueeze(0)
462            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
463        let logits_ref = common::wrap_tensor(logits_tensor);
464
465        let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
466        Ok(DecodeOutput::new(logits_ref, kv_handle))
467    }
468
469    /// Override default fallback to acquire the model lock ONCE for the whole
470    /// batch, avoiding N round-trips through parking_lot. Does not yet do
471    /// true attention batching (each cache has its own kv_len), but removes
472    /// mutex churn that was serialising concurrent requests at async level.
473    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
474        if inputs.is_empty() {
475            return Ok(Vec::new());
476        }
477        let prof = llm_executor_runtime_env().batch_decode_prof;
478        let t0 = if prof {
479            Some(std::time::Instant::now())
480        } else {
481            None
482        };
483        // Pre-extract all per-input metadata OUTSIDE the lock — this is pure
484        // borrow/downcast work that doesn't touch the model.
485        struct Prep {
486            cache_id: String,
487            token: u32,
488            seq_len: u32,
489            handle: Arc<GenericKvCacheHandle>,
490        }
491        let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
492        for input in inputs {
493            let input_handle = input
494                .kv_cache
495                .as_any()
496                .downcast_ref::<GenericKvCacheHandle>()
497                .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
498            use ferrum_interfaces::KvCacheHandle;
499            let seq_len = input_handle.block_table().sequence_length as u32;
500            let tokens = common::tensor_to_tokens(&input.input_ids)?;
501            if tokens.is_empty() {
502                return Err(FerrumError::model("Decode input is empty"));
503            }
504            prepped.push(Prep {
505                cache_id: input_handle.request_cache_id().to_string(),
506                token: tokens[0],
507                seq_len,
508                handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
509            });
510        }
511        let t_prep = if prof {
512            Some(std::time::Instant::now())
513        } else {
514            None
515        };
516
517        // One lock for the whole batch. Try unified_forward first: paged
518        // configs route through the varlen kernel (single mixed dispatch
519        // for the whole batch); contig configs fall back to model's
520        // legacy decode_batch (separate paged_decode_attention call per
521        // item, batched matmul for QKV/MLP).
522        let (all_logits, t_lock_acq, t_model_call): (Vec<Vec<f32>>, _, _) = {
523            let lock_t0 = if prof {
524                Some(std::time::Instant::now())
525            } else {
526                None
527            };
528            let mut model = self.model.lock();
529            let lock_acq = lock_t0.map(|t| t.elapsed());
530            let model_t0 = if prof {
531                Some(std::time::Instant::now())
532            } else {
533                None
534            };
535            let unified_items: Vec<(String, Vec<u32>, usize, bool)> = prepped
536                .iter()
537                .map(|p| (p.cache_id.clone(), vec![p.token], p.seq_len as usize, true))
538                .collect();
539            let logits = match model.unified_forward(&unified_items) {
540                Ok(per_item) => {
541                    if per_item.len() != prepped.len() {
542                        return Err(FerrumError::model(format!(
543                            "unified_forward returned {} entries for {} items",
544                            per_item.len(),
545                            prepped.len(),
546                        )));
547                    }
548                    let mut out = Vec::with_capacity(prepped.len());
549                    for (i, opt) in per_item.into_iter().enumerate() {
550                        out.push(opt.ok_or_else(|| {
551                            FerrumError::model(format!(
552                                "unified_forward returned None for decode item {i}"
553                            ))
554                        })?);
555                    }
556                    out
557                }
558                Err(FerrumError::Unsupported { .. }) => {
559                    let tuples: Vec<(String, u32, u32)> = prepped
560                        .iter()
561                        .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
562                        .collect();
563                    model.decode_batch(&tuples)
564                }
565                Err(e) => return Err(e),
566            };
567            let model_call = model_t0.map(|t| t.elapsed());
568            (logits, lock_acq, model_call)
569        };
570        let t_model_done = if prof {
571            Some(std::time::Instant::now())
572        } else {
573            None
574        };
575
576        let m_count = prepped.len();
577        let mut outputs = Vec::with_capacity(m_count);
578        for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
579            debug!(
580                "LlmExecutor batch_decode: token={}, pos={}",
581                p.token, p.seq_len
582            );
583            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
584                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
585                .unsqueeze(0)
586                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
587            let logits_ref = common::wrap_tensor(logits_tensor);
588            outputs.push(DecodeOutput::new(logits_ref, p.handle));
589        }
590        if let (Some(t0), Some(tp), Some(tm)) = (t0, t_prep, t_model_done) {
591            static EX_PROF_CALLS: std::sync::atomic::AtomicU64 =
592                std::sync::atomic::AtomicU64::new(0);
593            let n = EX_PROF_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
594            if n.is_multiple_of(8) {
595                let total = t0.elapsed().as_micros();
596                let prep = tp.duration_since(t0).as_micros();
597                let lock_acq = t_lock_acq.map(|d| d.as_micros()).unwrap_or(0);
598                let model_call = t_model_call.map(|d| d.as_micros()).unwrap_or(0);
599                let model_block = tm.duration_since(tp).as_micros();
600                let wrap = tm.elapsed().as_micros();
601                eprintln!(
602                    "[exec-batch-decode-prof] call#{} m={} total={}us prep={}us model_block={}us(lock_acq={}us model_call={}us) wrap={}us",
603                    n, m_count, total, prep, model_block, lock_acq, model_call, wrap,
604                );
605            }
606        }
607        Ok(outputs)
608    }
609
610    /// Unified mixed-batch dispatch (chunked-prefill API).
611    ///
612    /// This impl is a behavior-preserving FALLBACK over the existing
613    /// trait methods on `DecoderOnlyLLM`: prefill items go through
614    /// `model.prefill(seq_id, &q_tokens)` (one at a time, mirroring the
615    /// engine's current sequential prefill loop), decode items
616    /// (`q_len == 1 && is_final_chunk`) are grouped into a single
617    /// `model.decode_batch(...)` call. Net behavior is identical to the
618    /// engine's pre-Phase-13 path; this just changes WHO orchestrates
619    /// the prefill/decode split (caller → unified_decode) so the engine
620    /// can converge on a single call.
621    ///
622    /// The real performance unlock comes in Step 5 when models override
623    /// this with a true unified-forward (one [M_total, hidden] forward
624    /// + varlen attention) — at that point the kernel-level mix replaces
625    /// the host-side serial dispatch here.
626    async fn unified_decode(&self, batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
627        let mut results: Vec<Option<Vec<f32>>> = vec![None; batch.items.len()];
628        if batch.items.is_empty() {
629            return Ok(results);
630        }
631
632        // ── Real unified path (Step 5b+): if the model implements
633        // `DecoderOnlyLLM::unified_forward`, route the entire batch
634        // through one model forward (mixed prefill chunks + decode
635        // tokens in a single [M_total, hidden] pass). The model returns
636        // `Err(unsupported)` if it hasn't been wired yet — fall through
637        // to the behaviour-preserving fallback below.
638        let unified_items: Vec<(String, Vec<u32>, usize, bool)> = batch
639            .items
640            .iter()
641            .map(|it| {
642                (
643                    it.seq_id.clone(),
644                    it.q_tokens.clone(),
645                    it.pos_offset,
646                    it.is_final_chunk,
647                )
648            })
649            .collect();
650        let model_result = {
651            let mut model = self.model.lock();
652            model.unified_forward(&unified_items)
653        };
654        match model_result {
655            Ok(per_item) => {
656                if per_item.len() != batch.items.len() {
657                    return Err(FerrumError::model(format!(
658                        "unified_forward returned {} entries for {} items",
659                        per_item.len(),
660                        batch.items.len(),
661                    )));
662                }
663                return Ok(per_item);
664            }
665            Err(FerrumError::Unsupported { .. }) => {
666                // Fall through to the dispatch fallback below.
667            }
668            Err(e) => return Err(e),
669        }
670
671        // Partition: pure decode items vs prefill chunks.
672        // A "decode" item has q_len == 1 AND is_final_chunk == true.
673        // Anything else (chunked prefill mid-stream OR a single-token
674        // prefill that returns logits) goes through the per-item prefill
675        // path so the model receives the right pos_offset behaviour.
676        let mut prefill_indices: Vec<usize> = Vec::new();
677        let mut decode_indices: Vec<usize> = Vec::new();
678        for (i, item) in batch.items.iter().enumerate() {
679            if item.q_tokens.len() == 1 && item.is_final_chunk {
680                decode_indices.push(i);
681            } else {
682                prefill_indices.push(i);
683            }
684        }
685
686        // Prefill items — sequential, mirrors current engine behaviour.
687        // Held under a single model lock to amortise lock acquire across
688        // all prefills in this batch (we may revisit per-call locking
689        // when chunked-prefill becomes the perf-critical path).
690        if !prefill_indices.is_empty() {
691            let mut model = self.model.lock();
692            for &i in &prefill_indices {
693                let item = &batch.items[i];
694                let logits = model.prefill(&item.seq_id, &item.q_tokens);
695                if item.is_final_chunk {
696                    results[i] = Some(logits);
697                }
698            }
699        }
700
701        // Decode items — single batched dispatch.
702        if !decode_indices.is_empty() {
703            let tuples: Vec<(String, u32, u32)> = decode_indices
704                .iter()
705                .map(|&i| {
706                    let it = &batch.items[i];
707                    (it.seq_id.clone(), it.q_tokens[0], it.pos_offset as u32)
708                })
709                .collect();
710            let logits_vec = {
711                let mut model = self.model.lock();
712                model.decode_batch(&tuples)
713            };
714            for (j, &i) in decode_indices.iter().enumerate() {
715                results[i] = Some(logits_vec[j].clone());
716            }
717        }
718
719        Ok(results)
720    }
721
722    fn release_cache(&self, cache_id: &str) {
723        self.model.lock().release(cache_id);
724    }
725
726    fn capabilities(&self) -> ExecutorCapabilities {
727        let cfg = self.model.lock().config().clone();
728        ExecutorCapabilities {
729            max_batch_size: 256,
730            max_sequence_length: cfg.max_seq_len,
731            attention_mechanisms: vec![AttentionType::GroupedQuery],
732            supports_dynamic_batching: true,
733            supports_continuous_batching: true,
734            supports_speculative_decoding: false,
735            supports_tensor_parallelism: false,
736            supports_pipeline_parallelism: false,
737            supported_dtypes: vec![DataType::FP32],
738            supported_devices: vec![self.info.device.clone()],
739            memory_requirements: MemoryRequirements {
740                parameter_memory: (self.info.num_parameters * 4) as u64,
741                activation_memory_per_token: cfg.hidden_size * 4,
742                kv_cache_memory_per_token: cfg.hidden_size * 2,
743                overhead_memory: 256 * 1024 * 1024,
744            },
745        }
746    }
747
748    fn status(&self) -> ExecutorStatus {
749        common::default_executor_status()
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756
757    #[test]
758    fn llm_executor_runtime_env_parses_profile_flags_by_presence() {
759        let env = LlmExecutorRuntimeEnv::from_env_vars([
760            ("FERRUM_BATCH_PREFILL_PROF", ""),
761            ("FERRUM_BATCH_DECODE_PROF", "0"),
762        ]);
763
764        assert!(env.batch_prefill_prof);
765        assert!(env.batch_decode_prof);
766    }
767
768    #[test]
769    fn llm_executor_runtime_env_defaults_profile_flags_off() {
770        let env = LlmExecutorRuntimeEnv::from_env_vars([("UNRELATED", "1")]);
771
772        assert!(!env.batch_prefill_prof);
773        assert!(!env.batch_decode_prof);
774    }
775}