ferrum_models/executor/
llm_executor.rs1use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15
16use parking_lot::Mutex;
17use tracing::debug;
18
19use ferrum_interfaces::{
20 model_executor::{
21 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorStatus,
22 MemoryRequirements, PrefillInput, PrefillOutput,
23 },
24 ModelExecutor,
25};
26use ferrum_types::{DataType, FerrumError, ModelInfo, Result};
27
28use crate::common::DecoderOnlyLLM;
29
30use super::common::{self, GenericKvCacheHandle};
31
32pub struct LlmExecutor {
33 model: Mutex<Box<dyn DecoderOnlyLLM>>,
34 info: ModelInfo,
35 next_cache_id: AtomicU64,
36}
37
38impl LlmExecutor {
39 pub fn new(model: Box<dyn DecoderOnlyLLM>, info: ModelInfo) -> Self {
40 Self {
41 model: Mutex::new(model),
42 info,
43 next_cache_id: AtomicU64::new(0),
44 }
45 }
46
47 fn gen_cache_id(&self) -> String {
48 format!(
49 "llm-cache-{}",
50 self.next_cache_id.fetch_add(1, Ordering::Relaxed)
51 )
52 }
53}
54
55#[async_trait::async_trait]
56impl ModelExecutor for LlmExecutor {
57 fn info(&self) -> &ModelInfo {
58 &self.info
59 }
60
61 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
62 let tokens = common::tensor_to_tokens(&input.input_ids)?;
63 debug!("LlmExecutor prefill: {} tokens", tokens.len());
64
65 let cache_id = self.gen_cache_id();
66
67 let logits = {
68 let mut model = self.model.lock();
69 model.prefill(&cache_id, &tokens)
70 };
71
72 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
74 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
75 .unsqueeze(0)
76 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?
77 .unsqueeze(0)
78 .map_err(|e| FerrumError::model(format!("unsqueeze2: {e}")))?;
79 let logits_ref = common::wrap_tensor(logits_tensor);
80
81 let cfg = self.model.lock().config().clone();
82 let kv_handle = Arc::new(GenericKvCacheHandle::new(
85 cfg.num_layers,
86 cfg.num_kv_heads,
87 cfg.head_dim,
88 candle_core::Device::Cpu,
89 tokens.len(),
90 cache_id,
91 ));
92
93 Ok(PrefillOutput::new(logits_ref, kv_handle))
94 }
95
96 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
97 let input_handle = input
98 .kv_cache
99 .as_any()
100 .downcast_ref::<GenericKvCacheHandle>()
101 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
102
103 let cache_id = input_handle.request_cache_id().to_string();
104 let seq_len = {
105 use ferrum_interfaces::KvCacheHandle;
106 input_handle.block_table().sequence_length
107 };
108
109 let tokens = common::tensor_to_tokens(&input.input_ids)?;
110 if tokens.is_empty() {
111 return Err(FerrumError::model("Decode input is empty"));
112 }
113 let token = tokens[0];
114
115 debug!("LlmExecutor decode: token={token}, pos={seq_len}");
116
117 let logits = {
118 let mut model = self.model.lock();
119 model.decode(&cache_id, token, seq_len as u32)
120 };
121
122 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
123 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
124 .unsqueeze(0)
125 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
126 let logits_ref = common::wrap_tensor(logits_tensor);
127
128 let kv_handle = Arc::new(input_handle.with_sequence_length(seq_len + 1));
129 Ok(DecodeOutput::new(logits_ref, kv_handle))
130 }
131
132 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
137 if inputs.is_empty() {
138 return Ok(Vec::new());
139 }
140 struct Prep {
143 cache_id: String,
144 token: u32,
145 seq_len: u32,
146 handle: Arc<GenericKvCacheHandle>,
147 }
148 let mut prepped: Vec<Prep> = Vec::with_capacity(inputs.len());
149 for input in inputs {
150 let input_handle = input
151 .kv_cache
152 .as_any()
153 .downcast_ref::<GenericKvCacheHandle>()
154 .ok_or_else(|| FerrumError::model("Invalid KV cache handle type"))?;
155 use ferrum_interfaces::KvCacheHandle;
156 let seq_len = input_handle.block_table().sequence_length as u32;
157 let tokens = common::tensor_to_tokens(&input.input_ids)?;
158 if tokens.is_empty() {
159 return Err(FerrumError::model("Decode input is empty"));
160 }
161 prepped.push(Prep {
162 cache_id: input_handle.request_cache_id().to_string(),
163 token: tokens[0],
164 seq_len,
165 handle: Arc::new(input_handle.with_sequence_length((seq_len + 1) as usize)),
166 });
167 }
168
169 let all_logits: Vec<Vec<f32>> = {
174 let mut model = self.model.lock();
175 let tuples: Vec<(String, u32, u32)> = prepped
176 .iter()
177 .map(|p| (p.cache_id.clone(), p.token, p.seq_len))
178 .collect();
179 model.decode_batch(&tuples)
180 };
181
182 let mut outputs = Vec::with_capacity(prepped.len());
183 for (p, logits) in prepped.into_iter().zip(all_logits.into_iter()) {
184 debug!(
185 "LlmExecutor batch_decode: token={}, pos={}",
186 p.token, p.seq_len
187 );
188 let logits_tensor = candle_core::Tensor::new(&logits[..], &candle_core::Device::Cpu)
189 .map_err(|e| FerrumError::model(format!("logits tensor: {e}")))?
190 .unsqueeze(0)
191 .map_err(|e| FerrumError::model(format!("unsqueeze: {e}")))?;
192 let logits_ref = common::wrap_tensor(logits_tensor);
193 outputs.push(DecodeOutput::new(logits_ref, p.handle));
194 }
195 Ok(outputs)
196 }
197
198 fn release_cache(&self, cache_id: &str) {
199 self.model.lock().release(cache_id);
200 }
201
202 fn capabilities(&self) -> ExecutorCapabilities {
203 let cfg = self.model.lock().config().clone();
204 ExecutorCapabilities {
205 max_batch_size: 256,
206 max_sequence_length: cfg.max_seq_len,
207 attention_mechanisms: vec![AttentionType::GroupedQuery],
208 supports_dynamic_batching: true,
209 supports_continuous_batching: true,
210 supports_speculative_decoding: false,
211 supports_tensor_parallelism: false,
212 supports_pipeline_parallelism: false,
213 supported_dtypes: vec![DataType::FP32],
214 supported_devices: vec![self.info.device.clone()],
215 memory_requirements: MemoryRequirements {
216 parameter_memory: (self.info.num_parameters * 4) as u64,
217 activation_memory_per_token: cfg.hidden_size * 4,
218 kv_cache_memory_per_token: cfg.hidden_size * 2,
219 overhead_memory: 256 * 1024 * 1024,
220 },
221 }
222 }
223
224 fn status(&self) -> ExecutorStatus {
225 common::default_executor_status()
226 }
227}