Skip to main content

ferrum_interfaces/
model_executor.rs

1//! Model execution interface with clear prefill/decode separation
2//!
3//! This module provides the ModelExecutor trait that replaces the "fat" Model
4//! interface, focusing purely on tensor operations without tokenization or sampling.
5
6use crate::{KvCacheHandle, TensorRef};
7use async_trait::async_trait;
8use ferrum_types::{ModelInfo, Result};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12/// Input for prefill phase (processing the initial prompt)
13#[derive(Debug, Clone)]
14pub struct PrefillInput {
15    /// Input token IDs [batch_size, sequence_length]
16    pub input_ids: TensorRef,
17    /// Attention mask [batch_size, sequence_length] (optional)
18    pub attention_mask: Option<TensorRef>,
19    /// Position IDs [batch_size, sequence_length] (optional, for RoPE)
20    pub position_ids: Option<TensorRef>,
21    /// Pre-allocated KV cache handle (optional, for paged attention)
22    pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
23    /// Request metadata that can affect model execution.
24    pub metadata: HashMap<String, serde_json::Value>,
25}
26
27impl PrefillInput {
28    /// Create new prefill input
29    pub fn new(input_ids: TensorRef) -> Self {
30        Self {
31            input_ids,
32            attention_mask: None,
33            position_ids: None,
34            kv_cache: None,
35            metadata: HashMap::new(),
36        }
37    }
38
39    /// Create prefill input with a pre-allocated KV cache handle.
40    pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
41        self.kv_cache = Some(kv_cache);
42        self
43    }
44
45    /// Attach request metadata.
46    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
47        self.metadata = metadata;
48        self
49    }
50
51    /// Add attention mask
52    pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
53        self.attention_mask = Some(mask);
54        self
55    }
56
57    /// Add position IDs
58    pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
59        self.position_ids = Some(positions);
60        self
61    }
62
63    /// Get batch size
64    pub fn batch_size(&self) -> usize {
65        self.input_ids.shape()[0]
66    }
67
68    /// Get sequence length
69    pub fn sequence_length(&self) -> usize {
70        if self.input_ids.shape().len() >= 2 {
71            self.input_ids.shape()[1]
72        } else {
73            1
74        }
75    }
76}
77
78/// Output from prefill phase
79#[derive(Debug, Clone)]
80pub struct PrefillOutput {
81    /// Logits for all positions [batch_size, sequence_length, vocab_size]
82    pub logits: TensorRef,
83    /// KV cache handle populated with prompt states
84    pub kv_cache: Arc<dyn KvCacheHandle>,
85    /// Hidden states at each layer (optional, for analysis)
86    pub hidden_states: Option<Vec<TensorRef>>,
87    /// Attention weights (optional, for analysis)
88    pub attention_weights: Option<Vec<TensorRef>>,
89}
90
91impl PrefillOutput {
92    /// Create new prefill output
93    pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
94        Self {
95            logits,
96            kv_cache,
97            hidden_states: None,
98            attention_weights: None,
99        }
100    }
101
102    /// Get logits for last position (for next token generation)
103    pub fn last_token_logits(&self) -> Result<TensorRef> {
104        let shape = self.logits.shape();
105        if shape.len() != 3 {
106            return Err(ferrum_types::FerrumError::backend(
107                "Expected 3D logits tensor [batch, seq, vocab]",
108            ));
109        }
110
111        let seq_len = shape[1];
112        if seq_len == 0 {
113            return Err(ferrum_types::FerrumError::backend("Empty sequence"));
114        }
115
116        // Extract last position: [batch, seq-1:seq, vocab] -> [batch, vocab]
117        self.logits
118            .view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
119    }
120}
121
122/// Input for decode phase (generating one token at a time)
123#[derive(Debug, Clone)]
124pub struct DecodeInput {
125    /// Input token ID for current step [batch_size, 1]
126    pub input_ids: TensorRef,
127    /// Existing KV cache from previous steps
128    pub kv_cache: Arc<dyn KvCacheHandle>,
129    /// Position IDs for current step [batch_size, 1] (optional)
130    pub position_ids: Option<TensorRef>,
131    /// Request metadata that can affect model execution.
132    pub metadata: HashMap<String, serde_json::Value>,
133}
134
135impl DecodeInput {
136    /// Create new decode input
137    pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
138        Self {
139            input_ids,
140            kv_cache,
141            position_ids: None,
142            metadata: HashMap::new(),
143        }
144    }
145
146    /// Add position IDs
147    pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
148        self.position_ids = Some(positions);
149        self
150    }
151
152    /// Attach request metadata.
153    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
154        self.metadata = metadata;
155        self
156    }
157
158    /// Get batch size
159    pub fn batch_size(&self) -> usize {
160        self.input_ids.shape()[0]
161    }
162}
163
164/// One sequence's contribution to a unified mixed-batch forward.
165///
166/// A unified batch lets a single model forward pass process a mix of
167/// per-sequence work units: a prefill chunk (q_tokens.len() ≥ 1, possibly
168/// continuing from `pos_offset > 0` for chunked prefill) and a decode step
169/// (q_tokens.len() == 1, `pos_offset` = current cache length) coexist in
170/// the same call. The model layer concatenates all `q_tokens` into one
171/// [M_total, hidden] tensor and runs all GEMMs / norms once; only the
172/// attention kernel sees per-item segmentation.
173///
174/// This is the abstraction that enables vLLM-style chunked prefill where
175/// decode tokens for already-running sequences are produced in the same
176/// iter as a prefill chunk for a newly-arriving sequence.
177#[derive(Clone)]
178pub struct UnifiedBatchItem {
179    /// Identifier matching the sequence's KV cache (model-side keying).
180    pub seq_id: String,
181    /// Tokens to process this iter. For decode this is exactly 1 token;
182    /// for prefill (chunked or whole) this is the chunk's tokens.
183    pub q_tokens: Vec<u32>,
184    /// KV cache handle for this sequence.
185    pub kv_cache: Arc<dyn KvCacheHandle>,
186    /// Starting absolute position for the FIRST token in `q_tokens`.
187    /// 0 for a fresh prefill, `kv_len` for a decode step or a continuing
188    /// chunked-prefill slice.
189    pub pos_offset: usize,
190    /// True iff this item completes the request's prefill (or is a decode
191    /// item) — i.e. logits at the last token of `q_tokens` should be
192    /// returned for sampling. Intermediate prefill chunks set this false
193    /// to skip the lm_head + sampling path.
194    pub is_final_chunk: bool,
195    /// Request metadata that can affect model execution.
196    pub metadata: HashMap<String, serde_json::Value>,
197}
198
199impl std::fmt::Debug for UnifiedBatchItem {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("UnifiedBatchItem")
202            .field("seq_id", &self.seq_id)
203            .field("q_len", &self.q_tokens.len())
204            .field("pos_offset", &self.pos_offset)
205            .field("is_final_chunk", &self.is_final_chunk)
206            .finish()
207    }
208}
209
210/// A mixed-batch forward request: any combination of in-progress prefill
211/// chunks and decode steps. See [`UnifiedBatchItem`] for the per-item
212/// semantics. The producer (engine) groups all sequences active in this
213/// iter into a single batch; the consumer (model) runs one forward and
214/// returns per-item logits (only for items with `is_final_chunk = true`,
215/// in the order they appear in `items`).
216#[derive(Debug, Clone, Default)]
217pub struct UnifiedBatch {
218    pub items: Vec<UnifiedBatchItem>,
219}
220
221impl UnifiedBatch {
222    pub fn new() -> Self {
223        Self::default()
224    }
225
226    /// Total query tokens across all items — corresponds to the M dim of
227    /// the model's per-layer GEMMs in the unified forward.
228    pub fn total_q_tokens(&self) -> usize {
229        self.items.iter().map(|it| it.q_tokens.len()).sum()
230    }
231
232    /// Number of items that will produce a logits vector (decode items
233    /// always; prefill items only on their final chunk).
234    pub fn num_sampled_items(&self) -> usize {
235        self.items.iter().filter(|it| it.is_final_chunk).count()
236    }
237}
238
239/// Output from decode phase
240#[derive(Debug, Clone)]
241pub struct DecodeOutput {
242    /// Logits for next token [batch_size, vocab_size]
243    pub logits: TensorRef,
244    /// Updated KV cache with new token state
245    pub kv_cache: Arc<dyn KvCacheHandle>,
246    /// Hidden state for current token (optional)
247    pub hidden_state: Option<TensorRef>,
248    /// Attention weights for current token (optional)
249    pub attention_weights: Option<Vec<TensorRef>>,
250}
251
252impl DecodeOutput {
253    /// Create new decode output
254    pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
255        Self {
256            logits,
257            kv_cache,
258            hidden_state: None,
259            attention_weights: None,
260        }
261    }
262}
263
264/// Core model executor trait focusing on tensor operations
265#[async_trait]
266pub trait ModelExecutor: Send + Sync {
267    /// Get model information and metadata
268    fn info(&self) -> &ModelInfo;
269
270    /// Per-request KV capacity in tokens when the executor owns a smaller
271    /// runtime cache window than the model's declared context length.
272    fn kv_capacity(&self) -> Option<usize> {
273        None
274    }
275
276    /// Execute prefill phase (process initial prompt)
277    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
278
279    /// Batch prefill: process multiple prompts' prefill in ONE forward pass.
280    ///
281    /// Default implementation falls back to per-request `prefill()` (serial,
282    /// which is the current behavior the engine sees today). Executors that
283    /// support unified mixed-batch forward (e.g. via `model.unified_forward`
284    /// over a varlen QKV path) should override this to amortize launch /
285    /// kernel-overhead across all `inputs` items in one call.
286    ///
287    /// Used by the continuous-batching engine to coalesce a cohort of new
288    /// prefills (apples M3 c=32 sees 32 simultaneous prefills as one logical
289    /// batch; the serial fallback runs each in ~47 ms while a true batched
290    /// path runs all 32 in ~100 ms).
291    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
292        let mut outputs = Vec::with_capacity(inputs.len());
293        for input in inputs {
294            outputs.push(self.prefill(input).await?);
295        }
296        Ok(outputs)
297    }
298
299    /// Execute decode phase (generate next token)
300    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
301
302    /// Batch decode: process multiple sequences in one forward pass.
303    ///
304    /// Default implementation falls back to per-request `decode()`.
305    /// Executors with batched CUDA runners should override this.
306    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
307        let mut outputs = Vec::with_capacity(inputs.len());
308        for input in inputs {
309            outputs.push(self.decode(input).await?);
310        }
311        Ok(outputs)
312    }
313
314    /// Unified mixed-batch forward: process a [`UnifiedBatch`] containing
315    /// any combination of prefill chunks (one or more `q_tokens` per item,
316    /// possibly continuing from `pos_offset > 0`) and decode steps
317    /// (`q_tokens.len() == 1`, `is_final_chunk = true`) in a single model
318    /// forward pass.
319    ///
320    /// Returns one element per `batch.items[i]`:
321    /// - `Some(logits)` for items with `is_final_chunk = true` (the
322    ///   request's final-position logits, ready for sampling)
323    /// - `None` for intermediate prefill chunks (no lm_head executed —
324    ///   model only updates KV state)
325    ///
326    /// Default implementation returns `Err(unsupported)`. Concrete LLM
327    /// executors should override with either:
328    /// - A behavioral fallback that dispatches each chunk via existing
329    ///   `prefill()` and groups decode items into `batch_decode()` (this
330    ///   preserves current behavior; no perf change), OR
331    /// - A real unified-forward path that runs all items through one
332    ///   `[M_total, hidden]` GEMM chain with a varlen attention kernel
333    ///   (this is the chunked-prefill perf unlock).
334    async fn unified_decode(&self, _batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
335        Err(ferrum_types::FerrumError::unsupported(
336            "unified_decode not implemented for this executor",
337        ))
338    }
339
340    /// Optional: full forward pass (for non-autoregressive use cases)
341    async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
342        // Default implementation not supported
343        Err(ferrum_types::FerrumError::unsupported(
344            "Full forward pass not supported by this executor",
345        ))
346    }
347
348    /// Roll the KV cache for this executor's sequence back to `new_len`.
349    /// Used by speculative decoding on partial rejection so the next
350    /// iteration sees a KV prefix that matches the accepted token stream.
351    /// Default: Ok(()) — executors that don't cache per-sequence state
352    /// (stub, mock) are inherently tolerant; real LLM executors override.
353    async fn truncate_kv(
354        &self,
355        _kv_cache: &std::sync::Arc<dyn crate::KvCacheHandle>,
356        _new_len: usize,
357    ) -> Result<()> {
358        Ok(())
359    }
360
361    /// Multi-position decode-verify: one forward over `N+1` tokens,
362    /// producing one logits row per position. Used by speculative
363    /// decoding's target path so we don't pay N+1 sequential forwards.
364    ///
365    /// Default falls back to N+1 sequential `decode()` calls — correct
366    /// but slow; real LLM executors override.
367    ///
368    /// Returns a `Vec<DecodeOutput>` of length `inputs.len()` with the
369    /// final KV handle attached to the last element.
370    async fn forward_verify(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
371        let mut out = Vec::with_capacity(inputs.len());
372        for input in inputs {
373            out.push(self.decode(input).await?);
374        }
375        Ok(out)
376    }
377
378    /// Get executor capabilities
379    fn capabilities(&self) -> ExecutorCapabilities;
380
381    /// Get current executor status
382    fn status(&self) -> ExecutorStatus;
383
384    /// Optional model/executor cache metrics.
385    ///
386    /// Concrete LLM executors use this for model-level paged KV prefix reuse
387    /// counters. Default implementations keep non-autoregressive executors
388    /// and tests from needing cache-specific plumbing.
389    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
390        None
391    }
392
393    /// Optional LoRA runtime metrics.
394    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
395        None
396    }
397
398    /// Warm up executor (load model, allocate memory, etc.)
399    async fn warmup(&mut self) -> Result<()> {
400        // Default no-op implementation
401        Ok(())
402    }
403
404    /// Shutdown executor gracefully
405    async fn shutdown(&mut self) -> Result<()> {
406        // Default no-op implementation
407        Ok(())
408    }
409
410    /// Release KV cache and state for a completed sequence.
411    ///
412    /// Called by the engine when a request finishes (success or error) to free
413    /// GPU memory held by the sequence's KV cache. The `cache_id` matches the
414    /// value embedded in the `KvCacheHandle` returned by prefill/decode.
415    fn release_cache(&self, _cache_id: &str) {
416        // Default no-op — executors that manage per-sequence KV caches should override.
417    }
418}
419
420/// Executor capabilities and configuration
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct ExecutorCapabilities {
423    /// Maximum supported batch size
424    pub max_batch_size: usize,
425    /// Maximum sequence length
426    pub max_sequence_length: usize,
427    /// Supported attention mechanisms
428    pub attention_mechanisms: Vec<AttentionType>,
429    /// Whether executor supports dynamic batching
430    pub supports_dynamic_batching: bool,
431    /// Whether executor supports continuous batching
432    pub supports_continuous_batching: bool,
433    /// Whether executor supports speculative decoding
434    pub supports_speculative_decoding: bool,
435    /// Whether executor supports tensor parallelism
436    pub supports_tensor_parallelism: bool,
437    /// Whether executor supports pipeline parallelism
438    pub supports_pipeline_parallelism: bool,
439    /// Supported data types
440    pub supported_dtypes: Vec<ferrum_types::DataType>,
441    /// Supported devices
442    pub supported_devices: Vec<ferrum_types::Device>,
443    /// Memory requirements estimation
444    pub memory_requirements: MemoryRequirements,
445}
446
447/// Attention mechanism types
448#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
449pub enum AttentionType {
450    /// Standard multi-head attention
451    MultiHead,
452    /// Multi-query attention (MQA)
453    MultiQuery,
454    /// Grouped-query attention (GQA)
455    GroupedQuery,
456    /// Flash attention
457    Flash,
458    /// Paged attention
459    Paged,
460    /// Sliding window attention
461    SlidingWindow,
462}
463
464/// Memory requirements for model execution
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct MemoryRequirements {
467    /// Model parameter memory in bytes
468    pub parameter_memory: u64,
469    /// Minimum activation memory per token
470    pub activation_memory_per_token: usize,
471    /// KV cache memory per token per layer
472    pub kv_cache_memory_per_token: usize,
473    /// Additional overhead memory
474    pub overhead_memory: u64,
475}
476
477impl MemoryRequirements {
478    /// Calculate total memory for given configuration
479    pub fn calculate_total_memory(
480        &self,
481        batch_size: usize,
482        sequence_length: usize,
483        num_layers: usize,
484    ) -> u64 {
485        let activation_mem =
486            (self.activation_memory_per_token * batch_size * sequence_length) as u64;
487        let kv_cache_mem =
488            (self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
489
490        self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
491    }
492}
493
494/// Executor status information
495#[derive(Debug, Clone, Serialize, Deserialize)]
496pub struct ExecutorStatus {
497    /// Current executor state
498    pub state: ExecutorState,
499    /// Whether executor is ready to accept requests
500    pub is_ready: bool,
501    /// Current batch size being processed
502    pub current_batch_size: usize,
503    /// Number of prefill operations completed
504    pub prefill_operations: u64,
505    /// Number of decode operations completed
506    pub decode_operations: u64,
507    /// Average prefill time in milliseconds
508    pub avg_prefill_time_ms: f64,
509    /// Average decode time in milliseconds
510    pub avg_decode_time_ms: f64,
511    /// Memory usage statistics
512    pub memory_usage: ExecutorMemoryUsage,
513    /// Last operation timestamp
514    #[serde(skip)]
515    pub last_operation: Option<std::time::Instant>,
516}
517
518/// Executor state
519#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
520pub enum ExecutorState {
521    /// Executor is initializing
522    Initializing,
523    /// Executor is ready to accept requests
524    Ready,
525    /// Executor is processing requests
526    Busy,
527    /// Executor encountered an error
528    Error,
529    /// Executor is shutting down
530    Shutdown,
531}
532
533/// Executor memory usage
534#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct ExecutorMemoryUsage {
536    /// Total allocated memory in bytes
537    pub allocated_bytes: usize,
538    /// Currently used memory in bytes
539    pub used_bytes: usize,
540    /// Peak memory usage
541    pub peak_bytes: usize,
542    /// Memory utilization percentage
543    pub utilization_percent: f32,
544}
545
546/// Batch model executor for processing multiple requests efficiently
547#[async_trait]
548pub trait BatchModelExecutor: ModelExecutor {
549    /// Execute batch prefill for multiple sequences
550    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
551
552    /// Execute batch decode for multiple sequences
553    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
554
555    /// Get optimal batch size for current conditions
556    fn optimal_batch_size(&self) -> usize;
557
558    /// Check if batch size is supported
559    fn supports_batch_size(&self, batch_size: usize) -> bool;
560}
561
562/// Speculative execution support
563#[async_trait]
564pub trait SpeculativeExecutor: ModelExecutor {
565    /// Execute speculative decoding with draft model
566    async fn speculative_decode(
567        &self,
568        input: &DecodeInput,
569        draft_tokens: &[ferrum_types::TokenId],
570        acceptance_threshold: f32,
571    ) -> Result<SpeculativeDecodeOutput>;
572}
573
574/// Output from speculative decoding
575#[derive(Debug, Clone)]
576pub struct SpeculativeDecodeOutput {
577    /// Accepted tokens (subset of draft tokens)
578    pub accepted_tokens: Vec<ferrum_types::TokenId>,
579    /// Logits for the next token after last accepted
580    pub next_logits: TensorRef,
581    /// Updated KV cache
582    pub kv_cache: Arc<dyn KvCacheHandle>,
583    /// Number of draft tokens accepted
584    pub acceptance_count: usize,
585}
586
587/// Model executor factory
588#[async_trait]
589pub trait ModelExecutorFactory: Send + Sync {
590    /// Create executor from model configuration
591    async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
592
593    /// Create batch executor
594    async fn create_batch_executor(
595        &self,
596        config: &ExecutorConfig,
597    ) -> Result<Box<dyn BatchModelExecutor>>;
598
599    /// Get supported executor types
600    fn supported_types(&self) -> Vec<ExecutorType>;
601
602    /// Validate configuration
603    fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
604}
605
606/// Executor configuration
607#[derive(Debug, Clone, Serialize, Deserialize)]
608pub struct ExecutorConfig {
609    /// Model information
610    pub model_info: ModelInfo,
611    /// Target device
612    pub device: ferrum_types::Device,
613    /// Data type for computation
614    pub dtype: ferrum_types::DataType,
615    /// Maximum batch size
616    pub max_batch_size: usize,
617    /// Maximum sequence length
618    pub max_sequence_length: usize,
619    /// Attention configuration
620    pub attention_config: ExecutorAttentionConfig,
621    /// Memory configuration
622    pub memory_config: ExecutorMemoryConfig,
623    /// Optimization settings
624    pub optimization_config: OptimizationConfig,
625    /// Additional executor-specific options
626    pub executor_options: HashMap<String, serde_json::Value>,
627}
628
629/// Runtime attention configuration for model executor
630///
631/// Note: This is different from ferrum_types::AttentionConfig which describes
632/// the model architecture's attention configuration from config.json.
633/// This type describes the runtime execution settings.
634#[derive(Debug, Clone, Serialize, Deserialize)]
635pub struct ExecutorAttentionConfig {
636    /// Type of attention to use
637    pub attention_type: AttentionType,
638    /// Enable flash attention if available
639    pub enable_flash_attention: bool,
640    /// Enable paged attention
641    pub enable_paged_attention: bool,
642    /// Block size for paged attention
643    pub block_size: Option<usize>,
644    /// Sliding window size (if using sliding window attention)
645    pub sliding_window_size: Option<usize>,
646}
647
648/// Memory configuration for executor
649#[derive(Debug, Clone, Serialize, Deserialize)]
650pub struct ExecutorMemoryConfig {
651    /// Enable memory pooling
652    pub enable_memory_pooling: bool,
653    /// Memory pool size in bytes (None for auto)
654    pub memory_pool_size: Option<usize>,
655    /// Enable KV cache sharing
656    pub enable_kv_cache_sharing: bool,
657    /// Maximum memory usage percentage
658    pub max_memory_usage: f32,
659}
660
661/// Optimization configuration
662#[derive(Debug, Clone, Serialize, Deserialize)]
663pub struct OptimizationConfig {
664    /// Enable CUDA graphs (if supported)
665    pub enable_cuda_graphs: bool,
666    /// Enable kernel fusion
667    pub enable_kernel_fusion: bool,
668    /// Enable mixed precision
669    pub enable_mixed_precision: bool,
670    /// Optimization level (0-3)
671    pub optimization_level: u8,
672    /// Custom optimization flags
673    pub custom_flags: HashMap<String, bool>,
674}
675
676/// Supported executor types
677#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
678pub enum ExecutorType {
679    /// Standard sequential executor
680    Sequential,
681    /// Batch executor for parallel processing
682    Batch,
683    /// Continuous batching executor
684    ContinuousBatch,
685    /// Speculative decoding executor
686    Speculative,
687    /// Pipeline parallel executor
688    PipelineParallel,
689    /// Tensor parallel executor
690    TensorParallel,
691}
692
693/// Executor performance metrics
694#[derive(Debug, Clone, Serialize, Deserialize)]
695pub struct ExecutorMetrics {
696    /// Total operations executed
697    pub total_operations: u64,
698    /// Prefill operations
699    pub prefill_operations: u64,
700    /// Decode operations
701    pub decode_operations: u64,
702    /// Average prefill latency (ms)
703    pub avg_prefill_latency: f64,
704    /// Average decode latency (ms)
705    pub avg_decode_latency: f64,
706    /// P95 prefill latency (ms)
707    pub p95_prefill_latency: f64,
708    /// P95 decode latency (ms)
709    pub p95_decode_latency: f64,
710    /// Throughput (tokens per second)
711    pub throughput_tps: f64,
712    /// Memory efficiency (used/allocated)
713    pub memory_efficiency: f32,
714    /// Batch utilization
715    pub batch_utilization: f32,
716}
717
718/// Executor registry for managing multiple executors
719pub trait ExecutorRegistry: Send + Sync {
720    /// Register executor with name
721    fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
722
723    /// Get executor by name
724    fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
725
726    /// Remove executor by name
727    fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
728
729    /// List registered executor names
730    fn list_names(&self) -> Vec<String>;
731
732    /// Get executor metrics
733    fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
734}