1use crate::{KvCacheHandle, TensorRef};
7use async_trait::async_trait;
8use ferrum_types::{ModelInfo, Result};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12#[derive(Debug, Clone)]
14pub struct PrefillInput {
15 pub input_ids: TensorRef,
17 pub attention_mask: Option<TensorRef>,
19 pub position_ids: Option<TensorRef>,
21 pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
23}
24
25impl PrefillInput {
26 pub fn new(input_ids: TensorRef) -> Self {
28 Self {
29 input_ids,
30 attention_mask: None,
31 position_ids: None,
32 kv_cache: None,
33 }
34 }
35
36 pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
38 self.kv_cache = Some(kv_cache);
39 self
40 }
41
42 pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
44 self.attention_mask = Some(mask);
45 self
46 }
47
48 pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
50 self.position_ids = Some(positions);
51 self
52 }
53
54 pub fn batch_size(&self) -> usize {
56 self.input_ids.shape()[0]
57 }
58
59 pub fn sequence_length(&self) -> usize {
61 if self.input_ids.shape().len() >= 2 {
62 self.input_ids.shape()[1]
63 } else {
64 1
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct PrefillOutput {
72 pub logits: TensorRef,
74 pub kv_cache: Arc<dyn KvCacheHandle>,
76 pub hidden_states: Option<Vec<TensorRef>>,
78 pub attention_weights: Option<Vec<TensorRef>>,
80}
81
82impl PrefillOutput {
83 pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
85 Self {
86 logits,
87 kv_cache,
88 hidden_states: None,
89 attention_weights: None,
90 }
91 }
92
93 pub fn last_token_logits(&self) -> Result<TensorRef> {
95 let shape = self.logits.shape();
96 if shape.len() != 3 {
97 return Err(ferrum_types::FerrumError::backend(
98 "Expected 3D logits tensor [batch, seq, vocab]",
99 ));
100 }
101
102 let seq_len = shape[1];
103 if seq_len == 0 {
104 return Err(ferrum_types::FerrumError::backend("Empty sequence"));
105 }
106
107 self.logits
109 .view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct DecodeInput {
116 pub input_ids: TensorRef,
118 pub kv_cache: Arc<dyn KvCacheHandle>,
120 pub position_ids: Option<TensorRef>,
122}
123
124impl DecodeInput {
125 pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
127 Self {
128 input_ids,
129 kv_cache,
130 position_ids: None,
131 }
132 }
133
134 pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
136 self.position_ids = Some(positions);
137 self
138 }
139
140 pub fn batch_size(&self) -> usize {
142 self.input_ids.shape()[0]
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct DecodeOutput {
149 pub logits: TensorRef,
151 pub kv_cache: Arc<dyn KvCacheHandle>,
153 pub hidden_state: Option<TensorRef>,
155 pub attention_weights: Option<Vec<TensorRef>>,
157}
158
159impl DecodeOutput {
160 pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
162 Self {
163 logits,
164 kv_cache,
165 hidden_state: None,
166 attention_weights: None,
167 }
168 }
169}
170
171#[async_trait]
173pub trait ModelExecutor: Send + Sync {
174 fn info(&self) -> &ModelInfo;
176
177 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
179
180 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
182
183 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
188 let mut outputs = Vec::with_capacity(inputs.len());
189 for input in inputs {
190 outputs.push(self.decode(input).await?);
191 }
192 Ok(outputs)
193 }
194
195 async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
197 Err(ferrum_types::FerrumError::unsupported(
199 "Full forward pass not supported by this executor",
200 ))
201 }
202
203 async fn truncate_kv(
209 &self,
210 _kv_cache: &std::sync::Arc<dyn crate::KvCacheHandle>,
211 _new_len: usize,
212 ) -> Result<()> {
213 Ok(())
214 }
215
216 async fn forward_verify(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
226 let mut out = Vec::with_capacity(inputs.len());
227 for input in inputs {
228 out.push(self.decode(input).await?);
229 }
230 Ok(out)
231 }
232
233 fn capabilities(&self) -> ExecutorCapabilities;
235
236 fn status(&self) -> ExecutorStatus;
238
239 async fn warmup(&mut self) -> Result<()> {
241 Ok(())
243 }
244
245 async fn shutdown(&mut self) -> Result<()> {
247 Ok(())
249 }
250
251 fn release_cache(&self, _cache_id: &str) {
257 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ExecutorCapabilities {
264 pub max_batch_size: usize,
266 pub max_sequence_length: usize,
268 pub attention_mechanisms: Vec<AttentionType>,
270 pub supports_dynamic_batching: bool,
272 pub supports_continuous_batching: bool,
274 pub supports_speculative_decoding: bool,
276 pub supports_tensor_parallelism: bool,
278 pub supports_pipeline_parallelism: bool,
280 pub supported_dtypes: Vec<ferrum_types::DataType>,
282 pub supported_devices: Vec<ferrum_types::Device>,
284 pub memory_requirements: MemoryRequirements,
286}
287
288#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
290pub enum AttentionType {
291 MultiHead,
293 MultiQuery,
295 GroupedQuery,
297 Flash,
299 Paged,
301 SlidingWindow,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct MemoryRequirements {
308 pub parameter_memory: u64,
310 pub activation_memory_per_token: usize,
312 pub kv_cache_memory_per_token: usize,
314 pub overhead_memory: u64,
316}
317
318impl MemoryRequirements {
319 pub fn calculate_total_memory(
321 &self,
322 batch_size: usize,
323 sequence_length: usize,
324 num_layers: usize,
325 ) -> u64 {
326 let activation_mem =
327 (self.activation_memory_per_token * batch_size * sequence_length) as u64;
328 let kv_cache_mem =
329 (self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
330
331 self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
332 }
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct ExecutorStatus {
338 pub state: ExecutorState,
340 pub is_ready: bool,
342 pub current_batch_size: usize,
344 pub prefill_operations: u64,
346 pub decode_operations: u64,
348 pub avg_prefill_time_ms: f64,
350 pub avg_decode_time_ms: f64,
352 pub memory_usage: ExecutorMemoryUsage,
354 #[serde(skip)]
356 pub last_operation: Option<std::time::Instant>,
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
361pub enum ExecutorState {
362 Initializing,
364 Ready,
366 Busy,
368 Error,
370 Shutdown,
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct ExecutorMemoryUsage {
377 pub allocated_bytes: usize,
379 pub used_bytes: usize,
381 pub peak_bytes: usize,
383 pub utilization_percent: f32,
385}
386
387#[async_trait]
389pub trait BatchModelExecutor: ModelExecutor {
390 async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
392
393 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
395
396 fn optimal_batch_size(&self) -> usize;
398
399 fn supports_batch_size(&self, batch_size: usize) -> bool;
401}
402
403#[async_trait]
405pub trait SpeculativeExecutor: ModelExecutor {
406 async fn speculative_decode(
408 &self,
409 input: &DecodeInput,
410 draft_tokens: &[ferrum_types::TokenId],
411 acceptance_threshold: f32,
412 ) -> Result<SpeculativeDecodeOutput>;
413}
414
415#[derive(Debug, Clone)]
417pub struct SpeculativeDecodeOutput {
418 pub accepted_tokens: Vec<ferrum_types::TokenId>,
420 pub next_logits: TensorRef,
422 pub kv_cache: Arc<dyn KvCacheHandle>,
424 pub acceptance_count: usize,
426}
427
428#[async_trait]
430pub trait ModelExecutorFactory: Send + Sync {
431 async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
433
434 async fn create_batch_executor(
436 &self,
437 config: &ExecutorConfig,
438 ) -> Result<Box<dyn BatchModelExecutor>>;
439
440 fn supported_types(&self) -> Vec<ExecutorType>;
442
443 fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
445}
446
447#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct ExecutorConfig {
450 pub model_info: ModelInfo,
452 pub device: ferrum_types::Device,
454 pub dtype: ferrum_types::DataType,
456 pub max_batch_size: usize,
458 pub max_sequence_length: usize,
460 pub attention_config: ExecutorAttentionConfig,
462 pub memory_config: ExecutorMemoryConfig,
464 pub optimization_config: OptimizationConfig,
466 pub executor_options: HashMap<String, serde_json::Value>,
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ExecutorAttentionConfig {
477 pub attention_type: AttentionType,
479 pub enable_flash_attention: bool,
481 pub enable_paged_attention: bool,
483 pub block_size: Option<usize>,
485 pub sliding_window_size: Option<usize>,
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct ExecutorMemoryConfig {
492 pub enable_memory_pooling: bool,
494 pub memory_pool_size: Option<usize>,
496 pub enable_kv_cache_sharing: bool,
498 pub max_memory_usage: f32,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
504pub struct OptimizationConfig {
505 pub enable_cuda_graphs: bool,
507 pub enable_kernel_fusion: bool,
509 pub enable_mixed_precision: bool,
511 pub optimization_level: u8,
513 pub custom_flags: HashMap<String, bool>,
515}
516
517#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
519pub enum ExecutorType {
520 Sequential,
522 Batch,
524 ContinuousBatch,
526 Speculative,
528 PipelineParallel,
530 TensorParallel,
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct ExecutorMetrics {
537 pub total_operations: u64,
539 pub prefill_operations: u64,
541 pub decode_operations: u64,
543 pub avg_prefill_latency: f64,
545 pub avg_decode_latency: f64,
547 pub p95_prefill_latency: f64,
549 pub p95_decode_latency: f64,
551 pub throughput_tps: f64,
553 pub memory_efficiency: f32,
555 pub batch_utilization: f32,
557}
558
559pub trait ExecutorRegistry: Send + Sync {
561 fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
563
564 fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
566
567 fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
569
570 fn list_names(&self) -> Vec<String>;
572
573 fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
575}