Skip to main content

ferrum_models/executor/
qwen2_executor.rs

1//! Qwen2 model executor using Candle
2
3use async_trait::async_trait;
4use candle_core::{Device as CandleDevice, Tensor};
5use ferrum_interfaces::{
6    kv_cache::{BlockTable, CacheHandleStats},
7    model_executor::{
8        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
9        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
10    },
11    KvCacheHandle, ModelExecutor, TensorRef,
12};
13use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
14use parking_lot::Mutex;
15use std::{
16    sync::{
17        atomic::{AtomicU64, Ordering},
18        Arc,
19    },
20    time::Instant,
21};
22use tracing::{debug, info};
23
24use crate::{architectures::qwen2::Qwen2ModelWrapper, tensor_wrapper::CandleTensorWrapper};
25
26/// Shared state between prefill and decode phases
27#[derive(Debug, Clone)]
28struct Qwen2CacheState {
29    /// Current sequence length processed by the model
30    sequence_length: usize,
31    /// Request-scoped cache identifier
32    cache_id: String,
33}
34
35/// Candle-based Qwen2 model executor
36pub struct Qwen2ModelExecutor {
37    model: Arc<Qwen2ModelWrapper>,
38    info: ModelInfo,
39    state: Mutex<Option<Qwen2CacheState>>,
40    next_cache_id: AtomicU64,
41}
42
43impl Qwen2ModelExecutor {
44    /// Create new Qwen2 executor
45    pub fn new(model: Qwen2ModelWrapper, info: ModelInfo) -> Self {
46        info!("✅ Created Qwen2ModelExecutor for: {}", info.model_id);
47
48        Self {
49            model: Arc::new(model),
50            info,
51            state: Mutex::new(None),
52            next_cache_id: AtomicU64::new(1),
53        }
54    }
55
56    /// Extract token IDs from tensor reference (supports [batch, seq] tensors)
57    fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
58        if let Ok(tokens) = tensor.to_vec_u32() {
59            if tokens.is_empty() {
60                return Err(FerrumError::model("Input token tensor is empty"));
61            }
62            return Ok(tokens);
63        }
64
65        if let Ok(tokens_f32) = tensor.to_vec_f32() {
66            let tokens: Vec<u32> = tokens_f32.into_iter().map(|x| x as u32).collect();
67            if tokens.is_empty() {
68                return Err(FerrumError::model("Input token tensor is empty"));
69            }
70            return Ok(tokens);
71        }
72
73        Err(FerrumError::model(
74            "Unable to extract token IDs from input tensor",
75        ))
76    }
77
78    /// Create Candle tensor from token IDs on the correct device
79    fn tokens_to_tensor(&self, tokens: &[u32]) -> Result<Tensor> {
80        let base = Tensor::new(tokens, &CandleDevice::Cpu)
81            .map_err(|e| FerrumError::model(format!("Failed to create tensor: {}", e)))?
82            .unsqueeze(0)
83            .map_err(|e| FerrumError::model(format!("Failed to unsqueeze tensor: {}", e)))?
84            .to_dtype(candle_core::DType::I64)
85            .map_err(|e| FerrumError::model(format!("Failed to cast tokens to I64: {}", e)))?;
86
87        match self.model.candle_device() {
88            CandleDevice::Cpu => Ok(base),
89            CandleDevice::Cuda(dev) => base
90                .to_device(&CandleDevice::Cuda(dev.clone()))
91                .map_err(|e| FerrumError::model(format!("Failed to move tensor to CUDA: {}", e))),
92            CandleDevice::Metal(dev) => base
93                .to_device(&CandleDevice::Metal(dev.clone()))
94                .map_err(|e| FerrumError::model(format!("Failed to move tensor to Metal: {}", e))),
95        }
96    }
97
98    /// Convert Candle tensor to TensorRef wrapper
99    fn wrap_tensor(&self, tensor: Tensor) -> TensorRef {
100        Arc::new(CandleTensorWrapper::new(tensor))
101    }
102}
103
104#[async_trait]
105impl ModelExecutor for Qwen2ModelExecutor {
106    fn info(&self) -> &ModelInfo {
107        &self.info
108    }
109
110    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
111        debug!(
112            "Qwen2 Prefill: batch={}, seq_len={}",
113            input.batch_size(),
114            input.sequence_length()
115        );
116
117        // Extract tokens and build tensor
118        let tokens = self.tensor_to_tokens(&input.input_ids)?;
119        if tokens.is_empty() {
120            return Err(FerrumError::model("Prefill input is empty"));
121        }
122
123        // Reset internal KV cache before new request
124        self.model.reset_cache()?;
125
126        let input_tensor = self.tokens_to_tensor(&tokens)?;
127
128        // Run forward pass with offset 0
129        let logits = self
130            .model
131            .forward_prefill(&input_tensor)
132            .map_err(|e| FerrumError::model(format!("Qwen2 prefill failed: {}", e)))?;
133
134        let logits = match logits.dims().len() {
135            2 => logits
136                .unsqueeze(1)
137                .map_err(|e| FerrumError::model(format!("Unsqueeze logits failed: {}", e)))?,
138            3 => logits,
139            dims => {
140                return Err(FerrumError::model(format!(
141                    "Unexpected Qwen2 prefill logits rank: {} (shape {:?})",
142                    dims,
143                    logits.dims()
144                )))
145            }
146        };
147
148        let logits_ref = self.wrap_tensor(logits);
149
150        let cache_id = format!(
151            "qwen2-cache-{}",
152            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
153        );
154
155        // Create KV cache handle representing internal state
156        let kv_handle = Arc::new(Qwen2KvCacheHandle::new(
157            self.model.config(),
158            self.model.device().clone(),
159            tokens.len(),
160            cache_id.clone(),
161        ));
162
163        // Store state for subsequent decode steps
164        *self.state.lock() = Some(Qwen2CacheState {
165            sequence_length: tokens.len(),
166            cache_id,
167        });
168
169        Ok(PrefillOutput::new(logits_ref, kv_handle))
170    }
171
172    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
173        debug!("Qwen2 Decode: batch={}", input.batch_size());
174
175        let mut guard = self.state.lock();
176        let state = guard
177            .as_mut()
178            .ok_or_else(|| FerrumError::model("Decode called before prefill"))?;
179
180        let input_handle = input
181            .kv_cache
182            .as_any()
183            .downcast_ref::<Qwen2KvCacheHandle>()
184            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type for Qwen2 executor"))?;
185        if input_handle.request_cache_id() != state.cache_id {
186            return Err(FerrumError::model(format!(
187                "KV cache handle mismatch: expected {}, got {}",
188                state.cache_id,
189                input_handle.request_cache_id()
190            )));
191        }
192
193        // Extract single token for decode
194        let tokens = self.tensor_to_tokens(&input.input_ids)?;
195        if tokens.is_empty() {
196            return Err(FerrumError::model("Decode input is empty"));
197        }
198
199        let input_tensor = self.tokens_to_tensor(&tokens)?;
200
201        let logits = self
202            .model
203            .forward_decode(&input_tensor, state.sequence_length)
204            .map_err(|e| FerrumError::model(format!("Qwen2 decode failed: {}", e)))?;
205
206        let logits_ref = self.wrap_tensor(logits);
207
208        // Update sequence length and KV handle
209        state.sequence_length += tokens.len();
210        let new_handle = Arc::new(input_handle.with_sequence_length(state.sequence_length));
211
212        Ok(DecodeOutput::new(logits_ref, new_handle))
213    }
214
215    fn capabilities(&self) -> ExecutorCapabilities {
216        ExecutorCapabilities {
217            max_batch_size: 1,
218            max_sequence_length: self.info.max_sequence_length,
219            attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
220            supports_dynamic_batching: false,
221            supports_continuous_batching: false,
222            supports_speculative_decoding: false,
223            supports_tensor_parallelism: false,
224            supports_pipeline_parallelism: false,
225            supported_dtypes: vec![DataType::FP16, DataType::FP32, DataType::BF16],
226            supported_devices: vec![self.info.device.clone()],
227            memory_requirements: MemoryRequirements {
228                parameter_memory: (self.info.num_parameters * 2) as u64,
229                activation_memory_per_token: self.info.hidden_size * 4,
230                kv_cache_memory_per_token: self.info.hidden_size * 2,
231                overhead_memory: 256 * 1024 * 1024, // 256MB placeholder
232            },
233        }
234    }
235
236    fn status(&self) -> ExecutorStatus {
237        ExecutorStatus {
238            state: ExecutorState::Ready,
239            is_ready: true,
240            current_batch_size: 0,
241            prefill_operations: 0,
242            decode_operations: 0,
243            avg_prefill_time_ms: 0.0,
244            avg_decode_time_ms: 0.0,
245            memory_usage: ExecutorMemoryUsage {
246                allocated_bytes: 0,
247                used_bytes: 0,
248                peak_bytes: 0,
249                utilization_percent: 0.0,
250            },
251            last_operation: Some(Instant::now()),
252        }
253    }
254}
255
256/// Lightweight KV cache handle for Qwen2 models (model maintains cache internally)
257#[derive(Debug, Clone)]
258struct Qwen2KvCacheHandle {
259    block_table: BlockTable,
260    num_layers: usize,
261    num_heads: usize,
262    head_dim: usize,
263    device: Device,
264    request_cache_id: String,
265}
266
267impl Qwen2KvCacheHandle {
268    fn new(
269        config: &candle_transformers::models::qwen2::Config,
270        device: CandleDevice,
271        seq_len: usize,
272        request_cache_id: String,
273    ) -> Self {
274        let mut block_table = BlockTable::new(16);
275        block_table.sequence_length = seq_len;
276
277        Self {
278            block_table,
279            num_layers: config.num_hidden_layers,
280            num_heads: config.num_attention_heads,
281            head_dim: config.hidden_size / config.num_attention_heads,
282            request_cache_id,
283            device: match device {
284                CandleDevice::Cpu => Device::CPU,
285                CandleDevice::Cuda(_dev) => Device::CUDA(0),
286                #[cfg(any(target_os = "macos", target_os = "ios"))]
287                CandleDevice::Metal(_) => Device::Metal,
288                #[cfg(not(any(target_os = "macos", target_os = "ios")))]
289                CandleDevice::Metal(_) => Device::CPU,
290            },
291        }
292    }
293
294    fn with_sequence_length(&self, seq_len: usize) -> Self {
295        let mut block_table = self.block_table.clone();
296        block_table.sequence_length = seq_len;
297
298        Self {
299            block_table,
300            num_layers: self.num_layers,
301            num_heads: self.num_heads,
302            head_dim: self.head_dim,
303            device: self.device.clone(),
304            request_cache_id: self.request_cache_id.clone(),
305        }
306    }
307
308    fn request_cache_id(&self) -> &str {
309        &self.request_cache_id
310    }
311}
312
313impl KvCacheHandle for Qwen2KvCacheHandle {
314    fn block_table(&self) -> &BlockTable {
315        &self.block_table
316    }
317
318    fn block_table_mut(&mut self) -> &mut BlockTable {
319        &mut self.block_table
320    }
321
322    fn as_any(&self) -> &dyn std::any::Any {
323        self
324    }
325
326    fn device(&self) -> Device {
327        self.device.clone()
328    }
329
330    fn num_layers(&self) -> usize {
331        self.num_layers
332    }
333
334    fn num_heads(&self) -> usize {
335        self.num_heads
336    }
337
338    fn head_dim(&self) -> usize {
339        self.head_dim
340    }
341
342    fn key_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
343        Ok(None)
344    }
345
346    fn value_cache(&self, _layer: usize) -> Result<Option<TensorRef>> {
347        Ok(None)
348    }
349
350    fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
351        Ok(Arc::new(self.clone()))
352    }
353
354    fn stats(&self) -> CacheHandleStats {
355        CacheHandleStats {
356            memory_bytes: 0,
357            blocks_allocated: self.block_table.num_blocks(),
358            tokens_stored: self.block_table.sequence_length,
359            utilization: 0.0,
360            last_access: Instant::now(),
361        }
362    }
363
364    fn is_valid(&self) -> bool {
365        true
366    }
367
368    fn cache_id(&self) -> String {
369        format!(
370            "{}-{}",
371            self.request_cache_id, self.block_table.sequence_length
372        )
373    }
374}