Skip to main content

ferrum_types/
config.rs

1//! Configuration types for Ferrum components
2
3use crate::{DataType, Device, ModelId, ModelInfo, SamplingParams, SamplingPresets};
4use serde::{Deserialize, Serialize};
5use std::{collections::HashMap, time::Duration};
6
7/// Engine configuration
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct EngineConfig {
10    pub model: EngineModelConfig,
11    pub scheduler: SchedulerConfig,
12    pub sampling: SamplingConfig,
13    pub backend: BackendConfig,
14    pub kv_cache: KvCacheConfig,
15    pub memory: MemoryConfig,
16    pub batching: BatchConfig,
17    pub monitoring: MonitoringConfig,
18}
19
20impl Default for EngineConfig {
21    fn default() -> Self {
22        Self {
23            model: EngineModelConfig::default(),
24            scheduler: SchedulerConfig::default(),
25            sampling: SamplingConfig::default(),
26            backend: BackendConfig::default(),
27            kv_cache: KvCacheConfig::default(),
28            memory: MemoryConfig::default(),
29            batching: BatchConfig::default(),
30            monitoring: MonitoringConfig::default(),
31        }
32    }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct EngineModelConfig {
37    pub model_id: ModelId,
38    pub model_info: Option<ModelInfo>,
39    pub tokenizer: TokenizerConfig,
40}
41
42impl Default for EngineModelConfig {
43    fn default() -> Self {
44        Self {
45            model_id: ModelId::new("default"),
46            model_info: None,
47            tokenizer: TokenizerConfig::default(),
48        }
49    }
50}
51
52/// Scheduler configuration
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct SchedulerConfig {
55    /// Scheduling policy
56    pub policy: SchedulingPolicy,
57    /// Maximum waiting queue size
58    pub max_waiting_requests: usize,
59    /// Maximum running requests
60    pub max_running_requests: usize,
61    /// Enable request preemption
62    pub enable_preemption: bool,
63    /// Enable load balancing
64    pub enable_load_balancing: bool,
65    /// Fair share weights per client
66    pub fair_share_weights: HashMap<String, f32>,
67    /// SLA enforcement enabled
68    pub enable_sla_enforcement: bool,
69}
70
71impl Default for SchedulerConfig {
72    fn default() -> Self {
73        Self {
74            policy: SchedulingPolicy::Priority,
75            max_waiting_requests: 1000,
76            max_running_requests: 256,
77            enable_preemption: true,
78            enable_load_balancing: false,
79            fair_share_weights: HashMap::new(),
80            enable_sla_enforcement: false,
81        }
82    }
83}
84
85/// Scheduling policies
86#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
87pub enum SchedulingPolicy {
88    /// First-Come-First-Served
89    FCFS,
90    /// Priority-based scheduling
91    Priority,
92    /// Fair-share scheduling
93    FairShare,
94    /// Shortest-Job-First
95    SJF,
96    /// Round-Robin
97    RoundRobin,
98    /// Iteration-level continuous batching with preemption
99    ContinuousBatch,
100}
101
102/// KV Cache configuration
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct KvCacheConfig {
105    /// Cache implementation type
106    pub cache_type: KvCacheType,
107    /// Block size for paged attention
108    pub block_size: usize,
109    /// Maximum number of blocks
110    pub max_blocks: usize,
111    /// Enable cache compression
112    pub enable_compression: bool,
113    /// Compression ratio target
114    pub compression_ratio: f32,
115    /// Enable multi-level caching (GPU + CPU)
116    pub enable_multi_level: bool,
117    /// Swap threshold (when to move to CPU)
118    pub swap_threshold: f32,
119    /// Enable prefix caching
120    pub enable_prefix_caching: bool,
121    /// Prefix cache size
122    pub prefix_cache_size: usize,
123}
124
125impl Default for KvCacheConfig {
126    fn default() -> Self {
127        Self {
128            cache_type: KvCacheType::Contiguous,
129            block_size: 16,
130            max_blocks: 1024,
131            enable_compression: false,
132            compression_ratio: 0.5,
133            enable_multi_level: true,
134            swap_threshold: 0.8,
135            enable_prefix_caching: true,
136            prefix_cache_size: 100,
137        }
138    }
139}
140
141/// KV Cache implementation types
142#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
143pub enum KvCacheType {
144    /// Simple contiguous memory allocation
145    Contiguous,
146    /// Paged attention with block-based allocation
147    Paged,
148    /// Tree-based cache for prefix sharing
149    Tree,
150}
151
152/// Memory management configuration
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct MemoryConfig {
155    /// Memory pool size in bytes
156    pub pool_size: Option<usize>,
157    /// Enable memory pooling
158    pub enable_pooling: bool,
159    /// Memory alignment in bytes
160    pub alignment: usize,
161    /// Enable memory defragmentation
162    pub enable_defragmentation: bool,
163    /// Defragmentation threshold
164    pub defragmentation_threshold: f32,
165    /// Enable memory statistics tracking
166    pub enable_memory_stats: bool,
167    /// Memory pressure warning threshold
168    pub pressure_warning_threshold: f32,
169    /// Memory pressure critical threshold
170    pub pressure_critical_threshold: f32,
171}
172
173impl Default for MemoryConfig {
174    fn default() -> Self {
175        Self {
176            pool_size: None,
177            enable_pooling: true,
178            alignment: 256,
179            enable_defragmentation: false,
180            defragmentation_threshold: 0.7,
181            enable_memory_stats: true,
182            pressure_warning_threshold: 0.8,
183            pressure_critical_threshold: 0.95,
184        }
185    }
186}
187
188/// Backend configuration
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct BackendConfig {
191    /// Backend type
192    pub backend_type: BackendType,
193    /// Target device
194    pub device: Device,
195    /// Data type for computation
196    pub dtype: DataType,
197    /// Enable optimizations
198    pub enable_optimizations: bool,
199    /// Optimization level (0-3)
200    pub optimization_level: u8,
201    /// Enable CUDA graphs
202    pub enable_cuda_graphs: bool,
203    /// Enable kernel fusion
204    pub enable_kernel_fusion: bool,
205    /// Custom backend-specific options
206    pub backend_options: HashMap<String, serde_json::Value>,
207}
208
209impl Default for BackendConfig {
210    fn default() -> Self {
211        Self {
212            backend_type: BackendType::Candle,
213            device: Device::CPU,
214            dtype: DataType::FP16,
215            enable_optimizations: true,
216            optimization_level: 2,
217            enable_cuda_graphs: false,
218            enable_kernel_fusion: true,
219            backend_options: HashMap::new(),
220        }
221    }
222}
223
224/// Supported backend types
225#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
226pub enum BackendType {
227    /// Candle framework
228    Candle,
229    /// ONNX Runtime
230    OnnxRuntime,
231    /// TensorRT
232    TensorRT,
233    /// Custom backend
234    Custom,
235}
236
237/// Tokenizer configuration
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct TokenizerConfig {
240    /// Tokenizer type
241    pub tokenizer_type: TokenizerType,
242    /// Path to tokenizer files
243    pub tokenizer_path: Option<String>,
244    /// Enable fast tokenization
245    pub enable_fast: bool,
246    /// Add special tokens
247    pub add_special_tokens: bool,
248    /// Truncation strategy
249    pub truncation: Option<TruncationConfig>,
250    /// Padding strategy
251    pub padding: Option<PaddingConfig>,
252}
253
254impl Default for TokenizerConfig {
255    fn default() -> Self {
256        Self {
257            tokenizer_type: TokenizerType::BPE,
258            tokenizer_path: None,
259            enable_fast: true,
260            add_special_tokens: true,
261            truncation: None,
262            padding: None,
263        }
264    }
265}
266
267/// Tokenizer algorithms
268#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
269pub enum TokenizerType {
270    /// Byte Pair Encoding
271    BPE,
272    /// WordPiece tokenizer (BERT-style)
273    WordPiece,
274    /// SentencePiece tokenizer
275    SentencePiece,
276    /// Tiktoken tokenizer family
277    Tiktoken,
278    /// Any custom tokenizer implementation
279    Custom,
280}
281
282/// Truncation configuration
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct TruncationConfig {
285    /// Maximum length
286    pub max_length: usize,
287    /// Truncation strategy
288    pub strategy: TruncationStrategy,
289}
290
291/// Truncation strategies
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub enum TruncationStrategy {
294    /// Remove from the beginning
295    TruncateStart,
296    /// Remove from the end
297    TruncateEnd,
298    /// Remove from both sides
299    TruncateBoth,
300}
301
302/// Padding configuration
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct PaddingConfig {
305    /// Padding strategy
306    pub strategy: PaddingStrategy,
307    /// Padding token ID
308    pub token_id: u32,
309    /// Target length
310    pub target_length: Option<usize>,
311}
312
313/// Padding strategies
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub enum PaddingStrategy {
316    /// No padding
317    None,
318    /// Pad to maximum length in batch
319    MaxLength,
320    /// Pad to specific length
321    FixedLength,
322}
323
324/// Sampling configuration presets
325
326/// Security configuration
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct SecurityConfig {
329    /// Enable API authentication
330    pub enable_auth: bool,
331    /// API keys for authentication
332    pub api_keys: Vec<String>,
333    /// Enable rate limiting
334    pub enable_rate_limiting: bool,
335    /// Rate limit per client (requests per minute)
336    pub rate_limit_rpm: u32,
337    /// Enable content filtering
338    pub enable_content_filter: bool,
339    /// Maximum prompt length
340    pub max_prompt_length: usize,
341    /// Enable prompt validation
342    pub enable_prompt_validation: bool,
343    /// Allowed file extensions for uploads
344    pub allowed_extensions: Vec<String>,
345}
346
347impl Default for SecurityConfig {
348    fn default() -> Self {
349        Self {
350            enable_auth: false,
351            api_keys: vec![],
352            enable_rate_limiting: true,
353            rate_limit_rpm: 60,
354            enable_content_filter: false,
355            max_prompt_length: 32768,
356            enable_prompt_validation: true,
357            allowed_extensions: vec!["txt".to_string(), "json".to_string()],
358        }
359    }
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct SamplingConfig {
364    pub default_params: SamplingParams,
365    pub presets: SamplingPresets,
366    pub enable_custom_processors: bool,
367}
368
369impl Default for SamplingConfig {
370    fn default() -> Self {
371        Self {
372            default_params: SamplingParams::default(),
373            presets: SamplingPresets::default(),
374            enable_custom_processors: false,
375        }
376    }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct MonitoringConfig {
381    pub enable_metrics: bool,
382    pub enable_tracing: bool,
383    pub export_interval: Duration,
384}
385
386impl Default for MonitoringConfig {
387    fn default() -> Self {
388        Self {
389            enable_metrics: true,
390            enable_tracing: true,
391            export_interval: Duration::from_secs(5),
392        }
393    }
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct BatchConfig {
398    pub max_batch_size: usize,
399    pub max_wait_ms: u64,
400    pub enable_dynamic: bool,
401    pub enable_continuous: bool,
402}
403
404impl Default for BatchConfig {
405    fn default() -> Self {
406        Self {
407            max_batch_size: 16,
408            max_wait_ms: 8,
409            enable_dynamic: true,
410            enable_continuous: false,
411        }
412    }
413}