Skip to main content

ferrum_models/executor/
candle_executor.rs

1//! Llama model executor using our custom Llama implementation.
2//!
3//! Uses GenericKvCacheHandle (like Qwen3) with per-request cache_id.
4//! Supports CUDA decode runner for GPU acceleration.
5
6use async_trait::async_trait;
7use candle_core::Tensor;
8use ferrum_interfaces::{
9    model_executor::{
10        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
11        MemoryRequirements, PrefillInput, PrefillOutput,
12    },
13    KvCacheHandle, ModelExecutor, TensorRef,
14};
15use ferrum_types::{DataType, Device, FerrumError, ModelInfo, Result};
16use std::collections::HashMap;
17use std::sync::{
18    atomic::{AtomicU64, Ordering},
19    Arc,
20};
21use tracing::{debug, info};
22
23use super::common::{self, GenericKvCacheHandle};
24use crate::architectures::llama::LlamaModelWrapper;
25use crate::tensor_wrapper::CandleTensorWrapper;
26use parking_lot::Mutex;
27
28struct LlamaCacheState {
29    sequence_length: usize,
30}
31
32/// Llama model executor
33pub struct CandleModelExecutor {
34    model: Arc<LlamaModelWrapper>,
35    info: ModelInfo,
36    states: Mutex<HashMap<String, LlamaCacheState>>,
37    next_cache_id: AtomicU64,
38    #[cfg(feature = "cuda")]
39    cuda_runner: Mutex<Option<ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner>>,
40    #[cfg(feature = "cuda")]
41    tp_group: Mutex<Option<ferrum_cuda_kernels::tp_decode::TpDecodeGroup>>,
42}
43
44impl CandleModelExecutor {
45    pub fn new(model: LlamaModelWrapper, info: ModelInfo) -> Self {
46        info!("Created CandleModelExecutor (Llama) for: {}", info.model_id);
47        Self {
48            model: Arc::new(model),
49            info,
50            states: Mutex::new(HashMap::new()),
51            next_cache_id: AtomicU64::new(1),
52            #[cfg(feature = "cuda")]
53            cuda_runner: Mutex::new(None),
54            #[cfg(feature = "cuda")]
55            tp_group: Mutex::new(None),
56        }
57    }
58
59    fn tokens_to_tensor(&self, token_ids: &[u32]) -> Result<Tensor> {
60        Tensor::new(token_ids, self.model.device())
61            .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
62            .unsqueeze(0)
63            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))
64    }
65
66    fn wrap_tensor(&self, tensor: Tensor) -> TensorRef {
67        Arc::new(CandleTensorWrapper::new(tensor))
68    }
69
70    fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
71        common::tensor_to_tokens(tensor)
72    }
73
74    /// Get TP size: FERRUM_TP env overrides, otherwise auto-detect GPU count.
75    /// FERRUM_TP=0 or FERRUM_TP=1 explicitly disables TP.
76    #[cfg(feature = "cuda")]
77    fn tp_size() -> usize {
78        if let Ok(v) = std::env::var("FERRUM_TP") {
79            if let Ok(n) = v.parse::<usize>() {
80                return n;
81            }
82        }
83        // Auto-detect: use all available GPUs
84        candle_core::cuda_backend::cudarc::driver::CudaContext::device_count()
85            .map(|n| n as usize)
86            .unwrap_or(1)
87    }
88
89    #[cfg(not(feature = "cuda"))]
90    fn tp_size() -> usize {
91        0
92    }
93
94    /// Initialize TP decode group if FERRUM_TP > 1.
95    #[cfg(feature = "cuda")]
96    fn ensure_tp_group(&self) -> bool {
97        if self.tp_group.lock().is_some() {
98            return true;
99        }
100        let tp = Self::tp_size();
101        if tp <= 1 {
102            return false;
103        }
104
105        info!("Initializing tensor parallel group: tp_size={tp}");
106
107        let model_dir = match self.model.model_dir.as_ref() {
108            Some(d) => d.clone(),
109            None => {
110                tracing::warn!("TP requires model_dir");
111                return false;
112            }
113        };
114
115        // Load sharded weights for each rank
116        let loader = crate::loader::SafeTensorsLoader::new(&model_dir);
117        let vb = match loader.load_varbuilder(self.model.device(), self.model.dtype()) {
118            Ok(v) => v,
119            Err(e) => {
120                tracing::warn!("TP weight load failed: {e}");
121                return false;
122            }
123        };
124
125        let cfg = self.model.config();
126        let tp_cfg = crate::loader::tp_weight_loader::TpWeightConfig {
127            num_hidden_layers: cfg.num_hidden_layers,
128            hidden_size: cfg.hidden_size,
129            intermediate_size: cfg.intermediate_size,
130            num_attention_heads: cfg.num_attention_heads,
131            num_kv_heads: cfg.num_key_value_heads,
132            head_dim: cfg.head_dim,
133            vocab_size: cfg.vocab_size,
134            max_seq_len: cfg.max_position_embeddings,
135            rope_theta: cfg.rope_theta as f64,
136            // Auto-detect Q/K norm by probing safetensors
137            has_qk_norm: crate::loader::SafeTensorsLoader::new(&model_dir)
138                .load_varbuilder(self.model.device(), self.model.dtype())
139                .map(|vb| {
140                    vb.get(cfg.head_dim, "model.layers.0.self_attn.q_norm.weight")
141                        .is_ok()
142                })
143                .unwrap_or(false),
144            tp_size: tp,
145            rank: 0,
146        };
147        info!(
148            "TP config: has_qk_norm={}, head_dim={}, nq={}, nkv={}",
149            tp_cfg.has_qk_norm, tp_cfg.head_dim, tp_cfg.num_attention_heads, tp_cfg.num_kv_heads
150        );
151
152        // Each rank loads weights in its own thread (correct CUDA context).
153        // After loading, we sync replicated weights from rank 0 to other ranks
154        // to fix candle VarBuilder producing different BF16→F16 conversions
155        // when loaded independently on different GPUs.
156        type RankResult = candle_core::Result<(
157            ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner,
158            std::sync::Arc<candle_core::cuda_backend::cudarc::driver::CudaStream>,
159        )>;
160
161        let mut handles: Vec<std::thread::JoinHandle<RankResult>> = Vec::with_capacity(tp);
162        let dtype = self.model.dtype();
163        for rank in 0..tp {
164            let mut rank_cfg = tp_cfg.clone();
165            rank_cfg.rank = rank;
166            let model_dir = model_dir.clone();
167
168            handles.push(std::thread::spawn(move || {
169                let device = candle_core::Device::new_cuda(rank)?;
170                let loader = crate::loader::SafeTensorsLoader::new(&model_dir);
171                let vb = loader
172                    .load_varbuilder(&device, dtype)
173                    .map_err(|e| candle_core::Error::Msg(format!("VB rank {rank}: {e}")))?;
174
175                let (weights, dims, stream) =
176                    crate::loader::tp_weight_loader::load_sharded_weights(&vb, &rank_cfg, &device)
177                        .map_err(|e| candle_core::Error::Msg(format!("shard {rank}: {e}")))?;
178
179                let cuda_dev = device.as_cuda_device()?.clone();
180
181                let runner = ferrum_cuda_kernels::cuda_decode::CudaDecodeRunner::new(
182                    weights,
183                    dims,
184                    cuda_dev,
185                    stream.clone(),
186                )?;
187                Ok((runner, stream))
188            }));
189        }
190
191        let mut runners = Vec::with_capacity(tp);
192        let mut nccl_streams = Vec::with_capacity(tp);
193        for (rank, handle) in handles.into_iter().enumerate() {
194            match handle.join() {
195                Ok(Ok((runner, stream))) => {
196                    runners.push(runner);
197                    nccl_streams.push(stream);
198                }
199                Ok(Err(e)) => {
200                    tracing::warn!("TP rank {rank} failed: {e}");
201                    return false;
202                }
203                Err(_) => {
204                    tracing::warn!("TP rank {rank} panicked");
205                    return false;
206                }
207            }
208        }
209
210        // Sync replicated weights: copy from rank 0 to all other ranks.
211        // Fixes candle VarBuilder BF16→F16 divergence across GPUs.
212        for rank in 1..tp {
213            let (src, dst) = if rank == 1 {
214                let (first, rest) = runners.split_at_mut(1);
215                (&first[0], &mut rest[0])
216            } else {
217                let (first, rest) = runners.split_at_mut(1);
218                (&first[0], &mut rest[rank - 1])
219            };
220            if let Err(e) = dst.sync_replicated_weights_from(src) {
221                tracing::warn!("TP weight sync rank 0→{rank} failed: {e}");
222                return false;
223            }
224            info!("Synced replicated weights: rank 0 → rank {rank}");
225        }
226
227        // Retain primary context for all GPUs on main thread.
228        // Init threads retained their own, but those may not carry over.
229        let _main_thread_devices: Vec<_> = (0..tp)
230            .filter_map(|r| candle_core::Device::new_cuda(r).ok())
231            .collect();
232
233        // Init NCCL using ncclCommInitAll (single thread, no deadlock)
234        let nccl_ranks = match ferrum_cuda_kernels::nccl_comm::NcclRank::init_all(nccl_streams) {
235            Ok(ranks) => ranks,
236            Err(e) => {
237                tracing::warn!("NCCL init_all failed: {e}");
238                return false;
239            }
240        };
241
242        match ferrum_cuda_kernels::tp_decode::TpDecodeGroup::new(runners, nccl_ranks) {
243            Ok(group) => {
244                info!("Tensor parallel group initialized: {tp} GPUs");
245                *self.tp_group.lock() = Some(group);
246                true
247            }
248            Err(e) => {
249                tracing::warn!("TpDecodeGroup init: {e}");
250                false
251            }
252        }
253    }
254
255    #[cfg(feature = "cuda")]
256    fn ensure_cuda_runner(&self) -> bool {
257        if self.cuda_runner.lock().is_some() {
258            return true;
259        }
260        if std::env::var("FERRUM_DISABLE_CUDA_RUNNER").map_or(false, |v| v == "1") {
261            return false;
262        }
263        match self.model.create_decode_runner() {
264            Ok(runner) => {
265                info!("CUDA decode runner initialized for Llama");
266                *self.cuda_runner.lock() = Some(runner);
267                true
268            }
269            Err(e) => {
270                tracing::warn!("CUDA runner init failed for Llama: {e}");
271                false
272            }
273        }
274    }
275
276    #[cfg(feature = "cuda")]
277    fn ensure_runner_kv_cache(&self, cache_id: &str, _seq_len: usize) -> Result<()> {
278        use candle_core::Storage;
279
280        let mut runner_guard = self.cuda_runner.lock();
281        let runner = match runner_guard.as_mut() {
282            Some(r) => r,
283            None => return Ok(()),
284        };
285        if runner.has_kv_cache(cache_id) {
286            return Ok(());
287        }
288
289        // Export from our custom model's per-request KV cache
290        let kv_data = self
291            .model
292            .export_kv_cache(cache_id)
293            .or_else(|| {
294                cache_id
295                    .rfind("-clone-")
296                    .and_then(|pos| self.model.export_kv_cache(&cache_id[..pos]))
297            })
298            .ok_or_else(|| FerrumError::model(format!("No KV cache to export for: {cache_id}")))?;
299
300        if kv_data.is_empty() {
301            return Err(FerrumError::model("Empty KV cache export"));
302        }
303        let prefill_len = kv_data[0].2;
304        let max_len = kv_data[0].3;
305
306        let mut kv_slices = Vec::new();
307        for (k_tensor, v_tensor, _len, _max) in &kv_data {
308            let (k_s, _) = k_tensor.storage_and_layout();
309            let (v_s, _) = v_tensor.storage_and_layout();
310            let k_cuda = match &*k_s {
311                Storage::Cuda(cs) => cs
312                    .as_cuda_slice::<half::f16>()
313                    .map_err(|e| FerrumError::model(format!("KV extract: {e}")))?
314                    .clone(),
315                _ => return Err(FerrumError::model("KV not on CUDA")),
316            };
317            let v_cuda = match &*v_s {
318                Storage::Cuda(cs) => cs
319                    .as_cuda_slice::<half::f16>()
320                    .map_err(|e| FerrumError::model(format!("KV extract: {e}")))?
321                    .clone(),
322                _ => return Err(FerrumError::model("KV not on CUDA")),
323            };
324            drop(k_s);
325            drop(v_s);
326            kv_slices.push((k_cuda, v_cuda));
327        }
328
329        runner
330            .init_kv_cache(cache_id, kv_slices, prefill_len, max_len)
331            .map_err(|e| FerrumError::model(format!("KV init: {e}")))?;
332
333        if !self.states.lock().contains_key(cache_id) {
334            self.states.lock().insert(
335                cache_id.to_string(),
336                LlamaCacheState {
337                    sequence_length: prefill_len,
338                },
339            );
340        }
341
342        debug!("Migrated Llama KV to CUDA runner: {cache_id}");
343        Ok(())
344    }
345
346    pub fn release_sequence(&self, cache_id: &str) {
347        self.states.lock().remove(cache_id);
348        self.model.release_cache(cache_id);
349        #[cfg(feature = "cuda")]
350        if let Some(ref mut runner) = *self.cuda_runner.lock() {
351            runner.release_kv_cache(cache_id);
352        }
353    }
354}
355
356#[async_trait]
357impl ModelExecutor for CandleModelExecutor {
358    fn info(&self) -> &ModelInfo {
359        &self.info
360    }
361
362    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
363        debug!("Llama Prefill: seq_len={}", input.sequence_length());
364
365        let tokens = self.tensor_to_tokens(&input.input_ids)?;
366        if tokens.is_empty() {
367            return Err(FerrumError::model("Empty input"));
368        }
369
370        let cache_id = format!(
371            "llama-cache-{}",
372            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
373        );
374
375        let input_tensor = self.tokens_to_tensor(&tokens)?;
376        let logits = self.model.forward_prefill(&input_tensor, &cache_id)?;
377
378        let logits = match logits.dims().len() {
379            2 => logits
380                .unsqueeze(1)
381                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?,
382            3 => logits,
383            d => return Err(FerrumError::model(format!("Unexpected logits rank: {d}"))),
384        };
385
386        let logits_ref = self.wrap_tensor(logits);
387        let cfg = self.model.config();
388        let handle = Arc::new(GenericKvCacheHandle::new(
389            cfg.num_hidden_layers,
390            cfg.num_attention_heads,
391            cfg.head_dim,
392            self.model.device().clone(),
393            tokens.len(),
394            cache_id.clone(),
395        ));
396
397        self.states.lock().insert(
398            cache_id,
399            LlamaCacheState {
400                sequence_length: tokens.len(),
401            },
402        );
403
404        Ok(PrefillOutput::new(logits_ref, handle))
405    }
406
407    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
408        let handle = input
409            .kv_cache
410            .as_any()
411            .downcast_ref::<GenericKvCacheHandle>()
412            .ok_or_else(|| FerrumError::model("Invalid KV handle for Llama"))?;
413        let cache_id = handle.request_cache_id().to_string();
414
415        let seq_len = {
416            let mut states = self.states.lock();
417            if let Some(s) = states.get(&cache_id) {
418                s.sequence_length
419            } else {
420                let len = handle.block_table().sequence_length;
421                states.insert(
422                    cache_id.clone(),
423                    LlamaCacheState {
424                        sequence_length: len,
425                    },
426                );
427                len
428            }
429        };
430
431        let tokens = self.tensor_to_tokens(&input.input_ids)?;
432        if tokens.is_empty() {
433            return Err(FerrumError::model("Empty decode input"));
434        }
435
436        // Try TP path first (FERRUM_TP > 1)
437        #[cfg(feature = "cuda")]
438        {
439            if Self::tp_size() > 1 && self.ensure_tp_group() {
440                // Ensure KV cache exists on all TP ranks.
441                // Export candle KV and shard by heads across ranks.
442                {
443                    let mut group = self.tp_group.lock();
444                    if let Some(ref mut g) = *group {
445                        if !g.has_kv_cache(&cache_id) {
446                            let tp = g.world_size();
447                            let kv_data = self
448                                .model
449                                .export_kv_cache(&cache_id)
450                                .or_else(|| {
451                                    cache_id.rfind("-clone-").and_then(|pos| {
452                                        self.model.export_kv_cache(&cache_id[..pos])
453                                    })
454                                })
455                                .ok_or_else(|| {
456                                    FerrumError::model(format!("No KV for TP: {cache_id}"))
457                                })?;
458
459                            if !kv_data.is_empty() {
460                                let prefill_len = kv_data[0].2;
461                                let max_len = kv_data[0].3;
462                                let num_kv_heads = self.model.config().num_key_value_heads;
463                                let heads_per_rank = num_kv_heads / tp;
464
465                                use ferrum_cuda_kernels::tp_decode::KvSource;
466
467                                let mut per_rank_kv: Vec<Vec<(KvSource, KvSource)>> =
468                                    (0..tp).map(|_| Vec::new()).collect();
469
470                                for (k_tensor, v_tensor, _len, _max) in &kv_data {
471                                    for rank in 0..tp {
472                                        let start = rank * heads_per_rank;
473                                        let k_shard = k_tensor
474                                            .narrow(1, start, heads_per_rank)
475                                            .and_then(|t| t.contiguous())
476                                            .map_err(|e| {
477                                                FerrumError::model(format!("KV shard: {e}"))
478                                            })?;
479                                        let v_shard = v_tensor
480                                            .narrow(1, start, heads_per_rank)
481                                            .and_then(|t| t.contiguous())
482                                            .map_err(|e| {
483                                                FerrumError::model(format!("KV shard: {e}"))
484                                            })?;
485
486                                        if rank == 0 {
487                                            // Same GPU: extract CudaSlice directly
488                                            use candle_core::Storage;
489                                            let (ks, _) = k_shard.storage_and_layout();
490                                            let (vs, _) = v_shard.storage_and_layout();
491                                            let kc = match &*ks {
492                                                Storage::Cuda(cs) => cs
493                                                    .as_cuda_slice::<half::f16>()
494                                                    .map_err(|e| {
495                                                        FerrumError::model(format!("KV: {e}"))
496                                                    })?
497                                                    .clone(),
498                                                _ => return Err(FerrumError::model("KV not CUDA")),
499                                            };
500                                            let vc = match &*vs {
501                                                Storage::Cuda(cs) => cs
502                                                    .as_cuda_slice::<half::f16>()
503                                                    .map_err(|e| {
504                                                        FerrumError::model(format!("KV: {e}"))
505                                                    })?
506                                                    .clone(),
507                                                _ => return Err(FerrumError::model("KV not CUDA")),
508                                            };
509                                            drop(ks);
510                                            drop(vs);
511                                            per_rank_kv[rank]
512                                                .push((KvSource::Gpu(kc), KvSource::Gpu(vc)));
513                                        } else {
514                                            // Cross-GPU: D2H here, worker does H2D
515                                            let k_host = k_shard
516                                                .flatten_all()
517                                                .and_then(|t| t.to_vec1::<half::f16>())
518                                                .map_err(|e| {
519                                                    FerrumError::model(format!("KV d2h: {e}"))
520                                                })?;
521                                            let v_host = v_shard
522                                                .flatten_all()
523                                                .and_then(|t| t.to_vec1::<half::f16>())
524                                                .map_err(|e| {
525                                                    FerrumError::model(format!("KV d2h: {e}"))
526                                                })?;
527                                            per_rank_kv[rank].push((
528                                                KvSource::Host(k_host),
529                                                KvSource::Host(v_host),
530                                            ));
531                                        }
532                                    }
533                                }
534                                g.init_kv_cache(&cache_id, per_rank_kv, prefill_len, max_len)
535                                    .map_err(|e| FerrumError::model(format!("TP KV init: {e}")))?;
536                            }
537                        }
538                    }
539                }
540
541                let logits = {
542                    let mut group = self.tp_group.lock();
543                    let group = group
544                        .as_mut()
545                        .ok_or_else(|| FerrumError::model("TP group gone"))?;
546                    group
547                        .decode_step(tokens[0], seq_len, &cache_id)
548                        .map_err(|e| FerrumError::model(format!("tp_decode: {e}")))?
549                };
550
551                let cuda_dev = self
552                    .model
553                    .candle_device()
554                    .as_cuda_device()
555                    .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
556                let vocab = self.info.vocab_size;
557                let storage = candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(
558                    logits,
559                    cuda_dev.clone(),
560                );
561                let logits_tensor = candle_core::Tensor::from_storage(
562                    candle_core::Storage::Cuda(storage),
563                    (1, 1, vocab),
564                    candle_core::op::BackpropOp::none(),
565                    false,
566                );
567                let logits_ref = self.wrap_tensor(logits_tensor);
568                let new_seq_len = seq_len + 1;
569                {
570                    let mut states = self.states.lock();
571                    if let Some(s) = states.get_mut(&cache_id) {
572                        s.sequence_length = new_seq_len;
573                    }
574                }
575                let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
576                return Ok(DecodeOutput::new(logits_ref, new_handle));
577            }
578        }
579
580        // Try single-GPU CUDA runner path
581        #[cfg(feature = "cuda")]
582        {
583            if std::env::var("FERRUM_DISABLE_CUDA_RUNNER").map_or(true, |v| v != "1") {
584                if self.ensure_cuda_runner() {
585                    self.ensure_runner_kv_cache(&cache_id, seq_len)?;
586
587                    let logits = {
588                        let mut runner = self.cuda_runner.lock();
589                        let runner = runner
590                            .as_mut()
591                            .ok_or_else(|| FerrumError::model("CUDA runner gone"))?;
592                        runner
593                            .decode_step(tokens[0], seq_len, &cache_id)
594                            .map_err(|e| FerrumError::model(format!("decode: {e}")))?
595                    };
596
597                    let cuda_dev = self
598                        .model
599                        .candle_device()
600                        .as_cuda_device()
601                        .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
602                    let vocab = self.info.vocab_size;
603                    let storage = candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(
604                        logits,
605                        cuda_dev.clone(),
606                    );
607                    let logits_tensor = candle_core::Tensor::from_storage(
608                        candle_core::Storage::Cuda(storage),
609                        (1, 1, vocab),
610                        candle_core::op::BackpropOp::none(),
611                        false,
612                    );
613
614                    let logits_ref = self.wrap_tensor(logits_tensor);
615
616                    let new_seq_len = seq_len + 1;
617                    {
618                        let mut states = self.states.lock();
619                        if let Some(s) = states.get_mut(&cache_id) {
620                            s.sequence_length = new_seq_len;
621                        }
622                    }
623                    let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
624                    return Ok(DecodeOutput::new(logits_ref, new_handle));
625                }
626            }
627        }
628
629        // Candle fallback
630        let input_tensor = self.tokens_to_tensor(&tokens)?;
631        let logits = self
632            .model
633            .forward_decode(&input_tensor, seq_len, &cache_id)?;
634        let logits = match logits.dims().len() {
635            2 => logits
636                .unsqueeze(1)
637                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?,
638            3 => logits,
639            _ => logits,
640        };
641        let logits_ref = self.wrap_tensor(logits);
642
643        let new_seq_len = seq_len + tokens.len();
644        {
645            let mut states = self.states.lock();
646            if let Some(s) = states.get_mut(&cache_id) {
647                s.sequence_length = new_seq_len;
648            }
649        }
650        let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
651        Ok(DecodeOutput::new(logits_ref, new_handle))
652    }
653
654    fn capabilities(&self) -> ExecutorCapabilities {
655        ExecutorCapabilities {
656            max_batch_size: 1,
657            max_sequence_length: self.info.max_sequence_length,
658            attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
659            supports_dynamic_batching: false,
660            supports_continuous_batching: true,
661            supports_speculative_decoding: false,
662            supports_tensor_parallelism: false,
663            supports_pipeline_parallelism: false,
664            supported_dtypes: vec![DataType::FP16, DataType::FP32],
665            supported_devices: vec![self.info.device.clone()],
666            memory_requirements: MemoryRequirements {
667                parameter_memory: (self.info.num_parameters * 2) as u64,
668                activation_memory_per_token: 4 * self.info.hidden_size,
669                kv_cache_memory_per_token: 2 * self.info.num_layers * self.info.hidden_size,
670                overhead_memory: 1024 * 1024 * 1024,
671            },
672        }
673    }
674
675    fn release_cache(&self, cache_id: &str) {
676        self.release_sequence(cache_id);
677    }
678
679    fn status(&self) -> ExecutorStatus {
680        common::default_executor_status()
681    }
682}