Skip to main content

ferrum_models/executor/
tp_executor.rs

1//! Tensor-Parallel model executor.
2//!
3//! Wraps TpDecodeGroup for multi-GPU decode with NCCL all-reduce.
4//! Prefill uses candle on GPU 0; decode uses sharded runners on all GPUs.
5//!
6//! Feature-gated: only available with `tensor-parallel` feature.
7
8#[cfg(feature = "cuda")]
9use async_trait::async_trait;
10#[cfg(feature = "cuda")]
11use ferrum_interfaces::{
12    model_executor::{
13        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
14        MemoryRequirements, PrefillInput, PrefillOutput,
15    },
16    KvCacheHandle, ModelExecutor, TensorRef,
17};
18#[cfg(feature = "cuda")]
19use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
20#[cfg(feature = "cuda")]
21use std::sync::{
22    atomic::{AtomicU64, Ordering},
23    Arc,
24};
25#[cfg(feature = "cuda")]
26use tracing::info;
27
28#[cfg(feature = "cuda")]
29use super::common::{self, GenericKvCacheHandle};
30#[cfg(feature = "cuda")]
31use crate::tensor_wrapper::CandleTensorWrapper;
32#[cfg(feature = "cuda")]
33use parking_lot::Mutex;
34
35#[cfg(feature = "cuda")]
36use ferrum_cuda_kernels::tp_decode::TpDecodeGroup;
37
38#[cfg(feature = "cuda")]
39struct TpCacheState {
40    sequence_length: usize,
41}
42
43/// Tensor-parallel executor.
44///
45/// - Prefill: single GPU (candle FlashAttention-2 on GPU 0)
46/// - Decode: all GPUs via TpDecodeGroup with NCCL all-reduce
47#[cfg(feature = "cuda")]
48pub struct TpModelExecutor {
49    /// Candle model on GPU 0 for prefill
50    model: Arc<crate::architectures::llama::LlamaModelWrapper>,
51    /// TP decode group (one runner per GPU)
52    tp_group: Mutex<TpDecodeGroup>,
53    info: ModelInfo,
54    states: Mutex<std::collections::HashMap<String, TpCacheState>>,
55    next_cache_id: AtomicU64,
56    tp_size: usize,
57}
58
59#[cfg(feature = "cuda")]
60impl TpModelExecutor {
61    pub fn new(
62        model: crate::architectures::llama::LlamaModelWrapper,
63        tp_group: TpDecodeGroup,
64        info: ModelInfo,
65        tp_size: usize,
66    ) -> Self {
67        info!(
68            "Created TpModelExecutor: tp_size={}, model={}",
69            tp_size, info.model_id
70        );
71        Self {
72            model: Arc::new(model),
73            tp_group: Mutex::new(tp_group),
74            info,
75            states: Mutex::new(std::collections::HashMap::new()),
76            next_cache_id: AtomicU64::new(1),
77            tp_size,
78        }
79    }
80
81    fn tokens_to_tensor(&self, token_ids: &[u32]) -> Result<candle_core::Tensor> {
82        candle_core::Tensor::new(token_ids, self.model.device())
83            .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
84            .unsqueeze(0)
85            .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))
86    }
87
88    fn wrap_tensor(&self, tensor: candle_core::Tensor) -> TensorRef {
89        Arc::new(CandleTensorWrapper::new(tensor))
90    }
91
92    fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
93        common::tensor_to_tokens(tensor)
94    }
95}
96
97#[cfg(feature = "cuda")]
98#[async_trait]
99impl ModelExecutor for TpModelExecutor {
100    fn info(&self) -> &ModelInfo {
101        &self.info
102    }
103
104    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
105        // Prefill on GPU 0 via candle (all GPUs participate in TP decode only)
106        let tokens = self.tensor_to_tokens(&input.input_ids)?;
107        if tokens.is_empty() {
108            return Err(FerrumError::model("Empty input"));
109        }
110
111        let cache_id = format!(
112            "tp-cache-{}",
113            self.next_cache_id.fetch_add(1, Ordering::Relaxed)
114        );
115
116        let input_tensor = self.tokens_to_tensor(&tokens)?;
117        let logits = self.model.forward_prefill(&input_tensor, &cache_id)?;
118
119        let logits = match logits.dims().len() {
120            2 => logits
121                .unsqueeze(1)
122                .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?,
123            3 => logits,
124            d => return Err(FerrumError::model(format!("Unexpected logits rank: {d}"))),
125        };
126        let logits_ref = self.wrap_tensor(logits);
127
128        let cfg = self.model.config();
129        let handle = Arc::new(GenericKvCacheHandle::new(
130            cfg.num_hidden_layers,
131            cfg.num_attention_heads,
132            cfg.head_dim,
133            self.model.device().clone(),
134            tokens.len(),
135            cache_id.clone(),
136        ));
137
138        self.states.lock().insert(
139            cache_id,
140            TpCacheState {
141                sequence_length: tokens.len(),
142            },
143        );
144
145        Ok(PrefillOutput::new(logits_ref, handle))
146    }
147
148    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
149        let handle = input
150            .kv_cache
151            .as_any()
152            .downcast_ref::<GenericKvCacheHandle>()
153            .ok_or_else(|| FerrumError::model("Invalid KV handle"))?;
154        let cache_id = handle.request_cache_id().to_string();
155
156        let seq_len = {
157            let mut states = self.states.lock();
158            if let Some(s) = states.get(&cache_id) {
159                s.sequence_length
160            } else {
161                let len = handle.block_table().sequence_length;
162                states.insert(
163                    cache_id.clone(),
164                    TpCacheState {
165                        sequence_length: len,
166                    },
167                );
168                len
169            }
170        };
171
172        let tokens = self.tensor_to_tokens(&input.input_ids)?;
173
174        // TODO: migrate KV from candle prefill to TP runners
175        // For now, use tp_group.decode_step which requires KV already in runners.
176        // The full integration needs ensure_runner_kv_cache per rank.
177
178        let logits = {
179            let mut group = self.tp_group.lock();
180            group
181                .decode_step(tokens[0], seq_len, &cache_id)
182                .map_err(|e| FerrumError::model(format!("tp_decode: {e}")))?
183        };
184
185        let cuda_dev = self
186            .model
187            .candle_device()
188            .as_cuda_device()
189            .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
190        let vocab = self.info.vocab_size;
191        let storage =
192            candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(logits, cuda_dev.clone());
193        let logits_tensor = candle_core::Tensor::from_storage(
194            candle_core::Storage::Cuda(storage),
195            (1, 1, vocab),
196            candle_core::op::BackpropOp::none(),
197            false,
198        );
199        let logits_ref = self.wrap_tensor(logits_tensor);
200
201        let new_seq_len = seq_len + 1;
202        {
203            let mut states = self.states.lock();
204            if let Some(s) = states.get_mut(&cache_id) {
205                s.sequence_length = new_seq_len;
206            }
207        }
208        let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
209        Ok(DecodeOutput::new(logits_ref, new_handle))
210    }
211
212    fn capabilities(&self) -> ExecutorCapabilities {
213        ExecutorCapabilities {
214            max_batch_size: 1,
215            max_sequence_length: self.info.max_sequence_length,
216            attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
217            supports_dynamic_batching: false,
218            supports_continuous_batching: true,
219            supports_speculative_decoding: false,
220            supports_tensor_parallelism: true,
221            supports_pipeline_parallelism: false,
222            supported_dtypes: vec![DataType::FP16],
223            supported_devices: vec![self.info.device.clone()],
224            memory_requirements: MemoryRequirements {
225                parameter_memory: self.info.num_parameters * 2 / self.tp_size as u64,
226                activation_memory_per_token: 4 * self.info.hidden_size,
227                kv_cache_memory_per_token: 2 * self.info.num_layers * self.info.hidden_size
228                    / self.tp_size,
229                overhead_memory: 1024 * 1024 * 1024,
230            },
231        }
232    }
233
234    fn release_cache(&self, cache_id: &str) {
235        self.states.lock().remove(cache_id);
236        self.model.release_cache(cache_id);
237        self.tp_group.lock().release_kv_cache(cache_id);
238    }
239
240    fn status(&self) -> ExecutorStatus {
241        common::default_executor_status()
242    }
243}