Skip to main content

ferrum_models/executor/
llm_executor.rs

1//! `LlmExecutor<M>` — adapts a `DecoderOnlyLLM` to the `ModelExecutor` trait
2//! the engine scheduler calls.
3//!
4//! This is the Model-as-Code equivalent of `GenericModelExecutor`: where
5//! `GenericModelExecutor` wraps a `Box<dyn RunnerInterface>` (legacy
6//! `ModelRunner<B>`), `LlmExecutor` wraps a `Box<dyn DecoderOnlyLLM>`
7//! (new-style per-model code such as `Qwen3Model<B>`).
8//!
9//! Tokens/logits are currently bridged through candle Tensor for
10//! `TensorRef` — Phase C will likely replace that with `SmallTensor` to
11//! drop candle from the hot path.
12
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20    model_executor::{
21        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22        MemoryRequirements, PrefillInput, PrefillOutput,
23    },
24    ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29
30use super::common::{self, GenericKvCacheHandle};
31
32/// Map a `ferrum_types::Device` to the matching `candle_core::Device`.
33/// Used when materialising KV cache handles so downstream readers see
34/// the real backend the model runs on (Metal / CUDA / CPU) rather than
35/// a hard-coded CPU placeholder.
36fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
37    match d {
38        ferrum_types::Device::CPU => candle_core::Device::Cpu,
39        #[cfg(feature = "cuda")]
40        ferrum_types::Device::CUDA(i) => {
41            candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
42        }
43        #[cfg(not(feature = "cuda"))]
44        ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
45        #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
46        ferrum_types::Device::Metal => {
47            candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
48        }
49        _ => candle_core::Device::Cpu,
50    }
51}
52
53pub struct LlmExecutor {
54    model: Mutex<Box<dyn DecoderOnlyLLM>>,
55    info: ModelInfo,
56    next_cache_id: AtomicU64,
57}
58
59impl LlmExecutor {
60    pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
61        Self {
62            model: Mutex::new(model),
63            info,
64            next_cache_id: AtomicU64::new(0),
65        }
66    }
67
68    fn gen_cache_id(&self) -> String {
69        format!(
70            "llm-cache-{}",
71            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
72        )
73    }
74
75    /// Roll the KV cache for `cache_id` back to `new_len` positions.
76    /// Used by speculative decoding on partial rejection. The caller must
77    /// supply a `GenericKvCacheHandle` whose seq_len is also updated.
78    pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
79        let mut model = self.model.lock();
80        model.truncate_kv(cache_id, new_len);
81    }
82}
83
84#[async_trait::async_trait]
85impl ModelExecutor for LlmExecutor {
86    fn info(&self) -> &ModelInfo {
87        &self.info
88    }
89
90    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
91        let tokens = common::tensor_to_tokens(&input.input_ids)?;
92
93        // Reuse an existing cache_id when the caller supplies a KV handle
94        // (chunked prefill) — fresh id only on the very first call for a
95        // request. Without this, every chunk would create a new KV cache
96        // at position 0 and subsequent chunks wouldn't see prior tokens.
97        let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
98            h.as_any()
99                .downcast_ref::<GenericKvCacheHandle>()
100                .map(|g| g.request_cache_id().to_string())
101        });
102        let cache_id = supplied_handle_id
103            .clone()
104            .unwrap_or_else(|| self.gen_cache_id());
105
106        let logits = {
107            let mut model = self.model.lock();
108            model.prefill(&cache_id, &tokens)
109        };
110
111        // Wrap logits as TensorRef: [1, 1, vocab_size]
112        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
113            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
114            .unsqueeze(0)
115            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
116            .unsqueeze(0)
117            .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
118        let logits_ref = common::wrap_tensor(logits_tensor);
119
120        let cfg = self.model.lock().config().clone();
121        // Sequence-length tracking across chunks: if the caller supplied a
122        // GenericKvCacheHandle (chunked prefill continuation), add this
123        // chunk's tokens to the prior length. Otherwise this is a fresh
124        // prefill so seq_len == this call's token count. Without this the
125        // handle would claim only the last chunk's length, misleading
126        // decode() into rewriting the KV at an earlier position.
127        let seq_len = input
128            .kv_cache
129            .as_ref()
130            .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
131            .map(|g| {
132                use ferrum_interfaces::KvCacheHandle;
133                g.block_table().sequence_length + tokens.len()
134            })
135            .unwrap_or(tokens.len());
136
137        let kv_handle = Arc::new(GenericKvCacheHandle::new(
138            cfg.num_layers,
139            cfg.num_kv_heads,
140            cfg.head_dim,
141            candle_core::Device::Cpu,
142            seq_len,
143            cache_id,
144        ));
145
146        Ok(PrefillOutput::new(logits_ref, kv_handle))
147    }
148
149    async fn truncate_kv(
150        &self,
151        kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
152        new_len: usize,
153    ) -> Result<()> {
154        if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
155            let cache_id = g.request_cache_id();
156            self.model.lock().truncate_kv(cache_id, new_len);
157        }
158        Ok(())
159    }
160
161    async fn forward_verify(
162        &self,
163        inputs: &[ferrum_interfaces::model_executor::DecodeInput],
164    ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
165        if inputs.is_empty() {
166            return Ok(Vec::new());
167        }
168
169        // All inputs must share the same KV handle (speculative decoding
170        // contract). Extract cache_id + starting seq_len once.
171        let first_handle = inputs[0].kv_cache.clone();
172        let cache_id = first_handle
173            .as_any()
174            .downcast_ref::<GenericKvCacheHandle>()
175            .ok_or_else(|| {
176                FerrumError::model("forward_verify requires GenericKvCacheHandle input")
177            })?
178            .request_cache_id()
179            .to_string();
180        let start_seq = {
181            use ferrum_interfaces::KvCacheHandle;
182            first_handle.block_table().sequence_length
183        };
184
185        // Collect the N+1 token ids.
186        let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
187        for input in inputs {
188            let toks = common::tensor_to_tokens(&input.input_ids)?;
189            if toks.is_empty() {
190                return Err(FerrumError::model("forward_verify input token empty"));
191            }
192            token_ids.push(toks[0]);
193        }
194
195        // One model forward for all N+1 positions → flat seq_len*vocab.
196        let flat = {
197            let mut model = self.model.lock();
198            model.forward_verify(&cache_id, &token_ids)
199        };
200
201        let cfg = self.model.lock().config().clone();
202        let vocab = cfg.vocab_size;
203
204        // Record the actual backend device so downstream code that reads
205        // `KvCacheHandle::device()` sees Metal/CUDA/CPU matching the
206        // model's real location. The logits `Tensor` still wraps CPU data
207        // because `B::to_vec` already moved it off-device.
208        let candle_device = ferrum_device_to_candle(&self.info.device);
209
210        // Split the flat logits into per-position tensors, each wrapped
211        // with a handle whose seq_len reflects the positions written so
212        // far. Matches what the spec runner expects from sequential
213        // decode() calls.
214        let mut outputs = Vec::with_capacity(inputs.len());
215        for (i, _) in inputs.iter().enumerate() {
216            let row = &flat[i * vocab..(i + 1) * vocab];
217            let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
218                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
219                .unsqueeze(0)
220                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
221            let logits_ref = common::wrap_tensor(logits_tensor);
222            let handle = Arc::new(GenericKvCacheHandle::new(
223                cfg.num_layers,
224                cfg.num_kv_heads,
225                cfg.head_dim,
226                candle_device.clone(),
227                start_seq + i + 1,
228                cache_id.clone(),
229            ));
230            outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
231                logits_ref, handle,
232            ));
233        }
234        Ok(outputs)
235    }
236
237    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
238        let input_handle = input
239            .kv_cache
240            .as_any()
241            .downcast_ref::<GenericKvCacheHandle>()
242            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
243
244        let cache_id = input_handle.request_cache_id().to_string();
245        let seq_len = {
246            use ferrum_interfaces::KvCacheHandle;
247            input_handle.block_table().sequence_length
248        };
249
250        let tokens = common::tensor_to_tokens(&input.input_ids)?;
251        if tokens.is_empty() {
252            return Err(FerrumError::model("Decode input is empty"));
253        }
254        let token = tokens[0];
255
256        debug!("LlmExecutor decode: token={token}, pos={seq_len}");
257
258        let logits = {
259            let mut model = self.model.lock();
260            model.decode(&cache_id, token, seq_len as u32)
261        };
262
263        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
264            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
265            .unsqueeze(0)
266            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
267        let logits_ref = common::wrap_tensor(logits_tensor);
268
269        let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
270        Ok(DecodeOutput::new(logits_ref, kv_handle))
271    }
272
273    /// Override default fallback to acquire the model lock ONCE for the whole
274    /// batch, avoiding N round-trips through parking_lot. Does not yet do
275    /// true attention batching (each cache has its own kv_len), but removes
276    /// mutex churn that was serialising concurrent requests at async level.
277    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
278        if inputs.is_empty() {
279            return Ok(Vec::new());
280        }
281        // Pre-extract all per-input metadata OUTSIDE the lock — this is pure
282        // borrow/downcast work that doesn't touch the model.
283        struct Prep {
284            cache_id: String,
285            token: u32,
286            seq_len: u32,
287            handle: Arc<GenericKvCacheHandle>,
288        }
289        let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
290        for input in inputs {
291            let input_handle = input
292                .kv_cache
293                .as_any()
294                .downcast_ref::<GenericKvCacheHandle>()
295                .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
296            use ferrum_interfaces::KvCacheHandle;
297            let seq_len = input_handle.block_table().sequence_length as u32;
298            let tokens = common::tensor_to_tokens(&input.input_ids)?;
299            if tokens.is_empty() {
300                return Err(FerrumError::model("Decode input is empty"));
301            }
302            prepped.push(Prep {
303                cache_id: input_handle.request_cache_id().to_string(),
304                token: tokens[0],
305                seq_len,
306                handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
307            });
308        }
309
310        // One lock for the whole batch, dispatch to model's decode_batch —
311        // which implementations may fuse into a single forward pass (GEMMs
312        // with m=batch, per-item attention) for true concurrency speedup.
313        // Trait default falls back to sequential decode per item.
314        let all_logits: Vec<Vec<f32>> = {
315            let mut model = self.model.lock();
316            let tuples: Vec<(String, u32, u32)> = prepped
317                .iter()
318                .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
319                .collect();
320            model.decode_batch(&tuples)
321        };
322
323        let mut outputs = Vec::with_capacity(prepped.len());
324        for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
325            debug!(
326                "LlmExecutor batch_decode: token={}, pos={}",
327                p.token, p.seq_len
328            );
329            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
330                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
331                .unsqueeze(0)
332                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
333            let logits_ref = common::wrap_tensor(logits_tensor);
334            outputs.push(DecodeOutput::new(logits_ref, p.handle));
335        }
336        Ok(outputs)
337    }
338
339    fn release_cache(&self, cache_id: &str) {
340        self.model.lock().release(cache_id);
341    }
342
343    fn capabilities(&self) -> ExecutorCapabilities {
344        let cfg = self.model.lock().config().clone();
345        ExecutorCapabilities {
346            max_batch_size: 256,
347            max_sequence_length: cfg.max_seq_len,
348            attention_mechanisms: vec![AttentionType::GroupedQuery],
349            supports_dynamic_batching: true,
350            supports_continuous_batching: true,
351            supports_speculative_decoding: false,
352            supports_tensor_parallelism: false,
353            supports_pipeline_parallelism: false,
354            supported_dtypes: vec![DataType::FP32],
355            supported_devices: vec![self.info.device.clone()],
356            memory_requirements: MemoryRequirements {
357                parameter_memory: (self.info.num_parameters * 4) as u64,
358                activation_memory_per_token: cfg.hidden_size * 4,
359                kv_cache_memory_per_token: cfg.hidden_size * 2,
360                overhead_memory: 256 * 1024 * 1024,
361            },
362        }
363    }
364
365    fn status(&self) -> ExecutorStatus {
366        common::default_executor_status()
367    }
368}