Skip to main content

ferrum_types/
config.rs

1//! Configuration types for Ferrum components
2
3use crate::{
4    parse_bool_env_value, parse_usize_env_value, DataType, Device, ModelId, ModelInfo,
5    RuntimeConfigSnapshot, SamplingParams, SamplingPresets,
6};
7use serde::{Deserialize, Serialize};
8use std::{collections::HashMap, time::Duration};
9
10/// Engine configuration
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct EngineConfig {
13    pub model: EngineModelConfig,
14    pub scheduler: SchedulerConfig,
15    pub sampling: SamplingConfig,
16    pub backend: BackendConfig,
17    pub kv_cache: KvCacheConfig,
18    pub memory: MemoryConfig,
19    pub batching: BatchConfig,
20    pub monitoring: MonitoringConfig,
21}
22
23impl EngineConfig {
24    pub fn apply_runtime_config_snapshot(
25        &mut self,
26        snapshot: &RuntimeConfigSnapshot,
27    ) -> std::result::Result<(), String> {
28        self.scheduler.apply_runtime_config_snapshot(snapshot)?;
29        if let Some(value) = runtime_config_value(snapshot, "FERRUM_KV_MAX_BLOCKS") {
30            self.kv_cache.max_blocks =
31                parse_required_positive_usize("FERRUM_KV_MAX_BLOCKS", value)?;
32        }
33        if let Some(value) = runtime_config_value(snapshot, "FERRUM_MAX_BATCHED_TOKENS") {
34            self.batching.max_num_batched_tokens =
35                parse_required_positive_usize("FERRUM_MAX_BATCHED_TOKENS", value)?;
36        }
37        Ok(())
38    }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct EngineModelConfig {
43    pub model_id: ModelId,
44    pub model_info: Option<ModelInfo>,
45    pub tokenizer: TokenizerConfig,
46}
47
48impl Default for EngineModelConfig {
49    fn default() -> Self {
50        Self {
51            model_id: ModelId::new("default"),
52            model_info: None,
53            tokenizer: TokenizerConfig::default(),
54        }
55    }
56}
57
58/// Scheduler configuration
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct SchedulerConfig {
61    /// Scheduling policy
62    pub policy: SchedulingPolicy,
63    /// Maximum waiting queue size
64    pub max_waiting_requests: usize,
65    /// Maximum running requests
66    pub max_running_requests: usize,
67    /// Enable request preemption
68    pub enable_preemption: bool,
69    /// Enable load balancing
70    pub enable_load_balancing: bool,
71    /// Fair share weights per client
72    pub fair_share_weights: HashMap<String, f32>,
73    /// SLA enforcement enabled
74    pub enable_sla_enforcement: bool,
75    /// Use prompt-token metadata for initial continuous-batch admission estimates.
76    #[serde(default)]
77    pub prompt_token_estimate: bool,
78    /// Prefer new prefills over early decodes until this many requests are active.
79    #[serde(default)]
80    pub prefill_first_until_active: Option<usize>,
81    /// Cap prefill admission chunks only while decode requests are already active.
82    #[serde(default)]
83    pub active_decode_prefill_chunk: Option<usize>,
84    /// Emit diagnostic scheduler None/SOME decisions.
85    #[serde(default)]
86    pub scheduler_none_prof: bool,
87}
88
89impl Default for SchedulerConfig {
90    fn default() -> Self {
91        Self {
92            policy: SchedulingPolicy::Priority,
93            max_waiting_requests: 1000,
94            max_running_requests: 32,
95            enable_preemption: true,
96            enable_load_balancing: false,
97            fair_share_weights: HashMap::new(),
98            enable_sla_enforcement: false,
99            prompt_token_estimate: false,
100            prefill_first_until_active: None,
101            active_decode_prefill_chunk: None,
102            scheduler_none_prof: false,
103        }
104    }
105}
106
107impl SchedulerConfig {
108    pub fn apply_runtime_config_snapshot(
109        &mut self,
110        snapshot: &RuntimeConfigSnapshot,
111    ) -> std::result::Result<(), String> {
112        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SCHED_PROMPT_TOKEN_ESTIMATE") {
113            self.prompt_token_estimate = parse_bool_env_value(value)
114                .map_err(|reason| format!("FERRUM_SCHED_PROMPT_TOKEN_ESTIMATE: {reason}"))?;
115        }
116        if let Some(value) =
117            runtime_config_value(snapshot, "FERRUM_SCHED_PREFILL_FIRST_UNTIL_ACTIVE")
118        {
119            self.prefill_first_until_active =
120                parse_optional_positive_usize("FERRUM_SCHED_PREFILL_FIRST_UNTIL_ACTIVE", value)?;
121        }
122        if let Some(value) = runtime_config_value(snapshot, "FERRUM_ACTIVE_DECODE_PREFILL_CHUNK") {
123            self.active_decode_prefill_chunk =
124                parse_optional_positive_usize("FERRUM_ACTIVE_DECODE_PREFILL_CHUNK", value)?;
125        }
126        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SCHED_NONE_PROF") {
127            self.scheduler_none_prof = parse_presence_bool(value)
128                .map_err(|reason| format!("FERRUM_SCHED_NONE_PROF: {reason}"))?;
129        }
130        Ok(())
131    }
132}
133
134fn runtime_config_value<'a>(snapshot: &'a RuntimeConfigSnapshot, key: &str) -> Option<&'a str> {
135    snapshot
136        .entries
137        .iter()
138        .find(|entry| entry.key == key)
139        .map(|entry| entry.effective_value.as_str())
140}
141
142fn parse_optional_positive_usize(
143    key: &str,
144    value: &str,
145) -> std::result::Result<Option<usize>, String> {
146    let parsed = parse_usize_env_value(value).map_err(|reason| format!("{key}: {reason}"))?;
147    Ok((parsed > 0).then_some(parsed))
148}
149
150fn parse_required_positive_usize(key: &str, value: &str) -> std::result::Result<usize, String> {
151    let parsed = parse_usize_env_value(value).map_err(|reason| format!("{key}: {reason}"))?;
152    if parsed == 0 {
153        Err(format!("{key}: must be greater than zero"))
154    } else {
155        Ok(parsed)
156    }
157}
158
159fn parse_presence_bool(value: &str) -> std::result::Result<bool, String> {
160    if value.trim().is_empty() {
161        Ok(true)
162    } else {
163        parse_bool_env_value(value)
164    }
165}
166
167/// Scheduling policies
168#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
169pub enum SchedulingPolicy {
170    /// First-Come-First-Served
171    FCFS,
172    /// Priority-based scheduling
173    Priority,
174    /// Fair-share scheduling
175    FairShare,
176    /// Shortest-Job-First
177    SJF,
178    /// Round-Robin
179    RoundRobin,
180    /// Iteration-level continuous batching with preemption
181    ContinuousBatch,
182}
183
184/// KV Cache configuration
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct KvCacheConfig {
187    /// Cache implementation type
188    pub cache_type: KvCacheType,
189    /// Element dtype (Dim 5 polymorphism point). FP16 is the
190    /// validated production path; INT8 / FP8 require a backend impl
191    /// of `BackendKvDtype<KvInt8>` / `BackendKvDtype<KvFp8>` and a
192    /// model wired through `KvCacheQuant<B, K>`.
193    #[serde(default)]
194    pub dtype: KvCacheDtype,
195    /// Block size for paged attention
196    pub block_size: usize,
197    /// Maximum number of blocks
198    pub max_blocks: usize,
199    /// Enable cache compression
200    pub enable_compression: bool,
201    /// Compression ratio target
202    pub compression_ratio: f32,
203    /// Enable multi-level caching (GPU + CPU)
204    pub enable_multi_level: bool,
205    /// Swap threshold (when to move to CPU)
206    pub swap_threshold: f32,
207    /// Enable prefix caching
208    pub enable_prefix_caching: bool,
209    /// Prefix cache size
210    pub prefix_cache_size: usize,
211}
212
213impl Default for KvCacheConfig {
214    fn default() -> Self {
215        // 2048 blocks covers c=32 ShareGPT prompts (~32Ɨ500/16 = 1000
216        // blocks). The previous 1024 floor crashed at c≄16 on real
217        // workloads with "Block pool exhausted". Runtime overrides are
218        // applied through EngineConfig::apply_runtime_config_snapshot.
219        Self {
220            cache_type: KvCacheType::Contiguous,
221            dtype: KvCacheDtype::default(),
222            block_size: 16,
223            max_blocks: 2048,
224            enable_compression: false,
225            compression_ratio: 0.5,
226            enable_multi_level: true,
227            swap_threshold: 0.8,
228            enable_prefix_caching: true,
229            prefix_cache_size: 100,
230        }
231    }
232}
233
234/// KV Cache implementation types
235#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
236pub enum KvCacheType {
237    /// Simple contiguous memory allocation
238    Contiguous,
239    /// Paged attention with block-based allocation
240    Paged,
241    /// Tree-based cache for prefix sharing
242    Tree,
243}
244
245/// KV Cache element dtype (Dim 5 polymorphism point).
246///
247/// Mirrors `ferrum_interfaces::kv_dtype::KvDtypeKind` markers but
248/// lives here because `KvCacheConfig` is part of the user-facing
249/// `EngineConfig` and needs `Serialize` / `Deserialize`.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum KvCacheDtype {
253    /// FP16 K/V — the validated production path on every backend.
254    #[default]
255    Fp16,
256    /// BF16 K/V — same memory cost as FP16, slightly different precision.
257    /// Marker only; no backend impl ships yet.
258    Bf16,
259    /// INT8 K/V with per-token per-kv-head FP16 scale (vLLM-style).
260    /// Halves KV memory at small (<1%) accuracy hit. CUDA kernels
261    /// land via `BackendKvDtype<KvInt8>` (PR #131); model wire-up
262    /// (`KvCacheQuant<B, KvInt8>` through the model decode loop) is
263    /// the only remaining step.
264    Int8,
265    /// FP8 (E4M3) K/V. Marker only; CUDA kernels pending.
266    Fp8,
267}
268
269impl KvCacheDtype {
270    /// Parse from a CLI / env-var string. Accepts fp16/f16/bf16/int8/fp8/f8e4m3.
271    pub fn parse(s: &str) -> Option<Self> {
272        match s.trim().to_ascii_lowercase().as_str() {
273            "fp16" | "f16" | "float16" => Some(Self::Fp16),
274            "bf16" | "bfloat16" => Some(Self::Bf16),
275            "int8" | "i8" => Some(Self::Int8),
276            "fp8" | "f8" | "f8e4m3" | "e4m3" => Some(Self::Fp8),
277            _ => None,
278        }
279    }
280
281    /// Short label for display + telemetry.
282    pub fn as_str(&self) -> &'static str {
283        match self {
284            Self::Fp16 => "fp16",
285            Self::Bf16 => "bf16",
286            Self::Int8 => "int8",
287            Self::Fp8 => "fp8",
288        }
289    }
290}
291
292/// Memory management configuration
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct MemoryConfig {
295    /// Memory pool size in bytes
296    pub pool_size: Option<usize>,
297    /// Enable memory pooling
298    pub enable_pooling: bool,
299    /// Memory alignment in bytes
300    pub alignment: usize,
301    /// Enable memory defragmentation
302    pub enable_defragmentation: bool,
303    /// Defragmentation threshold
304    pub defragmentation_threshold: f32,
305    /// Enable memory statistics tracking
306    pub enable_memory_stats: bool,
307    /// Memory pressure warning threshold
308    pub pressure_warning_threshold: f32,
309    /// Memory pressure critical threshold
310    pub pressure_critical_threshold: f32,
311}
312
313impl Default for MemoryConfig {
314    fn default() -> Self {
315        Self {
316            pool_size: None,
317            enable_pooling: true,
318            alignment: 256,
319            enable_defragmentation: false,
320            defragmentation_threshold: 0.7,
321            enable_memory_stats: true,
322            pressure_warning_threshold: 0.8,
323            pressure_critical_threshold: 0.95,
324        }
325    }
326}
327
328/// Backend configuration
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct BackendConfig {
331    /// Backend type
332    pub backend_type: BackendType,
333    /// Target device
334    pub device: Device,
335    /// Data type for computation
336    pub dtype: DataType,
337    /// Enable optimizations
338    pub enable_optimizations: bool,
339    /// Optimization level (0-3)
340    pub optimization_level: u8,
341    /// Enable CUDA graphs
342    pub enable_cuda_graphs: bool,
343    /// Enable kernel fusion
344    pub enable_kernel_fusion: bool,
345    /// Custom backend-specific options
346    pub backend_options: HashMap<String, serde_json::Value>,
347}
348
349impl Default for BackendConfig {
350    fn default() -> Self {
351        Self {
352            backend_type: BackendType::Candle,
353            device: Device::CPU,
354            dtype: DataType::FP16,
355            enable_optimizations: true,
356            optimization_level: 2,
357            enable_cuda_graphs: false,
358            enable_kernel_fusion: true,
359            backend_options: HashMap::new(),
360        }
361    }
362}
363
364/// Supported backend types
365#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
366pub enum BackendType {
367    /// Candle framework
368    Candle,
369    /// ONNX Runtime
370    OnnxRuntime,
371    /// TensorRT
372    TensorRT,
373    /// Custom backend
374    Custom,
375}
376
377/// Tokenizer configuration
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct TokenizerConfig {
380    /// Tokenizer type
381    pub tokenizer_type: TokenizerType,
382    /// Path to tokenizer files
383    pub tokenizer_path: Option<String>,
384    /// Enable fast tokenization
385    pub enable_fast: bool,
386    /// Add special tokens
387    pub add_special_tokens: bool,
388    /// Truncation strategy
389    pub truncation: Option<TruncationConfig>,
390    /// Padding strategy
391    pub padding: Option<PaddingConfig>,
392}
393
394impl Default for TokenizerConfig {
395    fn default() -> Self {
396        Self {
397            tokenizer_type: TokenizerType::BPE,
398            tokenizer_path: None,
399            enable_fast: true,
400            add_special_tokens: true,
401            truncation: None,
402            padding: None,
403        }
404    }
405}
406
407/// Tokenizer algorithms
408#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
409pub enum TokenizerType {
410    /// Byte Pair Encoding
411    BPE,
412    /// WordPiece tokenizer (BERT-style)
413    WordPiece,
414    /// SentencePiece tokenizer
415    SentencePiece,
416    /// Tiktoken tokenizer family
417    Tiktoken,
418    /// Any custom tokenizer implementation
419    Custom,
420}
421
422/// Truncation configuration
423#[derive(Debug, Clone, Serialize, Deserialize)]
424pub struct TruncationConfig {
425    /// Maximum length
426    pub max_length: usize,
427    /// Truncation strategy
428    pub strategy: TruncationStrategy,
429}
430
431/// Truncation strategies
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub enum TruncationStrategy {
434    /// Remove from the beginning
435    TruncateStart,
436    /// Remove from the end
437    TruncateEnd,
438    /// Remove from both sides
439    TruncateBoth,
440}
441
442/// Padding configuration
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct PaddingConfig {
445    /// Padding strategy
446    pub strategy: PaddingStrategy,
447    /// Padding token ID
448    pub token_id: u32,
449    /// Target length
450    pub target_length: Option<usize>,
451}
452
453/// Padding strategies
454#[derive(Debug, Clone, Serialize, Deserialize)]
455pub enum PaddingStrategy {
456    /// No padding
457    None,
458    /// Pad to maximum length in batch
459    MaxLength,
460    /// Pad to specific length
461    FixedLength,
462}
463
464/// Sampling configuration presets
465
466/// Security configuration
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct SecurityConfig {
469    /// Enable API authentication
470    pub enable_auth: bool,
471    /// API keys for authentication
472    pub api_keys: Vec<String>,
473    /// Enable rate limiting
474    pub enable_rate_limiting: bool,
475    /// Rate limit per client (requests per minute)
476    pub rate_limit_rpm: u32,
477    /// Enable content filtering
478    pub enable_content_filter: bool,
479    /// Maximum prompt length
480    pub max_prompt_length: usize,
481    /// Enable prompt validation
482    pub enable_prompt_validation: bool,
483    /// Allowed file extensions for uploads
484    pub allowed_extensions: Vec<String>,
485}
486
487impl Default for SecurityConfig {
488    fn default() -> Self {
489        Self {
490            enable_auth: false,
491            api_keys: vec![],
492            enable_rate_limiting: true,
493            rate_limit_rpm: 60,
494            enable_content_filter: false,
495            max_prompt_length: 32768,
496            enable_prompt_validation: true,
497            allowed_extensions: vec!["txt".to_string(), "json".to_string()],
498        }
499    }
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize, Default)]
503pub struct SamplingConfig {
504    pub default_params: SamplingParams,
505    pub presets: SamplingPresets,
506    pub enable_custom_processors: bool,
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
510pub struct MonitoringConfig {
511    pub enable_metrics: bool,
512    pub enable_tracing: bool,
513    pub export_interval: Duration,
514}
515
516impl Default for MonitoringConfig {
517    fn default() -> Self {
518        Self {
519            enable_metrics: true,
520            enable_tracing: true,
521            export_interval: Duration::from_secs(5),
522        }
523    }
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct BatchConfig {
528    pub max_batch_size: usize,
529    pub max_wait_ms: u64,
530    pub enable_dynamic: bool,
531    pub enable_continuous: bool,
532    /// vLLM-style per-iteration token budget. The scheduler emits a
533    /// mixed prefill+decode batch summing to at most this many Q
534    /// tokens (decode = 1 each, prefill chunk = its chunk size).
535    /// Default 2048. Runtime snapshots can override this with
536    /// `FERRUM_MAX_BATCHED_TOKENS`, usually from the GPU autosizer or a
537    /// named workload preset rather than a user hand-written env bundle.
538    #[serde(default = "BatchConfig::default_max_num_batched_tokens")]
539    pub max_num_batched_tokens: usize,
540}
541
542impl BatchConfig {
543    fn default_max_num_batched_tokens() -> usize {
544        2048
545    }
546}
547
548impl Default for BatchConfig {
549    fn default() -> Self {
550        Self {
551            max_batch_size: 32,
552            max_wait_ms: 8,
553            enable_dynamic: true,
554            enable_continuous: false,
555            max_num_batched_tokens: Self::default_max_num_batched_tokens(),
556        }
557    }
558}