Skip to main content

ferrum_types/
config.rs

1//! Configuration types for Ferrum components
2
3use crate::{
4    parse_bool_env_value, parse_usize_env_value, DataType, Device, ModelId, ModelInfo,
5    RuntimeConfigSnapshot, SamplingParams, SamplingPresets,
6};
7use serde::{Deserialize, Serialize};
8use std::{collections::HashMap, time::Duration};
9
10/// Engine configuration
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct EngineConfig {
13    pub model: EngineModelConfig,
14    pub scheduler: SchedulerConfig,
15    pub sampling: SamplingConfig,
16    pub backend: BackendConfig,
17    pub kv_cache: KvCacheConfig,
18    pub memory: MemoryConfig,
19    pub batching: BatchConfig,
20    pub monitoring: MonitoringConfig,
21}
22
23impl EngineConfig {
24    pub fn apply_runtime_config_snapshot(
25        &mut self,
26        snapshot: &RuntimeConfigSnapshot,
27    ) -> std::result::Result<(), String> {
28        self.scheduler.apply_runtime_config_snapshot(snapshot)?;
29        if let Some(value) = runtime_config_value(snapshot, "FERRUM_KV_MAX_BLOCKS") {
30            self.kv_cache.max_blocks =
31                parse_required_positive_usize("FERRUM_KV_MAX_BLOCKS", value)?;
32        }
33        if let Some(value) = runtime_config_value(snapshot, "FERRUM_MAX_BATCHED_TOKENS") {
34            self.batching.max_num_batched_tokens =
35                parse_required_positive_usize("FERRUM_MAX_BATCHED_TOKENS", value)?;
36        }
37        if let Some(value) = runtime_config_value(snapshot, "FERRUM_PAGED_MAX_SEQS") {
38            self.scheduler.max_running_requests =
39                parse_required_positive_usize("FERRUM_PAGED_MAX_SEQS", value)?;
40        }
41        Ok(())
42    }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct EngineModelConfig {
47    pub model_id: ModelId,
48    pub model_info: Option<ModelInfo>,
49    pub tokenizer: TokenizerConfig,
50}
51
52impl Default for EngineModelConfig {
53    fn default() -> Self {
54        Self {
55            model_id: ModelId::new("default"),
56            model_info: None,
57            tokenizer: TokenizerConfig::default(),
58        }
59    }
60}
61
62/// Scheduler configuration
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchedulerConfig {
65    /// Scheduling policy
66    pub policy: SchedulingPolicy,
67    /// Maximum waiting queue size
68    pub max_waiting_requests: usize,
69    /// Maximum running requests
70    pub max_running_requests: usize,
71    /// Enable request preemption
72    pub enable_preemption: bool,
73    /// Enable load balancing
74    pub enable_load_balancing: bool,
75    /// Fair share weights per client
76    pub fair_share_weights: HashMap<String, f32>,
77    /// SLA enforcement enabled
78    pub enable_sla_enforcement: bool,
79    /// Use prompt-token metadata for initial continuous-batch admission estimates.
80    #[serde(default)]
81    pub prompt_token_estimate: bool,
82    /// Prefer new prefills over early decodes until this many requests are active.
83    #[serde(default)]
84    pub prefill_first_until_active: Option<usize>,
85    /// Cap prefill admission chunks only while decode requests are already active.
86    #[serde(default)]
87    pub active_decode_prefill_chunk: Option<usize>,
88    /// Emit diagnostic scheduler None/SOME decisions.
89    #[serde(default)]
90    pub scheduler_none_prof: bool,
91}
92
93impl Default for SchedulerConfig {
94    fn default() -> Self {
95        Self {
96            policy: SchedulingPolicy::Priority,
97            max_waiting_requests: 1000,
98            max_running_requests: 32,
99            enable_preemption: true,
100            enable_load_balancing: false,
101            fair_share_weights: HashMap::new(),
102            enable_sla_enforcement: false,
103            prompt_token_estimate: false,
104            prefill_first_until_active: None,
105            active_decode_prefill_chunk: None,
106            scheduler_none_prof: false,
107        }
108    }
109}
110
111impl SchedulerConfig {
112    pub fn apply_runtime_config_snapshot(
113        &mut self,
114        snapshot: &RuntimeConfigSnapshot,
115    ) -> std::result::Result<(), String> {
116        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SCHED_PROMPT_TOKEN_ESTIMATE") {
117            self.prompt_token_estimate = parse_bool_env_value(value)
118                .map_err(|reason| format!("FERRUM_SCHED_PROMPT_TOKEN_ESTIMATE: {reason}"))?;
119        }
120        if let Some(value) =
121            runtime_config_value(snapshot, "FERRUM_SCHED_PREFILL_FIRST_UNTIL_ACTIVE")
122        {
123            self.prefill_first_until_active =
124                parse_optional_positive_usize("FERRUM_SCHED_PREFILL_FIRST_UNTIL_ACTIVE", value)?;
125        }
126        if let Some(value) = runtime_config_value(snapshot, "FERRUM_ACTIVE_DECODE_PREFILL_CHUNK") {
127            self.active_decode_prefill_chunk =
128                parse_optional_positive_usize("FERRUM_ACTIVE_DECODE_PREFILL_CHUNK", value)?;
129        }
130        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SCHED_NONE_PROF") {
131            self.scheduler_none_prof = parse_presence_bool(value)
132                .map_err(|reason| format!("FERRUM_SCHED_NONE_PROF: {reason}"))?;
133        }
134        Ok(())
135    }
136}
137
138fn runtime_config_value<'a>(snapshot: &'a RuntimeConfigSnapshot, key: &str) -> Option<&'a str> {
139    snapshot
140        .entries
141        .iter()
142        .find(|entry| entry.key == key)
143        .map(|entry| entry.effective_value.as_str())
144}
145
146fn parse_optional_positive_usize(
147    key: &str,
148    value: &str,
149) -> std::result::Result<Option<usize>, String> {
150    let parsed = parse_usize_env_value(value).map_err(|reason| format!("{key}: {reason}"))?;
151    Ok((parsed > 0).then_some(parsed))
152}
153
154fn parse_required_positive_usize(key: &str, value: &str) -> std::result::Result<usize, String> {
155    let parsed = parse_usize_env_value(value).map_err(|reason| format!("{key}: {reason}"))?;
156    if parsed == 0 {
157        Err(format!("{key}: must be greater than zero"))
158    } else {
159        Ok(parsed)
160    }
161}
162
163fn parse_presence_bool(value: &str) -> std::result::Result<bool, String> {
164    if value.trim().is_empty() {
165        Ok(true)
166    } else {
167        parse_bool_env_value(value)
168    }
169}
170
171/// Scheduling policies
172#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
173pub enum SchedulingPolicy {
174    /// First-Come-First-Served
175    FCFS,
176    /// Priority-based scheduling
177    Priority,
178    /// Fair-share scheduling
179    FairShare,
180    /// Shortest-Job-First
181    SJF,
182    /// Round-Robin
183    RoundRobin,
184    /// Iteration-level continuous batching with preemption
185    ContinuousBatch,
186}
187
188/// KV Cache configuration
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct KvCacheConfig {
191    /// Cache implementation type
192    pub cache_type: KvCacheType,
193    /// Element dtype (Dim 5 polymorphism point). FP16 is the
194    /// validated production path; INT8 / FP8 require a backend impl
195    /// of `BackendKvDtype<KvInt8>` / `BackendKvDtype<KvFp8>` and a
196    /// model wired through `KvCacheQuant<B, K>`.
197    #[serde(default)]
198    pub dtype: KvCacheDtype,
199    /// Block size for paged attention
200    pub block_size: usize,
201    /// Maximum number of blocks
202    pub max_blocks: usize,
203    /// Enable cache compression
204    pub enable_compression: bool,
205    /// Compression ratio target
206    pub compression_ratio: f32,
207    /// Enable multi-level caching (GPU + CPU)
208    pub enable_multi_level: bool,
209    /// Swap threshold (when to move to CPU)
210    pub swap_threshold: f32,
211    /// Enable prefix caching
212    pub enable_prefix_caching: bool,
213    /// Prefix cache size
214    pub prefix_cache_size: usize,
215}
216
217impl Default for KvCacheConfig {
218    fn default() -> Self {
219        // 2048 blocks covers c=32 ShareGPT prompts (~32Ɨ500/16 = 1000
220        // blocks). The previous 1024 floor crashed at c≄16 on real
221        // workloads with "Block pool exhausted". Runtime overrides are
222        // applied through EngineConfig::apply_runtime_config_snapshot.
223        Self {
224            cache_type: KvCacheType::Contiguous,
225            dtype: KvCacheDtype::default(),
226            block_size: 16,
227            max_blocks: 2048,
228            enable_compression: false,
229            compression_ratio: 0.5,
230            enable_multi_level: true,
231            swap_threshold: 0.8,
232            enable_prefix_caching: true,
233            prefix_cache_size: 100,
234        }
235    }
236}
237
238/// KV Cache implementation types
239#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
240pub enum KvCacheType {
241    /// Simple contiguous memory allocation
242    Contiguous,
243    /// Paged attention with block-based allocation
244    Paged,
245    /// Tree-based cache for prefix sharing
246    Tree,
247}
248
249/// KV Cache element dtype (Dim 5 polymorphism point).
250///
251/// Mirrors `ferrum_interfaces::kv_dtype::KvDtypeKind` markers but
252/// lives here because `KvCacheConfig` is part of the user-facing
253/// `EngineConfig` and needs `Serialize` / `Deserialize`.
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
255#[serde(rename_all = "lowercase")]
256pub enum KvCacheDtype {
257    /// FP16 K/V — the validated production path on every backend.
258    #[default]
259    Fp16,
260    /// BF16 K/V — same memory cost as FP16, slightly different precision.
261    /// Marker only; no backend impl ships yet.
262    Bf16,
263    /// INT8 K/V with per-token per-kv-head FP16 scale (vLLM-style).
264    /// Halves KV memory at small (<1%) accuracy hit. CUDA kernels
265    /// land via `BackendKvDtype<KvInt8>` (PR #131); model wire-up
266    /// (`KvCacheQuant<B, KvInt8>` through the model decode loop) is
267    /// the only remaining step.
268    Int8,
269    /// FP8 (E4M3) K/V. Marker only; CUDA kernels pending.
270    Fp8,
271}
272
273impl KvCacheDtype {
274    /// Parse from a CLI / env-var string. Accepts fp16/f16/bf16/int8/fp8/f8e4m3.
275    pub fn parse(s: &str) -> Option<Self> {
276        match s.trim().to_ascii_lowercase().as_str() {
277            "fp16" | "f16" | "float16" => Some(Self::Fp16),
278            "bf16" | "bfloat16" => Some(Self::Bf16),
279            "int8" | "i8" => Some(Self::Int8),
280            "fp8" | "f8" | "f8e4m3" | "e4m3" => Some(Self::Fp8),
281            _ => None,
282        }
283    }
284
285    /// Short label for display + telemetry.
286    pub fn as_str(&self) -> &'static str {
287        match self {
288            Self::Fp16 => "fp16",
289            Self::Bf16 => "bf16",
290            Self::Int8 => "int8",
291            Self::Fp8 => "fp8",
292        }
293    }
294}
295
296/// Memory management configuration
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct MemoryConfig {
299    /// Memory pool size in bytes
300    pub pool_size: Option<usize>,
301    /// Enable memory pooling
302    pub enable_pooling: bool,
303    /// Memory alignment in bytes
304    pub alignment: usize,
305    /// Enable memory defragmentation
306    pub enable_defragmentation: bool,
307    /// Defragmentation threshold
308    pub defragmentation_threshold: f32,
309    /// Enable memory statistics tracking
310    pub enable_memory_stats: bool,
311    /// Memory pressure warning threshold
312    pub pressure_warning_threshold: f32,
313    /// Memory pressure critical threshold
314    pub pressure_critical_threshold: f32,
315}
316
317impl Default for MemoryConfig {
318    fn default() -> Self {
319        Self {
320            pool_size: None,
321            enable_pooling: true,
322            alignment: 256,
323            enable_defragmentation: false,
324            defragmentation_threshold: 0.7,
325            enable_memory_stats: true,
326            pressure_warning_threshold: 0.8,
327            pressure_critical_threshold: 0.95,
328        }
329    }
330}
331
332/// Backend configuration
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct BackendConfig {
335    /// Backend type
336    pub backend_type: BackendType,
337    /// Target device
338    pub device: Device,
339    /// Data type for computation
340    pub dtype: DataType,
341    /// Enable optimizations
342    pub enable_optimizations: bool,
343    /// Optimization level (0-3)
344    pub optimization_level: u8,
345    /// Enable CUDA graphs
346    pub enable_cuda_graphs: bool,
347    /// Enable kernel fusion
348    pub enable_kernel_fusion: bool,
349    /// Custom backend-specific options
350    pub backend_options: HashMap<String, serde_json::Value>,
351}
352
353impl Default for BackendConfig {
354    fn default() -> Self {
355        Self {
356            backend_type: BackendType::Candle,
357            device: Device::CPU,
358            dtype: DataType::FP16,
359            enable_optimizations: true,
360            optimization_level: 2,
361            enable_cuda_graphs: false,
362            enable_kernel_fusion: true,
363            backend_options: HashMap::new(),
364        }
365    }
366}
367
368/// Supported backend types
369#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
370pub enum BackendType {
371    /// Candle framework
372    Candle,
373    /// ONNX Runtime
374    OnnxRuntime,
375    /// TensorRT
376    TensorRT,
377    /// Custom backend
378    Custom,
379}
380
381/// Tokenizer configuration
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct TokenizerConfig {
384    /// Tokenizer type
385    pub tokenizer_type: TokenizerType,
386    /// Path to tokenizer files
387    pub tokenizer_path: Option<String>,
388    /// Enable fast tokenization
389    pub enable_fast: bool,
390    /// Add special tokens
391    pub add_special_tokens: bool,
392    /// Truncation strategy
393    pub truncation: Option<TruncationConfig>,
394    /// Padding strategy
395    pub padding: Option<PaddingConfig>,
396}
397
398impl Default for TokenizerConfig {
399    fn default() -> Self {
400        Self {
401            tokenizer_type: TokenizerType::BPE,
402            tokenizer_path: None,
403            enable_fast: true,
404            add_special_tokens: true,
405            truncation: None,
406            padding: None,
407        }
408    }
409}
410
411/// Tokenizer algorithms
412#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
413pub enum TokenizerType {
414    /// Byte Pair Encoding
415    BPE,
416    /// WordPiece tokenizer (BERT-style)
417    WordPiece,
418    /// SentencePiece tokenizer
419    SentencePiece,
420    /// Tiktoken tokenizer family
421    Tiktoken,
422    /// Any custom tokenizer implementation
423    Custom,
424}
425
426/// Truncation configuration
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct TruncationConfig {
429    /// Maximum length
430    pub max_length: usize,
431    /// Truncation strategy
432    pub strategy: TruncationStrategy,
433}
434
435/// Truncation strategies
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub enum TruncationStrategy {
438    /// Remove from the beginning
439    TruncateStart,
440    /// Remove from the end
441    TruncateEnd,
442    /// Remove from both sides
443    TruncateBoth,
444}
445
446/// Padding configuration
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct PaddingConfig {
449    /// Padding strategy
450    pub strategy: PaddingStrategy,
451    /// Padding token ID
452    pub token_id: u32,
453    /// Target length
454    pub target_length: Option<usize>,
455}
456
457/// Padding strategies
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub enum PaddingStrategy {
460    /// No padding
461    None,
462    /// Pad to maximum length in batch
463    MaxLength,
464    /// Pad to specific length
465    FixedLength,
466}
467
468/// Sampling configuration presets
469
470/// Security configuration
471#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct SecurityConfig {
473    /// Enable API authentication
474    pub enable_auth: bool,
475    /// API keys for authentication
476    pub api_keys: Vec<String>,
477    /// Enable rate limiting
478    pub enable_rate_limiting: bool,
479    /// Rate limit per client (requests per minute)
480    pub rate_limit_rpm: u32,
481    /// Enable content filtering
482    pub enable_content_filter: bool,
483    /// Maximum prompt length
484    pub max_prompt_length: usize,
485    /// Enable prompt validation
486    pub enable_prompt_validation: bool,
487    /// Allowed file extensions for uploads
488    pub allowed_extensions: Vec<String>,
489}
490
491impl Default for SecurityConfig {
492    fn default() -> Self {
493        Self {
494            enable_auth: false,
495            api_keys: vec![],
496            enable_rate_limiting: true,
497            rate_limit_rpm: 60,
498            enable_content_filter: false,
499            max_prompt_length: 32768,
500            enable_prompt_validation: true,
501            allowed_extensions: vec!["txt".to_string(), "json".to_string()],
502        }
503    }
504}
505
506#[derive(Debug, Clone, Serialize, Deserialize, Default)]
507pub struct SamplingConfig {
508    pub default_params: SamplingParams,
509    pub presets: SamplingPresets,
510    pub enable_custom_processors: bool,
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize)]
514pub struct MonitoringConfig {
515    pub enable_metrics: bool,
516    pub enable_tracing: bool,
517    pub export_interval: Duration,
518}
519
520impl Default for MonitoringConfig {
521    fn default() -> Self {
522        Self {
523            enable_metrics: true,
524            enable_tracing: true,
525            export_interval: Duration::from_secs(5),
526        }
527    }
528}
529
530#[derive(Debug, Clone, Serialize, Deserialize)]
531pub struct BatchConfig {
532    pub max_batch_size: usize,
533    pub max_wait_ms: u64,
534    pub enable_dynamic: bool,
535    pub enable_continuous: bool,
536    /// vLLM-style per-iteration token budget. The scheduler emits a
537    /// mixed prefill+decode batch summing to at most this many Q
538    /// tokens (decode = 1 each, prefill chunk = its chunk size).
539    /// Default 2048. Runtime snapshots can override this with
540    /// `FERRUM_MAX_BATCHED_TOKENS`, usually from the GPU autosizer or a
541    /// named workload preset rather than a user hand-written env bundle.
542    #[serde(default = "BatchConfig::default_max_num_batched_tokens")]
543    pub max_num_batched_tokens: usize,
544}
545
546impl BatchConfig {
547    fn default_max_num_batched_tokens() -> usize {
548        2048
549    }
550}
551
552impl Default for BatchConfig {
553    fn default() -> Self {
554        Self {
555            max_batch_size: 32,
556            max_wait_ms: 8,
557            enable_dynamic: true,
558            enable_continuous: false,
559            max_num_batched_tokens: Self::default_max_num_batched_tokens(),
560        }
561    }
562}