Skip to main content

ferrum_models/executor/
qwen3_executor.rs

1//! Qwen3 model executor using Candle
2
3use async_trait::async_trait;
4use candle_core::{Device as CandleDevice, Tensor};
5use ferrum_interfaces::{
6    model_executor::{
7        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
8        MemoryRequirements, PrefillInput, PrefillOutput,
9    },
10    KvCacheHandle, ModelExecutor, TensorRef,
11};
12use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
13use parking_lot::Mutex;
14use std::{
15    collections::HashMap,
16    sync::{
17        atomic::{AtomicU64, Ordering},
18        Arc,
19    },
20};
21use tracing::{debug, info};
22
23use crate::architectures::qwen3::Qwen3ModelWrapper;
24use crate::executor::common;
25
26#[derive(Debug, Clone)]
27struct Qwen3CacheState {
28    sequence_length: usize,
29}
30
31/// Candle-based Qwen3 model executor with multi-sequence support.
32///
33/// Each active sequence gets its own KV cache keyed by a unique cache_id.
34/// This allows concurrent prefill and decode across many sequences without
35/// one sequence's prefill destroying another's KV cache.
36///
37/// On CUDA devices, lazily creates a `CudaDecodeRunner` that bypasses candle
38/// for the decode hot path, using cuBLAS + custom kernels with pre-allocated
39/// buffers and optional CUDA Graph acceleration.
40pub struct Qwen3ModelExecutor {
41    model: Arc<Qwen3ModelWrapper>,
42    info: ModelInfo,
43    states: Mutex<HashMap<String, Qwen3CacheState>>,
44    next_cache_id: AtomicU64,
45    /// CUDA decode runner (created lazily on first CUDA decode call).
46    #[cfg(feature = "cuda")]
47    cuda_runner: Mutex<Option<ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner>>,
48    /// Whether CUDA runner initialization has been attempted (avoid retrying on failure).
49    #[cfg(feature = "cuda")]
50    cuda_runner_init_attempted: std::sync::atomic::AtomicBool,
51}
52
53impl Qwen3ModelExecutor {
54    pub fn new(model: Qwen3ModelWrapper, info: ModelInfo) -> Self {
55        info!("Created Qwen3ModelExecutor for: {}", info.model_id);
56
57        Self {
58            model: Arc::new(model),
59            info,
60            states: Mutex::new(HashMap::new()),
61            next_cache_id: AtomicU64::new(1),
62            #[cfg(feature = "cuda")]
63            cuda_runner: Mutex::new(None),
64            #[cfg(feature = "cuda")]
65            cuda_runner_init_attempted: std::sync::atomic::AtomicBool::new(false),
66        }
67    }
68
69    /// Try to initialize the CUDA decode runner (lazy, first call only).
70    /// Returns true if runner is available for use.
71    ///
72    #[cfg(feature = "cuda")]
73    fn ensure_cuda_runner(&self) -> bool {
74        // FERRUM_DISABLE_CUDA_RUNNER=1 → always use candle path
75        if std::env::var("FERRUM_DISABLE_CUDA_RUNNER").unwrap_or_default() == "1" {
76            return false;
77        }
78        if self.cuda_runner.lock().is_some() {
79            return true;
80        }
81        if self
82            .cuda_runner_init_attempted
83            .swap(true, Ordering::Relaxed)
84        {
85            return false;
86        }
87        if !matches!(self.model.candle_device(), CandleDevice::Cuda(_)) {
88            return false;
89        }
90        match self.model.create_decode_runner() {
91            Ok(runner) => {
92                info!("CUDA decode runner initialized — decode will bypass candle");
93                *self.cuda_runner.lock() = Some(runner);
94                true
95            }
96            Err(e) => {
97                tracing::warn!("CUDA decode runner init failed, using candle path: {e}");
98                false
99            }
100        }
101    }
102
103    /// Release a sequence's KV cache, freeing GPU memory.
104    /// Should be called when a request completes.
105    pub fn release_sequence(&self, cache_id: &str) {
106        self.states.lock().remove(cache_id);
107        self.model.release_cache(cache_id);
108        // Also release from CUDA decode runner if active
109        #[cfg(feature = "cuda")]
110        if let Some(ref mut runner) = *self.cuda_runner.lock() {
111            runner.release_kv_cache(cache_id);
112        }
113        tracing::warn!("Released KV cache for sequence: {}", cache_id);
114    }
115
116    /// Ensure the CUDA decode runner has KV cache for a sequence.
117    /// On first call for a sequence, migrates KV data from candle's PreAllocKvCache.
118    #[cfg(feature = "cuda")]
119    fn ensure_runner_kv_cache(&self, cache_id: &str, _seq_len: usize) -> Result<()> {
120        use candle_core::Storage;
121
122        let mut runner_guard = self.cuda_runner.lock();
123        let runner = match runner_guard.as_mut() {
124            Some(r) => r,
125            None => return Ok(()), // No runner, will fall through to candle
126        };
127
128        // Check if runner already has KV cache for this sequence
129        if runner.has_kv_cache(cache_id) {
130            return Ok(());
131        }
132
133        // Export KV data from candle model.
134        // For prefix cache clones (cache_id like "qwen3-cache-2-clone-0"),
135        // try the base ID ("qwen3-cache-2") since candle only has the original.
136        let kv_data_tensors = self
137            .model
138            .export_kv_cache(cache_id)
139            .or_else(|| {
140                // Strip "-clone-N" suffix and try original
141                cache_id
142                    .rfind("-clone-")
143                    .and_then(|pos| self.model.export_kv_cache(&cache_id[..pos]))
144            })
145            .ok_or_else(|| {
146                FerrumError::model(format!("No candle KV cache to export for: {cache_id}"))
147            })?;
148
149        if kv_data_tensors.is_empty() {
150            return Err(FerrumError::model("Empty KV cache export"));
151        }
152        let prefill_len = kv_data_tensors[0].2;
153        let max_len = kv_data_tensors[0].3;
154
155        // Extract CudaSlice from each layer's K/V tensors.
156        // clone() on CudaSlice does a D2D copy — we get independent buffers.
157        let mut kv_slices = Vec::new();
158        for (k_tensor, v_tensor, _len, _max) in &kv_data_tensors {
159            let (k_s, _) = k_tensor.storage_and_layout();
160            let (v_s, _) = v_tensor.storage_and_layout();
161            let k_cuda = match &*k_s {
162                Storage::Cuda(cs) => cs
163                    .as_cuda_slice::<half::f16>()
164                    .map_err(|e| FerrumError::model(format!("KV slice extract: {e}")))?
165                    .clone(),
166                _ => return Err(FerrumError::model("KV cache not on CUDA")),
167            };
168            let v_cuda = match &*v_s {
169                Storage::Cuda(cs) => cs
170                    .as_cuda_slice::<half::f16>()
171                    .map_err(|e| FerrumError::model(format!("KV slice extract: {e}")))?
172                    .clone(),
173                _ => return Err(FerrumError::model("KV cache not on CUDA")),
174            };
175            drop(k_s);
176            drop(v_s);
177            kv_slices.push((k_cuda, v_cuda));
178        }
179
180        runner
181            .init_kv_cache(cache_id, kv_slices, prefill_len, max_len)
182            .map_err(|e| FerrumError::model(format!("CUDA runner KV init failed: {e}")))?;
183
184        // For prefix cache clones: register clone ID in executor states
185        // so decode() can find the sequence_length for this cache_id.
186        if !self.states.lock().contains_key(cache_id) {
187            let base_len = cache_id
188                .rfind("-clone-")
189                .and_then(|pos| {
190                    self.states
191                        .lock()
192                        .get(&cache_id[..pos])
193                        .map(|s| s.sequence_length)
194                })
195                .unwrap_or(prefill_len);
196            self.states.lock().insert(
197                cache_id.to_string(),
198                Qwen3CacheState {
199                    sequence_length: base_len,
200                },
201            );
202        }
203
204        debug!("Migrated KV cache to CUDA runner for sequence: {cache_id}");
205        Ok(())
206    }
207
208    fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
209        common::tensor_to_tokens(tensor)
210    }
211
212    fn tokens_to_tensor(&self, tokens: &[u32]) -> Result<Tensor> {
213        common::tokens_to_tensor(tokens, self.model.candle_device())
214    }
215
216    fn wrap_tensor(&self, tensor: Tensor) -> TensorRef {
217        common::wrap_tensor(tensor)
218    }
219}
220
221#[async_trait]
222impl ModelExecutor for Qwen3ModelExecutor {
223    fn info(&self) -> &ModelInfo {
224        &self.info
225    }
226
227    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
228        debug!(
229            "Qwen3 Prefill: batch={}, seq_len={}",
230            input.batch_size(),
231            input.sequence_length()
232        );
233
234        let tokens = self.tensor_to_tokens(&input.input_ids)?;
235        if tokens.is_empty() {
236            return Err(FerrumError::model("Prefill input is empty"));
237        }
238
239        let cache_id = format!(
240            "qwen3-cache-{}",
241            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
242        );
243
244        let input_tensor = self.tokens_to_tensor(&tokens)?;
245
246        // Each sequence gets its own KV cache slot; no need to clear other sequences.
247        let logits = self
248            .model
249            .forward_prefill(&input_tensor, &cache_id)
250            .map_err(|e| FerrumError::model(format!("Qwen3 prefill failed: {}", e)))?;
251
252        let logits = match logits.dims().len() {
253            2 => logits
254                .unsqueeze(1)
255                .map_err(|e| FerrumError::model(format!("Unsqueeze logits failed: {}", e)))?,
256            3 => logits,
257            dims => {
258                return Err(FerrumError::model(format!(
259                    "Unexpected Qwen3 prefill logits rank: {} (shape {:?})",
260                    dims,
261                    logits.dims()
262                )))
263            }
264        };
265
266        let logits_ref = self.wrap_tensor(logits);
267
268        let cfg = self.model.config();
269        let kv_handle = Arc::new(common::GenericKvCacheHandle::new(
270            cfg.num_hidden_layers,
271            cfg.num_attention_heads,
272            cfg.head_dim,
273            self.model.device().clone(),
274            tokens.len(),
275            cache_id.clone(),
276        ));
277
278        self.states.lock().insert(
279            cache_id,
280            Qwen3CacheState {
281                sequence_length: tokens.len(),
282            },
283        );
284
285        Ok(PrefillOutput::new(logits_ref, kv_handle))
286    }
287
288    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
289        debug!("Qwen3 Decode: batch={}", input.batch_size());
290
291        let input_handle = input
292            .kv_cache
293            .as_any()
294            .downcast_ref::<common::GenericKvCacheHandle>()
295            .ok_or_else(|| FerrumError::model("Invalid KV cache handle type for Qwen3 executor"))?;
296        let req_cache_id = input_handle.request_cache_id().to_string();
297
298        let seq_len = {
299            let mut states = self.states.lock();
300            if let Some(s) = states.get(&req_cache_id) {
301                s.sequence_length
302            } else {
303                // For prefix cache clones: base entry may have been released.
304                // Use KV handle's sequence_length and register clone state.
305                let len = input_handle.block_table().sequence_length;
306                states.insert(
307                    req_cache_id.clone(),
308                    Qwen3CacheState {
309                        sequence_length: len,
310                    },
311                );
312                len
313            }
314        };
315
316        let tokens = self.tensor_to_tokens(&input.input_ids)?;
317        if tokens.is_empty() {
318            return Err(FerrumError::model("Decode input is empty"));
319        }
320
321        // Try CUDA decode runner path (bypasses candle for the hot path).
322        // Falls back to candle if runner fails (e.g., KV cache not yet initialized).
323        #[cfg(feature = "cuda")]
324        if tokens.len() == 1 && self.ensure_cuda_runner() {
325            let token_id = tokens[0];
326
327            // Ensure CUDA runner has KV cache for this sequence
328            // (migrated from candle's PreAllocKvCache on first decode call)
329            self.ensure_runner_kv_cache(&req_cache_id, seq_len)?;
330
331            let cuda_result = {
332                let mut runner = self.cuda_runner.lock();
333                if let Some(ref mut runner) = *runner {
334                    Some(runner.decode_step_graphed(token_id, seq_len, &req_cache_id))
335                } else {
336                    None
337                }
338            };
339            if let Some(Ok(logits_slice)) = cuda_result {
340                // Wrap CudaSlice into candle Tensor (zero-copy, stays on GPU)
341                let cuda_dev = self
342                    .model
343                    .candle_device()
344                    .as_cuda_device()
345                    .map_err(|e| FerrumError::model(format!("Not CUDA device: {e}")))?;
346                let storage = candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(
347                    logits_slice,
348                    cuda_dev.clone(),
349                );
350                let logits_tensor = candle_core::Tensor::from_storage(
351                    candle_core::Storage::Cuda(storage),
352                    (1, 1, self.info.vocab_size),
353                    candle_core::op::BackpropOp::none(),
354                    false,
355                );
356
357                // FERRUM_LOG_TOKENS=1 → log argmax token for every decode step
358                if std::env::var("FERRUM_LOG_TOKENS").unwrap_or_default() == "1" || seq_len == 13 {
359                    if let Ok(flat) = logits_tensor.flatten_all() {
360                        if let Ok(vals) = flat.to_vec1::<half::f16>() {
361                            let mut indexed: Vec<(usize, f32)> = vals
362                                .iter()
363                                .enumerate()
364                                .map(|(i, v)| (i, v.to_f32()))
365                                .collect();
366                            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
367                            let top5: Vec<String> = indexed[..5]
368                                .iter()
369                                .map(|(i, v)| format!("{}:{:.2}", i, v))
370                                .collect();
371                            tracing::info!("[CUDA] pos={} top5=[{}]", seq_len, top5.join(", "));
372                        }
373                    }
374                }
375
376                let logits_ref = self.wrap_tensor(logits_tensor);
377
378                let new_seq_len = {
379                    let mut states = self.states.lock();
380                    if let Some(state) = states.get_mut(&req_cache_id) {
381                        state.sequence_length += 1;
382                        state.sequence_length
383                    } else {
384                        seq_len + 1
385                    }
386                };
387                let new_handle = Arc::new(input_handle.with_sequence_length(new_seq_len));
388
389                return Ok(DecodeOutput::new(logits_ref, new_handle));
390            } else if let Some(Err(e)) = cuda_result {
391                // CUDA runner failed — log and fall through to candle path
392                tracing::debug!("CUDA decode runner failed, falling back to candle: {e}");
393            }
394        }
395
396        // Fallback: standard candle decode path
397        let input_tensor = self.tokens_to_tensor(&tokens)?;
398
399        let logits = self
400            .model
401            .forward_decode(&input_tensor, seq_len, &req_cache_id)
402            .map_err(|e| FerrumError::model(format!("Qwen3 decode failed: {}", e)))?;
403
404        if std::env::var("FERRUM_LOG_TOKENS").unwrap_or_default() == "1" || seq_len == 13 {
405            if let Ok(flat) = logits.flatten_all() {
406                if let Ok(vals) = flat.to_vec1::<half::f16>() {
407                    let mut indexed: Vec<(usize, f32)> = vals
408                        .iter()
409                        .enumerate()
410                        .map(|(i, v)| (i, v.to_f32()))
411                        .collect();
412                    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
413                    let top5: Vec<String> = indexed[..5]
414                        .iter()
415                        .map(|(i, v)| format!("{}:{:.2}", i, v))
416                        .collect();
417                    tracing::info!("[CANDLE] pos={} top5=[{}]", seq_len, top5.join(", "));
418                }
419            }
420        }
421
422        let logits_ref = self.wrap_tensor(logits);
423
424        let new_seq_len = {
425            let mut states = self.states.lock();
426            if let Some(state) = states.get_mut(&req_cache_id) {
427                state.sequence_length += tokens.len();
428                state.sequence_length
429            } else {
430                seq_len + tokens.len()
431            }
432        };
433        let new_handle = Arc::new(input_handle.with_sequence_length(new_seq_len));
434
435        Ok(DecodeOutput::new(logits_ref, new_handle))
436    }
437
438    #[cfg(feature = "cuda")]
439    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
440        use ferrum_cuda_kernels::cuda_decode::BatchDecodeRequest;
441
442        // Fallback to per-request when: batch=1, no CUDA runner, or paged KV
443        // (batch_decode_step doesn't support paged KV yet)
444        let paged_kv = std::env::var("FERRUM_PAGED_KV").map_or(false, |v| v == "1");
445        if inputs.len() <= 1 || !self.ensure_cuda_runner() || paged_kv {
446            let mut outputs = Vec::with_capacity(inputs.len());
447            for input in inputs {
448                outputs.push(self.decode(input).await?);
449            }
450            return Ok(outputs);
451        }
452
453        // Extract cache_ids, positions, and tokens for all inputs
454        let mut requests = Vec::with_capacity(inputs.len());
455        let mut cache_ids = Vec::with_capacity(inputs.len());
456        let mut seq_lens = Vec::with_capacity(inputs.len());
457
458        for input in inputs {
459            let handle = input
460                .kv_cache
461                .as_any()
462                .downcast_ref::<common::GenericKvCacheHandle>()
463                .ok_or_else(|| FerrumError::model("Invalid KV cache handle"))?;
464            let cache_id = handle.request_cache_id().to_string();
465            let seq_len = {
466                let mut states = self.states.lock();
467                if let Some(s) = states.get(&cache_id) {
468                    s.sequence_length
469                } else {
470                    let len = handle.block_table().sequence_length;
471                    states.insert(
472                        cache_id.clone(),
473                        Qwen3CacheState {
474                            sequence_length: len,
475                        },
476                    );
477                    len
478                }
479            };
480            let tokens = self.tensor_to_tokens(&input.input_ids)?;
481            if tokens.len() != 1 {
482                return Err(FerrumError::model(
483                    "batch_decode requires single-token inputs",
484                ));
485            }
486
487            self.ensure_runner_kv_cache(&cache_id, seq_len)?;
488            cache_ids.push(cache_id);
489            seq_lens.push(seq_len);
490            requests.push((tokens[0], seq_len));
491        }
492
493        // Build BatchDecodeRequests
494        let batch_requests: Vec<BatchDecodeRequest<'_>> = requests
495            .iter()
496            .zip(cache_ids.iter())
497            .map(|((token_id, position), cache_key)| BatchDecodeRequest {
498                token_id: *token_id,
499                position: *position,
500                cache_key: cache_key.as_str(),
501            })
502            .collect();
503
504        // Call runner.batch_decode_step
505        let logits_slice = {
506            let mut runner = self.cuda_runner.lock();
507            let runner = runner
508                .as_mut()
509                .ok_or_else(|| FerrumError::model("CUDA runner not initialized"))?;
510            runner
511                .batch_decode_step(&batch_requests)
512                .map_err(|e| FerrumError::model(format!("batch_decode_step: {e}")))?
513        };
514
515        // Split [B * vocab] logits into per-request outputs
516        let batch = inputs.len();
517        let vocab = self.info.vocab_size;
518        let cuda_dev = self
519            .model
520            .candle_device()
521            .as_cuda_device()
522            .map_err(|e| FerrumError::model(format!("Not CUDA: {e}")))?;
523
524        let storage =
525            candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(logits_slice, cuda_dev.clone());
526        let logits_tensor = candle_core::Tensor::from_storage(
527            candle_core::Storage::Cuda(storage),
528            (batch, 1, vocab),
529            candle_core::op::BackpropOp::none(),
530            false,
531        );
532
533        let mut outputs = Vec::with_capacity(batch);
534        for (i, input) in inputs.iter().enumerate() {
535            let item_logits = logits_tensor
536                .narrow(0, i, 1)
537                .map_err(|e| FerrumError::model(format!("logits narrow: {e}")))?;
538            let logits_ref = self.wrap_tensor(item_logits);
539
540            let handle = input
541                .kv_cache
542                .as_any()
543                .downcast_ref::<common::GenericKvCacheHandle>()
544                .unwrap();
545            let new_seq_len = {
546                let mut states = self.states.lock();
547                if let Some(state) = states.get_mut(&cache_ids[i]) {
548                    state.sequence_length += 1;
549                    state.sequence_length
550                } else {
551                    seq_lens[i] + 1
552                }
553            };
554            let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
555            outputs.push(DecodeOutput::new(logits_ref, new_handle));
556        }
557
558        Ok(outputs)
559    }
560
561    #[cfg(not(feature = "cuda"))]
562    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
563        let mut outputs = Vec::with_capacity(inputs.len());
564        for input in inputs {
565            outputs.push(self.decode(input).await?);
566        }
567        Ok(outputs)
568    }
569
570    fn capabilities(&self) -> ExecutorCapabilities {
571        ExecutorCapabilities {
572            max_batch_size: 256,
573            max_sequence_length: self.info.max_sequence_length,
574            attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
575            supports_dynamic_batching: true,
576            supports_continuous_batching: true,
577            supports_speculative_decoding: false,
578            supports_tensor_parallelism: false,
579            supports_pipeline_parallelism: false,
580            supported_dtypes: vec![DataType::FP16, DataType::FP32, DataType::BF16],
581            supported_devices: vec![self.info.device.clone()],
582            memory_requirements: MemoryRequirements {
583                parameter_memory: (self.info.num_parameters * 2) as u64,
584                activation_memory_per_token: self.info.hidden_size * 4,
585                kv_cache_memory_per_token: self.info.hidden_size * 2,
586                overhead_memory: 256 * 1024 * 1024,
587            },
588        }
589    }
590
591    fn release_cache(&self, cache_id: &str) {
592        self.release_sequence(cache_id);
593    }
594
595    fn status(&self) -> ExecutorStatus {
596        common::default_executor_status()
597    }
598}
599
600// Qwen3KvCacheHandle replaced by common::GenericKvCacheHandle