Skip to main content

ferrum_interfaces/
model_executor.rs

1//! Model execution interface with clear prefill/decode separation
2//!
3//! This module provides the ModelExecutor trait that replaces the "fat" Model
4//! interface, focusing purely on tensor operations without tokenization or sampling.
5
6use crate::{KvCacheHandle, TensorRef};
7use async_trait::async_trait;
8use ferrum_types::{ModelInfo, Result};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12/// Input for prefill phase (processing the initial prompt)
13#[derive(Debug, Clone)]
14pub struct PrefillInput {
15    /// Input token IDs [batch_size, sequence_length]
16    pub input_ids: TensorRef,
17    /// Attention mask [batch_size, sequence_length] (optional)
18    pub attention_mask: Option<TensorRef>,
19    /// Position IDs [batch_size, sequence_length] (optional, for RoPE)
20    pub position_ids: Option<TensorRef>,
21    /// Pre-allocated KV cache handle (optional, for paged attention)
22    pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
23    /// Request metadata that can affect model execution.
24    pub metadata: HashMap<String, serde_json::Value>,
25}
26
27impl PrefillInput {
28    /// Create new prefill input
29    pub fn new(input_ids: TensorRef) -> Self {
30        Self {
31            input_ids,
32            attention_mask: None,
33            position_ids: None,
34            kv_cache: None,
35            metadata: HashMap::new(),
36        }
37    }
38
39    /// Create prefill input with a pre-allocated KV cache handle.
40    pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
41        self.kv_cache = Some(kv_cache);
42        self
43    }
44
45    /// Attach request metadata.
46    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
47        self.metadata = metadata;
48        self
49    }
50
51    /// Add attention mask
52    pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
53        self.attention_mask = Some(mask);
54        self
55    }
56
57    /// Add position IDs
58    pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
59        self.position_ids = Some(positions);
60        self
61    }
62
63    /// Get batch size
64    pub fn batch_size(&self) -> usize {
65        self.input_ids.shape()[0]
66    }
67
68    /// Get sequence length
69    pub fn sequence_length(&self) -> usize {
70        if self.input_ids.shape().len() >= 2 {
71            self.input_ids.shape()[1]
72        } else {
73            1
74        }
75    }
76}
77
78/// Output from prefill phase
79#[derive(Debug, Clone)]
80pub struct PrefillOutput {
81    /// Logits for all positions [batch_size, sequence_length, vocab_size]
82    pub logits: TensorRef,
83    /// KV cache handle populated with prompt states
84    pub kv_cache: Arc<dyn KvCacheHandle>,
85    /// Hidden states at each layer (optional, for analysis)
86    pub hidden_states: Option<Vec<TensorRef>>,
87    /// Attention weights (optional, for analysis)
88    pub attention_weights: Option<Vec<TensorRef>>,
89}
90
91impl PrefillOutput {
92    /// Create new prefill output
93    pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
94        Self {
95            logits,
96            kv_cache,
97            hidden_states: None,
98            attention_weights: None,
99        }
100    }
101
102    /// Get logits for last position (for next token generation)
103    pub fn last_token_logits(&self) -> Result<TensorRef> {
104        let shape = self.logits.shape();
105        if shape.len() != 3 {
106            return Err(ferrum_types::FerrumError::backend(
107                "Expected 3D logits tensor [batch, seq, vocab]",
108            ));
109        }
110
111        let seq_len = shape[1];
112        if seq_len == 0 {
113            return Err(ferrum_types::FerrumError::backend("Empty sequence"));
114        }
115
116        // Extract last position: [batch, seq-1:seq, vocab] -> [batch, vocab]
117        self.logits
118            .view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
119    }
120}
121
122/// Input for decode phase (generating one token at a time)
123#[derive(Debug, Clone)]
124pub struct DecodeInput {
125    /// Input token ID for current step [batch_size, 1]
126    pub input_ids: TensorRef,
127    /// Existing KV cache from previous steps
128    pub kv_cache: Arc<dyn KvCacheHandle>,
129    /// Position IDs for current step [batch_size, 1] (optional)
130    pub position_ids: Option<TensorRef>,
131    /// Request metadata that can affect model execution.
132    pub metadata: HashMap<String, serde_json::Value>,
133}
134
135impl DecodeInput {
136    /// Create new decode input
137    pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
138        Self {
139            input_ids,
140            kv_cache,
141            position_ids: None,
142            metadata: HashMap::new(),
143        }
144    }
145
146    /// Add position IDs
147    pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
148        self.position_ids = Some(positions);
149        self
150    }
151
152    /// Attach request metadata.
153    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
154        self.metadata = metadata;
155        self
156    }
157
158    /// Get batch size
159    pub fn batch_size(&self) -> usize {
160        self.input_ids.shape()[0]
161    }
162}
163
164/// One sequence's contribution to a unified mixed-batch forward.
165///
166/// A unified batch lets a single model forward pass process a mix of
167/// per-sequence work units: a prefill chunk (q_tokens.len() ≥ 1, possibly
168/// continuing from `pos_offset > 0` for chunked prefill) and a decode step
169/// (q_tokens.len() == 1, `pos_offset` = current cache length) coexist in
170/// the same call. The model layer concatenates all `q_tokens` into one
171/// [M_total, hidden] tensor and runs all GEMMs / norms once; only the
172/// attention kernel sees per-item segmentation.
173///
174/// This is the abstraction that enables vLLM-style chunked prefill where
175/// decode tokens for already-running sequences are produced in the same
176/// iter as a prefill chunk for a newly-arriving sequence.
177#[derive(Clone)]
178pub struct UnifiedBatchItem {
179    /// Identifier matching the sequence's KV cache (model-side keying).
180    pub seq_id: String,
181    /// Tokens to process this iter. For decode this is exactly 1 token;
182    /// for prefill (chunked or whole) this is the chunk's tokens.
183    pub q_tokens: Vec<u32>,
184    /// KV cache handle for this sequence.
185    pub kv_cache: Arc<dyn KvCacheHandle>,
186    /// Starting absolute position for the FIRST token in `q_tokens`.
187    /// 0 for a fresh prefill, `kv_len` for a decode step or a continuing
188    /// chunked-prefill slice.
189    pub pos_offset: usize,
190    /// True iff this item completes the request's prefill (or is a decode
191    /// item) — i.e. logits at the last token of `q_tokens` should be
192    /// returned for sampling. Intermediate prefill chunks set this false
193    /// to skip the lm_head + sampling path.
194    pub is_final_chunk: bool,
195    /// Request metadata that can affect model execution.
196    pub metadata: HashMap<String, serde_json::Value>,
197}
198
199impl std::fmt::Debug for UnifiedBatchItem {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        f.debug_struct("UnifiedBatchItem")
202            .field("seq_id", &self.seq_id)
203            .field("q_len", &self.q_tokens.len())
204            .field("pos_offset", &self.pos_offset)
205            .field("is_final_chunk", &self.is_final_chunk)
206            .finish()
207    }
208}
209
210/// A mixed-batch forward request: any combination of in-progress prefill
211/// chunks and decode steps. See [`UnifiedBatchItem`] for the per-item
212/// semantics. The producer (engine) groups all sequences active in this
213/// iter into a single batch; the consumer (model) runs one forward and
214/// returns per-item logits (only for items with `is_final_chunk = true`,
215/// in the order they appear in `items`).
216#[derive(Debug, Clone, Default)]
217pub struct UnifiedBatch {
218    pub items: Vec<UnifiedBatchItem>,
219}
220
221impl UnifiedBatch {
222    pub fn new() -> Self {
223        Self::default()
224    }
225
226    /// Total query tokens across all items — corresponds to the M dim of
227    /// the model's per-layer GEMMs in the unified forward.
228    pub fn total_q_tokens(&self) -> usize {
229        self.items.iter().map(|it| it.q_tokens.len()).sum()
230    }
231
232    /// Number of items that will produce a logits vector (decode items
233    /// always; prefill items only on their final chunk).
234    pub fn num_sampled_items(&self) -> usize {
235        self.items.iter().filter(|it| it.is_final_chunk).count()
236    }
237}
238
239/// Output from decode phase
240#[derive(Debug, Clone)]
241pub struct DecodeOutput {
242    /// Logits for next token [batch_size, vocab_size]
243    pub logits: TensorRef,
244    /// Updated KV cache with new token state
245    pub kv_cache: Arc<dyn KvCacheHandle>,
246    /// Hidden state for current token (optional)
247    pub hidden_state: Option<TensorRef>,
248    /// Attention weights for current token (optional)
249    pub attention_weights: Option<Vec<TensorRef>>,
250}
251
252impl DecodeOutput {
253    /// Create new decode output
254    pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
255        Self {
256            logits,
257            kv_cache,
258            hidden_state: None,
259            attention_weights: None,
260        }
261    }
262}
263
264/// Core model executor trait focusing on tensor operations
265#[async_trait]
266pub trait ModelExecutor: Send + Sync {
267    /// Get model information and metadata
268    fn info(&self) -> &ModelInfo;
269
270    /// Whether this executor's backend can run the unified mixed prefill+decode
271    /// forward natively. When false, the engine routes Qwen3-MoE batches through
272    /// the legacy split path. Reported by the (backend-aware) executor so the
273    /// engine stays backend-agnostic — replaces a `cfg(target_os)` branch that
274    /// previously hard-coded "Metal/CPU lack native unified" in the hot path.
275    ///
276    /// Default false (conservative legacy path); accelerators with a native
277    /// unified forward override to true.
278    fn supports_native_unified_decode(&self) -> bool {
279        false
280    }
281
282    /// Per-request KV capacity in tokens when the executor owns a smaller
283    /// runtime cache window than the model's declared context length.
284    fn kv_capacity(&self) -> Option<usize> {
285        None
286    }
287
288    /// Execute prefill phase (process initial prompt)
289    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
290
291    /// Batch prefill: process multiple prompts' prefill in ONE forward pass.
292    ///
293    /// Default implementation falls back to per-request `prefill()` (serial,
294    /// which is the current behavior the engine sees today). Executors that
295    /// support unified mixed-batch forward (e.g. via `model.unified_forward`
296    /// over a varlen QKV path) should override this to amortize launch /
297    /// kernel-overhead across all `inputs` items in one call.
298    ///
299    /// Used by the continuous-batching engine to coalesce a cohort of new
300    /// prefills (apples M3 c=32 sees 32 simultaneous prefills as one logical
301    /// batch; the serial fallback runs each in ~47 ms while a true batched
302    /// path runs all 32 in ~100 ms).
303    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
304        let mut outputs = Vec::with_capacity(inputs.len());
305        for input in inputs {
306            outputs.push(self.prefill(input).await?);
307        }
308        Ok(outputs)
309    }
310
311    /// Execute decode phase (generate next token)
312    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
313
314    /// Batch decode: process multiple sequences in one forward pass.
315    ///
316    /// Default implementation falls back to per-request `decode()`.
317    /// Executors with batched CUDA runners should override this.
318    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
319        let mut outputs = Vec::with_capacity(inputs.len());
320        for input in inputs {
321            outputs.push(self.decode(input).await?);
322        }
323        Ok(outputs)
324    }
325
326    /// Unified mixed-batch forward: process a [`UnifiedBatch`] containing
327    /// any combination of prefill chunks (one or more `q_tokens` per item,
328    /// possibly continuing from `pos_offset > 0`) and decode steps
329    /// (`q_tokens.len() == 1`, `is_final_chunk = true`) in a single model
330    /// forward pass.
331    ///
332    /// Returns one element per `batch.items[i]`:
333    /// - `Some(logits)` for items with `is_final_chunk = true` (the
334    ///   request's final-position logits, ready for sampling)
335    /// - `None` for intermediate prefill chunks (no lm_head executed —
336    ///   model only updates KV state)
337    ///
338    /// Default implementation returns `Err(unsupported)`. Concrete LLM
339    /// executors should override with either:
340    /// - A behavioral fallback that dispatches each chunk via existing
341    ///   `prefill()` and groups decode items into `batch_decode()` (this
342    ///   preserves current behavior; no perf change), OR
343    /// - A real unified-forward path that runs all items through one
344    ///   `[M_total, hidden]` GEMM chain with a varlen attention kernel
345    ///   (this is the chunked-prefill perf unlock).
346    async fn unified_decode(&self, _batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
347        Err(ferrum_types::FerrumError::unsupported(
348            "unified_decode not implemented for this executor",
349        ))
350    }
351
352    /// Optional: full forward pass (for non-autoregressive use cases)
353    async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
354        // Default implementation not supported
355        Err(ferrum_types::FerrumError::unsupported(
356            "Full forward pass not supported by this executor",
357        ))
358    }
359
360    /// Roll the KV cache for this executor's sequence back to `new_len`.
361    /// Used by speculative decoding on partial rejection so the next
362    /// iteration sees a KV prefix that matches the accepted token stream.
363    /// Default: Ok(()) — executors that don't cache per-sequence state
364    /// (stub, mock) are inherently tolerant; real LLM executors override.
365    async fn truncate_kv(
366        &self,
367        _kv_cache: &std::sync::Arc<dyn crate::KvCacheHandle>,
368        _new_len: usize,
369    ) -> Result<()> {
370        Ok(())
371    }
372
373    /// Multi-position decode-verify: one forward over `N+1` tokens,
374    /// producing one logits row per position. Used by speculative
375    /// decoding's target path so we don't pay N+1 sequential forwards.
376    ///
377    /// Default falls back to N+1 sequential `decode()` calls — correct
378    /// but slow; real LLM executors override.
379    ///
380    /// Returns a `Vec<DecodeOutput>` of length `inputs.len()` with the
381    /// final KV handle attached to the last element.
382    async fn forward_verify(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
383        let mut out = Vec::with_capacity(inputs.len());
384        for input in inputs {
385            out.push(self.decode(input).await?);
386        }
387        Ok(out)
388    }
389
390    /// Get executor capabilities
391    fn capabilities(&self) -> ExecutorCapabilities;
392
393    /// Get current executor status
394    fn status(&self) -> ExecutorStatus;
395
396    /// Optional model/executor cache metrics.
397    ///
398    /// Concrete LLM executors use this for model-level paged KV prefix reuse
399    /// counters. Default implementations keep non-autoregressive executors
400    /// and tests from needing cache-specific plumbing.
401    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
402        None
403    }
404
405    /// Optional LoRA runtime metrics.
406    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
407        None
408    }
409
410    /// Warm up executor (load model, allocate memory, etc.)
411    async fn warmup(&mut self) -> Result<()> {
412        // Default no-op implementation
413        Ok(())
414    }
415
416    /// Shutdown executor gracefully
417    async fn shutdown(&mut self) -> Result<()> {
418        // Default no-op implementation
419        Ok(())
420    }
421
422    /// Release KV cache and state for a completed sequence.
423    ///
424    /// Called by the engine when a request finishes (success or error) to free
425    /// GPU memory held by the sequence's KV cache. The `cache_id` matches the
426    /// value embedded in the `KvCacheHandle` returned by prefill/decode.
427    fn release_cache(&self, _cache_id: &str) {
428        // Default no-op — executors that manage per-sequence KV caches should override.
429    }
430}
431
432/// Executor capabilities and configuration
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ExecutorCapabilities {
435    /// Maximum supported batch size
436    pub max_batch_size: usize,
437    /// Maximum sequence length
438    pub max_sequence_length: usize,
439    /// Supported attention mechanisms
440    pub attention_mechanisms: Vec<AttentionType>,
441    /// Whether executor supports dynamic batching
442    pub supports_dynamic_batching: bool,
443    /// Whether executor supports continuous batching
444    pub supports_continuous_batching: bool,
445    /// Whether executor supports speculative decoding
446    pub supports_speculative_decoding: bool,
447    /// Whether executor supports tensor parallelism
448    pub supports_tensor_parallelism: bool,
449    /// Whether executor supports pipeline parallelism
450    pub supports_pipeline_parallelism: bool,
451    /// Supported data types
452    pub supported_dtypes: Vec<ferrum_types::DataType>,
453    /// Supported devices
454    pub supported_devices: Vec<ferrum_types::Device>,
455    /// Memory requirements estimation
456    pub memory_requirements: MemoryRequirements,
457}
458
459/// Attention mechanism types
460#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
461pub enum AttentionType {
462    /// Standard multi-head attention
463    MultiHead,
464    /// Multi-query attention (MQA)
465    MultiQuery,
466    /// Grouped-query attention (GQA)
467    GroupedQuery,
468    /// Flash attention
469    Flash,
470    /// Paged attention
471    Paged,
472    /// Sliding window attention
473    SlidingWindow,
474}
475
476/// Memory requirements for model execution
477#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct MemoryRequirements {
479    /// Model parameter memory in bytes
480    pub parameter_memory: u64,
481    /// Minimum activation memory per token
482    pub activation_memory_per_token: usize,
483    /// KV cache memory per token per layer
484    pub kv_cache_memory_per_token: usize,
485    /// Additional overhead memory
486    pub overhead_memory: u64,
487}
488
489impl MemoryRequirements {
490    /// Calculate total memory for given configuration
491    pub fn calculate_total_memory(
492        &self,
493        batch_size: usize,
494        sequence_length: usize,
495        num_layers: usize,
496    ) -> u64 {
497        let activation_mem =
498            (self.activation_memory_per_token * batch_size * sequence_length) as u64;
499        let kv_cache_mem =
500            (self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
501
502        self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
503    }
504}
505
506/// Executor status information
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct ExecutorStatus {
509    /// Current executor state
510    pub state: ExecutorState,
511    /// Whether executor is ready to accept requests
512    pub is_ready: bool,
513    /// Current batch size being processed
514    pub current_batch_size: usize,
515    /// Number of prefill operations completed
516    pub prefill_operations: u64,
517    /// Number of decode operations completed
518    pub decode_operations: u64,
519    /// Average prefill time in milliseconds
520    pub avg_prefill_time_ms: f64,
521    /// Average decode time in milliseconds
522    pub avg_decode_time_ms: f64,
523    /// Memory usage statistics
524    pub memory_usage: ExecutorMemoryUsage,
525    /// Last operation timestamp
526    #[serde(skip)]
527    pub last_operation: Option<std::time::Instant>,
528}
529
530/// Executor state
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
532pub enum ExecutorState {
533    /// Executor is initializing
534    Initializing,
535    /// Executor is ready to accept requests
536    Ready,
537    /// Executor is processing requests
538    Busy,
539    /// Executor encountered an error
540    Error,
541    /// Executor is shutting down
542    Shutdown,
543}
544
545/// Executor memory usage
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct ExecutorMemoryUsage {
548    /// Total allocated memory in bytes
549    pub allocated_bytes: usize,
550    /// Currently used memory in bytes
551    pub used_bytes: usize,
552    /// Peak memory usage
553    pub peak_bytes: usize,
554    /// Memory utilization percentage
555    pub utilization_percent: f32,
556}
557
558/// Batch model executor for processing multiple requests efficiently
559#[async_trait]
560pub trait BatchModelExecutor: ModelExecutor {
561    /// Execute batch prefill for multiple sequences
562    async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
563
564    /// Execute batch decode for multiple sequences
565    async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
566
567    /// Get optimal batch size for current conditions
568    fn optimal_batch_size(&self) -> usize;
569
570    /// Check if batch size is supported
571    fn supports_batch_size(&self, batch_size: usize) -> bool;
572}
573
574/// Speculative execution support
575#[async_trait]
576pub trait SpeculativeExecutor: ModelExecutor {
577    /// Execute speculative decoding with draft model
578    async fn speculative_decode(
579        &self,
580        input: &DecodeInput,
581        draft_tokens: &[ferrum_types::TokenId],
582        acceptance_threshold: f32,
583    ) -> Result<SpeculativeDecodeOutput>;
584}
585
586/// Output from speculative decoding
587#[derive(Debug, Clone)]
588pub struct SpeculativeDecodeOutput {
589    /// Accepted tokens (subset of draft tokens)
590    pub accepted_tokens: Vec<ferrum_types::TokenId>,
591    /// Logits for the next token after last accepted
592    pub next_logits: TensorRef,
593    /// Updated KV cache
594    pub kv_cache: Arc<dyn KvCacheHandle>,
595    /// Number of draft tokens accepted
596    pub acceptance_count: usize,
597}
598
599/// Model executor factory
600#[async_trait]
601pub trait ModelExecutorFactory: Send + Sync {
602    /// Create executor from model configuration
603    async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
604
605    /// Create batch executor
606    async fn create_batch_executor(
607        &self,
608        config: &ExecutorConfig,
609    ) -> Result<Box<dyn BatchModelExecutor>>;
610
611    /// Get supported executor types
612    fn supported_types(&self) -> Vec<ExecutorType>;
613
614    /// Validate configuration
615    fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
616}
617
618/// Executor configuration
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ExecutorConfig {
621    /// Model information
622    pub model_info: ModelInfo,
623    /// Target device
624    pub device: ferrum_types::Device,
625    /// Data type for computation
626    pub dtype: ferrum_types::DataType,
627    /// Maximum batch size
628    pub max_batch_size: usize,
629    /// Maximum sequence length
630    pub max_sequence_length: usize,
631    /// Attention configuration
632    pub attention_config: ExecutorAttentionConfig,
633    /// Memory configuration
634    pub memory_config: ExecutorMemoryConfig,
635    /// Optimization settings
636    pub optimization_config: OptimizationConfig,
637    /// Additional executor-specific options
638    pub executor_options: HashMap<String, serde_json::Value>,
639}
640
641/// Runtime attention configuration for model executor
642///
643/// Note: This is different from ferrum_types::AttentionConfig which describes
644/// the model architecture's attention configuration from config.json.
645/// This type describes the runtime execution settings.
646#[derive(Debug, Clone, Serialize, Deserialize)]
647pub struct ExecutorAttentionConfig {
648    /// Type of attention to use
649    pub attention_type: AttentionType,
650    /// Enable flash attention if available
651    pub enable_flash_attention: bool,
652    /// Enable paged attention
653    pub enable_paged_attention: bool,
654    /// Block size for paged attention
655    pub block_size: Option<usize>,
656    /// Sliding window size (if using sliding window attention)
657    pub sliding_window_size: Option<usize>,
658}
659
660/// Memory configuration for executor
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ExecutorMemoryConfig {
663    /// Enable memory pooling
664    pub enable_memory_pooling: bool,
665    /// Memory pool size in bytes (None for auto)
666    pub memory_pool_size: Option<usize>,
667    /// Enable KV cache sharing
668    pub enable_kv_cache_sharing: bool,
669    /// Maximum memory usage percentage
670    pub max_memory_usage: f32,
671}
672
673/// Optimization configuration
674#[derive(Debug, Clone, Serialize, Deserialize)]
675pub struct OptimizationConfig {
676    /// Enable CUDA graphs (if supported)
677    pub enable_cuda_graphs: bool,
678    /// Enable kernel fusion
679    pub enable_kernel_fusion: bool,
680    /// Enable mixed precision
681    pub enable_mixed_precision: bool,
682    /// Optimization level (0-3)
683    pub optimization_level: u8,
684    /// Custom optimization flags
685    pub custom_flags: HashMap<String, bool>,
686}
687
688/// Supported executor types
689#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
690pub enum ExecutorType {
691    /// Standard sequential executor
692    Sequential,
693    /// Batch executor for parallel processing
694    Batch,
695    /// Continuous batching executor
696    ContinuousBatch,
697    /// Speculative decoding executor
698    Speculative,
699    /// Pipeline parallel executor
700    PipelineParallel,
701    /// Tensor parallel executor
702    TensorParallel,
703}
704
705/// Executor performance metrics
706#[derive(Debug, Clone, Serialize, Deserialize)]
707pub struct ExecutorMetrics {
708    /// Total operations executed
709    pub total_operations: u64,
710    /// Prefill operations
711    pub prefill_operations: u64,
712    /// Decode operations
713    pub decode_operations: u64,
714    /// Average prefill latency (ms)
715    pub avg_prefill_latency: f64,
716    /// Average decode latency (ms)
717    pub avg_decode_latency: f64,
718    /// P95 prefill latency (ms)
719    pub p95_prefill_latency: f64,
720    /// P95 decode latency (ms)
721    pub p95_decode_latency: f64,
722    /// Throughput (tokens per second)
723    pub throughput_tps: f64,
724    /// Memory efficiency (used/allocated)
725    pub memory_efficiency: f32,
726    /// Batch utilization
727    pub batch_utilization: f32,
728}
729
730/// Executor registry for managing multiple executors
731pub trait ExecutorRegistry: Send + Sync {
732    /// Register executor with name
733    fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
734
735    /// Get executor by name
736    fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
737
738    /// Remove executor by name
739    fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
740
741    /// List registered executor names
742    fn list_names(&self) -> Vec<String>;
743
744    /// Get executor metrics
745    fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
746}