Skip to main content

ferrum_types/
config.rs

1//! Configuration types for Ferrum components
2
3use crate::{DataType, Device, ModelId, ModelInfo, SamplingParams, SamplingPresets};
4use serde::{Deserialize, Serialize};
5use std::{collections::HashMap, time::Duration};
6
7/// Engine configuration
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct EngineConfig {
10    pub model: EngineModelConfig,
11    pub scheduler: SchedulerConfig,
12    pub sampling: SamplingConfig,
13    pub backend: BackendConfig,
14    pub kv_cache: KvCacheConfig,
15    pub memory: MemoryConfig,
16    pub batching: BatchConfig,
17    pub monitoring: MonitoringConfig,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct EngineModelConfig {
22    pub model_id: ModelId,
23    pub model_info: Option<ModelInfo>,
24    pub tokenizer: TokenizerConfig,
25}
26
27impl Default for EngineModelConfig {
28    fn default() -> Self {
29        Self {
30            model_id: ModelId::new("default"),
31            model_info: None,
32            tokenizer: TokenizerConfig::default(),
33        }
34    }
35}
36
37/// Scheduler configuration
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SchedulerConfig {
40    /// Scheduling policy
41    pub policy: SchedulingPolicy,
42    /// Maximum waiting queue size
43    pub max_waiting_requests: usize,
44    /// Maximum running requests
45    pub max_running_requests: usize,
46    /// Enable request preemption
47    pub enable_preemption: bool,
48    /// Enable load balancing
49    pub enable_load_balancing: bool,
50    /// Fair share weights per client
51    pub fair_share_weights: HashMap<String, f32>,
52    /// SLA enforcement enabled
53    pub enable_sla_enforcement: bool,
54}
55
56impl Default for SchedulerConfig {
57    fn default() -> Self {
58        Self {
59            policy: SchedulingPolicy::Priority,
60            max_waiting_requests: 1000,
61            max_running_requests: 256,
62            enable_preemption: true,
63            enable_load_balancing: false,
64            fair_share_weights: HashMap::new(),
65            enable_sla_enforcement: false,
66        }
67    }
68}
69
70/// Scheduling policies
71#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
72pub enum SchedulingPolicy {
73    /// First-Come-First-Served
74    FCFS,
75    /// Priority-based scheduling
76    Priority,
77    /// Fair-share scheduling
78    FairShare,
79    /// Shortest-Job-First
80    SJF,
81    /// Round-Robin
82    RoundRobin,
83    /// Iteration-level continuous batching with preemption
84    ContinuousBatch,
85}
86
87/// KV Cache configuration
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct KvCacheConfig {
90    /// Cache implementation type
91    pub cache_type: KvCacheType,
92    /// Block size for paged attention
93    pub block_size: usize,
94    /// Maximum number of blocks
95    pub max_blocks: usize,
96    /// Enable cache compression
97    pub enable_compression: bool,
98    /// Compression ratio target
99    pub compression_ratio: f32,
100    /// Enable multi-level caching (GPU + CPU)
101    pub enable_multi_level: bool,
102    /// Swap threshold (when to move to CPU)
103    pub swap_threshold: f32,
104    /// Enable prefix caching
105    pub enable_prefix_caching: bool,
106    /// Prefix cache size
107    pub prefix_cache_size: usize,
108}
109
110impl Default for KvCacheConfig {
111    fn default() -> Self {
112        Self {
113            cache_type: KvCacheType::Contiguous,
114            block_size: 16,
115            max_blocks: 1024,
116            enable_compression: false,
117            compression_ratio: 0.5,
118            enable_multi_level: true,
119            swap_threshold: 0.8,
120            enable_prefix_caching: true,
121            prefix_cache_size: 100,
122        }
123    }
124}
125
126/// KV Cache implementation types
127#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
128pub enum KvCacheType {
129    /// Simple contiguous memory allocation
130    Contiguous,
131    /// Paged attention with block-based allocation
132    Paged,
133    /// Tree-based cache for prefix sharing
134    Tree,
135}
136
137/// Memory management configuration
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct MemoryConfig {
140    /// Memory pool size in bytes
141    pub pool_size: Option<usize>,
142    /// Enable memory pooling
143    pub enable_pooling: bool,
144    /// Memory alignment in bytes
145    pub alignment: usize,
146    /// Enable memory defragmentation
147    pub enable_defragmentation: bool,
148    /// Defragmentation threshold
149    pub defragmentation_threshold: f32,
150    /// Enable memory statistics tracking
151    pub enable_memory_stats: bool,
152    /// Memory pressure warning threshold
153    pub pressure_warning_threshold: f32,
154    /// Memory pressure critical threshold
155    pub pressure_critical_threshold: f32,
156}
157
158impl Default for MemoryConfig {
159    fn default() -> Self {
160        Self {
161            pool_size: None,
162            enable_pooling: true,
163            alignment: 256,
164            enable_defragmentation: false,
165            defragmentation_threshold: 0.7,
166            enable_memory_stats: true,
167            pressure_warning_threshold: 0.8,
168            pressure_critical_threshold: 0.95,
169        }
170    }
171}
172
173/// Backend configuration
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct BackendConfig {
176    /// Backend type
177    pub backend_type: BackendType,
178    /// Target device
179    pub device: Device,
180    /// Data type for computation
181    pub dtype: DataType,
182    /// Enable optimizations
183    pub enable_optimizations: bool,
184    /// Optimization level (0-3)
185    pub optimization_level: u8,
186    /// Enable CUDA graphs
187    pub enable_cuda_graphs: bool,
188    /// Enable kernel fusion
189    pub enable_kernel_fusion: bool,
190    /// Custom backend-specific options
191    pub backend_options: HashMap<String, serde_json::Value>,
192}
193
194impl Default for BackendConfig {
195    fn default() -> Self {
196        Self {
197            backend_type: BackendType::Candle,
198            device: Device::CPU,
199            dtype: DataType::FP16,
200            enable_optimizations: true,
201            optimization_level: 2,
202            enable_cuda_graphs: false,
203            enable_kernel_fusion: true,
204            backend_options: HashMap::new(),
205        }
206    }
207}
208
209/// Supported backend types
210#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
211pub enum BackendType {
212    /// Candle framework
213    Candle,
214    /// ONNX Runtime
215    OnnxRuntime,
216    /// TensorRT
217    TensorRT,
218    /// Custom backend
219    Custom,
220}
221
222/// Tokenizer configuration
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct TokenizerConfig {
225    /// Tokenizer type
226    pub tokenizer_type: TokenizerType,
227    /// Path to tokenizer files
228    pub tokenizer_path: Option<String>,
229    /// Enable fast tokenization
230    pub enable_fast: bool,
231    /// Add special tokens
232    pub add_special_tokens: bool,
233    /// Truncation strategy
234    pub truncation: Option<TruncationConfig>,
235    /// Padding strategy
236    pub padding: Option<PaddingConfig>,
237}
238
239impl Default for TokenizerConfig {
240    fn default() -> Self {
241        Self {
242            tokenizer_type: TokenizerType::BPE,
243            tokenizer_path: None,
244            enable_fast: true,
245            add_special_tokens: true,
246            truncation: None,
247            padding: None,
248        }
249    }
250}
251
252/// Tokenizer algorithms
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
254pub enum TokenizerType {
255    /// Byte Pair Encoding
256    BPE,
257    /// WordPiece tokenizer (BERT-style)
258    WordPiece,
259    /// SentencePiece tokenizer
260    SentencePiece,
261    /// Tiktoken tokenizer family
262    Tiktoken,
263    /// Any custom tokenizer implementation
264    Custom,
265}
266
267/// Truncation configuration
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct TruncationConfig {
270    /// Maximum length
271    pub max_length: usize,
272    /// Truncation strategy
273    pub strategy: TruncationStrategy,
274}
275
276/// Truncation strategies
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub enum TruncationStrategy {
279    /// Remove from the beginning
280    TruncateStart,
281    /// Remove from the end
282    TruncateEnd,
283    /// Remove from both sides
284    TruncateBoth,
285}
286
287/// Padding configuration
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct PaddingConfig {
290    /// Padding strategy
291    pub strategy: PaddingStrategy,
292    /// Padding token ID
293    pub token_id: u32,
294    /// Target length
295    pub target_length: Option<usize>,
296}
297
298/// Padding strategies
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub enum PaddingStrategy {
301    /// No padding
302    None,
303    /// Pad to maximum length in batch
304    MaxLength,
305    /// Pad to specific length
306    FixedLength,
307}
308
309/// Sampling configuration presets
310
311/// Security configuration
312#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct SecurityConfig {
314    /// Enable API authentication
315    pub enable_auth: bool,
316    /// API keys for authentication
317    pub api_keys: Vec<String>,
318    /// Enable rate limiting
319    pub enable_rate_limiting: bool,
320    /// Rate limit per client (requests per minute)
321    pub rate_limit_rpm: u32,
322    /// Enable content filtering
323    pub enable_content_filter: bool,
324    /// Maximum prompt length
325    pub max_prompt_length: usize,
326    /// Enable prompt validation
327    pub enable_prompt_validation: bool,
328    /// Allowed file extensions for uploads
329    pub allowed_extensions: Vec<String>,
330}
331
332impl Default for SecurityConfig {
333    fn default() -> Self {
334        Self {
335            enable_auth: false,
336            api_keys: vec![],
337            enable_rate_limiting: true,
338            rate_limit_rpm: 60,
339            enable_content_filter: false,
340            max_prompt_length: 32768,
341            enable_prompt_validation: true,
342            allowed_extensions: vec!["txt".to_string(), "json".to_string()],
343        }
344    }
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize, Default)]
348pub struct SamplingConfig {
349    pub default_params: SamplingParams,
350    pub presets: SamplingPresets,
351    pub enable_custom_processors: bool,
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct MonitoringConfig {
356    pub enable_metrics: bool,
357    pub enable_tracing: bool,
358    pub export_interval: Duration,
359}
360
361impl Default for MonitoringConfig {
362    fn default() -> Self {
363        Self {
364            enable_metrics: true,
365            enable_tracing: true,
366            export_interval: Duration::from_secs(5),
367        }
368    }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct BatchConfig {
373    pub max_batch_size: usize,
374    pub max_wait_ms: u64,
375    pub enable_dynamic: bool,
376    pub enable_continuous: bool,
377}
378
379impl Default for BatchConfig {
380    fn default() -> Self {
381        Self {
382            max_batch_size: 16,
383            max_wait_ms: 8,
384            enable_dynamic: true,
385            enable_continuous: false,
386        }
387    }
388}