1#[cfg(feature = "cuda")]
9use async_trait::async_trait;
10#[cfg(feature = "cuda")]
11use ferrum_interfaces::{
12 model_executor::{
13 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
14 MemoryRequirements, PrefillInput, PrefillOutput,
15 },
16 KvCacheHandle, ModelExecutor, TensorRef,
17};
18#[cfg(feature = "cuda")]
19use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
20#[cfg(feature = "cuda")]
21use std::sync::{
22 atomic::{AtomicU64, Ordering},
23 Arc,
24};
25#[cfg(feature = "cuda")]
26use tracing::info;
27
28#[cfg(feature = "cuda")]
29use super::common::{self, GenericKvCacheHandle};
30#[cfg(feature = "cuda")]
31use crate::tensor_wrapper::CandleTensorWrapper;
32#[cfg(feature = "cuda")]
33use parking_lot::Mutex;
34
35#[cfg(feature = "cuda")]
36use ferrum_cuda_kernels::tp_decode::TpDecodeGroup;
37
38#[cfg(feature = "cuda")]
39struct TpCacheState {
40 sequence_length: usize,
41}
42
43#[cfg(feature = "cuda")]
48pub struct TpModelExecutor {
49 model: Arc<crate::architectures::llama::LlamaModelWrapper>,
51 tp_group: Mutex<TpDecodeGroup>,
53 info: ModelInfo,
54 states: Mutex<std::collections::HashMap<String, TpCacheState>>,
55 next_cache_id: AtomicU64,
56 tp_size: usize,
57}
58
59#[cfg(feature = "cuda")]
60impl TpModelExecutor {
61 pub fn new(
62 model: crate::architectures::llama::LlamaModelWrapper,
63 tp_group: TpDecodeGroup,
64 info: ModelInfo,
65 tp_size: usize,
66 ) -> Self {
67 info!(
68 "Created TpModelExecutor: tp_size={}, model={}",
69 tp_size, info.model_id
70 );
71 Self {
72 model: Arc::new(model),
73 tp_group: Mutex::new(tp_group),
74 info,
75 states: Mutex::new(std::collections::HashMap::new()),
76 next_cache_id: AtomicU64::new(1),
77 tp_size,
78 }
79 }
80
81 fn tokens_to_tensor(&self, token_ids: &[u32]) -> Result<candle_core::Tensor> {
82 candle_core::Tensor::new(token_ids, self.model.device())
83 .map_err(|e| FerrumError::model(format!("tensor: {e}")))?
84 .unsqueeze(0)
85 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))
86 }
87
88 fn wrap_tensor(&self, tensor: candle_core::Tensor) -> TensorRef {
89 Arc::new(CandleTensorWrapper::new(tensor))
90 }
91
92 fn tensor_to_tokens(&self, tensor: &TensorRef) -> Result<Vec<u32>> {
93 common::tensor_to_tokens(tensor)
94 }
95}
96
97#[cfg(feature = "cuda")]
98#[async_trait]
99impl ModelExecutor for TpModelExecutor {
100 fn info(&self) -> &ModelInfo {
101 &self.info
102 }
103
104 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
105 let tokens = self.tensor_to_tokens(&input.input_ids)?;
107 if tokens.is_empty() {
108 return Err(FerrumError::model("Empty input"));
109 }
110
111 let cache_id = format!(
112 "tp-cache-{}",
113 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
114 );
115
116 let input_tensor = self.tokens_to_tensor(&tokens)?;
117 let logits = self.model.forward_prefill(&input_tensor, &cache_id)?;
118
119 let logits = match logits.dims().len() {
120 2 => logits
121 .unsqueeze(1)
122 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?,
123 3 => logits,
124 d => return Err(FerrumError::model(format!("Unexpected logits rank: {d}"))),
125 };
126 let logits_ref = self.wrap_tensor(logits);
127
128 let cfg = self.model.config();
129 let handle = Arc::new(GenericKvCacheHandle::new(
130 cfg.num_hidden_layers,
131 cfg.num_attention_heads,
132 cfg.head_dim,
133 self.model.device().clone(),
134 tokens.len(),
135 cache_id.clone(),
136 ));
137
138 self.states.lock().insert(
139 cache_id,
140 TpCacheState {
141 sequence_length: tokens.len(),
142 },
143 );
144
145 Ok(PrefillOutput::new(logits_ref, handle))
146 }
147
148 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
149 let handle = input
150 .kv_cache
151 .as_any()
152 .downcast_ref::<GenericKvCacheHandle>()
153 .ok_or_else(|| FerrumError::model("Invalid KV handle"))?;
154 let cache_id = handle.request_cache_id().to_string();
155
156 let seq_len = {
157 let mut states = self.states.lock();
158 if let Some(s) = states.get(&cache_id) {
159 s.sequence_length
160 } else {
161 let len = handle.block_table().sequence_length;
162 states.insert(
163 cache_id.clone(),
164 TpCacheState {
165 sequence_length: len,
166 },
167 );
168 len
169 }
170 };
171
172 let tokens = self.tensor_to_tokens(&input.input_ids)?;
173
174 let logits = {
179 let mut group = self.tp_group.lock();
180 group
181 .decode_step(tokens[0], seq_len, &cache_id)
182 .map_err(|e| FerrumError::model(format!("tp_decode: {e}")))?
183 };
184
185 let cuda_dev = self
186 .model
187 .candle_device()
188 .as_cuda_device()
189 .map_err(|e| FerrumError::model(format!("not CUDA: {e}")))?;
190 let vocab = self.info.vocab_size;
191 let storage =
192 candle_core::cuda_backend::CudaStorage::wrap_cuda_slice(logits, cuda_dev.clone());
193 let logits_tensor = candle_core::Tensor::from_storage(
194 candle_core::Storage::Cuda(storage),
195 (1, 1, vocab),
196 candle_core::op::BackpropOp::none(),
197 false,
198 );
199 let logits_ref = self.wrap_tensor(logits_tensor);
200
201 let new_seq_len = seq_len + 1;
202 {
203 let mut states = self.states.lock();
204 if let Some(s) = states.get_mut(&cache_id) {
205 s.sequence_length = new_seq_len;
206 }
207 }
208 let new_handle = Arc::new(handle.with_sequence_length(new_seq_len));
209 Ok(DecodeOutput::new(logits_ref, new_handle))
210 }
211
212 fn capabilities(&self) -> ExecutorCapabilities {
213 ExecutorCapabilities {
214 max_batch_size: 1,
215 max_sequence_length: self.info.max_sequence_length,
216 attention_mechanisms: vec![AttentionType::MultiHead, AttentionType::GroupedQuery],
217 supports_dynamic_batching: false,
218 supports_continuous_batching: true,
219 supports_speculative_decoding: false,
220 supports_tensor_parallelism: true,
221 supports_pipeline_parallelism: false,
222 supported_dtypes: vec![DataType::FP16],
223 supported_devices: vec![self.info.device.clone()],
224 memory_requirements: MemoryRequirements {
225 parameter_memory: self.info.num_parameters * 2 / self.tp_size as u64,
226 activation_memory_per_token: 4 * self.info.hidden_size,
227 kv_cache_memory_per_token: 2 * self.info.num_layers * self.info.hidden_size
228 / self.tp_size,
229 overhead_memory: 1024 * 1024 * 1024,
230 },
231 }
232 }
233
234 fn release_cache(&self, cache_id: &str) {
235 self.states.lock().remove(cache_id);
236 self.model.release_cache(cache_id);
237 self.tp_group.lock().release_kv_cache(cache_id);
238 }
239
240 fn status(&self) -> ExecutorStatus {
241 common::default_executor_status()
242 }
243}