Skip to main content

ferrum_models/executor/
llm_executor.rs

1//! `LlmExecutor<M>` — adapts a `DecoderOnlyLLM` to the `ModelExecutor` trait
2//! the engine scheduler calls.
3//!
4//! This is the Model-as-Code equivalent of `GenericModelExecutor`: where
5//! `GenericModelExecutor` wraps a `Box<dyn RunnerInterface>` (legacy
6//! `ModelRunner<B>`), `LlmExecutor` wraps a `Box<dyn DecoderOnlyLLM>`
7//! (new-style per-model code such as `Qwen3Model<B>`).
8//!
9//! Tokens/logits are currently bridged through candle Tensor for
10//! `TensorRef` — Phase C will likely replace that with `SmallTensor` to
11//! drop candle from the hot path.
12
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, OnceLock};
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20    model_executor::{
21        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22        MemoryRequirements, PrefillInput, PrefillOutput, UnifiedBatch,
23    },
24    ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29use crate::lora::ActiveLoraAdapter;
30
31use super::common::{self, GenericKvCacheHandle};
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34struct LlmExecutorRuntimeEnv {
35    batch_prefill_prof: bool,
36    batch_decode_prof: bool,
37}
38
39impl LlmExecutorRuntimeEnv {
40    fn from_env() -> Self {
41        Self::from_env_vars(std::env::vars())
42    }
43
44    fn from_env_vars<I, K, V>(vars: I) -> Self
45    where
46        I: IntoIterator<Item = (K, V)>,
47        K: AsRef<str>,
48    {
49        let mut batch_prefill_prof = false;
50        let mut batch_decode_prof = false;
51
52        for (key, _) in vars {
53            match key.as_ref() {
54                "FERRUM_BATCH_PREFILL_PROF" => batch_prefill_prof = true,
55                "FERRUM_BATCH_DECODE_PROF" => batch_decode_prof = true,
56                _ => {}
57            }
58        }
59
60        Self {
61            batch_prefill_prof,
62            batch_decode_prof,
63        }
64    }
65}
66
67fn llm_executor_runtime_env() -> &'static LlmExecutorRuntimeEnv {
68    static CONFIG: OnceLock<LlmExecutorRuntimeEnv> = OnceLock::new();
69    CONFIG.get_or_init(LlmExecutorRuntimeEnv::from_env)
70}
71
72fn active_lora_from_metadata(
73    metadata: &std::collections::HashMap<String, serde_json::Value>,
74) -> Result<Option<ActiveLoraAdapter>> {
75    let name = metadata
76        .get("ferrum_lora_adapter")
77        .and_then(|value| value.as_str());
78    let path = metadata
79        .get("ferrum_lora_path")
80        .and_then(|value| value.as_str());
81    match (name, path) {
82        (Some(name), Some(path)) => Ok(Some(ActiveLoraAdapter {
83            name: name.to_string(),
84            path: std::path::PathBuf::from(path),
85        })),
86        (None, None) => Ok(None),
87        _ => Err(FerrumError::model(
88            "incomplete LoRA metadata: expected ferrum_lora_adapter and ferrum_lora_path",
89        )),
90    }
91}
92
93fn metadata_requires_full_logits(
94    metadata: &std::collections::HashMap<String, serde_json::Value>,
95) -> bool {
96    metadata
97        .get("ferrum_require_full_logits")
98        .and_then(|value| value.as_bool())
99        .unwrap_or(false)
100}
101
102/// Map a `ferrum_types::Device` to the matching `candle_core::Device`.
103/// Used when materialising KV cache handles so downstream readers see
104/// the real backend the model runs on (Metal / CUDA / CPU) rather than
105/// a hard-coded CPU placeholder.
106fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
107    match d {
108        ferrum_types::Device::CPU => candle_core::Device::Cpu,
109        #[cfg(feature = "cuda")]
110        ferrum_types::Device::CUDA(i) => {
111            candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
112        }
113        #[cfg(not(feature = "cuda"))]
114        ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
115        #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
116        ferrum_types::Device::Metal => {
117            candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
118        }
119        _ => candle_core::Device::Cpu,
120    }
121}
122
123pub struct LlmExecutor {
124    model: Mutex<Box<dyn DecoderOnlyLLM>>,
125    info: ModelInfo,
126    next_cache_id: AtomicU64,
127}
128
129impl LlmExecutor {
130    pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
131        Self {
132            model: Mutex::new(model),
133            info,
134            next_cache_id: AtomicU64::new(0),
135        }
136    }
137
138    fn gen_cache_id(&self) -> String {
139        format!(
140            "llm-cache-{}",
141            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
142        )
143    }
144
145    /// Roll the KV cache for `cache_id` back to `new_len` positions.
146    /// Used by speculative decoding on partial rejection. The caller must
147    /// supply a `GenericKvCacheHandle` whose seq_len is also updated.
148    pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
149        let mut model = self.model.lock();
150        model.truncate_kv(cache_id, new_len);
151    }
152}
153
154#[async_trait::async_trait]
155impl ModelExecutor for LlmExecutor {
156    fn info(&self) -> &ModelInfo {
157        &self.info
158    }
159
160    fn kv_capacity(&self) -> Option<usize> {
161        Some(self.model.lock().kv_capacity())
162    }
163
164    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
165        let tokens = common::tensor_to_tokens(&input.input_ids)?;
166
167        // Reuse an existing cache_id when the caller supplies a KV handle
168        // (chunked prefill) — fresh id only on the very first call for a
169        // request. Without this, every chunk would create a new KV cache
170        // at position 0 and subsequent chunks wouldn't see prior tokens.
171        let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
172            h.as_any()
173                .downcast_ref::<GenericKvCacheHandle>()
174                .map(|g| g.request_cache_id().to_string())
175        });
176        let cache_id = supplied_handle_id
177            .clone()
178            .unwrap_or_else(|| self.gen_cache_id());
179
180        // For chunked-prefill continuation, the prior KV length is the seq
181        // length already in the supplied handle; for fresh prefill it's 0.
182        let prior_seq_len = input
183            .kv_cache
184            .as_ref()
185            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
186            .map(|g| {
187                use ferrum_interfaces::KvCacheHandle;
188                g.block_table().sequence_length
189            })
190            .unwrap_or(0);
191
192        // Try the unified_forward path first: when the model has it wired
193        // (paged KV pools allocated), this routes through the chunked-prefill
194        // varlen kernel — the same code that handles mixed prefill+decode
195        // batches. On Unsupported, fall back to the legacy single-item
196        // prefill path so contig-KV configs keep their existing behaviour.
197        let logits = {
198            let mut model = self.model.lock();
199            model.set_lora_adapter_for_cache(
200                &cache_id,
201                active_lora_from_metadata(&input.metadata)?,
202            )?;
203            let unified_item = vec![(cache_id.clone(), tokens.clone(), prior_seq_len, true)];
204            match model.unified_forward(&unified_item) {
205                Ok(mut per_item) => per_item
206                    .pop()
207                    .flatten()
208                    .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
209                Err(FerrumError::Unsupported { .. }) => model.prefill(&cache_id, &tokens),
210                Err(e) => return Err(e),
211            }
212        };
213
214        // Wrap logits as TensorRef: [1, 1, vocab_size]
215        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
216            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
217            .unsqueeze(0)
218            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
219            .unsqueeze(0)
220            .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
221        let logits_ref = common::wrap_tensor(logits_tensor);
222
223        let cfg = self.model.lock().config().clone();
224        // Sequence-length tracking across chunks: if the caller supplied a
225        // GenericKvCacheHandle (chunked prefill continuation), add this
226        // chunk's tokens to the prior length. Otherwise this is a fresh
227        // prefill so seq_len == this call's token count. Without this the
228        // handle would claim only the last chunk's length, misleading
229        // decode() into rewriting the KV at an earlier position.
230        let seq_len = input
231            .kv_cache
232            .as_ref()
233            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
234            .map(|g| {
235                use ferrum_interfaces::KvCacheHandle;
236                g.block_table().sequence_length + tokens.len()
237            })
238            .unwrap_or(tokens.len());
239
240        let kv_handle = Arc::new(GenericKvCacheHandle::new(
241            cfg.num_layers,
242            cfg.num_kv_heads,
243            cfg.head_dim,
244            candle_core::Device::Cpu,
245            seq_len,
246            cache_id,
247        ));
248
249        Ok(PrefillOutput::new(logits_ref, kv_handle))
250    }
251
252    /// Batched prefill: combine all prompts into ONE `model.unified_forward`
253    /// call so launch / kernel-overhead is amortized across the cohort.
254    ///
255    /// Falls back to the trait default (serial per-item) when the model
256    /// returns `Err(unsupported)` from `unified_forward` — e.g. Qwen3MoeModel
257    /// today, until Phase 2 adds its native unified path.
258    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
259        if inputs.is_empty() {
260            return Ok(Vec::new());
261        }
262
263        // Per-input: derive cache_id (reuse supplied handle's id or generate
264        // fresh) + prior_seq_len. Mirrors the single-prefill path so chunked
265        // prefill continuations route correctly when batched.
266        let mut cache_ids = Vec::with_capacity(inputs.len());
267        let mut prior_seq_lens = Vec::with_capacity(inputs.len());
268        let mut tokens_per_input = Vec::with_capacity(inputs.len());
269        let mut lora_per_input = Vec::with_capacity(inputs.len());
270        for input in inputs {
271            let tokens = common::tensor_to_tokens(&input.input_ids)?;
272            let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
273                h.as_any()
274                    .downcast_ref::<GenericKvCacheHandle>()
275                    .map(|g| g.request_cache_id().to_string())
276            });
277            let cache_id = supplied_handle_id
278                .clone()
279                .unwrap_or_else(|| self.gen_cache_id());
280            let prior_seq_len = input
281                .kv_cache
282                .as_ref()
283                .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
284                .map(|g| {
285                    use ferrum_interfaces::KvCacheHandle;
286                    g.block_table().sequence_length
287                })
288                .unwrap_or(0);
289            cache_ids.push(cache_id);
290            prior_seq_lens.push(prior_seq_len);
291            tokens_per_input.push(tokens);
292            lora_per_input.push(active_lora_from_metadata(&input.metadata)?);
293        }
294
295        // Build unified items and ONE `unified_forward` call. If the model
296        // doesn't support it, fall back to the trait-default serial path.
297        let unified_items: Vec<(String, Vec<u32>, usize, bool)> = cache_ids
298            .iter()
299            .zip(tokens_per_input.iter())
300            .zip(prior_seq_lens.iter())
301            .map(|((cid, toks), &prior)| (cid.clone(), toks.clone(), prior, true))
302            .collect();
303
304        let nb_prof = llm_executor_runtime_env().batch_prefill_prof;
305        let bp_t0 = if nb_prof {
306            Some(std::time::Instant::now())
307        } else {
308            None
309        };
310        let mut took_fallback = false;
311        let per_item_logits: Vec<Vec<f32>> = {
312            let mut model = self.model.lock();
313            for (cache_id, adapter) in cache_ids.iter().zip(lora_per_input.iter()) {
314                model.set_lora_adapter_for_cache(cache_id, adapter.clone())?;
315            }
316            match model.unified_forward(&unified_items) {
317                Ok(per_item) => per_item
318                    .into_iter()
319                    .map(|opt| opt.expect("is_final_chunk=true must yield logits"))
320                    .collect(),
321                Err(FerrumError::Unsupported { .. }) => {
322                    took_fallback = true;
323                    let mut out = Vec::with_capacity(inputs.len());
324                    for (cid, toks) in cache_ids.iter().zip(tokens_per_input.iter()) {
325                        out.push(model.prefill(cid, toks));
326                    }
327                    out
328                }
329                Err(e) => return Err(e),
330            }
331        };
332        if let Some(t0) = bp_t0 {
333            let total_q: usize = unified_items.iter().map(|it| it.1.len()).sum();
334            eprintln!(
335                "[batch-prefill] n_items={} total_q={} fallback={} elapsed={}us",
336                inputs.len(),
337                total_q,
338                took_fallback,
339                t0.elapsed().as_micros()
340            );
341        }
342
343        let cfg = self.model.lock().config().clone();
344        let mut outputs = Vec::with_capacity(inputs.len());
345        for (i, logits) in per_item_logits.into_iter().enumerate() {
346            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
347                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
348                .unsqueeze(0)
349                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
350                .unsqueeze(0)
351                .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
352            let logits_ref = common::wrap_tensor(logits_tensor);
353            let seq_len = inputs[i]
354                .kv_cache
355                .as_ref()
356                .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
357                .map(|g| {
358                    use ferrum_interfaces::KvCacheHandle;
359                    g.block_table().sequence_length + tokens_per_input[i].len()
360                })
361                .unwrap_or(tokens_per_input[i].len());
362            let kv_handle = Arc::new(GenericKvCacheHandle::new(
363                cfg.num_layers,
364                cfg.num_kv_heads,
365                cfg.head_dim,
366                candle_core::Device::Cpu,
367                seq_len,
368                cache_ids[i].clone(),
369            ));
370            outputs.push(PrefillOutput::new(logits_ref, kv_handle));
371        }
372        Ok(outputs)
373    }
374
375    async fn truncate_kv(
376        &self,
377        kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
378        new_len: usize,
379    ) -> Result<()> {
380        if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
381            let cache_id = g.request_cache_id();
382            self.model.lock().truncate_kv(cache_id, new_len);
383        }
384        Ok(())
385    }
386
387    async fn forward_verify(
388        &self,
389        inputs: &[ferrum_interfaces::model_executor::DecodeInput],
390    ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
391        if inputs.is_empty() {
392            return Ok(Vec::new());
393        }
394
395        // All inputs must share the same KV handle (speculative decoding
396        // contract). Extract cache_id + starting seq_len once.
397        let first_handle = inputs[0].kv_cache.clone();
398        let cache_id = first_handle
399            .as_any()
400            .downcast_ref::<GenericKvCacheHandle>()
401            .ok_or_else(|| {
402                FerrumError::model("forward_verify requires GenericKvCacheHandle input")
403            })?
404            .request_cache_id()
405            .to_string();
406        let start_seq = {
407            use ferrum_interfaces::KvCacheHandle;
408            first_handle.block_table().sequence_length
409        };
410
411        // Collect the N+1 token ids.
412        let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
413        for input in inputs {
414            let toks = common::tensor_to_tokens(&input.input_ids)?;
415            if toks.is_empty() {
416                return Err(FerrumError::model("forward_verify input token empty"));
417            }
418            token_ids.push(toks[0]);
419        }
420
421        // One model forward for all N+1 positions → flat seq_len*vocab.
422        let flat = {
423            let mut model = self.model.lock();
424            model.set_lora_adapter_for_cache(
425                &cache_id,
426                active_lora_from_metadata(&inputs[0].metadata)?,
427            )?;
428            model.forward_verify(&cache_id, &token_ids)
429        };
430
431        let cfg = self.model.lock().config().clone();
432        let vocab = cfg.vocab_size;
433
434        // Record the actual backend device so downstream code that reads
435        // `KvCacheHandle::device()` sees Metal/CUDA/CPU matching the
436        // model's real location. The logits `Tensor` still wraps CPU data
437        // because `B::to_vec` already moved it off-device.
438        let candle_device = ferrum_device_to_candle(&self.info.device);
439
440        // Split the flat logits into per-position tensors, each wrapped
441        // with a handle whose seq_len reflects the positions written so
442        // far. Matches what the spec runner expects from sequential
443        // decode() calls.
444        let mut outputs = Vec::with_capacity(inputs.len());
445        for (i, _) in inputs.iter().enumerate() {
446            let row = &flat[i * vocab..(i + 1) * vocab];
447            let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
448                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
449                .unsqueeze(0)
450                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
451            let logits_ref = common::wrap_tensor(logits_tensor);
452            let handle = Arc::new(GenericKvCacheHandle::new(
453                cfg.num_layers,
454                cfg.num_kv_heads,
455                cfg.head_dim,
456                candle_device.clone(),
457                start_seq + i + 1,
458                cache_id.clone(),
459            ));
460            outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
461                logits_ref, handle,
462            ));
463        }
464        Ok(outputs)
465    }
466
467    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
468        let input_handle = input
469            .kv_cache
470            .as_any()
471            .downcast_ref::<GenericKvCacheHandle>()
472            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
473
474        let cache_id = input_handle.request_cache_id().to_string();
475        let seq_len = {
476            use ferrum_interfaces::KvCacheHandle;
477            input_handle.block_table().sequence_length
478        };
479
480        let tokens = common::tensor_to_tokens(&input.input_ids)?;
481        if tokens.is_empty() {
482            return Err(FerrumError::model("Decode input is empty"));
483        }
484        let token = tokens[0];
485
486        debug!("LlmExecutor decode: token={token}, pos={seq_len}");
487
488        // Try unified_forward first so paged-KV configs route the single
489        // decode through the same varlen kernel used by batched mixed
490        // batches. Falls back to legacy paged_decode_attention for contig
491        // configs that haven't wired unified_forward.
492        let logits = {
493            let mut model = self.model.lock();
494            model.set_lora_adapter_for_cache(
495                &cache_id,
496                active_lora_from_metadata(&input.metadata)?,
497            )?;
498            let unified_item = vec![(cache_id.clone(), vec![token], seq_len, true)];
499            match model.unified_forward(&unified_item) {
500                Ok(mut per_item) => per_item
501                    .pop()
502                    .flatten()
503                    .ok_or_else(|| FerrumError::model("unified_forward returned no logits"))?,
504                Err(FerrumError::Unsupported { .. }) => {
505                    model.decode(&cache_id, token, seq_len as u32)
506                }
507                Err(e) => return Err(e),
508            }
509        };
510
511        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
512            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
513            .unsqueeze(0)
514            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
515        let logits_ref = common::wrap_tensor(logits_tensor);
516
517        let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
518        Ok(DecodeOutput::new(logits_ref, kv_handle))
519    }
520
521    /// Override default fallback to acquire the model lock ONCE for the whole
522    /// batch, avoiding N round-trips through parking_lot. Does not yet do
523    /// true attention batching (each cache has its own kv_len), but removes
524    /// mutex churn that was serialising concurrent requests at async level.
525    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
526        if inputs.is_empty() {
527            return Ok(Vec::new());
528        }
529        let prof = llm_executor_runtime_env().batch_decode_prof;
530        let t0 = if prof {
531            Some(std::time::Instant::now())
532        } else {
533            None
534        };
535        // Pre-extract all per-input metadata OUTSIDE the lock — this is pure
536        // borrow/downcast work that doesn't touch the model.
537        struct Prep {
538            cache_id: String,
539            token: u32,
540            seq_len: u32,
541            lora: Option<ActiveLoraAdapter>,
542            requires_full_logits: bool,
543            handle: Arc<GenericKvCacheHandle>,
544        }
545        let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
546        for input in inputs {
547            let input_handle = input
548                .kv_cache
549                .as_any()
550                .downcast_ref::<GenericKvCacheHandle>()
551                .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
552            use ferrum_interfaces::KvCacheHandle;
553            let seq_len = input_handle.block_table().sequence_length as u32;
554            let tokens = common::tensor_to_tokens(&input.input_ids)?;
555            if tokens.is_empty() {
556                return Err(FerrumError::model("Decode input is empty"));
557            }
558            prepped.push(Prep {
559                cache_id: input_handle.request_cache_id().to_string(),
560                token: tokens[0],
561                seq_len,
562                lora: active_lora_from_metadata(&input.metadata)?,
563                requires_full_logits: metadata_requires_full_logits(&input.metadata),
564                handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
565            });
566        }
567        let t_prep = if prof {
568            Some(std::time::Instant::now())
569        } else {
570            None
571        };
572
573        // One lock for the whole batch. Try unified_forward first: paged
574        // configs route through the varlen kernel (single mixed dispatch
575        // for the whole batch); contig configs fall back to model's
576        // legacy decode_batch (separate paged_decode_attention call per
577        // item, batched matmul for QKV/MLP).
578        let (all_logits, t_lock_acq, t_model_call): (Vec<Vec<f32>>, _, _) = {
579            let lock_t0 = if prof {
580                Some(std::time::Instant::now())
581            } else {
582                None
583            };
584            let mut model = self.model.lock();
585            let lock_acq = lock_t0.map(|t| t.elapsed());
586            let model_t0 = if prof {
587                Some(std::time::Instant::now())
588            } else {
589                None
590            };
591            for p in &prepped {
592                model.set_lora_adapter_for_cache(&p.cache_id, p.lora.clone())?;
593            }
594            let unified_items: Vec<(String, Vec<u32>, usize, bool)> = prepped
595                .iter()
596                .map(|p| (p.cache_id.clone(), vec![p.token], p.seq_len as usize, true))
597                .collect();
598            let logits = match model.unified_forward(&unified_items) {
599                Ok(per_item) => {
600                    if per_item.len() != prepped.len() {
601                        return Err(FerrumError::model(format!(
602                            "unified_forward returned {} entries for {} items",
603                            per_item.len(),
604                            prepped.len(),
605                        )));
606                    }
607                    let mut out = Vec::with_capacity(prepped.len());
608                    for (i, opt) in per_item.into_iter().enumerate() {
609                        out.push(opt.ok_or_else(|| {
610                            FerrumError::model(format!(
611                                "unified_forward returned None for decode item {i}"
612                            ))
613                        })?);
614                    }
615                    out
616                }
617                Err(FerrumError::Unsupported { .. }) => {
618                    let tuples: Vec<(String, u32, u32)> = prepped
619                        .iter()
620                        .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
621                        .collect();
622                    let force_full_logits = prepped.iter().any(|p| p.requires_full_logits);
623                    model.decode_batch_with_full_logits(&tuples, force_full_logits)
624                }
625                Err(e) => return Err(e),
626            };
627            let model_call = model_t0.map(|t| t.elapsed());
628            (logits, lock_acq, model_call)
629        };
630        let t_model_done = if prof {
631            Some(std::time::Instant::now())
632        } else {
633            None
634        };
635
636        let m_count = prepped.len();
637        let mut outputs = Vec::with_capacity(m_count);
638        for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
639            debug!(
640                "LlmExecutor batch_decode: token={}, pos={}",
641                p.token, p.seq_len
642            );
643            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
644                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
645                .unsqueeze(0)
646                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
647            let logits_ref = common::wrap_tensor(logits_tensor);
648            outputs.push(DecodeOutput::new(logits_ref, p.handle));
649        }
650        if let (Some(t0), Some(tp), Some(tm)) = (t0, t_prep, t_model_done) {
651            static EX_PROF_CALLS: std::sync::atomic::AtomicU64 =
652                std::sync::atomic::AtomicU64::new(0);
653            let n = EX_PROF_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
654            if n.is_multiple_of(8) {
655                let total = t0.elapsed().as_micros();
656                let prep = tp.duration_since(t0).as_micros();
657                let lock_acq = t_lock_acq.map(|d| d.as_micros()).unwrap_or(0);
658                let model_call = t_model_call.map(|d| d.as_micros()).unwrap_or(0);
659                let model_block = tm.duration_since(tp).as_micros();
660                let wrap = tm.elapsed().as_micros();
661                eprintln!(
662                    "[exec-batch-decode-prof] call#{} m={} total={}us prep={}us model_block={}us(lock_acq={}us model_call={}us) wrap={}us",
663                    n, m_count, total, prep, model_block, lock_acq, model_call, wrap,
664                );
665            }
666        }
667        Ok(outputs)
668    }
669
670    /// Unified mixed-batch dispatch (chunked-prefill API).
671    ///
672    /// This impl is a behavior-preserving FALLBACK over the existing
673    /// trait methods on `DecoderOnlyLLM`: prefill items go through
674    /// `model.prefill(seq_id, &q_tokens)` (one at a time, mirroring the
675    /// engine's current sequential prefill loop), decode items
676    /// (`q_len == 1 && is_final_chunk`) are grouped into a single
677    /// `model.decode_batch(...)` call. Net behavior is identical to the
678    /// engine's pre-Phase-13 path; this just changes WHO orchestrates
679    /// the prefill/decode split (caller → unified_decode) so the engine
680    /// can converge on a single call.
681    ///
682    /// The real performance unlock comes in Step 5 when models override
683    /// this with a true unified-forward (one [M_total, hidden] forward
684    /// + varlen attention) — at that point the kernel-level mix replaces
685    /// the host-side serial dispatch here.
686    async fn unified_decode(&self, batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
687        let mut results: Vec<Option<Vec<f32>>> = vec![None; batch.items.len()];
688        if batch.items.is_empty() {
689            return Ok(results);
690        }
691
692        // ── Real unified path (Step 5b+): if the model implements
693        // `DecoderOnlyLLM::unified_forward`, route the entire batch
694        // through one model forward (mixed prefill chunks + decode
695        // tokens in a single [M_total, hidden] pass). The model returns
696        // `Err(unsupported)` if it hasn't been wired yet — fall through
697        // to the behaviour-preserving fallback below.
698        let unified_items: Vec<(String, Vec<u32>, usize, bool)> = batch
699            .items
700            .iter()
701            .map(|it| {
702                (
703                    it.seq_id.clone(),
704                    it.q_tokens.clone(),
705                    it.pos_offset,
706                    it.is_final_chunk,
707                )
708            })
709            .collect();
710        let model_result = {
711            let mut model = self.model.lock();
712            for item in &batch.items {
713                model.set_lora_adapter_for_cache(
714                    &item.seq_id,
715                    active_lora_from_metadata(&item.metadata)?,
716                )?;
717            }
718            model.unified_forward(&unified_items)
719        };
720        match model_result {
721            Ok(per_item) => {
722                if per_item.len() != batch.items.len() {
723                    return Err(FerrumError::model(format!(
724                        "unified_forward returned {} entries for {} items",
725                        per_item.len(),
726                        batch.items.len(),
727                    )));
728                }
729                return Ok(per_item);
730            }
731            Err(FerrumError::Unsupported { .. }) => {
732                // Fall through to the dispatch fallback below.
733            }
734            Err(e) => return Err(e),
735        }
736
737        // Partition: pure decode items vs prefill chunks.
738        // A "decode" item has q_len == 1 AND is_final_chunk == true.
739        // Anything else (chunked prefill mid-stream OR a single-token
740        // prefill that returns logits) goes through the per-item prefill
741        // path so the model receives the right pos_offset behaviour.
742        let mut prefill_indices: Vec<usize> = Vec::new();
743        let mut decode_indices: Vec<usize> = Vec::new();
744        for (i, item) in batch.items.iter().enumerate() {
745            if item.q_tokens.len() == 1 && item.is_final_chunk {
746                decode_indices.push(i);
747            } else {
748                prefill_indices.push(i);
749            }
750        }
751
752        // Prefill items — sequential, mirrors current engine behaviour.
753        // Held under a single model lock to amortise lock acquire across
754        // all prefills in this batch (we may revisit per-call locking
755        // when chunked-prefill becomes the perf-critical path).
756        if !prefill_indices.is_empty() {
757            let mut model = self.model.lock();
758            for &i in &prefill_indices {
759                let item = &batch.items[i];
760                model.set_lora_adapter_for_cache(
761                    &item.seq_id,
762                    active_lora_from_metadata(&item.metadata)?,
763                )?;
764                let logits = model.prefill(&item.seq_id, &item.q_tokens);
765                if item.is_final_chunk {
766                    results[i] = Some(logits);
767                }
768            }
769        }
770
771        // Decode items — single batched dispatch.
772        if !decode_indices.is_empty() {
773            let tuples: Vec<(String, u32, u32)> = decode_indices
774                .iter()
775                .map(|&i| {
776                    let it = &batch.items[i];
777                    (it.seq_id.clone(), it.q_tokens[0], it.pos_offset as u32)
778                })
779                .collect();
780            let logits_vec = {
781                let mut model = self.model.lock();
782                for &i in &decode_indices {
783                    let item = &batch.items[i];
784                    model.set_lora_adapter_for_cache(
785                        &item.seq_id,
786                        active_lora_from_metadata(&item.metadata)?,
787                    )?;
788                }
789                let force_full_logits = decode_indices
790                    .iter()
791                    .any(|&i| metadata_requires_full_logits(&batch.items[i].metadata));
792                model.decode_batch_with_full_logits(&tuples, force_full_logits)
793            };
794            for (j, &i) in decode_indices.iter().enumerate() {
795                results[i] = Some(logits_vec[j].clone());
796            }
797        }
798
799        Ok(results)
800    }
801
802    fn release_cache(&self, cache_id: &str) {
803        self.model.lock().release(cache_id);
804    }
805
806    fn capabilities(&self) -> ExecutorCapabilities {
807        let cfg = self.model.lock().config().clone();
808        ExecutorCapabilities {
809            max_batch_size: 256,
810            max_sequence_length: cfg.max_seq_len,
811            attention_mechanisms: vec![AttentionType::GroupedQuery],
812            supports_dynamic_batching: true,
813            supports_continuous_batching: true,
814            supports_speculative_decoding: false,
815            supports_tensor_parallelism: false,
816            supports_pipeline_parallelism: false,
817            supported_dtypes: vec![DataType::FP32],
818            supported_devices: vec![self.info.device.clone()],
819            memory_requirements: MemoryRequirements {
820                parameter_memory: (self.info.num_parameters * 4) as u64,
821                activation_memory_per_token: cfg.hidden_size * 4,
822                kv_cache_memory_per_token: cfg.hidden_size * 2,
823                overhead_memory: 256 * 1024 * 1024,
824            },
825        }
826    }
827
828    fn status(&self) -> ExecutorStatus {
829        common::default_executor_status()
830    }
831
832    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
833        self.model.lock().cache_metrics_snapshot()
834    }
835
836    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
837        self.model.lock().lora_metrics_snapshot()
838    }
839}
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844
845    #[test]
846    fn llm_executor_runtime_env_parses_profile_flags_by_presence() {
847        let env = LlmExecutorRuntimeEnv::from_env_vars([
848            ("FERRUM_BATCH_PREFILL_PROF", ""),
849            ("FERRUM_BATCH_DECODE_PROF", "0"),
850        ]);
851
852        assert!(env.batch_prefill_prof);
853        assert!(env.batch_decode_prof);
854    }
855
856    #[test]
857    fn llm_executor_runtime_env_defaults_profile_flags_off() {
858        let env = LlmExecutorRuntimeEnv::from_env_vars([("UNRELATED", "1")]);
859
860        assert!(!env.batch_prefill_prof);
861        assert!(!env.batch_decode_prof);
862    }
863}