1use async_trait::async_trait;
4use candle_core::{Device as CandleDevice, Tensor};
5use ferrum_interfaces::{
6 kv_cache::{BlockTable, CacheHandleStats},
7 model_executor::{
8 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
9 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
10 },
11 KvCacheHandle, ModelExecutor, TensorRef,
12};
13use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
14use parking_lot::Mutex;
15use std::{
16 sync::{
17 atomic::{AtomicU64, Ordering},
18 Arc,
19 },
20 time::Instant,
21};
22use tracing::{debug, info};
23
24use crate::{architectures::qwen2::Qwen2ModelWrapper, tensor_wrapper::CandleTensorWrapper};
25
26#[derive(Debug, Clone)]
28struct Qwen2CacheState {
29 sequence_length: usize,
31 cache_id: String,
33}
34
35pub struct Qwen2ModelExecutor {
37 model: Arc<Qwen2ModelWrapper>,
38 info: ModelInfo,
39 state: Mutex<Option<Qwen2CacheState>>,
40 next_cache_id: AtomicU64,
41}
42
43impl Qwen2ModelExecutor {
44 pub fn new(model: Qwen2ModelWrapper, info: ModelInfo) -> Self {
46 info!("✅ Created Qwen2ModelExecutor for: {}", info.model_id);
47
48 Self {
49 model: Arc::new(model),
50 info,
51 state: Mutex::new(None),
52 next_cache_id: AtomicU64::new(1),
53 }
54 }
55
56 fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
58 if let Ok(tokens) = tensor.to_vec_u32() {
59 if tokens.is_empty() {
60 return Err(FerrumError::model("Input token tensor is empty"));
61 }
62 return Ok(tokens);
63 }
64
65 if let Ok(tokens_f32) = tensor.to_vec_f32() {
66 let tokens: Vec<u32> = tokens_f32.into_iter().map(|x| x as u32).collect();
67 if tokens.is_empty() {
68 return Err(FerrumError::model("Input token tensor is empty"));
69 }
70 return Ok(tokens);
71 }
72
73 Err(FerrumError::model(
74 "Unable to extract token IDs from input tensor",
75 ))
76 }
77
78 fn tokens_to_tensor(&self, tokens: &[u32]) -> Result<Tensor> {
80 let base = Tensor::new(tokens, &CandleDevice::Cpu)
81 .map_err(|e| FerrumError::model(format!("Failed to create tensor: {}", e)))?
82 .unsqueeze(0)
83 .map_err(|e| FerrumError::model(format!("Failed to unsqueeze tensor: {}", e)))?
84 .to_dtype(candle_core::DType::I64)
85 .map_err(|e| FerrumError::model(format!("Failed to cast tokens to I64: {}", e)))?;
86
87 match self.model.candle_device() {
88 CandleDevice::Cpu => Ok(base),
89 CandleDevice::Cuda(dev) => base
90 .to_device(&CandleDevice::Cuda(dev.clone()))
91 .map_err(|e| FerrumError::model(format!("Failed to move tensor to CUDA: {}", e))),
92 CandleDevice::Metal(dev) => base
93 .to_device(&CandleDevice::Metal(dev.clone()))
94 .map_err(|e| FerrumError::model(format!("Failed to move tensor to Metal: {}", e))),
95 }
96 }
97
98 fn wrap_tensor(&self, tensor: Tensor) -> TensorRef {
100 Arc::new(CandleTensorWrapper::new(tensor))
101 }
102}
103
104#[async_trait]
105impl ModelExecutor for Qwen2ModelExecutor {
106 fn info(&self) -> &ModelInfo {
107 &self.info
108 }
109
110 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
111 debug!(
112 "Qwen2 Prefill: batch={}, seq_len={}",
113 input.batch_size(),
114 input.sequence_length()
115 );
116
117 let tokens = self.tensor_to_tokens(&input.input_ids)?;
119 if tokens.is_empty() {
120 return Err(FerrumError::model("Prefill input is empty"));
121 }
122
123 self.model.reset_cache()?;
125
126 let input_tensor = self.tokens_to_tensor(&tokens)?;
127
128 let logits = self
130 .model
131 .forward_prefill(&input_tensor)
132 .map_err(|e| FerrumError::model(format!("Qwen2 prefill failed: {}", e)))?;
133
134 let logits = match logits.dims().len() {
135 2 => logits
136 .unsqueeze(1)
137 .map_err(|e| FerrumError::model(format!("Unsqueeze logits failed: {}", e)))?,
138 3 => logits,
139 dims => {
140 return Err(FerrumError::model(format!(
141 "Unexpected Qwen2 prefill logits rank: {} (shape {:?})",
142 dims,
143 logits.dims()
144 )))
145 }
146 };
147
148 let logits_ref = self.wrap_tensor(logits);
149
150 let cache_id = format!(
151 "qwen2-cache-{}",
152 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
153 );
154
155 let kv_handle = Arc::new(Qwen2KvCacheHandle::new(
157 self.model.config(),
158 self.model.device().clone(),
159 tokens.len(),
160 cache_id.clone(),
161 ));
162
163 *self.state.lock() = Some(Qwen2CacheState {
165 sequence_length: tokens.len(),
166 cache_id,
167 });
168
169 Ok(PrefillOutput::new(logits_ref, kv_handle))
170 }
171
172 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
173 debug!("Qwen2 Decode: batch={}", input.batch_size());
174
175 let mut guard = self.state.lock();
176 let state = guard
177 .as_mut()
178 .ok_or_else(|| FerrumError::model("Decode called before prefill"))?;
179
180 let input_handle = input
181 .kv_cache
182 .as_any()
183 .downcast_ref::<Qwen2KvCacheHandle>()
184 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type for Qwen2 executor"))?;
185 if input_handle.request_cache_id() != state.cache_id {
186 return Err(FerrumError::model(format!(
187 "KV cache handle mismatch: expected {}, got {}",
188 state.cache_id,
189 input_handle.request_cache_id()
190 )));
191 }
192
193 let tokens = self.tensor_to_tokens(&input.input_ids)?;
195 if tokens.is_empty() {
196 return Err(FerrumError::model("Decode input is empty"));
197 }
198
199 let input_tensor = self.tokens_to_tensor(&tokens)?;
200
201 let logits = self
202 .model
203 .forward_decode(&input_tensor, state.sequence_length)
204 .map_err(|e| FerrumError::model(format!("Qwen2 decode failed: {}", e)))?;
205
206 let logits_ref = self.wrap_tensor(logits);
207
208 state.sequence_length += tokens.len();
210 let new_handle = Arc::new(input_handle.with_sequence_length(state.sequence_length));
211
212 Ok(DecodeOutput::new(logits_ref, new_handle))
213 }
214
215 fn capabilities(&self) -> ExecutorCapabilities {
216 ExecutorCapabilities {
217 max_batch_size: 1,
218 max_sequence_length: self.info.max_sequence_length,
219 attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
220 supports_dynamic_batching: false,
221 supports_continuous_batching: false,
222 supports_speculative_decoding: false,
223 supports_tensor_parallelism: false,
224 supports_pipeline_parallelism: false,
225 supported_dtypes: vec![DataType::FP16, DataType::FP32, DataType::BF16],
226 supported_devices: vec![self.info.device.clone()],
227 memory_requirements: MemoryRequirements {
228 parameter_memory: (self.info.num_parameters * 2) as u64,
229 activation_memory_per_token: self.info.hidden_size * 4,
230 kv_cache_memory_per_token: self.info.hidden_size * 2,
231 overhead_memory: 256 * 1024 * 1024, },
233 }
234 }
235
236 fn status(&self) -> ExecutorStatus {
237 ExecutorStatus {
238 state: ExecutorState::Ready,
239 is_ready: true,
240 current_batch_size: 0,
241 prefill_operations: 0,
242 decode_operations: 0,
243 avg_prefill_time_ms: 0.0,
244 avg_decode_time_ms: 0.0,
245 memory_usage: ExecutorMemoryUsage {
246 allocated_bytes: 0,
247 used_bytes: 0,
248 peak_bytes: 0,
249 utilization_percent: 0.0,
250 },
251 last_operation: Some(Instant::now()),
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
258struct Qwen2KvCacheHandle {
259 block_table: BlockTable,
260 num_layers: usize,
261 num_heads: usize,
262 head_dim: usize,
263 device: Device,
264 request_cache_id: String,
265}
266
267impl Qwen2KvCacheHandle {
268 fn new(
269 config: &candle_transformers::models::qwen2::Config,
270 device: CandleDevice,
271 seq_len: usize,
272 request_cache_id: String,
273 ) -> Self {
274 let mut block_table = BlockTable::new(16);
275 block_table.sequence_length = seq_len;
276
277 Self {
278 block_table,
279 num_layers: config.num_hidden_layers,
280 num_heads: config.num_attention_heads,
281 head_dim: config.hidden_size / config.num_attention_heads,
282 request_cache_id,
283 device: match device {
284 CandleDevice::Cpu => Device::CPU,
285 CandleDevice::Cuda(_dev) => Device::CUDA(0),
286 #[cfg(any(target_os = "macos", target_os = "ios"))]
287 CandleDevice::Metal(_) => Device::Metal,
288 #[cfg(not(any(target_os = "macos", target_os = "ios")))]
289 CandleDevice::Metal(_) => Device::CPU,
290 },
291 }
292 }
293
294 fn with_sequence_length(&self, seq_len: usize) -> Self {
295 let mut block_table = self.block_table.clone();
296 block_table.sequence_length = seq_len;
297
298 Self {
299 block_table,
300 num_layers: self.num_layers,
301 num_heads: self.num_heads,
302 head_dim: self.head_dim,
303 device: self.device.clone(),
304 request_cache_id: self.request_cache_id.clone(),
305 }
306 }
307
308 fn request_cache_id(&self) -> &str {
309 &self.request_cache_id
310 }
311}
312
313impl KvCacheHandle for Qwen2KvCacheHandle {
314 fn block_table(&self) -> &BlockTable {
315 &self.block_table
316 }
317
318 fn block_table_mut(&mut self) -> &mut BlockTable {
319 &mut self.block_table
320 }
321
322 fn as_any(&self) -> &dyn std::any::Any {
323 self
324 }
325
326 fn device(&self) -> Device {
327 self.device.clone()
328 }
329
330 fn num_layers(&self) -> usize {
331 self.num_layers
332 }
333
334 fn num_heads(&self) -> usize {
335 self.num_heads
336 }
337
338 fn head_dim(&self) -> usize {
339 self.head_dim
340 }
341
342 fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
343 Ok(None)
344 }
345
346 fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
347 Ok(None)
348 }
349
350 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
351 Ok(Arc::new(self.clone()))
352 }
353
354 fn stats(&self) -> CacheHandleStats {
355 CacheHandleStats {
356 memory_bytes: 0,
357 blocks_allocated: self.block_table.num_blocks(),
358 tokens_stored: self.block_table.sequence_length,
359 utilization: 0.0,
360 last_access: Instant::now(),
361 }
362 }
363
364 fn is_valid(&self) -> bool {
365 true
366 }
367
368 fn cache_id(&self) -> String {
369 format!(
370 "{}-{}",
371 self.request_cache_id, self.block_table.sequence_length
372 )
373 }
374}