Skip to main content

inference_lab/config/
mod.rs

1pub mod hardware;
2pub mod model;
3pub mod scheduler;
4pub mod simulation;
5pub mod workload;
6
7pub use hardware::HardwareConfig;
8pub use model::ModelConfig;
9pub use scheduler::SchedulerConfig;
10pub use simulation::SimulationConfig;
11pub use workload::{LengthDistribution, WorkloadConfig};
12
13use serde::Deserialize;
14use std::fs;
15use std::path::Path;
16
17/// Top-level configuration that aggregates all sub-configs
18#[derive(Debug, Clone, Deserialize)]
19pub struct Config {
20    pub hardware: HardwareConfig,
21    pub model: ModelConfig,
22    pub scheduler: SchedulerConfig,
23    pub workload: WorkloadConfig,
24    #[serde(default)]
25    pub simulation: SimulationConfig,
26}
27
28impl Config {
29    /// Load configuration from a TOML file
30    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
31        let contents = fs::read_to_string(path)?;
32        let mut config: Config = toml::from_str(&contents)?;
33
34        // Compute derived fields
35        config
36            .model
37            .compute_kv_cache_size(config.hardware.bytes_per_param);
38
39        // Compute KV cache capacity if not explicitly set
40        let model_size_bytes = config.model.num_parameters * config.hardware.bytes_per_param as u64;
41        config.hardware.compute_kv_cache_capacity(model_size_bytes);
42
43        config
44            .scheduler
45            .set_default_prefill_threshold(config.model.max_seq_len);
46
47        Ok(config)
48    }
49
50    /// Get a default configuration for testing
51    #[cfg(test)]
52    pub fn test_default() -> Self {
53        let hardware = HardwareConfig {
54            name: "Test GPU".to_string(),
55            compute_flops: 1e15,
56            memory_bandwidth: 1e12,
57            memory_capacity: 80_000_000_000,
58            kv_cache_capacity: 60_000_000_000,
59            gpu_memory_utilization: 0.9,
60            bytes_per_param: 2,
61        };
62
63        let mut model = ModelConfig {
64            name: "Test Model".to_string(),
65            num_parameters: 7_000_000_000,
66            num_active_parameters: None,
67            num_layers: 32,
68            hidden_dim: 4096,
69            num_heads: 32,
70            num_kv_heads: None,
71            max_seq_len: 2048,
72            sliding_window: None,
73            num_sliding_layers: None,
74            kv_cache_bytes_per_token: 0,
75        };
76        model.compute_kv_cache_size(hardware.bytes_per_param);
77
78        let mut scheduler = SchedulerConfig {
79            max_num_batched_tokens: 2048,
80            max_num_seqs: 128,
81            policy: "fcfs".to_string(),
82            enable_chunked_prefill: true,
83            long_prefill_token_threshold: 0,
84            max_num_partial_prefills: 1,
85            block_size: 16,
86            enable_preemption_free: false,
87        };
88        scheduler.set_default_prefill_threshold(model.max_seq_len);
89
90        let workload = WorkloadConfig {
91            dataset_path: None,
92            arrival_pattern: "poisson".to_string(),
93            arrival_rate: 1.0,
94            num_concurrent_users: None,
95            input_len_dist: LengthDistribution::Fixed { value: 100 },
96            output_len_dist: LengthDistribution::Fixed { value: 50 },
97            num_requests: Some(10),
98            duration_secs: None,
99            seed: 42,
100        };
101
102        let simulation = SimulationConfig::default();
103
104        Config {
105            hardware,
106            model,
107            scheduler,
108            workload,
109            simulation,
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_model_kv_cache_calculation() {
120        let mut model = ModelConfig {
121            name: "Test".to_string(),
122            num_parameters: 7_000_000_000,
123            num_active_parameters: None,
124            num_layers: 32,
125            hidden_dim: 4096,
126            num_heads: 32,
127            num_kv_heads: None,
128            max_seq_len: 2048,
129            sliding_window: None,
130            num_sliding_layers: None,
131            kv_cache_bytes_per_token: 0,
132        };
133        model.compute_kv_cache_size(2); // bf16
134
135        // 2 (K+V) * 4096 (hidden) * 2 (bytes) * 32 (layers) = 524,288 bytes per token
136        assert_eq!(model.kv_cache_bytes_per_token, 524_288);
137
138        // For a 100-token sequence
139        let size = model.kv_cache_size_for_sequence(100);
140        assert_eq!(size, 52_428_800); // 524,288 * 100
141    }
142
143    #[test]
144    fn test_config_creation() {
145        let config = Config::test_default();
146        assert!(config.model.kv_cache_bytes_per_token > 0);
147    }
148
149    #[test]
150    fn test_sliding_window_kv_cache_uses_byte_units() {
151        let mut model = ModelConfig {
152            name: "Sliding".to_string(),
153            num_parameters: 7_000_000_000,
154            num_active_parameters: None,
155            num_layers: 4,
156            hidden_dim: 16,
157            num_heads: 4,
158            num_kv_heads: Some(2),
159            max_seq_len: 2048,
160            sliding_window: Some(8),
161            num_sliding_layers: Some(2),
162            kv_cache_bytes_per_token: 0,
163        };
164        model.compute_kv_cache_size(2);
165
166        let size = model.kv_cache_size_for_sequence(10);
167        assert_eq!(size, 1_152);
168    }
169}