use crate::{DataType, Device, ModelId, ModelInfo, SamplingParams, SamplingPresets};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, time::Duration};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EngineConfig {
pub model: EngineModelConfig,
pub scheduler: SchedulerConfig,
pub sampling: SamplingConfig,
pub backend: BackendConfig,
pub kv_cache: KvCacheConfig,
pub memory: MemoryConfig,
pub batching: BatchConfig,
pub monitoring: MonitoringConfig,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
model: EngineModelConfig::default(),
scheduler: SchedulerConfig::default(),
sampling: SamplingConfig::default(),
backend: BackendConfig::default(),
kv_cache: KvCacheConfig::default(),
memory: MemoryConfig::default(),
batching: BatchConfig::default(),
monitoring: MonitoringConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EngineModelConfig {
pub model_id: ModelId,
pub model_info: Option<ModelInfo>,
pub tokenizer: TokenizerConfig,
}
impl Default for EngineModelConfig {
fn default() -> Self {
Self {
model_id: ModelId::new("default"),
model_info: None,
tokenizer: TokenizerConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig {
pub policy: SchedulingPolicy,
pub max_waiting_requests: usize,
pub max_running_requests: usize,
pub enable_preemption: bool,
pub enable_load_balancing: bool,
pub fair_share_weights: HashMap<String, f32>,
pub enable_sla_enforcement: bool,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
policy: SchedulingPolicy::Priority,
max_waiting_requests: 1000,
max_running_requests: 256,
enable_preemption: true,
enable_load_balancing: false,
fair_share_weights: HashMap::new(),
enable_sla_enforcement: false,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SchedulingPolicy {
FCFS,
Priority,
FairShare,
SJF,
RoundRobin,
ContinuousBatch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvCacheConfig {
pub cache_type: KvCacheType,
pub block_size: usize,
pub max_blocks: usize,
pub enable_compression: bool,
pub compression_ratio: f32,
pub enable_multi_level: bool,
pub swap_threshold: f32,
pub enable_prefix_caching: bool,
pub prefix_cache_size: usize,
}
impl Default for KvCacheConfig {
fn default() -> Self {
Self {
cache_type: KvCacheType::Contiguous,
block_size: 16,
max_blocks: 1024,
enable_compression: false,
compression_ratio: 0.5,
enable_multi_level: true,
swap_threshold: 0.8,
enable_prefix_caching: true,
prefix_cache_size: 100,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum KvCacheType {
Contiguous,
Paged,
Tree,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub pool_size: Option<usize>,
pub enable_pooling: bool,
pub alignment: usize,
pub enable_defragmentation: bool,
pub defragmentation_threshold: f32,
pub enable_memory_stats: bool,
pub pressure_warning_threshold: f32,
pub pressure_critical_threshold: f32,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
pool_size: None,
enable_pooling: true,
alignment: 256,
enable_defragmentation: false,
defragmentation_threshold: 0.7,
enable_memory_stats: true,
pressure_warning_threshold: 0.8,
pressure_critical_threshold: 0.95,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendConfig {
pub backend_type: BackendType,
pub device: Device,
pub dtype: DataType,
pub enable_optimizations: bool,
pub optimization_level: u8,
pub enable_cuda_graphs: bool,
pub enable_kernel_fusion: bool,
pub backend_options: HashMap<String, serde_json::Value>,
}
impl Default for BackendConfig {
fn default() -> Self {
Self {
backend_type: BackendType::Candle,
device: Device::CPU,
dtype: DataType::FP16,
enable_optimizations: true,
optimization_level: 2,
enable_cuda_graphs: false,
enable_kernel_fusion: true,
backend_options: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum BackendType {
Candle,
OnnxRuntime,
TensorRT,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerConfig {
pub tokenizer_type: TokenizerType,
pub tokenizer_path: Option<String>,
pub enable_fast: bool,
pub add_special_tokens: bool,
pub truncation: Option<TruncationConfig>,
pub padding: Option<PaddingConfig>,
}
impl Default for TokenizerConfig {
fn default() -> Self {
Self {
tokenizer_type: TokenizerType::BPE,
tokenizer_path: None,
enable_fast: true,
add_special_tokens: true,
truncation: None,
padding: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenizerType {
BPE,
WordPiece,
SentencePiece,
Tiktoken,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationConfig {
pub max_length: usize,
pub strategy: TruncationStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TruncationStrategy {
TruncateStart,
TruncateEnd,
TruncateBoth,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaddingConfig {
pub strategy: PaddingStrategy,
pub token_id: u32,
pub target_length: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PaddingStrategy {
None,
MaxLength,
FixedLength,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub enable_auth: bool,
pub api_keys: Vec<String>,
pub enable_rate_limiting: bool,
pub rate_limit_rpm: u32,
pub enable_content_filter: bool,
pub max_prompt_length: usize,
pub enable_prompt_validation: bool,
pub allowed_extensions: Vec<String>,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
enable_auth: false,
api_keys: vec![],
enable_rate_limiting: true,
rate_limit_rpm: 60,
enable_content_filter: false,
max_prompt_length: 32768,
enable_prompt_validation: true,
allowed_extensions: vec!["txt".to_string(), "json".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingConfig {
pub default_params: SamplingParams,
pub presets: SamplingPresets,
pub enable_custom_processors: bool,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
default_params: SamplingParams::default(),
presets: SamplingPresets::default(),
enable_custom_processors: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
pub enable_metrics: bool,
pub enable_tracing: bool,
pub export_interval: Duration,
}
impl Default for MonitoringConfig {
fn default() -> Self {
Self {
enable_metrics: true,
enable_tracing: true,
export_interval: Duration::from_secs(5),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_wait_ms: u64,
pub enable_dynamic: bool,
pub enable_continuous: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 16,
max_wait_ms: 8,
enable_dynamic: true,
enable_continuous: false,
}
}
}