1use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20 model_executor::{
21 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22 MemoryRequirements, PrefillInput, PrefillOutput,
23 },
24 ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29
30use super::common::{self, GenericKvCacheHandle};
31
32fn ferrum_device_to_candle(d: &ferrum_types::Device) -> candle_core::Device {
37 match d {
38 ferrum_types::Device::CPU => candle_core::Device::Cpu,
39 #[cfg(feature = "cuda")]
40 ferrum_types::Device::CUDA(i) => {
41 candle_core::Device::new_cuda(*i as usize).unwrap_or(candle_core::Device::Cpu)
42 }
43 #[cfg(not(feature = "cuda"))]
44 ferrum_types::Device::CUDA(_) => candle_core::Device::Cpu,
45 #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
46 ferrum_types::Device::Metal => {
47 candle_core::Device::new_metal(0).unwrap_or(candle_core::Device::Cpu)
48 }
49 _ => candle_core::Device::Cpu,
50 }
51}
52
53pub struct LlmExecutor {
54 model: Mutex<Box<dyn DecoderOnlyLLM>>,
55 info: ModelInfo,
56 next_cache_id: AtomicU64,
57}
58
59impl LlmExecutor {
60 pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
61 Self {
62 model: Mutex::new(model),
63 info,
64 next_cache_id: AtomicU64::new(0),
65 }
66 }
67
68 fn gen_cache_id(&self) -> String {
69 format!(
70 "llm-cache-{}",
71 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
72 )
73 }
74
75 pub fn truncate_kv_for_cache_id(&self, cache_id: &str, new_len: usize) {
79 let mut model = self.model.lock();
80 model.truncate_kv(cache_id, new_len);
81 }
82}
83
84#[async_trait::async_trait]
85impl ModelExecutor for LlmExecutor {
86 fn info(&self) -> &ModelInfo {
87 &self.info
88 }
89
90 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
91 let tokens = common::tensor_to_tokens(&input.input_ids)?;
92
93 let supplied_handle_id = input.kv_cache.as_ref().and_then(|h| {
98 h.as_any()
99 .downcast_ref::<GenericKvCacheHandle>()
100 .map(|g| g.request_cache_id().to_string())
101 });
102 let cache_id = supplied_handle_id
103 .clone()
104 .unwrap_or_else(|| self.gen_cache_id());
105
106 let logits = {
107 let mut model = self.model.lock();
108 model.prefill(&cache_id, &tokens)
109 };
110
111 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
113 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
114 .unsqueeze(0)
115 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
116 .unsqueeze(0)
117 .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
118 let logits_ref = common::wrap_tensor(logits_tensor);
119
120 let cfg = self.model.lock().config().clone();
121 let seq_len = input
128 .kv_cache
129 .as_ref()
130 .and_then(|h| h.as_any().downcast_ref::<GenericKvCacheHandle>())
131 .map(|g| {
132 use ferrum_interfaces::KvCacheHandle;
133 g.block_table().sequence_length + tokens.len()
134 })
135 .unwrap_or(tokens.len());
136
137 let kv_handle = Arc::new(GenericKvCacheHandle::new(
138 cfg.num_layers,
139 cfg.num_kv_heads,
140 cfg.head_dim,
141 candle_core::Device::Cpu,
142 seq_len,
143 cache_id,
144 ));
145
146 Ok(PrefillOutput::new(logits_ref, kv_handle))
147 }
148
149 async fn truncate_kv(
150 &self,
151 kv_cache: &Arc<dyn ferrum_interfaces::KvCacheHandle>,
152 new_len: usize,
153 ) -> Result<()> {
154 if let Some(g) = kv_cache.as_any().downcast_ref::<GenericKvCacheHandle>() {
155 let cache_id = g.request_cache_id();
156 self.model.lock().truncate_kv(cache_id, new_len);
157 }
158 Ok(())
159 }
160
161 async fn forward_verify(
162 &self,
163 inputs: &[ferrum_interfaces::model_executor::DecodeInput],
164 ) -> Result<Vec<ferrum_interfaces::model_executor::DecodeOutput>> {
165 if inputs.is_empty() {
166 return Ok(Vec::new());
167 }
168
169 let first_handle = inputs[0].kv_cache.clone();
172 let cache_id = first_handle
173 .as_any()
174 .downcast_ref::<GenericKvCacheHandle>()
175 .ok_or_else(|| {
176 FerrumError::model("forward_verify requires GenericKvCacheHandle input")
177 })?
178 .request_cache_id()
179 .to_string();
180 let start_seq = {
181 use ferrum_interfaces::KvCacheHandle;
182 first_handle.block_table().sequence_length
183 };
184
185 let mut token_ids: Vec<u32> = Vec::with_capacity(inputs.len());
187 for input in inputs {
188 let toks = common::tensor_to_tokens(&input.input_ids)?;
189 if toks.is_empty() {
190 return Err(FerrumError::model("forward_verify input token empty"));
191 }
192 token_ids.push(toks[0]);
193 }
194
195 let flat = {
197 let mut model = self.model.lock();
198 model.forward_verify(&cache_id, &token_ids)
199 };
200
201 let cfg = self.model.lock().config().clone();
202 let vocab = cfg.vocab_size;
203
204 let candle_device = ferrum_device_to_candle(&self.info.device);
209
210 let mut outputs = Vec::with_capacity(inputs.len());
215 for (i, _) in inputs.iter().enumerate() {
216 let row = &flat[i * vocab..(i + 1) * vocab];
217 let logits_tensor = candle_core::Tensor::new(row, &candle_core::Device::Cpu)
218 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
219 .unsqueeze(0)
220 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
221 let logits_ref = common::wrap_tensor(logits_tensor);
222 let handle = Arc::new(GenericKvCacheHandle::new(
223 cfg.num_layers,
224 cfg.num_kv_heads,
225 cfg.head_dim,
226 candle_device.clone(),
227 start_seq + i + 1,
228 cache_id.clone(),
229 ));
230 outputs.push(ferrum_interfaces::model_executor::DecodeOutput::new(
231 logits_ref, handle,
232 ));
233 }
234 Ok(outputs)
235 }
236
237 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
238 let input_handle = input
239 .kv_cache
240 .as_any()
241 .downcast_ref::<GenericKvCacheHandle>()
242 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
243
244 let cache_id = input_handle.request_cache_id().to_string();
245 let seq_len = {
246 use ferrum_interfaces::KvCacheHandle;
247 input_handle.block_table().sequence_length
248 };
249
250 let tokens = common::tensor_to_tokens(&input.input_ids)?;
251 if tokens.is_empty() {
252 return Err(FerrumError::model("Decode input is empty"));
253 }
254 let token = tokens[0];
255
256 debug!("LlmExecutor decode: token={token}, pos={seq_len}");
257
258 let logits = {
259 let mut model = self.model.lock();
260 model.decode(&cache_id, token, seq_len as u32)
261 };
262
263 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
264 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
265 .unsqueeze(0)
266 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
267 let logits_ref = common::wrap_tensor(logits_tensor);
268
269 let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
270 Ok(DecodeOutput::new(logits_ref, kv_handle))
271 }
272
273 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
278 if inputs.is_empty() {
279 return Ok(Vec::new());
280 }
281 struct Prep {
284 cache_id: String,
285 token: u32,
286 seq_len: u32,
287 handle: Arc<GenericKvCacheHandle>,
288 }
289 let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
290 for input in inputs {
291 let input_handle = input
292 .kv_cache
293 .as_any()
294 .downcast_ref::<GenericKvCacheHandle>()
295 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
296 use ferrum_interfaces::KvCacheHandle;
297 let seq_len = input_handle.block_table().sequence_length as u32;
298 let tokens = common::tensor_to_tokens(&input.input_ids)?;
299 if tokens.is_empty() {
300 return Err(FerrumError::model("Decode input is empty"));
301 }
302 prepped.push(Prep {
303 cache_id: input_handle.request_cache_id().to_string(),
304 token: tokens[0],
305 seq_len,
306 handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
307 });
308 }
309
310 let all_logits: Vec<Vec<f32>> = {
315 let mut model = self.model.lock();
316 let tuples: Vec<(String, u32, u32)> = prepped
317 .iter()
318 .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
319 .collect();
320 model.decode_batch(&tuples)
321 };
322
323 let mut outputs = Vec::with_capacity(prepped.len());
324 for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
325 debug!(
326 "LlmExecutor batch_decode: token={}, pos={}",
327 p.token, p.seq_len
328 );
329 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
330 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
331 .unsqueeze(0)
332 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
333 let logits_ref = common::wrap_tensor(logits_tensor);
334 outputs.push(DecodeOutput::new(logits_ref, p.handle));
335 }
336 Ok(outputs)
337 }
338
339 fn release_cache(&self, cache_id: &str) {
340 self.model.lock().release(cache_id);
341 }
342
343 fn capabilities(&self) -> ExecutorCapabilities {
344 let cfg = self.model.lock().config().clone();
345 ExecutorCapabilities {
346 max_batch_size: 256,
347 max_sequence_length: cfg.max_seq_len,
348 attention_mechanisms: vec![AttentionType::GroupedQuery],
349 supports_dynamic_batching: true,
350 supports_continuous_batching: true,
351 supports_speculative_decoding: false,
352 supports_tensor_parallelism: false,
353 supports_pipeline_parallelism: false,
354 supported_dtypes: vec![DataType::FP32],
355 supported_devices: vec![self.info.device.clone()],
356 memory_requirements: MemoryRequirements {
357 parameter_memory: (self.info.num_parameters * 4) as u64,
358 activation_memory_per_token: cfg.hidden_size * 4,
359 kv_cache_memory_per_token: cfg.hidden_size * 2,
360 overhead_memory: 256 * 1024 * 1024,
361 },
362 }
363 }
364
365 fn status(&self) -> ExecutorStatus {
366 common::default_executor_status()
367 }
368}