use crate::{KvCacheHandle, TensorRef};
use async_trait::async_trait;
use ferrum_types::{ModelInfo, Result};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
#[derive(Debug, Clone)]
pub struct PrefillInput {
pub input_ids: TensorRef,
pub attention_mask: Option<TensorRef>,
pub position_ids: Option<TensorRef>,
pub kv_cache: Option<Arc<dyn KvCacheHandle>>,
}
impl PrefillInput {
pub fn new(input_ids: TensorRef) -> Self {
Self {
input_ids,
attention_mask: None,
position_ids: None,
kv_cache: None,
}
}
pub fn with_kv_cache(mut self, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
self.kv_cache = Some(kv_cache);
self
}
pub fn with_attention_mask(mut self, mask: TensorRef) -> Self {
self.attention_mask = Some(mask);
self
}
pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
self.position_ids = Some(positions);
self
}
pub fn batch_size(&self) -> usize {
self.input_ids.shape()[0]
}
pub fn sequence_length(&self) -> usize {
if self.input_ids.shape().len() >= 2 {
self.input_ids.shape()[1]
} else {
1
}
}
}
#[derive(Debug, Clone)]
pub struct PrefillOutput {
pub logits: TensorRef,
pub kv_cache: Arc<dyn KvCacheHandle>,
pub hidden_states: Option<Vec<TensorRef>>,
pub attention_weights: Option<Vec<TensorRef>>,
}
impl PrefillOutput {
pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
Self {
logits,
kv_cache,
hidden_states: None,
attention_weights: None,
}
}
pub fn last_token_logits(&self) -> Result<TensorRef> {
let shape = self.logits.shape();
if shape.len() != 3 {
return Err(ferrum_types::FerrumError::backend(
"Expected 3D logits tensor [batch, seq, vocab]",
));
}
let seq_len = shape[1];
if seq_len == 0 {
return Err(ferrum_types::FerrumError::backend("Empty sequence"));
}
self.logits
.view(&[0, seq_len - 1, 0], &[shape[0], seq_len, shape[2]])
}
}
#[derive(Debug, Clone)]
pub struct DecodeInput {
pub input_ids: TensorRef,
pub kv_cache: Arc<dyn KvCacheHandle>,
pub position_ids: Option<TensorRef>,
}
impl DecodeInput {
pub fn new(input_ids: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
Self {
input_ids,
kv_cache,
position_ids: None,
}
}
pub fn with_position_ids(mut self, positions: TensorRef) -> Self {
self.position_ids = Some(positions);
self
}
pub fn batch_size(&self) -> usize {
self.input_ids.shape()[0]
}
}
#[derive(Debug, Clone)]
pub struct DecodeOutput {
pub logits: TensorRef,
pub kv_cache: Arc<dyn KvCacheHandle>,
pub hidden_state: Option<TensorRef>,
pub attention_weights: Option<Vec<TensorRef>>,
}
impl DecodeOutput {
pub fn new(logits: TensorRef, kv_cache: Arc<dyn KvCacheHandle>) -> Self {
Self {
logits,
kv_cache,
hidden_state: None,
attention_weights: None,
}
}
}
#[async_trait]
pub trait ModelExecutor: Send + Sync {
fn info(&self) -> &ModelInfo;
async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput>;
async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput>;
async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>> {
let mut outputs = Vec::with_capacity(inputs.len());
for input in inputs {
outputs.push(self.decode(input).await?);
}
Ok(outputs)
}
async fn forward(&self, _input: &TensorRef) -> Result<TensorRef> {
Err(ferrum_types::FerrumError::unsupported(
"Full forward pass not supported by this executor",
))
}
fn capabilities(&self) -> ExecutorCapabilities;
fn status(&self) -> ExecutorStatus;
async fn warmup(&mut self) -> Result<()> {
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
fn release_cache(&self, _cache_id: &str) {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorCapabilities {
pub max_batch_size: usize,
pub max_sequence_length: usize,
pub attention_mechanisms: Vec<AttentionType>,
pub supports_dynamic_batching: bool,
pub supports_continuous_batching: bool,
pub supports_speculative_decoding: bool,
pub supports_tensor_parallelism: bool,
pub supports_pipeline_parallelism: bool,
pub supported_dtypes: Vec<ferrum_types::DataType>,
pub supported_devices: Vec<ferrum_types::Device>,
pub memory_requirements: MemoryRequirements,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum AttentionType {
MultiHead,
MultiQuery,
GroupedQuery,
Flash,
Paged,
SlidingWindow,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRequirements {
pub parameter_memory: u64,
pub activation_memory_per_token: usize,
pub kv_cache_memory_per_token: usize,
pub overhead_memory: u64,
}
impl MemoryRequirements {
pub fn calculate_total_memory(
&self,
batch_size: usize,
sequence_length: usize,
num_layers: usize,
) -> u64 {
let activation_mem =
(self.activation_memory_per_token * batch_size * sequence_length) as u64;
let kv_cache_mem =
(self.kv_cache_memory_per_token * batch_size * sequence_length * num_layers) as u64;
self.parameter_memory + activation_mem + kv_cache_mem + self.overhead_memory
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorStatus {
pub state: ExecutorState,
pub is_ready: bool,
pub current_batch_size: usize,
pub prefill_operations: u64,
pub decode_operations: u64,
pub avg_prefill_time_ms: f64,
pub avg_decode_time_ms: f64,
pub memory_usage: ExecutorMemoryUsage,
#[serde(skip)]
pub last_operation: Option<std::time::Instant>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExecutorState {
Initializing,
Ready,
Busy,
Error,
Shutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorMemoryUsage {
pub allocated_bytes: usize,
pub used_bytes: usize,
pub peak_bytes: usize,
pub utilization_percent: f32,
}
#[async_trait]
pub trait BatchModelExecutor: ModelExecutor {
async fn batch_prefill(&self, inputs: &[PrefillInput]) -> Result<Vec<PrefillOutput>>;
async fn batch_decode(&self, inputs: &[DecodeInput]) -> Result<Vec<DecodeOutput>>;
fn optimal_batch_size(&self) -> usize;
fn supports_batch_size(&self, batch_size: usize) -> bool;
}
#[async_trait]
pub trait SpeculativeExecutor: ModelExecutor {
async fn speculative_decode(
&self,
input: &DecodeInput,
draft_tokens: &[ferrum_types::TokenId],
acceptance_threshold: f32,
) -> Result<SpeculativeDecodeOutput>;
}
#[derive(Debug, Clone)]
pub struct SpeculativeDecodeOutput {
pub accepted_tokens: Vec<ferrum_types::TokenId>,
pub next_logits: TensorRef,
pub kv_cache: Arc<dyn KvCacheHandle>,
pub acceptance_count: usize,
}
#[async_trait]
pub trait ModelExecutorFactory: Send + Sync {
async fn create_executor(&self, config: &ExecutorConfig) -> Result<Box<dyn ModelExecutor>>;
async fn create_batch_executor(
&self,
config: &ExecutorConfig,
) -> Result<Box<dyn BatchModelExecutor>>;
fn supported_types(&self) -> Vec<ExecutorType>;
fn validate_config(&self, config: &ExecutorConfig) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorConfig {
pub model_info: ModelInfo,
pub device: ferrum_types::Device,
pub dtype: ferrum_types::DataType,
pub max_batch_size: usize,
pub max_sequence_length: usize,
pub attention_config: ExecutorAttentionConfig,
pub memory_config: ExecutorMemoryConfig,
pub optimization_config: OptimizationConfig,
pub executor_options: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorAttentionConfig {
pub attention_type: AttentionType,
pub enable_flash_attention: bool,
pub enable_paged_attention: bool,
pub block_size: Option<usize>,
pub sliding_window_size: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorMemoryConfig {
pub enable_memory_pooling: bool,
pub memory_pool_size: Option<usize>,
pub enable_kv_cache_sharing: bool,
pub max_memory_usage: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationConfig {
pub enable_cuda_graphs: bool,
pub enable_kernel_fusion: bool,
pub enable_mixed_precision: bool,
pub optimization_level: u8,
pub custom_flags: HashMap<String, bool>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ExecutorType {
Sequential,
Batch,
ContinuousBatch,
Speculative,
PipelineParallel,
TensorParallel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutorMetrics {
pub total_operations: u64,
pub prefill_operations: u64,
pub decode_operations: u64,
pub avg_prefill_latency: f64,
pub avg_decode_latency: f64,
pub p95_prefill_latency: f64,
pub p95_decode_latency: f64,
pub throughput_tps: f64,
pub memory_efficiency: f32,
pub batch_utilization: f32,
}
pub trait ExecutorRegistry: Send + Sync {
fn register(&mut self, name: &str, executor: Box<dyn ModelExecutor>) -> Result<()>;
fn get(&self, name: &str) -> Option<&dyn ModelExecutor>;
fn remove(&mut self, name: &str) -> Option<Box<dyn ModelExecutor>>;
fn list_names(&self) -> Vec<String>;
fn get_metrics(&self, name: &str) -> Option<ExecutorMetrics>;
}