Skip to main content

ferrum_models/executor/
llm_executor.rs

1//! `LlmExecutor<M>` — adapts a `DecoderOnlyLLM` to the `ModelExecutor` trait
2//! the engine scheduler calls.
3//!
4//! This is the Model-as-Code equivalent of `GenericModelExecutor`: where
5//! `GenericModelExecutor` wraps a `Box<dyn RunnerInterface>` (legacy
6//! `ModelRunner<B>`), `LlmExecutor` wraps a `Box<dyn DecoderOnlyLLM>`
7//! (new-style per-model code such as `Qwen3Model<B>`).
8//!
9//! Tokens/logits are currently bridged through candle Tensor for
10//! `TensorRef` — Phase C will likely replace that with `SmallTensor` to
11//! drop candle from the hot path.
12
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20    model_executor::{
21        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22        MemoryRequirements, PrefillInput, PrefillOutput,
23    },
24    ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29
30use super::common::{self, GenericKvCacheHandle};
31
32pub struct LlmExecutor {
33    model: Mutex<Box<dyn DecoderOnlyLLM>>,
34    info: ModelInfo,
35    next_cache_id: AtomicU64,
36}
37
38impl LlmExecutor {
39    pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
40        Self {
41            model: Mutex::new(model),
42            info,
43            next_cache_id: AtomicU64::new(0),
44        }
45    }
46
47    fn gen_cache_id(&self) -> String {
48        format!(
49            "llm-cache-{}",
50            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
51        )
52    }
53}
54
55#[async_trait::async_trait]
56impl ModelExecutor for LlmExecutor {
57    fn info(&self) -> &ModelInfo {
58        &self.info
59    }
60
61    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
62        let tokens = common::tensor_to_tokens(&input.input_ids)?;
63        debug!("LlmExecutor prefill: {} tokens", tokens.len());
64
65        let cache_id = self.gen_cache_id();
66
67        let logits = {
68            let mut model = self.model.lock();
69            model.prefill(&cache_id, &tokens)
70        };
71
72        // Wrap logits as TensorRef: [1, 1, vocab_size]
73        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
74            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
75            .unsqueeze(0)
76            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
77            .unsqueeze(0)
78            .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
79        let logits_ref = common::wrap_tensor(logits_tensor);
80
81        let cfg = self.model.lock().config().clone();
82        // num_kv_heads for KV cache sizing; GenericKvCacheHandle's third arg
83        // is head count which here is the KV-head count.
84        let kv_handle = Arc::new(GenericKvCacheHandle::new(
85            cfg.num_layers,
86            cfg.num_kv_heads,
87            cfg.head_dim,
88            candle_core::Device::Cpu,
89            tokens.len(),
90            cache_id,
91        ));
92
93        Ok(PrefillOutput::new(logits_ref, kv_handle))
94    }
95
96    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
97        let input_handle = input
98            .kv_cache
99            .as_any()
100            .downcast_ref::<GenericKvCacheHandle>()
101            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
102
103        let cache_id = input_handle.request_cache_id().to_string();
104        let seq_len = {
105            use ferrum_interfaces::KvCacheHandle;
106            input_handle.block_table().sequence_length
107        };
108
109        let tokens = common::tensor_to_tokens(&input.input_ids)?;
110        if tokens.is_empty() {
111            return Err(FerrumError::model("Decode input is empty"));
112        }
113        let token = tokens[0];
114
115        debug!("LlmExecutor decode: token={token}, pos={seq_len}");
116
117        let logits = {
118            let mut model = self.model.lock();
119            model.decode(&cache_id, token, seq_len as u32)
120        };
121
122        let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
123            .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
124            .unsqueeze(0)
125            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
126        let logits_ref = common::wrap_tensor(logits_tensor);
127
128        let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
129        Ok(DecodeOutput::new(logits_ref, kv_handle))
130    }
131
132    /// Override default fallback to acquire the model lock ONCE for the whole
133    /// batch, avoiding N round-trips through parking_lot. Does not yet do
134    /// true attention batching (each cache has its own kv_len), but removes
135    /// mutex churn that was serialising concurrent requests at async level.
136    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
137        if inputs.is_empty() {
138            return Ok(Vec::new());
139        }
140        // Pre-extract all per-input metadata OUTSIDE the lock — this is pure
141        // borrow/downcast work that doesn't touch the model.
142        struct Prep {
143            cache_id: String,
144            token: u32,
145            seq_len: u32,
146            handle: Arc<GenericKvCacheHandle>,
147        }
148        let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
149        for input in inputs {
150            let input_handle = input
151                .kv_cache
152                .as_any()
153                .downcast_ref::<GenericKvCacheHandle>()
154                .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
155            use ferrum_interfaces::KvCacheHandle;
156            let seq_len = input_handle.block_table().sequence_length as u32;
157            let tokens = common::tensor_to_tokens(&input.input_ids)?;
158            if tokens.is_empty() {
159                return Err(FerrumError::model("Decode input is empty"));
160            }
161            prepped.push(Prep {
162                cache_id: input_handle.request_cache_id().to_string(),
163                token: tokens[0],
164                seq_len,
165                handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
166            });
167        }
168
169        // One lock for the whole batch, dispatch to model's decode_batch —
170        // which implementations may fuse into a single forward pass (GEMMs
171        // with m=batch, per-item attention) for true concurrency speedup.
172        // Trait default falls back to sequential decode per item.
173        let all_logits: Vec<Vec<f32>> = {
174            let mut model = self.model.lock();
175            let tuples: Vec<(String, u32, u32)> = prepped
176                .iter()
177                .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
178                .collect();
179            model.decode_batch(&tuples)
180        };
181
182        let mut outputs = Vec::with_capacity(prepped.len());
183        for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
184            debug!(
185                "LlmExecutor batch_decode: token={}, pos={}",
186                p.token, p.seq_len
187            );
188            let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
189                .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
190                .unsqueeze(0)
191                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
192            let logits_ref = common::wrap_tensor(logits_tensor);
193            outputs.push(DecodeOutput::new(logits_ref, p.handle));
194        }
195        Ok(outputs)
196    }
197
198    fn release_cache(&self, cache_id: &str) {
199        self.model.lock().release(cache_id);
200    }
201
202    fn capabilities(&self) -> ExecutorCapabilities {
203        let cfg = self.model.lock().config().clone();
204        ExecutorCapabilities {
205            max_batch_size: 256,
206            max_sequence_length: cfg.max_seq_len,
207            attention_mechanisms: vec![AttentionType::GroupedQuery],
208            supports_dynamic_batching: true,
209            supports_continuous_batching: true,
210            supports_speculative_decoding: false,
211            supports_tensor_parallelism: false,
212            supports_pipeline_parallelism: false,
213            supported_dtypes: vec![DataType::FP32],
214            supported_devices: vec![self.info.device.clone()],
215            memory_requirements: MemoryRequirements {
216                parameter_memory: (self.info.num_parameters * 4) as u64,
217                activation_memory_per_token: cfg.hidden_size * 4,
218                kv_cache_memory_per_token: cfg.hidden_size * 2,
219                overhead_memory: 256 * 1024 * 1024,
220            },
221        }
222    }
223
224    fn status(&self) -> ExecutorStatus {
225        common::default_executor_status()
226    }
227}