Skip to main content

ferrum_interfaces/
model_executor.rs

1//! Model execution interface with clear prefill/decode separation
2//!
3//! This module provides the ModelExecutor trait that replaces the "fat" Model
4//! interface, focusing purely on tensor operations without tokenization or sampling.
5
6use crate::{KvCacheHandle, TensorRef};
7use async_trait::async_trait;
8use ferrum_types::{ModelInfo, Result};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12/// Input for prefill phase (processing the initial prompt)
13#[derive(Debug, Clone)]
14pub struct PrefillInput {
15    /// Input token IDs [batch_size, sequence_length]
16    pub input_ids: TensorRef,
17    /// Attention mask [batch_size, sequence_length] (optional)
18    pub attention_mask: Option<TensorRef>,
19    /// Position IDs [batch_size, sequence_length] (optional, for RoPE)
20    pub position_ids: Option<TensorRef>,
21    /// Pre-allocated KV cache handle (optional, for paged attention)
22    pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
23}
24
25impl PrefillInput {
26    /// Create new prefill input
27    pub fn new(input_ids: TensorRef) -> Self {
28        Self {
29            input_ids,
30            attention_mask: None,
31            position_ids: None,
32            kv_cache: None,
33        }
34    }
35
36    /// Create prefill input with a pre-allocated KV cache handle.
37    pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
38        self.kv_cache = Some(kv_cache);
39        self
40    }
41
42    /// Add attention mask
43    pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
44        self.attention_mask = Some(mask);
45        self
46    }
47
48    /// Add position IDs
49    pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
50        self.position_ids = Some(positions);
51        self
52    }
53
54    /// Get batch size
55    pub fn batch_size(&self) -> usize {
56        self.input_ids.shape()[0]
57    }
58
59    /// Get sequence length
60    pub fn sequence_length(&self) -> usize {
61        if self.input_ids.shape().len() >= 2 {
62            self.input_ids.shape()[1]
63        } else {
64            1
65        }
66    }
67}
68
69/// Output from prefill phase
70#[derive(Debug, Clone)]
71pub struct PrefillOutput {
72    /// Logits for all positions [batch_size, sequence_length, vocab_size]
73    pub logits: TensorRef,
74    /// KV cache handle populated with prompt states
75    pub kv_cache: Arc<dyn KvCacheHandle>,
76    /// Hidden states at each layer (optional, for analysis)
77    pub hidden_states: Option<Vec<TensorRef>>,
78    /// Attention weights (optional, for analysis)
79    pub attention_weights: Option<Vec<TensorRef>>,
80}
81
82impl PrefillOutput {
83    /// Create new prefill output
84    pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
85        Self {
86            logits,
87            kv_cache,
88            hidden_states: None,
89            attention_weights: None,
90        }
91    }
92
93    /// Get logits for last position (for next token generation)
94    pub fn last_token_logits(&self) -> Result<TensorRef> {
95        let shape = self.logits.shape();
96        if shape.len() != 3 {
97            return Err(ferrum_types::FerrumError::backend(
98                "Expected 3D logits tensor [batch, seq, vocab]",
99            ));
100        }
101
102        let seq_len = shape[1];
103        if seq_len == 0 {
104            return Err(ferrum_types::FerrumError::backend("Empty sequence"));
105        }
106
107        // Extract last position: [batch, seq-1:seq, vocab] -> [batch, vocab]
108        self.logits
109            .view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
110    }
111}
112
113/// Input for decode phase (generating one token at a time)
114#[derive(Debug, Clone)]
115pub struct DecodeInput {
116    /// Input token ID for current step [batch_size, 1]
117    pub input_ids: TensorRef,
118    /// Existing KV cache from previous steps
119    pub kv_cache: Arc<dyn KvCacheHandle>,
120    /// Position IDs for current step [batch_size, 1] (optional)
121    pub position_ids: Option<TensorRef>,
122}
123
124impl DecodeInput {
125    /// Create new decode input
126    pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
127        Self {
128            input_ids,
129            kv_cache,
130            position_ids: None,
131        }
132    }
133
134    /// Add position IDs
135    pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
136        self.position_ids = Some(positions);
137        self
138    }
139
140    /// Get batch size
141    pub fn batch_size(&self) -> usize {
142        self.input_ids.shape()[0]
143    }
144}
145
146/// Output from decode phase
147#[derive(Debug, Clone)]
148pub struct DecodeOutput {
149    /// Logits for next token [batch_size, vocab_size]
150    pub logits: TensorRef,
151    /// Updated KV cache with new token state
152    pub kv_cache: Arc<dyn KvCacheHandle>,
153    /// Hidden state for current token (optional)
154    pub hidden_state: Option<TensorRef>,
155    /// Attention weights for current token (optional)
156    pub attention_weights: Option<Vec<TensorRef>>,
157}
158
159impl DecodeOutput {
160    /// Create new decode output
161    pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
162        Self {
163            logits,
164            kv_cache,
165            hidden_state: None,
166            attention_weights: None,
167        }
168    }
169}
170
171/// Core model executor trait focusing on tensor operations
172#[async_trait]
173pub trait ModelExecutor: Send + Sync {
174    /// Get model information and metadata
175    fn info(&self) -> &ModelInfo;
176
177    /// Execute prefill phase (process initial prompt)
178    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
179
180    /// Execute decode phase (generate next token)
181    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
182
183    /// Batch decode: process multiple sequences in one forward pass.
184    ///
185    /// Default implementation falls back to per-request `decode()`.
186    /// Executors with batched CUDA runners should override this.
187    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
188        let mut outputs = Vec::with_capacity(inputs.len());
189        for input in inputs {
190            outputs.push(self.decode(input).await?);
191        }
192        Ok(outputs)
193    }
194
195    /// Optional: full forward pass (for non-autoregressive use cases)
196    async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
197        // Default implementation not supported
198        Err(ferrum_types::FerrumError::unsupported(
199            "Full forward pass not supported by this executor",
200        ))
201    }
202
203    /// Roll the KV cache for this executor's sequence back to `new_len`.
204    /// Used by speculative decoding on partial rejection so the next
205    /// iteration sees a KV prefix that matches the accepted token stream.
206    /// Default: Ok(()) — executors that don't cache per-sequence state
207    /// (stub, mock) are inherently tolerant; real LLM executors override.
208    async fn truncate_kv(
209        &self,
210        _kv_cache: &std::sync::Arc<dyn crate::KvCacheHandle>,
211        _new_len: usize,
212    ) -> Result<()> {
213        Ok(())
214    }
215
216    /// Multi-position decode-verify: one forward over `N+1` tokens,
217    /// producing one logits row per position. Used by speculative
218    /// decoding's target path so we don't pay N+1 sequential forwards.
219    ///
220    /// Default falls back to N+1 sequential `decode()` calls — correct
221    /// but slow; real LLM executors override.
222    ///
223    /// Returns a `Vec<DecodeOutput>` of length `inputs.len()` with the
224    /// final KV handle attached to the last element.
225    async fn forward_verify(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
226        let mut out = Vec::with_capacity(inputs.len());
227        for input in inputs {
228            out.push(self.decode(input).await?);
229        }
230        Ok(out)
231    }
232
233    /// Get executor capabilities
234    fn capabilities(&self) -> ExecutorCapabilities;
235
236    /// Get current executor status
237    fn status(&self) -> ExecutorStatus;
238
239    /// Warm up executor (load model, allocate memory, etc.)
240    async fn warmup(&mut self) -> Result<()> {
241        // Default no-op implementation
242        Ok(())
243    }
244
245    /// Shutdown executor gracefully
246    async fn shutdown(&mut self) -> Result<()> {
247        // Default no-op implementation
248        Ok(())
249    }
250
251    /// Release KV cache and state for a completed sequence.
252    ///
253    /// Called by the engine when a request finishes (success or error) to free
254    /// GPU memory held by the sequence's KV cache. The `cache_id` matches the
255    /// value embedded in the `KvCacheHandle` returned by prefill/decode.
256    fn release_cache(&self, _cache_id: &str) {
257        // Default no-op — executors that manage per-sequence KV caches should override.
258    }
259}
260
261/// Executor capabilities and configuration
262#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct ExecutorCapabilities {
264    /// Maximum supported batch size
265    pub max_batch_size: usize,
266    /// Maximum sequence length
267    pub max_sequence_length: usize,
268    /// Supported attention mechanisms
269    pub attention_mechanisms: Vec<AttentionType>,
270    /// Whether executor supports dynamic batching
271    pub supports_dynamic_batching: bool,
272    /// Whether executor supports continuous batching
273    pub supports_continuous_batching: bool,
274    /// Whether executor supports speculative decoding
275    pub supports_speculative_decoding: bool,
276    /// Whether executor supports tensor parallelism
277    pub supports_tensor_parallelism: bool,
278    /// Whether executor supports pipeline parallelism
279    pub supports_pipeline_parallelism: bool,
280    /// Supported data types
281    pub supported_dtypes: Vec<ferrum_types::DataType>,
282    /// Supported devices
283    pub supported_devices: Vec<ferrum_types::Device>,
284    /// Memory requirements estimation
285    pub memory_requirements: MemoryRequirements,
286}
287
288/// Attention mechanism types
289#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
290pub enum AttentionType {
291    /// Standard multi-head attention
292    MultiHead,
293    /// Multi-query attention (MQA)
294    MultiQuery,
295    /// Grouped-query attention (GQA)
296    GroupedQuery,
297    /// Flash attention
298    Flash,
299    /// Paged attention
300    Paged,
301    /// Sliding window attention
302    SlidingWindow,
303}
304
305/// Memory requirements for model execution
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct MemoryRequirements {
308    /// Model parameter memory in bytes
309    pub parameter_memory: u64,
310    /// Minimum activation memory per token
311    pub activation_memory_per_token: usize,
312    /// KV cache memory per token per layer
313    pub kv_cache_memory_per_token: usize,
314    /// Additional overhead memory
315    pub overhead_memory: u64,
316}
317
318impl MemoryRequirements {
319    /// Calculate total memory for given configuration
320    pub fn calculate_total_memory(
321        &self,
322        batch_size: usize,
323        sequence_length: usize,
324        num_layers: usize,
325    ) -> u64 {
326        let activation_mem =
327            (self.activation_memory_per_token * batch_size * sequence_length) as u64;
328        let kv_cache_mem =
329            (self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
330
331        self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
332    }
333}
334
335/// Executor status information
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct ExecutorStatus {
338    /// Current executor state
339    pub state: ExecutorState,
340    /// Whether executor is ready to accept requests
341    pub is_ready: bool,
342    /// Current batch size being processed
343    pub current_batch_size: usize,
344    /// Number of prefill operations completed
345    pub prefill_operations: u64,
346    /// Number of decode operations completed
347    pub decode_operations: u64,
348    /// Average prefill time in milliseconds
349    pub avg_prefill_time_ms: f64,
350    /// Average decode time in milliseconds
351    pub avg_decode_time_ms: f64,
352    /// Memory usage statistics
353    pub memory_usage: ExecutorMemoryUsage,
354    /// Last operation timestamp
355    #[serde(skip)]
356    pub last_operation: Option<std::time::Instant>,
357}
358
359/// Executor state
360#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
361pub enum ExecutorState {
362    /// Executor is initializing
363    Initializing,
364    /// Executor is ready to accept requests
365    Ready,
366    /// Executor is processing requests
367    Busy,
368    /// Executor encountered an error
369    Error,
370    /// Executor is shutting down
371    Shutdown,
372}
373
374/// Executor memory usage
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct ExecutorMemoryUsage {
377    /// Total allocated memory in bytes
378    pub allocated_bytes: usize,
379    /// Currently used memory in bytes
380    pub used_bytes: usize,
381    /// Peak memory usage
382    pub peak_bytes: usize,
383    /// Memory utilization percentage
384    pub utilization_percent: f32,
385}
386
387/// Batch model executor for processing multiple requests efficiently
388#[async_trait]
389pub trait BatchModelExecutor: ModelExecutor {
390    /// Execute batch prefill for multiple sequences
391    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
392
393    /// Execute batch decode for multiple sequences
394    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
395
396    /// Get optimal batch size for current conditions
397    fn optimal_batch_size(&self) -> usize;
398
399    /// Check if batch size is supported
400    fn supports_batch_size(&self, batch_size: usize) -> bool;
401}
402
403/// Speculative execution support
404#[async_trait]
405pub trait SpeculativeExecutor: ModelExecutor {
406    /// Execute speculative decoding with draft model
407    async fn speculative_decode(
408        &self,
409        input: &DecodeInput,
410        draft_tokens: &[ferrum_types::TokenId],
411        acceptance_threshold: f32,
412    ) -> Result<SpeculativeDecodeOutput>;
413}
414
415/// Output from speculative decoding
416#[derive(Debug, Clone)]
417pub struct SpeculativeDecodeOutput {
418    /// Accepted tokens (subset of draft tokens)
419    pub accepted_tokens: Vec<ferrum_types::TokenId>,
420    /// Logits for the next token after last accepted
421    pub next_logits: TensorRef,
422    /// Updated KV cache
423    pub kv_cache: Arc<dyn KvCacheHandle>,
424    /// Number of draft tokens accepted
425    pub acceptance_count: usize,
426}
427
428/// Model executor factory
429#[async_trait]
430pub trait ModelExecutorFactory: Send + Sync {
431    /// Create executor from model configuration
432    async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
433
434    /// Create batch executor
435    async fn create_batch_executor(
436        &self,
437        config: &ExecutorConfig,
438    ) -> Result<Box<dyn BatchModelExecutor>>;
439
440    /// Get supported executor types
441    fn supported_types(&self) -> Vec<ExecutorType>;
442
443    /// Validate configuration
444    fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
445}
446
447/// Executor configuration
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct ExecutorConfig {
450    /// Model information
451    pub model_info: ModelInfo,
452    /// Target device
453    pub device: ferrum_types::Device,
454    /// Data type for computation
455    pub dtype: ferrum_types::DataType,
456    /// Maximum batch size
457    pub max_batch_size: usize,
458    /// Maximum sequence length
459    pub max_sequence_length: usize,
460    /// Attention configuration
461    pub attention_config: ExecutorAttentionConfig,
462    /// Memory configuration
463    pub memory_config: ExecutorMemoryConfig,
464    /// Optimization settings
465    pub optimization_config: OptimizationConfig,
466    /// Additional executor-specific options
467    pub executor_options: HashMap<String, serde_json::Value>,
468}
469
470/// Runtime attention configuration for model executor
471///
472/// Note: This is different from ferrum_types::AttentionConfig which describes
473/// the model architecture's attention configuration from config.json.
474/// This type describes the runtime execution settings.
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ExecutorAttentionConfig {
477    /// Type of attention to use
478    pub attention_type: AttentionType,
479    /// Enable flash attention if available
480    pub enable_flash_attention: bool,
481    /// Enable paged attention
482    pub enable_paged_attention: bool,
483    /// Block size for paged attention
484    pub block_size: Option<usize>,
485    /// Sliding window size (if using sliding window attention)
486    pub sliding_window_size: Option<usize>,
487}
488
489/// Memory configuration for executor
490#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct ExecutorMemoryConfig {
492    /// Enable memory pooling
493    pub enable_memory_pooling: bool,
494    /// Memory pool size in bytes (None for auto)
495    pub memory_pool_size: Option<usize>,
496    /// Enable KV cache sharing
497    pub enable_kv_cache_sharing: bool,
498    /// Maximum memory usage percentage
499    pub max_memory_usage: f32,
500}
501
502/// Optimization configuration
503#[derive(Debug, Clone, Serialize, Deserialize)]
504pub struct OptimizationConfig {
505    /// Enable CUDA graphs (if supported)
506    pub enable_cuda_graphs: bool,
507    /// Enable kernel fusion
508    pub enable_kernel_fusion: bool,
509    /// Enable mixed precision
510    pub enable_mixed_precision: bool,
511    /// Optimization level (0-3)
512    pub optimization_level: u8,
513    /// Custom optimization flags
514    pub custom_flags: HashMap<String, bool>,
515}
516
517/// Supported executor types
518#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
519pub enum ExecutorType {
520    /// Standard sequential executor
521    Sequential,
522    /// Batch executor for parallel processing
523    Batch,
524    /// Continuous batching executor
525    ContinuousBatch,
526    /// Speculative decoding executor
527    Speculative,
528    /// Pipeline parallel executor
529    PipelineParallel,
530    /// Tensor parallel executor
531    TensorParallel,
532}
533
534/// Executor performance metrics
535#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct ExecutorMetrics {
537    /// Total operations executed
538    pub total_operations: u64,
539    /// Prefill operations
540    pub prefill_operations: u64,
541    /// Decode operations
542    pub decode_operations: u64,
543    /// Average prefill latency (ms)
544    pub avg_prefill_latency: f64,
545    /// Average decode latency (ms)
546    pub avg_decode_latency: f64,
547    /// P95 prefill latency (ms)
548    pub p95_prefill_latency: f64,
549    /// P95 decode latency (ms)
550    pub p95_decode_latency: f64,
551    /// Throughput (tokens per second)
552    pub throughput_tps: f64,
553    /// Memory efficiency (used/allocated)
554    pub memory_efficiency: f32,
555    /// Batch utilization
556    pub batch_utilization: f32,
557}
558
559/// Executor registry for managing multiple executors
560pub trait ExecutorRegistry: Send + Sync {
561    /// Register executor with name
562    fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
563
564    /// Get executor by name
565    fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
566
567    /// Remove executor by name
568    fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
569
570    /// List registered executor names
571    fn list_names(&self) -> Vec<String>;
572
573    /// Get executor metrics
574    fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
575}