1use crate::{KvCacheHandle, TensorRef};
7use async_trait::async_trait;
8use ferrum_types::{ModelInfo, Result};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12#[derive(Debug, Clone)]
14pub struct PrefillInput {
15 pub input_ids: TensorRef,
17 pub attention_mask: Option<TensorRef>,
19 pub position_ids: Option<TensorRef>,
21 pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
23}
24
25impl PrefillInput {
26 pub fn new(input_ids: TensorRef) -> Self {
28 Self {
29 input_ids,
30 attention_mask: None,
31 position_ids: None,
32 kv_cache: None,
33 }
34 }
35
36 pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
38 self.kv_cache = Some(kv_cache);
39 self
40 }
41
42 pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
44 self.attention_mask = Some(mask);
45 self
46 }
47
48 pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
50 self.position_ids = Some(positions);
51 self
52 }
53
54 pub fn batch_size(&self) -> usize {
56 self.input_ids.shape()[0]
57 }
58
59 pub fn sequence_length(&self) -> usize {
61 if self.input_ids.shape().len() >= 2 {
62 self.input_ids.shape()[1]
63 } else {
64 1
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct PrefillOutput {
72 pub logits: TensorRef,
74 pub kv_cache: Arc<dyn KvCacheHandle>,
76 pub hidden_states: Option<Vec<TensorRef>>,
78 pub attention_weights: Option<Vec<TensorRef>>,
80}
81
82impl PrefillOutput {
83 pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
85 Self {
86 logits,
87 kv_cache,
88 hidden_states: None,
89 attention_weights: None,
90 }
91 }
92
93 pub fn last_token_logits(&self) -> Result<TensorRef> {
95 let shape = self.logits.shape();
96 if shape.len() != 3 {
97 return Err(ferrum_types::FerrumError::backend(
98 "Expected 3D logits tensor [batch, seq, vocab]",
99 ));
100 }
101
102 let seq_len = shape[1];
103 if seq_len == 0 {
104 return Err(ferrum_types::FerrumError::backend("Empty sequence"));
105 }
106
107 self.logits
109 .view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct DecodeInput {
116 pub input_ids: TensorRef,
118 pub kv_cache: Arc<dyn KvCacheHandle>,
120 pub position_ids: Option<TensorRef>,
122}
123
124impl DecodeInput {
125 pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
127 Self {
128 input_ids,
129 kv_cache,
130 position_ids: None,
131 }
132 }
133
134 pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
136 self.position_ids = Some(positions);
137 self
138 }
139
140 pub fn batch_size(&self) -> usize {
142 self.input_ids.shape()[0]
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct DecodeOutput {
149 pub logits: TensorRef,
151 pub kv_cache: Arc<dyn KvCacheHandle>,
153 pub hidden_state: Option<TensorRef>,
155 pub attention_weights: Option<Vec<TensorRef>>,
157}
158
159impl DecodeOutput {
160 pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
162 Self {
163 logits,
164 kv_cache,
165 hidden_state: None,
166 attention_weights: None,
167 }
168 }
169}
170
171#[async_trait]
173pub trait ModelExecutor: Send + Sync {
174 fn info(&self) -> &ModelInfo;
176
177 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
179
180 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
182
183 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
188 let mut outputs = Vec::with_capacity(inputs.len());
189 for input in inputs {
190 outputs.push(self.decode(input).await?);
191 }
192 Ok(outputs)
193 }
194
195 async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
197 Err(ferrum_types::FerrumError::unsupported(
199 "Full forward pass not supported by this executor",
200 ))
201 }
202
203 fn capabilities(&self) -> ExecutorCapabilities;
205
206 fn status(&self) -> ExecutorStatus;
208
209 async fn warmup(&mut self) -> Result<()> {
211 Ok(())
213 }
214
215 async fn shutdown(&mut self) -> Result<()> {
217 Ok(())
219 }
220
221 fn release_cache(&self, _cache_id: &str) {
227 }
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct ExecutorCapabilities {
234 pub max_batch_size: usize,
236 pub max_sequence_length: usize,
238 pub attention_mechanisms: Vec<AttentionType>,
240 pub supports_dynamic_batching: bool,
242 pub supports_continuous_batching: bool,
244 pub supports_speculative_decoding: bool,
246 pub supports_tensor_parallelism: bool,
248 pub supports_pipeline_parallelism: bool,
250 pub supported_dtypes: Vec<ferrum_types::DataType>,
252 pub supported_devices: Vec<ferrum_types::Device>,
254 pub memory_requirements: MemoryRequirements,
256}
257
258#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
260pub enum AttentionType {
261 MultiHead,
263 MultiQuery,
265 GroupedQuery,
267 Flash,
269 Paged,
271 SlidingWindow,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct MemoryRequirements {
278 pub parameter_memory: u64,
280 pub activation_memory_per_token: usize,
282 pub kv_cache_memory_per_token: usize,
284 pub overhead_memory: u64,
286}
287
288impl MemoryRequirements {
289 pub fn calculate_total_memory(
291 &self,
292 batch_size: usize,
293 sequence_length: usize,
294 num_layers: usize,
295 ) -> u64 {
296 let activation_mem =
297 (self.activation_memory_per_token * batch_size * sequence_length) as u64;
298 let kv_cache_mem =
299 (self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
300
301 self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
302 }
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct ExecutorStatus {
308 pub state: ExecutorState,
310 pub is_ready: bool,
312 pub current_batch_size: usize,
314 pub prefill_operations: u64,
316 pub decode_operations: u64,
318 pub avg_prefill_time_ms: f64,
320 pub avg_decode_time_ms: f64,
322 pub memory_usage: ExecutorMemoryUsage,
324 #[serde(skip)]
326 pub last_operation: Option<std::time::Instant>,
327}
328
329#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
331pub enum ExecutorState {
332 Initializing,
334 Ready,
336 Busy,
338 Error,
340 Shutdown,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct ExecutorMemoryUsage {
347 pub allocated_bytes: usize,
349 pub used_bytes: usize,
351 pub peak_bytes: usize,
353 pub utilization_percent: f32,
355}
356
357#[async_trait]
359pub trait BatchModelExecutor: ModelExecutor {
360 async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
362
363 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
365
366 fn optimal_batch_size(&self) -> usize;
368
369 fn supports_batch_size(&self, batch_size: usize) -> bool;
371}
372
373#[async_trait]
375pub trait SpeculativeExecutor: ModelExecutor {
376 async fn speculative_decode(
378 &self,
379 input: &DecodeInput,
380 draft_tokens: &[ferrum_types::TokenId],
381 acceptance_threshold: f32,
382 ) -> Result<SpeculativeDecodeOutput>;
383}
384
385#[derive(Debug, Clone)]
387pub struct SpeculativeDecodeOutput {
388 pub accepted_tokens: Vec<ferrum_types::TokenId>,
390 pub next_logits: TensorRef,
392 pub kv_cache: Arc<dyn KvCacheHandle>,
394 pub acceptance_count: usize,
396}
397
398#[async_trait]
400pub trait ModelExecutorFactory: Send + Sync {
401 async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
403
404 async fn create_batch_executor(
406 &self,
407 config: &ExecutorConfig,
408 ) -> Result<Box<dyn BatchModelExecutor>>;
409
410 fn supported_types(&self) -> Vec<ExecutorType>;
412
413 fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
419pub struct ExecutorConfig {
420 pub model_info: ModelInfo,
422 pub device: ferrum_types::Device,
424 pub dtype: ferrum_types::DataType,
426 pub max_batch_size: usize,
428 pub max_sequence_length: usize,
430 pub attention_config: ExecutorAttentionConfig,
432 pub memory_config: ExecutorMemoryConfig,
434 pub optimization_config: OptimizationConfig,
436 pub executor_options: HashMap<String, serde_json::Value>,
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
446pub struct ExecutorAttentionConfig {
447 pub attention_type: AttentionType,
449 pub enable_flash_attention: bool,
451 pub enable_paged_attention: bool,
453 pub block_size: Option<usize>,
455 pub sliding_window_size: Option<usize>,
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
461pub struct ExecutorMemoryConfig {
462 pub enable_memory_pooling: bool,
464 pub memory_pool_size: Option<usize>,
466 pub enable_kv_cache_sharing: bool,
468 pub max_memory_usage: f32,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct OptimizationConfig {
475 pub enable_cuda_graphs: bool,
477 pub enable_kernel_fusion: bool,
479 pub enable_mixed_precision: bool,
481 pub optimization_level: u8,
483 pub custom_flags: HashMap<String, bool>,
485}
486
487#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
489pub enum ExecutorType {
490 Sequential,
492 Batch,
494 ContinuousBatch,
496 Speculative,
498 PipelineParallel,
500 TensorParallel,
502}
503
504#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct ExecutorMetrics {
507 pub total_operations: u64,
509 pub prefill_operations: u64,
511 pub decode_operations: u64,
513 pub avg_prefill_latency: f64,
515 pub avg_decode_latency: f64,
517 pub p95_prefill_latency: f64,
519 pub p95_decode_latency: f64,
521 pub throughput_tps: f64,
523 pub memory_efficiency: f32,
525 pub batch_utilization: f32,
527}
528
529pub trait ExecutorRegistry: Send + Sync {
531 fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
533
534 fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
536
537 fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
539
540 fn list_names(&self) -> Vec<String>;
542
543 fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
545}