1use async_trait::async_trait;
4use candle_core::{Device as CandleDevice, Tensor};
5use ferrum_interfaces::{
6 model_executor::{
7 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
8 MemoryRequirements, PrefillInput, PrefillOutput,
9 },
10 KvCacheHandle, ModelExecutor, TensorRef,
11};
12use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
13use parking_lot::Mutex;
14use std::{
15 collections::HashMap,
16 sync::{
17 atomic::{AtomicU64, Ordering},
18 Arc,
19 },
20};
21use tracing::{debug, info};
22
23use crate::architectures::qwen3::Qwen3ModelWrapper;
24use crate::executor::common;
25
26#[derive(Debug, Clone)]
27struct Qwen3CacheState {
28 sequence_length: usize,
29}
30
31pub struct Qwen3ModelExecutor {
41 model: Arc<Qwen3ModelWrapper>,
42 info: ModelInfo,
43 states: Mutex<HashMap<String, Qwen3CacheState>>,
44 next_cache_id: AtomicU64,
45 #[cfg(feature = "cuda")]
47 cuda_runner: Mutex<Option<ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner>>,
48 #[cfg(feature = "cuda")]
50 cuda_runner_init_attempted: std::sync::atomic::AtomicBool,
51}
52
53impl Qwen3ModelExecutor {
54 pub fn new(model: Qwen3ModelWrapper, info: ModelInfo) -> Self {
55 info!("Created Qwen3ModelExecutor for: {}", info.model_id);
56
57 Self {
58 model: Arc::new(model),
59 info,
60 states: Mutex::new(HashMap::new()),
61 next_cache_id: AtomicU64::new(1),
62 #[cfg(feature = "cuda")]
63 cuda_runner: Mutex::new(None),
64 #[cfg(feature = "cuda")]
65 cuda_runner_init_attempted: std::sync::atomic::AtomicBool::new(false),
66 }
67 }
68
69 #[cfg(feature = "cuda")]
73 fn ensure_cuda_runner(&self) -> bool {
74 if std::env::var("FERRUM_DISABLE_CUDA_RUNNER").unwrap_or_default() == "1" {
76 return false;
77 }
78 if self.cuda_runner.lock().is_some() {
79 return true;
80 }
81 if self
82 .cuda_runner_init_attempted
83 .swap(true, Ordering::Relaxed)
84 {
85 return false;
86 }
87 if !matches!(self.model.candle_device(), CandleDevice::Cuda(_)) {
88 return false;
89 }
90 match self.model.create_decode_runner() {
91 Ok(runner) => {
92 info!("CUDA decode runner initialized — decode will bypass candle");
93 *self.cuda_runner.lock() = Some(runner);
94 true
95 }
96 Err(e) => {
97 tracing::warn!("CUDA decode runner init failed, using candle path: {e}");
98 false
99 }
100 }
101 }
102
103 pub fn release_sequence(&self, cache_id: &str) {
106 self.states.lock().remove(cache_id);
107 self.model.release_cache(cache_id);
108 #[cfg(feature = "cuda")]
110 if let Some(ref mut runner) = *self.cuda_runner.lock() {
111 runner.release_kv_cache(cache_id);
112 }
113 tracing::warn!("Released KV cache for sequence: {}", cache_id);
114 }
115
116 #[cfg(feature = "cuda")]
119 fn ensure_runner_kv_cache(&self, cache_id: &str, _seq_len: usize) -> Result<()> {
120 use candle_core::Storage;
121
122 let mut runner_guard = self.cuda_runner.lock();
123 let runner = match runner_guard.as_mut() {
124 Some(r) => r,
125 None => return Ok(()), };
127
128 if runner.has_kv_cache(cache_id) {
130 return Ok(());
131 }
132
133 let kv_data_tensors = self
137 .model
138 .export_kv_cache(cache_id)
139 .or_else(|| {
140 cache_id
142 .rfind("-clone-")
143 .and_then(|pos| self.model.export_kv_cache(&cache_id[..pos]))
144 })
145 .ok_or_else(|| {
146 FerrumError::model(format!("No candle KV cache to export for: {cache_id}"))
147 })?;
148
149 if kv_data_tensors.is_empty() {
150 return Err(FerrumError::model("Empty KV cache export"));
151 }
152 let prefill_len = kv_data_tensors[0].2;
153 let max_len = kv_data_tensors[0].3;
154
155 let mut kv_slices = Vec::new();
158 for (k_tensor, v_tensor, _len, _max) in &kv_data_tensors {
159 let (k_s, _) = k_tensor.storage_and_layout();
160 let (v_s, _) = v_tensor.storage_and_layout();
161 let k_cuda = match &*k_s {
162 Storage::Cuda(cs) => cs
163 .as_cuda_slice::<half::f16>()
164 .map_err(|e| FerrumError::model(format!("KV slice extract: {e}")))?
165 .clone(),
166 _ => return Err(FerrumError::model("KV cache not on CUDA")),
167 };
168 let v_cuda = match &*v_s {
169 Storage::Cuda(cs) => cs
170 .as_cuda_slice::<half::f16>()
171 .map_err(|e| FerrumError::model(format!("KV slice extract: {e}")))?
172 .clone(),
173 _ => return Err(FerrumError::model("KV cache not on CUDA")),
174 };
175 drop(k_s);
176 drop(v_s);
177 kv_slices.push((k_cuda, v_cuda));
178 }
179
180 runner
181 .init_kv_cache(cache_id, kv_slices, prefill_len, max_len)
182 .map_err(|e| FerrumError::model(format!("CUDA runner KV init failed: {e}")))?;
183
184 if !self.states.lock().contains_key(cache_id) {
187 let base_len = cache_id
188 .rfind("-clone-")
189 .and_then(|pos| {
190 self.states
191 .lock()
192 .get(&cache_id[..pos])
193 .map(|s| s.sequence_length)
194 })
195 .unwrap_or(prefill_len);
196 self.states.lock().insert(
197 cache_id.to_string(),
198 Qwen3CacheState {
199 sequence_length: base_len,
200 },
201 );
202 }
203
204 debug!("Migrated KV cache to CUDA runner for sequence: {cache_id}");
205 Ok(())
206 }
207
208 fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
209 common::tensor_to_tokens(tensor)
210 }
211
212 fn tokens_to_tensor(&self, tokens: &[u32]) -> Result<Tensor> {
213 common::tokens_to_tensor(tokens, self.model.candle_device())
214 }
215
216 fn wrap_tensor(&self, tensor: Tensor) -> TensorRef {
217 common::wrap_tensor(tensor)
218 }
219}
220
221#[async_trait]
222impl ModelExecutor for Qwen3ModelExecutor {
223 fn info(&self) -> &ModelInfo {
224 &self.info
225 }
226
227 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
228 debug!(
229 "Qwen3 Prefill: batch={}, seq_len={}",
230 input.batch_size(),
231 input.sequence_length()
232 );
233
234 let tokens = self.tensor_to_tokens(&input.input_ids)?;
235 if tokens.is_empty() {
236 return Err(FerrumError::model("Prefill input is empty"));
237 }
238
239 let cache_id = format!(
240 "qwen3-cache-{}",
241 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
242 );
243
244 let input_tensor = self.tokens_to_tensor(&tokens)?;
245
246 let logits = self
248 .model
249 .forward_prefill(&input_tensor, &cache_id)
250 .map_err(|e| FerrumError::model(format!("Qwen3 prefill failed: {}", e)))?;
251
252 let logits = match logits.dims().len() {
253 2 => logits
254 .unsqueeze(1)
255 .map_err(|e| FerrumError::model(format!("Unsqueeze logits failed: {}", e)))?,
256 3 => logits,
257 dims => {
258 return Err(FerrumError::model(format!(
259 "Unexpected Qwen3 prefill logits rank: {} (shape {:?})",
260 dims,
261 logits.dims()
262 )))
263 }
264 };
265
266 let logits_ref = self.wrap_tensor(logits);
267
268 let cfg = self.model.config();
269 let kv_handle = Arc::new(common::GenericKvCacheHandle::new(
270 cfg.num_hidden_layers,
271 cfg.num_attention_heads,
272 cfg.head_dim,
273 self.model.device().clone(),
274 tokens.len(),
275 cache_id.clone(),
276 ));
277
278 self.states.lock().insert(
279 cache_id,
280 Qwen3CacheState {
281 sequence_length: tokens.len(),
282 },
283 );
284
285 Ok(PrefillOutput::new(logits_ref, kv_handle))
286 }
287
288 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
289 debug!("Qwen3 Decode: batch={}", input.batch_size());
290
291 let input_handle = input
292 .kv_cache
293 .as_any()
294 .downcast_ref::<common::GenericKvCacheHandle>()
295 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type for Qwen3 executor"))?;
296 let req_cache_id = input_handle.request_cache_id().to_string();
297
298 let seq_len = {
299 let mut states = self.states.lock();
300 if let Some(s) = states.get(&req_cache_id) {
301 s.sequence_length
302 } else {
303 let len = input_handle.block_table().sequence_length;
306 states.insert(
307 req_cache_id.clone(),
308 Qwen3CacheState {
309 sequence_length: len,
310 },
311 );
312 len
313 }
314 };
315
316 let tokens = self.tensor_to_tokens(&input.input_ids)?;
317 if tokens.is_empty() {
318 return Err(FerrumError::model("Decode input is empty"));
319 }
320
321 #[cfg(feature = "cuda")]
324 if tokens.len() == 1 && self.ensure_cuda_runner() {
325 let token_id = tokens[0];
326
327 self.ensure_runner_kv_cache(&req_cache_id, seq_len)?;
330
331 let cuda_result = {
332 let mut runner = self.cuda_runner.lock();
333 if let Some(ref mut runner) = *runner {
334 Some(runner.decode_step_graphed(token_id, seq_len, &req_cache_id))
335 } else {
336 None
337 }
338 };
339 if let Some(Ok(logits_slice)) = cuda_result {
340 let cuda_dev = self
342 .model
343 .candle_device()
344 .as_cuda_device()
345 .map_err(|e| FerrumError::model(format!("Not CUDA device: {e}")))?;
346 let storage = candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(
347 logits_slice,
348 cuda_dev.clone(),
349 );
350 let logits_tensor = candle_core::Tensor::from_storage(
351 candle_core::Storage::Cuda(storage),
352 (1, 1, self.info.vocab_size),
353 candle_core::op::BackpropOp::none(),
354 false,
355 );
356
357 if std::env::var("FERRUM_LOG_TOKENS").unwrap_or_default() == "1" || seq_len == 13 {
359 if let Ok(flat) = logits_tensor.flatten_all() {
360 if let Ok(vals) = flat.to_vec1::<half::f16>() {
361 let mut indexed: Vec<(usize, f32)> = vals
362 .iter()
363 .enumerate()
364 .map(|(i, v)| (i, v.to_f32()))
365 .collect();
366 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
367 let top5: Vec<String> = indexed[..5]
368 .iter()
369 .map(|(i, v)| format!("{}:{:.2}", i, v))
370 .collect();
371 tracing::info!("[CUDA] pos={} top5=[{}]", seq_len, top5.join(", "));
372 }
373 }
374 }
375
376 let logits_ref = self.wrap_tensor(logits_tensor);
377
378 let new_seq_len = {
379 let mut states = self.states.lock();
380 if let Some(state) = states.get_mut(&req_cache_id) {
381 state.sequence_length += 1;
382 state.sequence_length
383 } else {
384 seq_len + 1
385 }
386 };
387 let new_handle = Arc::new(input_handle.with_sequence_length(new_seq_len));
388
389 return Ok(DecodeOutput::new(logits_ref, new_handle));
390 } else if let Some(Err(e)) = cuda_result {
391 tracing::debug!("CUDA decode runner failed, falling back to candle: {e}");
393 }
394 }
395
396 let input_tensor = self.tokens_to_tensor(&tokens)?;
398
399 let logits = self
400 .model
401 .forward_decode(&input_tensor, seq_len, &req_cache_id)
402 .map_err(|e| FerrumError::model(format!("Qwen3 decode failed: {}", e)))?;
403
404 if std::env::var("FERRUM_LOG_TOKENS").unwrap_or_default() == "1" || seq_len == 13 {
405 if let Ok(flat) = logits.flatten_all() {
406 if let Ok(vals) = flat.to_vec1::<half::f16>() {
407 let mut indexed: Vec<(usize, f32)> = vals
408 .iter()
409 .enumerate()
410 .map(|(i, v)| (i, v.to_f32()))
411 .collect();
412 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
413 let top5: Vec<String> = indexed[..5]
414 .iter()
415 .map(|(i, v)| format!("{}:{:.2}", i, v))
416 .collect();
417 tracing::info!("[CANDLE] pos={} top5=[{}]", seq_len, top5.join(", "));
418 }
419 }
420 }
421
422 let logits_ref = self.wrap_tensor(logits);
423
424 let new_seq_len = {
425 let mut states = self.states.lock();
426 if let Some(state) = states.get_mut(&req_cache_id) {
427 state.sequence_length += tokens.len();
428 state.sequence_length
429 } else {
430 seq_len + tokens.len()
431 }
432 };
433 let new_handle = Arc::new(input_handle.with_sequence_length(new_seq_len));
434
435 Ok(DecodeOutput::new(logits_ref, new_handle))
436 }
437
438 #[cfg(feature = "cuda")]
439 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
440 use ferrum_cuda_kernels::cuda_decode::BatchDecodeRequest;
441
442 let paged_kv = std::env::var("FERRUM_PAGED_KV").map_or(false, |v| v == "1");
445 if inputs.len() <= 1 || !self.ensure_cuda_runner() || paged_kv {
446 let mut outputs = Vec::with_capacity(inputs.len());
447 for input in inputs {
448 outputs.push(self.decode(input).await?);
449 }
450 return Ok(outputs);
451 }
452
453 let mut requests = Vec::with_capacity(inputs.len());
455 let mut cache_ids = Vec::with_capacity(inputs.len());
456 let mut seq_lens = Vec::with_capacity(inputs.len());
457
458 for input in inputs {
459 let handle = input
460 .kv_cache
461 .as_any()
462 .downcast_ref::<common::GenericKvCacheHandle>()
463 .ok_or_else(|| FerrumError::model("Invalid KV cache handle"))?;
464 let cache_id = handle.request_cache_id().to_string();
465 let seq_len = {
466 let mut states = self.states.lock();
467 if let Some(s) = states.get(&cache_id) {
468 s.sequence_length
469 } else {
470 let len = handle.block_table().sequence_length;
471 states.insert(
472 cache_id.clone(),
473 Qwen3CacheState {
474 sequence_length: len,
475 },
476 );
477 len
478 }
479 };
480 let tokens = self.tensor_to_tokens(&input.input_ids)?;
481 if tokens.len() != 1 {
482 return Err(FerrumError::model(
483 "batch_decode requires single-token inputs",
484 ));
485 }
486
487 self.ensure_runner_kv_cache(&cache_id, seq_len)?;
488 cache_ids.push(cache_id);
489 seq_lens.push(seq_len);
490 requests.push((tokens[0], seq_len));
491 }
492
493 let batch_requests: Vec<BatchDecodeRequest<'_>> = requests
495 .iter()
496 .zip(cache_ids.iter())
497 .map(|((token_id, position), cache_key)| BatchDecodeRequest {
498 token_id: *token_id,
499 position: *position,
500 cache_key: cache_key.as_str(),
501 })
502 .collect();
503
504 let logits_slice = {
506 let mut runner = self.cuda_runner.lock();
507 let runner = runner
508 .as_mut()
509 .ok_or_else(|| FerrumError::model("CUDA runner not initialized"))?;
510 runner
511 .batch_decode_step(&batch_requests)
512 .map_err(|e| FerrumError::model(format!("batch_decode_step: {e}")))?
513 };
514
515 let batch = inputs.len();
517 let vocab = self.info.vocab_size;
518 let cuda_dev = self
519 .model
520 .candle_device()
521 .as_cuda_device()
522 .map_err(|e| FerrumError::model(format!("Not CUDA: {e}")))?;
523
524 let storage =
525 candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(logits_slice, cuda_dev.clone());
526 let logits_tensor = candle_core::Tensor::from_storage(
527 candle_core::Storage::Cuda(storage),
528 (batch, 1, vocab),
529 candle_core::op::BackpropOp::none(),
530 false,
531 );
532
533 let mut outputs = Vec::with_capacity(batch);
534 for (i, input) in inputs.iter().enumerate() {
535 let item_logits = logits_tensor
536 .narrow(0, i, 1)
537 .map_err(|e| FerrumError::model(format!("logits narrow: {e}")))?;
538 let logits_ref = self.wrap_tensor(item_logits);
539
540 let handle = input
541 .kv_cache
542 .as_any()
543 .downcast_ref::<common::GenericKvCacheHandle>()
544 .unwrap();
545 let new_seq_len = {
546 let mut states = self.states.lock();
547 if let Some(state) = states.get_mut(&cache_ids[i]) {
548 state.sequence_length += 1;
549 state.sequence_length
550 } else {
551 seq_lens[i] + 1
552 }
553 };
554 let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
555 outputs.push(DecodeOutput::new(logits_ref, new_handle));
556 }
557
558 Ok(outputs)
559 }
560
561 #[cfg(not(feature = "cuda"))]
562 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
563 let mut outputs = Vec::with_capacity(inputs.len());
564 for input in inputs {
565 outputs.push(self.decode(input).await?);
566 }
567 Ok(outputs)
568 }
569
570 fn capabilities(&self) -> ExecutorCapabilities {
571 ExecutorCapabilities {
572 max_batch_size: 256,
573 max_sequence_length: self.info.max_sequence_length,
574 attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
575 supports_dynamic_batching: true,
576 supports_continuous_batching: true,
577 supports_speculative_decoding: false,
578 supports_tensor_parallelism: false,
579 supports_pipeline_parallelism: false,
580 supported_dtypes: vec![DataType::FP16, DataType::FP32, DataType::BF16],
581 supported_devices: vec![self.info.device.clone()],
582 memory_requirements: MemoryRequirements {
583 parameter_memory: (self.info.num_parameters * 2) as u64,
584 activation_memory_per_token: self.info.hidden_size * 4,
585 kv_cache_memory_per_token: self.info.hidden_size * 2,
586 overhead_memory: 256 * 1024 * 1024,
587 },
588 }
589 }
590
591 fn release_cache(&self, cache_id: &str) {
592 self.release_sequence(cache_id);
593 }
594
595 fn status(&self) -> ExecutorStatus {
596 common::default_executor_status()
597 }
598}
599
600