Skip to main content

ferrum_types/
config.rs

1//! Configuration types for Ferrum components
2
3use crate::{
4    parse_bool_env_value, parse_usize_env_value, DataType, Device, ModelId, ModelInfo,
5    RuntimeConfigSnapshot, SamplingParams, SamplingPresets,
6};
7use serde::{Deserialize, Serialize};
8use std::{collections::HashMap, time::Duration};
9
10/// Engine runtime knobs the CLI/autosizer resolves and injects via the
11/// runtime-config snapshot. The continuous engine reads these from the typed
12/// config instead of `std::env::vars()`, so the env bridge stays at the
13/// composition root and tests can vary the knobs per `EngineConfig`.
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
15pub struct RuntimeKnobs {
16    pub kv_capacity: Option<usize>,
17    pub max_model_len: Option<usize>,
18    pub chunked_prefill_size: Option<usize>,
19    pub batch_decode_prof: bool,
20    pub next_batch_prof: bool,
21    pub rbd_prof: bool,
22    pub unified_post_prof: bool,
23    pub prefix_cache_enabled: bool,
24
25    // Engine-build composition knobs. Previously read directly from the
26    // environment by `builder.rs` (FERRUM_MODEL_PATH / FERRUM_SPEC_DRAFT /
27    // FERRUM_SPEC_N) and `registry.rs` (FERRUM_DTYPE / FERRUM_METAL_DTYPE /
28    // FERRUM_TP). The CLI composition root now resolves them into this typed
29    // field so the engine builder and component registry read the snapshot,
30    // not `std::env`.
31    pub model_path: Option<String>,
32    pub spec_draft: Option<String>,
33    pub spec_n: Option<usize>,
34    pub dtype: Option<String>,
35    pub metal_dtype: Option<String>,
36    pub tp: Option<usize>,
37}
38
39/// Engine configuration
40#[derive(Debug, Clone, Serialize, Deserialize, Default)]
41pub struct EngineConfig {
42    pub model: EngineModelConfig,
43    pub scheduler: SchedulerConfig,
44    pub sampling: SamplingConfig,
45    pub backend: BackendConfig,
46    pub kv_cache: KvCacheConfig,
47    pub memory: MemoryConfig,
48    pub batching: BatchConfig,
49    pub monitoring: MonitoringConfig,
50    #[serde(default)]
51    pub runtime: RuntimeKnobs,
52}
53
54impl EngineConfig {
55    pub fn apply_runtime_config_snapshot(
56        &mut self,
57        snapshot: &RuntimeConfigSnapshot,
58    ) -> std::result::Result<(), String> {
59        self.scheduler.apply_runtime_config_snapshot(snapshot)?;
60        if let Some(value) = runtime_config_value(snapshot, "FERRUM_KV_MAX_BLOCKS") {
61            self.kv_cache.max_blocks =
62                parse_required_positive_usize("FERRUM_KV_MAX_BLOCKS", value)?;
63        }
64        if let Some(value) = runtime_config_value(snapshot, "FERRUM_MAX_BATCHED_TOKENS") {
65            self.batching.max_num_batched_tokens =
66                parse_required_positive_usize("FERRUM_MAX_BATCHED_TOKENS", value)?;
67        }
68        if let Some(value) = runtime_config_value(snapshot, "FERRUM_PAGED_MAX_SEQS") {
69            self.scheduler.max_running_requests =
70                parse_required_positive_usize("FERRUM_PAGED_MAX_SEQS", value)?;
71        }
72        // Engine runtime knobs (previously read by the engine from env). The
73        // CLI/autosizer resolves these into the snapshot; the engine reads the
74        // typed `runtime` field instead of `std::env::vars()`.
75        if let Some(value) = runtime_config_value(snapshot, "FERRUM_KV_CAPACITY") {
76            self.runtime.kv_capacity =
77                Some(parse_required_positive_usize("FERRUM_KV_CAPACITY", value)?);
78        }
79        if let Some(value) = runtime_config_value(snapshot, "FERRUM_MAX_MODEL_LEN") {
80            self.runtime.max_model_len = Some(parse_required_positive_usize(
81                "FERRUM_MAX_MODEL_LEN",
82                value,
83            )?);
84        }
85        if let Some(value) = runtime_config_value(snapshot, "FERRUM_CHUNKED_PREFILL") {
86            self.runtime.chunked_prefill_size =
87                parse_usize_env_value(value).ok().filter(|&v| v > 0);
88        }
89        self.runtime.batch_decode_prof |=
90            runtime_config_value(snapshot, "FERRUM_BATCH_DECODE_PROF").is_some();
91        self.runtime.next_batch_prof |=
92            runtime_config_value(snapshot, "FERRUM_NEXT_BATCH_PROF").is_some();
93        self.runtime.rbd_prof |= runtime_config_value(snapshot, "FERRUM_RBD_PROF").is_some();
94        self.runtime.unified_post_prof |=
95            runtime_config_value(snapshot, "FERRUM_UNIFIED_POST_PROF").is_some();
96        self.runtime.prefix_cache_enabled |=
97            runtime_config_value(snapshot, "FERRUM_WHOLE_PROMPT_PREFIX_CACHE")
98                .map(|v| v == "1")
99                .unwrap_or(false);
100
101        // Engine-build composition knobs (previously read by builder.rs /
102        // registry.rs from env). Only overwrite when the key is present so a
103        // later snapshot apply without the key keeps an earlier value.
104        if let Some(value) = runtime_config_value(snapshot, "FERRUM_MODEL_PATH") {
105            self.runtime.model_path = Some(value.to_string());
106        }
107        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SPEC_DRAFT") {
108            self.runtime.spec_draft = if value.is_empty() {
109                None
110            } else {
111                Some(value.to_string())
112            };
113        }
114        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SPEC_N") {
115            self.runtime.spec_n = value.parse::<usize>().ok();
116        }
117        if let Some(value) = runtime_config_value(snapshot, "FERRUM_DTYPE") {
118            self.runtime.dtype = Some(value.to_string());
119        }
120        if let Some(value) = runtime_config_value(snapshot, "FERRUM_METAL_DTYPE") {
121            self.runtime.metal_dtype = Some(value.to_string());
122        }
123        if let Some(value) = runtime_config_value(snapshot, "FERRUM_TP") {
124            self.runtime.tp = value.parse::<usize>().ok();
125        }
126
127        // Publish the resolved snapshot process-wide. Model code (which is not
128        // threaded an EngineConfig) reads `active_runtime_snapshot()` for the
129        // remaining FERRUM_* toggles instead of `std::env`, keeping the env
130        // bridge at this single composition-root call.
131        crate::install_runtime_snapshot(snapshot.clone());
132        Ok(())
133    }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct EngineModelConfig {
138    pub model_id: ModelId,
139    pub model_info: Option<ModelInfo>,
140    pub tokenizer: TokenizerConfig,
141}
142
143impl Default for EngineModelConfig {
144    fn default() -> Self {
145        Self {
146            model_id: ModelId::new("default"),
147            model_info: None,
148            tokenizer: TokenizerConfig::default(),
149        }
150    }
151}
152
153/// Scheduler configuration
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct SchedulerConfig {
156    /// Scheduling policy
157    pub policy: SchedulingPolicy,
158    /// Maximum waiting queue size
159    pub max_waiting_requests: usize,
160    /// Maximum running requests
161    pub max_running_requests: usize,
162    /// Enable request preemption
163    pub enable_preemption: bool,
164    /// Enable load balancing
165    pub enable_load_balancing: bool,
166    /// Fair share weights per client
167    pub fair_share_weights: HashMap<String, f32>,
168    /// SLA enforcement enabled
169    pub enable_sla_enforcement: bool,
170    /// Use prompt-token metadata for initial continuous-batch admission estimates.
171    #[serde(default)]
172    pub prompt_token_estimate: bool,
173    /// Prefer new prefills over early decodes until this many requests are active.
174    #[serde(default)]
175    pub prefill_first_until_active: Option<usize>,
176    /// Cap prefill admission chunks only while decode requests are already active.
177    #[serde(default)]
178    pub active_decode_prefill_chunk: Option<usize>,
179    /// Emit diagnostic scheduler None/SOME decisions.
180    #[serde(default)]
181    pub scheduler_none_prof: bool,
182}
183
184impl Default for SchedulerConfig {
185    fn default() -> Self {
186        Self {
187            policy: SchedulingPolicy::Priority,
188            max_waiting_requests: 1000,
189            max_running_requests: 32,
190            enable_preemption: true,
191            enable_load_balancing: false,
192            fair_share_weights: HashMap::new(),
193            enable_sla_enforcement: false,
194            prompt_token_estimate: false,
195            prefill_first_until_active: None,
196            active_decode_prefill_chunk: None,
197            scheduler_none_prof: false,
198        }
199    }
200}
201
202impl SchedulerConfig {
203    pub fn apply_runtime_config_snapshot(
204        &mut self,
205        snapshot: &RuntimeConfigSnapshot,
206    ) -> std::result::Result<(), String> {
207        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SCHED_PROMPT_TOKEN_ESTIMATE") {
208            self.prompt_token_estimate = parse_bool_env_value(value)
209                .map_err(|reason| format!("FERRUM_SCHED_PROMPT_TOKEN_ESTIMATE: {reason}"))?;
210        }
211        if let Some(value) =
212            runtime_config_value(snapshot, "FERRUM_SCHED_PREFILL_FIRST_UNTIL_ACTIVE")
213        {
214            self.prefill_first_until_active =
215                parse_optional_positive_usize("FERRUM_SCHED_PREFILL_FIRST_UNTIL_ACTIVE", value)?;
216        }
217        if let Some(value) = runtime_config_value(snapshot, "FERRUM_ACTIVE_DECODE_PREFILL_CHUNK") {
218            self.active_decode_prefill_chunk =
219                parse_optional_positive_usize("FERRUM_ACTIVE_DECODE_PREFILL_CHUNK", value)?;
220        }
221        if let Some(value) = runtime_config_value(snapshot, "FERRUM_SCHED_NONE_PROF") {
222            self.scheduler_none_prof = parse_presence_bool(value)
223                .map_err(|reason| format!("FERRUM_SCHED_NONE_PROF: {reason}"))?;
224        }
225        Ok(())
226    }
227}
228
229fn runtime_config_value<'a>(snapshot: &'a RuntimeConfigSnapshot, key: &str) -> Option<&'a str> {
230    snapshot
231        .entries
232        .iter()
233        .find(|entry| entry.key == key)
234        .map(|entry| entry.effective_value.as_str())
235}
236
237fn parse_optional_positive_usize(
238    key: &str,
239    value: &str,
240) -> std::result::Result<Option<usize>, String> {
241    let parsed = parse_usize_env_value(value).map_err(|reason| format!("{key}: {reason}"))?;
242    Ok((parsed > 0).then_some(parsed))
243}
244
245fn parse_required_positive_usize(key: &str, value: &str) -> std::result::Result<usize, String> {
246    let parsed = parse_usize_env_value(value).map_err(|reason| format!("{key}: {reason}"))?;
247    if parsed == 0 {
248        Err(format!("{key}: must be greater than zero"))
249    } else {
250        Ok(parsed)
251    }
252}
253
254fn parse_presence_bool(value: &str) -> std::result::Result<bool, String> {
255    if value.trim().is_empty() {
256        Ok(true)
257    } else {
258        parse_bool_env_value(value)
259    }
260}
261
262/// Scheduling policies
263#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
264pub enum SchedulingPolicy {
265    /// First-Come-First-Served
266    FCFS,
267    /// Priority-based scheduling
268    Priority,
269    /// Fair-share scheduling
270    FairShare,
271    /// Shortest-Job-First
272    SJF,
273    /// Round-Robin
274    RoundRobin,
275    /// Iteration-level continuous batching with preemption
276    ContinuousBatch,
277}
278
279/// KV Cache configuration
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct KvCacheConfig {
282    /// Cache implementation type
283    pub cache_type: KvCacheType,
284    /// Element dtype (Dim 5 polymorphism point). FP16 is the
285    /// validated production path; INT8 / FP8 require a backend impl
286    /// of `BackendKvDtype<KvInt8>` / `BackendKvDtype<KvFp8>` and a
287    /// model wired through `KvCacheQuant<B, K>`.
288    #[serde(default)]
289    pub dtype: KvCacheDtype,
290    /// Block size for paged attention
291    pub block_size: usize,
292    /// Maximum number of blocks
293    pub max_blocks: usize,
294    /// Enable cache compression
295    pub enable_compression: bool,
296    /// Compression ratio target
297    pub compression_ratio: f32,
298    /// Enable multi-level caching (GPU + CPU)
299    pub enable_multi_level: bool,
300    /// Swap threshold (when to move to CPU)
301    pub swap_threshold: f32,
302    /// Enable prefix caching
303    pub enable_prefix_caching: bool,
304    /// Prefix cache size
305    pub prefix_cache_size: usize,
306}
307
308impl Default for KvCacheConfig {
309    fn default() -> Self {
310        // 2048 blocks covers c=32 ShareGPT prompts (~32Ɨ500/16 = 1000
311        // blocks). The previous 1024 floor crashed at c≄16 on real
312        // workloads with "Block pool exhausted". Runtime overrides are
313        // applied through EngineConfig::apply_runtime_config_snapshot.
314        Self {
315            cache_type: KvCacheType::Contiguous,
316            dtype: KvCacheDtype::default(),
317            block_size: 16,
318            max_blocks: 2048,
319            enable_compression: false,
320            compression_ratio: 0.5,
321            enable_multi_level: true,
322            swap_threshold: 0.8,
323            enable_prefix_caching: true,
324            prefix_cache_size: 100,
325        }
326    }
327}
328
329/// KV Cache implementation types
330#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
331pub enum KvCacheType {
332    /// Simple contiguous memory allocation
333    Contiguous,
334    /// Paged attention with block-based allocation
335    Paged,
336    /// Tree-based cache for prefix sharing
337    Tree,
338}
339
340/// KV Cache element dtype (Dim 5 polymorphism point).
341///
342/// Mirrors `ferrum_interfaces::kv_dtype::KvDtypeKind` markers but
343/// lives here because `KvCacheConfig` is part of the user-facing
344/// `EngineConfig` and needs `Serialize` / `Deserialize`.
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
346#[serde(rename_all = "lowercase")]
347pub enum KvCacheDtype {
348    /// FP16 K/V — the validated production path on every backend.
349    #[default]
350    Fp16,
351    /// BF16 K/V — same memory cost as FP16, slightly different precision.
352    /// Marker only; no backend impl ships yet.
353    Bf16,
354    /// INT8 K/V with per-token per-kv-head FP16 scale (vLLM-style).
355    /// Halves KV memory at small (<1%) accuracy hit. CUDA kernels
356    /// land via `BackendKvDtype<KvInt8>` (PR #131); model wire-up
357    /// (`KvCacheQuant<B, KvInt8>` through the model decode loop) is
358    /// the only remaining step.
359    Int8,
360    /// FP8 (E4M3) K/V. Marker only; CUDA kernels pending.
361    Fp8,
362}
363
364impl KvCacheDtype {
365    /// Parse from a CLI / env-var string. Accepts fp16/f16/bf16/int8/fp8/f8e4m3.
366    pub fn parse(s: &str) -> Option<Self> {
367        match s.trim().to_ascii_lowercase().as_str() {
368            "fp16" | "f16" | "float16" => Some(Self::Fp16),
369            "bf16" | "bfloat16" => Some(Self::Bf16),
370            "int8" | "i8" => Some(Self::Int8),
371            "fp8" | "f8" | "f8e4m3" | "e4m3" => Some(Self::Fp8),
372            _ => None,
373        }
374    }
375
376    /// Short label for display + telemetry.
377    pub fn as_str(&self) -> &'static str {
378        match self {
379            Self::Fp16 => "fp16",
380            Self::Bf16 => "bf16",
381            Self::Int8 => "int8",
382            Self::Fp8 => "fp8",
383        }
384    }
385}
386
387/// Memory management configuration
388#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct MemoryConfig {
390    /// Memory pool size in bytes
391    pub pool_size: Option<usize>,
392    /// Enable memory pooling
393    pub enable_pooling: bool,
394    /// Memory alignment in bytes
395    pub alignment: usize,
396    /// Enable memory defragmentation
397    pub enable_defragmentation: bool,
398    /// Defragmentation threshold
399    pub defragmentation_threshold: f32,
400    /// Enable memory statistics tracking
401    pub enable_memory_stats: bool,
402    /// Memory pressure warning threshold
403    pub pressure_warning_threshold: f32,
404    /// Memory pressure critical threshold
405    pub pressure_critical_threshold: f32,
406}
407
408impl Default for MemoryConfig {
409    fn default() -> Self {
410        Self {
411            pool_size: None,
412            enable_pooling: true,
413            alignment: 256,
414            enable_defragmentation: false,
415            defragmentation_threshold: 0.7,
416            enable_memory_stats: true,
417            pressure_warning_threshold: 0.8,
418            pressure_critical_threshold: 0.95,
419        }
420    }
421}
422
423/// Backend configuration
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct BackendConfig {
426    /// Backend type
427    pub backend_type: BackendType,
428    /// Target device
429    pub device: Device,
430    /// Data type for computation
431    pub dtype: DataType,
432    /// Enable optimizations
433    pub enable_optimizations: bool,
434    /// Optimization level (0-3)
435    pub optimization_level: u8,
436    /// Enable CUDA graphs
437    pub enable_cuda_graphs: bool,
438    /// Enable kernel fusion
439    pub enable_kernel_fusion: bool,
440    /// Custom backend-specific options
441    pub backend_options: HashMap<String, serde_json::Value>,
442}
443
444impl Default for BackendConfig {
445    fn default() -> Self {
446        Self {
447            backend_type: BackendType::Candle,
448            device: Device::CPU,
449            dtype: DataType::FP16,
450            enable_optimizations: true,
451            optimization_level: 2,
452            enable_cuda_graphs: false,
453            enable_kernel_fusion: true,
454            backend_options: HashMap::new(),
455        }
456    }
457}
458
459/// Supported backend types
460#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
461pub enum BackendType {
462    /// Candle framework
463    Candle,
464    /// ONNX Runtime
465    OnnxRuntime,
466    /// TensorRT
467    TensorRT,
468    /// Custom backend
469    Custom,
470}
471
472/// Tokenizer configuration
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct TokenizerConfig {
475    /// Tokenizer type
476    pub tokenizer_type: TokenizerType,
477    /// Path to tokenizer files
478    pub tokenizer_path: Option<String>,
479    /// Enable fast tokenization
480    pub enable_fast: bool,
481    /// Add special tokens
482    pub add_special_tokens: bool,
483    /// Truncation strategy
484    pub truncation: Option<TruncationConfig>,
485    /// Padding strategy
486    pub padding: Option<PaddingConfig>,
487}
488
489impl Default for TokenizerConfig {
490    fn default() -> Self {
491        Self {
492            tokenizer_type: TokenizerType::BPE,
493            tokenizer_path: None,
494            enable_fast: true,
495            add_special_tokens: true,
496            truncation: None,
497            padding: None,
498        }
499    }
500}
501
502/// Tokenizer algorithms
503#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
504pub enum TokenizerType {
505    /// Byte Pair Encoding
506    BPE,
507    /// WordPiece tokenizer (BERT-style)
508    WordPiece,
509    /// SentencePiece tokenizer
510    SentencePiece,
511    /// Tiktoken tokenizer family
512    Tiktoken,
513    /// Any custom tokenizer implementation
514    Custom,
515}
516
517/// Truncation configuration
518#[derive(Debug, Clone, Serialize, Deserialize)]
519pub struct TruncationConfig {
520    /// Maximum length
521    pub max_length: usize,
522    /// Truncation strategy
523    pub strategy: TruncationStrategy,
524}
525
526/// Truncation strategies
527#[derive(Debug, Clone, Serialize, Deserialize)]
528pub enum TruncationStrategy {
529    /// Remove from the beginning
530    TruncateStart,
531    /// Remove from the end
532    TruncateEnd,
533    /// Remove from both sides
534    TruncateBoth,
535}
536
537/// Padding configuration
538#[derive(Debug, Clone, Serialize, Deserialize)]
539pub struct PaddingConfig {
540    /// Padding strategy
541    pub strategy: PaddingStrategy,
542    /// Padding token ID
543    pub token_id: u32,
544    /// Target length
545    pub target_length: Option<usize>,
546}
547
548/// Padding strategies
549#[derive(Debug, Clone, Serialize, Deserialize)]
550pub enum PaddingStrategy {
551    /// No padding
552    None,
553    /// Pad to maximum length in batch
554    MaxLength,
555    /// Pad to specific length
556    FixedLength,
557}
558
559/// Sampling configuration presets
560
561/// Security configuration
562#[derive(Debug, Clone, Serialize, Deserialize)]
563pub struct SecurityConfig {
564    /// Enable API authentication
565    pub enable_auth: bool,
566    /// API keys for authentication
567    pub api_keys: Vec<String>,
568    /// Enable rate limiting
569    pub enable_rate_limiting: bool,
570    /// Rate limit per client (requests per minute)
571    pub rate_limit_rpm: u32,
572    /// Enable content filtering
573    pub enable_content_filter: bool,
574    /// Maximum prompt length
575    pub max_prompt_length: usize,
576    /// Enable prompt validation
577    pub enable_prompt_validation: bool,
578    /// Allowed file extensions for uploads
579    pub allowed_extensions: Vec<String>,
580}
581
582impl Default for SecurityConfig {
583    fn default() -> Self {
584        Self {
585            enable_auth: false,
586            api_keys: vec![],
587            enable_rate_limiting: true,
588            rate_limit_rpm: 60,
589            enable_content_filter: false,
590            max_prompt_length: 32768,
591            enable_prompt_validation: true,
592            allowed_extensions: vec!["txt".to_string(), "json".to_string()],
593        }
594    }
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize, Default)]
598pub struct SamplingConfig {
599    pub default_params: SamplingParams,
600    pub presets: SamplingPresets,
601    pub enable_custom_processors: bool,
602}
603
604#[derive(Debug, Clone, Serialize, Deserialize)]
605pub struct MonitoringConfig {
606    pub enable_metrics: bool,
607    pub enable_tracing: bool,
608    pub export_interval: Duration,
609}
610
611impl Default for MonitoringConfig {
612    fn default() -> Self {
613        Self {
614            enable_metrics: true,
615            enable_tracing: true,
616            export_interval: Duration::from_secs(5),
617        }
618    }
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize)]
622pub struct BatchConfig {
623    pub max_batch_size: usize,
624    pub max_wait_ms: u64,
625    pub enable_dynamic: bool,
626    pub enable_continuous: bool,
627    /// vLLM-style per-iteration token budget. The scheduler emits a
628    /// mixed prefill+decode batch summing to at most this many Q
629    /// tokens (decode = 1 each, prefill chunk = its chunk size).
630    /// Default 2048. Runtime snapshots can override this with
631    /// `FERRUM_MAX_BATCHED_TOKENS`, usually from the GPU autosizer or a
632    /// named workload preset rather than a user hand-written env bundle.
633    #[serde(default = "BatchConfig::default_max_num_batched_tokens")]
634    pub max_num_batched_tokens: usize,
635}
636
637impl BatchConfig {
638    fn default_max_num_batched_tokens() -> usize {
639        2048
640    }
641}
642
643impl Default for BatchConfig {
644    fn default() -> Self {
645        Self {
646            max_batch_size: 32,
647            max_wait_ms: 8,
648            enable_dynamic: true,
649            enable_continuous: false,
650            max_num_batched_tokens: Self::default_max_num_batched_tokens(),
651        }
652    }
653}