ferrum_interfaces/model_executor.rs
1//! Model execution interface with clear prefill/decode separation
2//!
3//! This module provides the ModelExecutor trait that replaces the "fat" Model
4//! interface, focusing purely on tensor operations without tokenization or sampling.
5
6use crate::{KvCacheHandle, TensorRef};
7use async_trait::async_trait;
8use ferrum_types::{ModelInfo, Result};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12/// Input for prefill phase (processing the initial prompt)
13#[derive(Debug, Clone)]
14pub struct PrefillInput {
15 /// Input token IDs [batch_size, sequence_length]
16 pub input_ids: TensorRef,
17 /// Attention mask [batch_size, sequence_length] (optional)
18 pub attention_mask: Option<TensorRef>,
19 /// Position IDs [batch_size, sequence_length] (optional, for RoPE)
20 pub position_ids: Option<TensorRef>,
21 /// Pre-allocated KV cache handle (optional, for paged attention)
22 pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
23 /// Request metadata that can affect model execution.
24 pub metadata: HashMap<String, serde_json::Value>,
25}
26
27impl PrefillInput {
28 /// Create new prefill input
29 pub fn new(input_ids: TensorRef) -> Self {
30 Self {
31 input_ids,
32 attention_mask: None,
33 position_ids: None,
34 kv_cache: None,
35 metadata: HashMap::new(),
36 }
37 }
38
39 /// Create prefill input with a pre-allocated KV cache handle.
40 pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
41 self.kv_cache = Some(kv_cache);
42 self
43 }
44
45 /// Attach request metadata.
46 pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
47 self.metadata = metadata;
48 self
49 }
50
51 /// Add attention mask
52 pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
53 self.attention_mask = Some(mask);
54 self
55 }
56
57 /// Add position IDs
58 pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
59 self.position_ids = Some(positions);
60 self
61 }
62
63 /// Get batch size
64 pub fn batch_size(&self) -> usize {
65 self.input_ids.shape()[0]
66 }
67
68 /// Get sequence length
69 pub fn sequence_length(&self) -> usize {
70 if self.input_ids.shape().len() >= 2 {
71 self.input_ids.shape()[1]
72 } else {
73 1
74 }
75 }
76}
77
78/// Output from prefill phase
79#[derive(Debug, Clone)]
80pub struct PrefillOutput {
81 /// Logits for all positions [batch_size, sequence_length, vocab_size]
82 pub logits: TensorRef,
83 /// KV cache handle populated with prompt states
84 pub kv_cache: Arc<dyn KvCacheHandle>,
85 /// Hidden states at each layer (optional, for analysis)
86 pub hidden_states: Option<Vec<TensorRef>>,
87 /// Attention weights (optional, for analysis)
88 pub attention_weights: Option<Vec<TensorRef>>,
89}
90
91impl PrefillOutput {
92 /// Create new prefill output
93 pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
94 Self {
95 logits,
96 kv_cache,
97 hidden_states: None,
98 attention_weights: None,
99 }
100 }
101
102 /// Get logits for last position (for next token generation)
103 pub fn last_token_logits(&self) -> Result<TensorRef> {
104 let shape = self.logits.shape();
105 if shape.len() != 3 {
106 return Err(ferrum_types::FerrumError::backend(
107 "Expected 3D logits tensor [batch, seq, vocab]",
108 ));
109 }
110
111 let seq_len = shape[1];
112 if seq_len == 0 {
113 return Err(ferrum_types::FerrumError::backend("Empty sequence"));
114 }
115
116 // Extract last position: [batch, seq-1:seq, vocab] -> [batch, vocab]
117 self.logits
118 .view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
119 }
120}
121
122/// Input for decode phase (generating one token at a time)
123#[derive(Debug, Clone)]
124pub struct DecodeInput {
125 /// Input token ID for current step [batch_size, 1]
126 pub input_ids: TensorRef,
127 /// Existing KV cache from previous steps
128 pub kv_cache: Arc<dyn KvCacheHandle>,
129 /// Position IDs for current step [batch_size, 1] (optional)
130 pub position_ids: Option<TensorRef>,
131 /// Request metadata that can affect model execution.
132 pub metadata: HashMap<String, serde_json::Value>,
133}
134
135impl DecodeInput {
136 /// Create new decode input
137 pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
138 Self {
139 input_ids,
140 kv_cache,
141 position_ids: None,
142 metadata: HashMap::new(),
143 }
144 }
145
146 /// Add position IDs
147 pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
148 self.position_ids = Some(positions);
149 self
150 }
151
152 /// Attach request metadata.
153 pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
154 self.metadata = metadata;
155 self
156 }
157
158 /// Get batch size
159 pub fn batch_size(&self) -> usize {
160 self.input_ids.shape()[0]
161 }
162}
163
164/// One sequence's contribution to a unified mixed-batch forward.
165///
166/// A unified batch lets a single model forward pass process a mix of
167/// per-sequence work units: a prefill chunk (q_tokens.len() ≥ 1, possibly
168/// continuing from `pos_offset > 0` for chunked prefill) and a decode step
169/// (q_tokens.len() == 1, `pos_offset` = current cache length) coexist in
170/// the same call. The model layer concatenates all `q_tokens` into one
171/// [M_total, hidden] tensor and runs all GEMMs / norms once; only the
172/// attention kernel sees per-item segmentation.
173///
174/// This is the abstraction that enables vLLM-style chunked prefill where
175/// decode tokens for already-running sequences are produced in the same
176/// iter as a prefill chunk for a newly-arriving sequence.
177#[derive(Clone)]
178pub struct UnifiedBatchItem {
179 /// Identifier matching the sequence's KV cache (model-side keying).
180 pub seq_id: String,
181 /// Tokens to process this iter. For decode this is exactly 1 token;
182 /// for prefill (chunked or whole) this is the chunk's tokens.
183 pub q_tokens: Vec<u32>,
184 /// KV cache handle for this sequence.
185 pub kv_cache: Arc<dyn KvCacheHandle>,
186 /// Starting absolute position for the FIRST token in `q_tokens`.
187 /// 0 for a fresh prefill, `kv_len` for a decode step or a continuing
188 /// chunked-prefill slice.
189 pub pos_offset: usize,
190 /// True iff this item completes the request's prefill (or is a decode
191 /// item) — i.e. logits at the last token of `q_tokens` should be
192 /// returned for sampling. Intermediate prefill chunks set this false
193 /// to skip the lm_head + sampling path.
194 pub is_final_chunk: bool,
195 /// Request metadata that can affect model execution.
196 pub metadata: HashMap<String, serde_json::Value>,
197}
198
199impl std::fmt::Debug for UnifiedBatchItem {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 f.debug_struct("UnifiedBatchItem")
202 .field("seq_id", &self.seq_id)
203 .field("q_len", &self.q_tokens.len())
204 .field("pos_offset", &self.pos_offset)
205 .field("is_final_chunk", &self.is_final_chunk)
206 .finish()
207 }
208}
209
210/// A mixed-batch forward request: any combination of in-progress prefill
211/// chunks and decode steps. See [`UnifiedBatchItem`] for the per-item
212/// semantics. The producer (engine) groups all sequences active in this
213/// iter into a single batch; the consumer (model) runs one forward and
214/// returns per-item logits (only for items with `is_final_chunk = true`,
215/// in the order they appear in `items`).
216#[derive(Debug, Clone, Default)]
217pub struct UnifiedBatch {
218 pub items: Vec<UnifiedBatchItem>,
219}
220
221impl UnifiedBatch {
222 pub fn new() -> Self {
223 Self::default()
224 }
225
226 /// Total query tokens across all items — corresponds to the M dim of
227 /// the model's per-layer GEMMs in the unified forward.
228 pub fn total_q_tokens(&self) -> usize {
229 self.items.iter().map(|it| it.q_tokens.len()).sum()
230 }
231
232 /// Number of items that will produce a logits vector (decode items
233 /// always; prefill items only on their final chunk).
234 pub fn num_sampled_items(&self) -> usize {
235 self.items.iter().filter(|it| it.is_final_chunk).count()
236 }
237}
238
239/// Output from decode phase
240#[derive(Debug, Clone)]
241pub struct DecodeOutput {
242 /// Logits for next token [batch_size, vocab_size]
243 pub logits: TensorRef,
244 /// Updated KV cache with new token state
245 pub kv_cache: Arc<dyn KvCacheHandle>,
246 /// Hidden state for current token (optional)
247 pub hidden_state: Option<TensorRef>,
248 /// Attention weights for current token (optional)
249 pub attention_weights: Option<Vec<TensorRef>>,
250}
251
252impl DecodeOutput {
253 /// Create new decode output
254 pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
255 Self {
256 logits,
257 kv_cache,
258 hidden_state: None,
259 attention_weights: None,
260 }
261 }
262}
263
264/// Core model executor trait focusing on tensor operations
265#[async_trait]
266pub trait ModelExecutor: Send + Sync {
267 /// Get model information and metadata
268 fn info(&self) -> &ModelInfo;
269
270 /// Whether this executor's backend can run the unified mixed prefill+decode
271 /// forward natively. When false, the engine routes Qwen3-MoE batches through
272 /// the legacy split path. Reported by the (backend-aware) executor so the
273 /// engine stays backend-agnostic — replaces a `cfg(target_os)` branch that
274 /// previously hard-coded "Metal/CPU lack native unified" in the hot path.
275 ///
276 /// Default false (conservative legacy path); accelerators with a native
277 /// unified forward override to true.
278 fn supports_native_unified_decode(&self) -> bool {
279 false
280 }
281
282 /// Per-request KV capacity in tokens when the executor owns a smaller
283 /// runtime cache window than the model's declared context length.
284 fn kv_capacity(&self) -> Option<usize> {
285 None
286 }
287
288 /// Execute prefill phase (process initial prompt)
289 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
290
291 /// Batch prefill: process multiple prompts' prefill in ONE forward pass.
292 ///
293 /// Default implementation falls back to per-request `prefill()` (serial,
294 /// which is the current behavior the engine sees today). Executors that
295 /// support unified mixed-batch forward (e.g. via `model.unified_forward`
296 /// over a varlen QKV path) should override this to amortize launch /
297 /// kernel-overhead across all `inputs` items in one call.
298 ///
299 /// Used by the continuous-batching engine to coalesce a cohort of new
300 /// prefills (apples M3 c=32 sees 32 simultaneous prefills as one logical
301 /// batch; the serial fallback runs each in ~47 ms while a true batched
302 /// path runs all 32 in ~100 ms).
303 async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>> {
304 let mut outputs = Vec::with_capacity(inputs.len());
305 for input in inputs {
306 outputs.push(self.prefill(input).await?);
307 }
308 Ok(outputs)
309 }
310
311 /// Execute decode phase (generate next token)
312 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
313
314 /// Batch decode: process multiple sequences in one forward pass.
315 ///
316 /// Default implementation falls back to per-request `decode()`.
317 /// Executors with batched CUDA runners should override this.
318 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
319 let mut outputs = Vec::with_capacity(inputs.len());
320 for input in inputs {
321 outputs.push(self.decode(input).await?);
322 }
323 Ok(outputs)
324 }
325
326 /// Unified mixed-batch forward: process a [`UnifiedBatch`] containing
327 /// any combination of prefill chunks (one or more `q_tokens` per item,
328 /// possibly continuing from `pos_offset > 0`) and decode steps
329 /// (`q_tokens.len() == 1`, `is_final_chunk = true`) in a single model
330 /// forward pass.
331 ///
332 /// Returns one element per `batch.items[i]`:
333 /// - `Some(logits)` for items with `is_final_chunk = true` (the
334 /// request's final-position logits, ready for sampling)
335 /// - `None` for intermediate prefill chunks (no lm_head executed —
336 /// model only updates KV state)
337 ///
338 /// Default implementation returns `Err(unsupported)`. Concrete LLM
339 /// executors should override with either:
340 /// - A behavioral fallback that dispatches each chunk via existing
341 /// `prefill()` and groups decode items into `batch_decode()` (this
342 /// preserves current behavior; no perf change), OR
343 /// - A real unified-forward path that runs all items through one
344 /// `[M_total, hidden]` GEMM chain with a varlen attention kernel
345 /// (this is the chunked-prefill perf unlock).
346 async fn unified_decode(&self, _batch: &UnifiedBatch) -> Result<Vec<Option<Vec<f32>>>> {
347 Err(ferrum_types::FerrumError::unsupported(
348 "unified_decode not implemented for this executor",
349 ))
350 }
351
352 /// Optional: full forward pass (for non-autoregressive use cases)
353 async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
354 // Default implementation not supported
355 Err(ferrum_types::FerrumError::unsupported(
356 "Full forward pass not supported by this executor",
357 ))
358 }
359
360 /// Roll the KV cache for this executor's sequence back to `new_len`.
361 /// Used by speculative decoding on partial rejection so the next
362 /// iteration sees a KV prefix that matches the accepted token stream.
363 /// Default: Ok(()) — executors that don't cache per-sequence state
364 /// (stub, mock) are inherently tolerant; real LLM executors override.
365 async fn truncate_kv(
366 &self,
367 _kv_cache: &std::sync::Arc<dyn crate::KvCacheHandle>,
368 _new_len: usize,
369 ) -> Result<()> {
370 Ok(())
371 }
372
373 /// Multi-position decode-verify: one forward over `N+1` tokens,
374 /// producing one logits row per position. Used by speculative
375 /// decoding's target path so we don't pay N+1 sequential forwards.
376 ///
377 /// Default falls back to N+1 sequential `decode()` calls — correct
378 /// but slow; real LLM executors override.
379 ///
380 /// Returns a `Vec<DecodeOutput>` of length `inputs.len()` with the
381 /// final KV handle attached to the last element.
382 async fn forward_verify(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
383 let mut out = Vec::with_capacity(inputs.len());
384 for input in inputs {
385 out.push(self.decode(input).await?);
386 }
387 Ok(out)
388 }
389
390 /// Get executor capabilities
391 fn capabilities(&self) -> ExecutorCapabilities;
392
393 /// Get current executor status
394 fn status(&self) -> ExecutorStatus;
395
396 /// Optional model/executor cache metrics.
397 ///
398 /// Concrete LLM executors use this for model-level paged KV prefix reuse
399 /// counters. Default implementations keep non-autoregressive executors
400 /// and tests from needing cache-specific plumbing.
401 fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
402 None
403 }
404
405 /// Optional LoRA runtime metrics.
406 fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
407 None
408 }
409
410 /// Warm up executor (load model, allocate memory, etc.)
411 async fn warmup(&mut self) -> Result<()> {
412 // Default no-op implementation
413 Ok(())
414 }
415
416 /// Shutdown executor gracefully
417 async fn shutdown(&mut self) -> Result<()> {
418 // Default no-op implementation
419 Ok(())
420 }
421
422 /// Release KV cache and state for a completed sequence.
423 ///
424 /// Called by the engine when a request finishes (success or error) to free
425 /// GPU memory held by the sequence's KV cache. The `cache_id` matches the
426 /// value embedded in the `KvCacheHandle` returned by prefill/decode.
427 fn release_cache(&self, _cache_id: &str) {
428 // Default no-op — executors that manage per-sequence KV caches should override.
429 }
430}
431
432/// Executor capabilities and configuration
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ExecutorCapabilities {
435 /// Maximum supported batch size
436 pub max_batch_size: usize,
437 /// Maximum sequence length
438 pub max_sequence_length: usize,
439 /// Supported attention mechanisms
440 pub attention_mechanisms: Vec<AttentionType>,
441 /// Whether executor supports dynamic batching
442 pub supports_dynamic_batching: bool,
443 /// Whether executor supports continuous batching
444 pub supports_continuous_batching: bool,
445 /// Whether executor supports speculative decoding
446 pub supports_speculative_decoding: bool,
447 /// Whether executor supports tensor parallelism
448 pub supports_tensor_parallelism: bool,
449 /// Whether executor supports pipeline parallelism
450 pub supports_pipeline_parallelism: bool,
451 /// Supported data types
452 pub supported_dtypes: Vec<ferrum_types::DataType>,
453 /// Supported devices
454 pub supported_devices: Vec<ferrum_types::Device>,
455 /// Memory requirements estimation
456 pub memory_requirements: MemoryRequirements,
457}
458
459/// Attention mechanism types
460#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
461pub enum AttentionType {
462 /// Standard multi-head attention
463 MultiHead,
464 /// Multi-query attention (MQA)
465 MultiQuery,
466 /// Grouped-query attention (GQA)
467 GroupedQuery,
468 /// Flash attention
469 Flash,
470 /// Paged attention
471 Paged,
472 /// Sliding window attention
473 SlidingWindow,
474}
475
476/// Memory requirements for model execution
477#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct MemoryRequirements {
479 /// Model parameter memory in bytes
480 pub parameter_memory: u64,
481 /// Minimum activation memory per token
482 pub activation_memory_per_token: usize,
483 /// KV cache memory per token per layer
484 pub kv_cache_memory_per_token: usize,
485 /// Additional overhead memory
486 pub overhead_memory: u64,
487}
488
489impl MemoryRequirements {
490 /// Calculate total memory for given configuration
491 pub fn calculate_total_memory(
492 &self,
493 batch_size: usize,
494 sequence_length: usize,
495 num_layers: usize,
496 ) -> u64 {
497 let activation_mem =
498 (self.activation_memory_per_token * batch_size * sequence_length) as u64;
499 let kv_cache_mem =
500 (self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
501
502 self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
503 }
504}
505
506/// Executor status information
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct ExecutorStatus {
509 /// Current executor state
510 pub state: ExecutorState,
511 /// Whether executor is ready to accept requests
512 pub is_ready: bool,
513 /// Current batch size being processed
514 pub current_batch_size: usize,
515 /// Number of prefill operations completed
516 pub prefill_operations: u64,
517 /// Number of decode operations completed
518 pub decode_operations: u64,
519 /// Average prefill time in milliseconds
520 pub avg_prefill_time_ms: f64,
521 /// Average decode time in milliseconds
522 pub avg_decode_time_ms: f64,
523 /// Memory usage statistics
524 pub memory_usage: ExecutorMemoryUsage,
525 /// Last operation timestamp
526 #[serde(skip)]
527 pub last_operation: Option<std::time::Instant>,
528}
529
530/// Executor state
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
532pub enum ExecutorState {
533 /// Executor is initializing
534 Initializing,
535 /// Executor is ready to accept requests
536 Ready,
537 /// Executor is processing requests
538 Busy,
539 /// Executor encountered an error
540 Error,
541 /// Executor is shutting down
542 Shutdown,
543}
544
545/// Executor memory usage
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct ExecutorMemoryUsage {
548 /// Total allocated memory in bytes
549 pub allocated_bytes: usize,
550 /// Currently used memory in bytes
551 pub used_bytes: usize,
552 /// Peak memory usage
553 pub peak_bytes: usize,
554 /// Memory utilization percentage
555 pub utilization_percent: f32,
556}
557
558/// Batch model executor for processing multiple requests efficiently
559#[async_trait]
560pub trait BatchModelExecutor: ModelExecutor {
561 /// Execute batch prefill for multiple sequences
562 async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
563
564 /// Execute batch decode for multiple sequences
565 async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
566
567 /// Get optimal batch size for current conditions
568 fn optimal_batch_size(&self) -> usize;
569
570 /// Check if batch size is supported
571 fn supports_batch_size(&self, batch_size: usize) -> bool;
572}
573
574/// Speculative execution support
575#[async_trait]
576pub trait SpeculativeExecutor: ModelExecutor {
577 /// Execute speculative decoding with draft model
578 async fn speculative_decode(
579 &self,
580 input: &DecodeInput,
581 draft_tokens: &[ferrum_types::TokenId],
582 acceptance_threshold: f32,
583 ) -> Result<SpeculativeDecodeOutput>;
584}
585
586/// Output from speculative decoding
587#[derive(Debug, Clone)]
588pub struct SpeculativeDecodeOutput {
589 /// Accepted tokens (subset of draft tokens)
590 pub accepted_tokens: Vec<ferrum_types::TokenId>,
591 /// Logits for the next token after last accepted
592 pub next_logits: TensorRef,
593 /// Updated KV cache
594 pub kv_cache: Arc<dyn KvCacheHandle>,
595 /// Number of draft tokens accepted
596 pub acceptance_count: usize,
597}
598
599/// Model executor factory
600#[async_trait]
601pub trait ModelExecutorFactory: Send + Sync {
602 /// Create executor from model configuration
603 async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
604
605 /// Create batch executor
606 async fn create_batch_executor(
607 &self,
608 config: &ExecutorConfig,
609 ) -> Result<Box<dyn BatchModelExecutor>>;
610
611 /// Get supported executor types
612 fn supported_types(&self) -> Vec<ExecutorType>;
613
614 /// Validate configuration
615 fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
616}
617
618/// Executor configuration
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct ExecutorConfig {
621 /// Model information
622 pub model_info: ModelInfo,
623 /// Target device
624 pub device: ferrum_types::Device,
625 /// Data type for computation
626 pub dtype: ferrum_types::DataType,
627 /// Maximum batch size
628 pub max_batch_size: usize,
629 /// Maximum sequence length
630 pub max_sequence_length: usize,
631 /// Attention configuration
632 pub attention_config: ExecutorAttentionConfig,
633 /// Memory configuration
634 pub memory_config: ExecutorMemoryConfig,
635 /// Optimization settings
636 pub optimization_config: OptimizationConfig,
637 /// Additional executor-specific options
638 pub executor_options: HashMap<String, serde_json::Value>,
639}
640
641/// Runtime attention configuration for model executor
642///
643/// Note: This is different from ferrum_types::AttentionConfig which describes
644/// the model architecture's attention configuration from config.json.
645/// This type describes the runtime execution settings.
646#[derive(Debug, Clone, Serialize, Deserialize)]
647pub struct ExecutorAttentionConfig {
648 /// Type of attention to use
649 pub attention_type: AttentionType,
650 /// Enable flash attention if available
651 pub enable_flash_attention: bool,
652 /// Enable paged attention
653 pub enable_paged_attention: bool,
654 /// Block size for paged attention
655 pub block_size: Option<usize>,
656 /// Sliding window size (if using sliding window attention)
657 pub sliding_window_size: Option<usize>,
658}
659
660/// Memory configuration for executor
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ExecutorMemoryConfig {
663 /// Enable memory pooling
664 pub enable_memory_pooling: bool,
665 /// Memory pool size in bytes (None for auto)
666 pub memory_pool_size: Option<usize>,
667 /// Enable KV cache sharing
668 pub enable_kv_cache_sharing: bool,
669 /// Maximum memory usage percentage
670 pub max_memory_usage: f32,
671}
672
673/// Optimization configuration
674#[derive(Debug, Clone, Serialize, Deserialize)]
675pub struct OptimizationConfig {
676 /// Enable CUDA graphs (if supported)
677 pub enable_cuda_graphs: bool,
678 /// Enable kernel fusion
679 pub enable_kernel_fusion: bool,
680 /// Enable mixed precision
681 pub enable_mixed_precision: bool,
682 /// Optimization level (0-3)
683 pub optimization_level: u8,
684 /// Custom optimization flags
685 pub custom_flags: HashMap<String, bool>,
686}
687
688/// Supported executor types
689#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
690pub enum ExecutorType {
691 /// Standard sequential executor
692 Sequential,
693 /// Batch executor for parallel processing
694 Batch,
695 /// Continuous batching executor
696 ContinuousBatch,
697 /// Speculative decoding executor
698 Speculative,
699 /// Pipeline parallel executor
700 PipelineParallel,
701 /// Tensor parallel executor
702 TensorParallel,
703}
704
705/// Executor performance metrics
706#[derive(Debug, Clone, Serialize, Deserialize)]
707pub struct ExecutorMetrics {
708 /// Total operations executed
709 pub total_operations: u64,
710 /// Prefill operations
711 pub prefill_operations: u64,
712 /// Decode operations
713 pub decode_operations: u64,
714 /// Average prefill latency (ms)
715 pub avg_prefill_latency: f64,
716 /// Average decode latency (ms)
717 pub avg_decode_latency: f64,
718 /// P95 prefill latency (ms)
719 pub p95_prefill_latency: f64,
720 /// P95 decode latency (ms)
721 pub p95_decode_latency: f64,
722 /// Throughput (tokens per second)
723 pub throughput_tps: f64,
724 /// Memory efficiency (used/allocated)
725 pub memory_efficiency: f32,
726 /// Batch utilization
727 pub batch_utilization: f32,
728}
729
730/// Executor registry for managing multiple executors
731pub trait ExecutorRegistry: Send + Sync {
732 /// Register executor with name
733 fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
734
735 /// Get executor by name
736 fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
737
738 /// Remove executor by name
739 fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
740
741 /// List registered executor names
742 fn list_names(&self) -> Vec<String>;
743
744 /// Get executor metrics
745 fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
746}