inference_lab/config/
mod.rs1pub mod hardware;
2pub mod model;
3pub mod scheduler;
4pub mod simulation;
5pub mod workload;
6
7pub use hardware::HardwareConfig;
8pub use model::ModelConfig;
9pub use scheduler::SchedulerConfig;
10pub use simulation::SimulationConfig;
11pub use workload::{LengthDistribution, WorkloadConfig};
12
13use serde::Deserialize;
14use std::fs;
15use std::path::Path;
16
17#[derive(Debug, Clone, Deserialize)]
19pub struct Config {
20 pub hardware: HardwareConfig,
21 pub model: ModelConfig,
22 pub scheduler: SchedulerConfig,
23 pub workload: WorkloadConfig,
24 #[serde(default)]
25 pub simulation: SimulationConfig,
26}
27
28impl Config {
29 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, Box<dyn std::error::Error>> {
31 let contents = fs::read_to_string(path)?;
32 let mut config: Config = toml::from_str(&contents)?;
33
34 config.hardware.compute_threshold();
36 config
37 .model
38 .compute_kv_cache_size(config.hardware.bytes_per_param);
39
40 let model_size_bytes = config.model.num_parameters * config.hardware.bytes_per_param as u64;
42 config.hardware.compute_kv_cache_capacity(model_size_bytes);
43
44 config
45 .scheduler
46 .set_default_prefill_threshold(config.model.max_seq_len);
47
48 Ok(config)
49 }
50
51 #[cfg(test)]
53 pub fn test_default() -> Self {
54 let mut hardware = HardwareConfig {
55 name: "Test GPU".to_string(),
56 compute_flops: 1e15,
57 memory_bandwidth: 1e12,
58 memory_capacity: 80_000_000_000,
59 kv_cache_capacity: 60_000_000_000,
60 gpu_memory_utilization: 0.9,
61 bytes_per_param: 2,
62 compute_bound_threshold: 0,
63 };
64 hardware.compute_threshold();
65
66 let mut model = ModelConfig {
67 name: "Test Model".to_string(),
68 num_parameters: 7_000_000_000,
69 num_active_parameters: None,
70 num_layers: 32,
71 hidden_dim: 4096,
72 num_heads: 32,
73 num_kv_heads: None,
74 max_seq_len: 2048,
75 sliding_window: None,
76 num_sliding_layers: None,
77 kv_cache_bytes_per_token: 0,
78 };
79 model.compute_kv_cache_size(hardware.bytes_per_param);
80
81 let mut scheduler = SchedulerConfig {
82 max_num_batched_tokens: 2048,
83 max_num_seqs: 128,
84 policy: "fcfs".to_string(),
85 enable_chunked_prefill: true,
86 long_prefill_token_threshold: 0,
87 max_num_partial_prefills: 1,
88 block_size: 16,
89 };
90 scheduler.set_default_prefill_threshold(model.max_seq_len);
91
92 let workload = WorkloadConfig {
93 arrival_pattern: "poisson".to_string(),
94 arrival_rate: 1.0,
95 num_concurrent_users: None,
96 input_len_dist: LengthDistribution::Fixed { value: 100 },
97 output_len_dist: LengthDistribution::Fixed { value: 50 },
98 num_requests: Some(10),
99 duration_secs: None,
100 seed: 42,
101 };
102
103 let simulation = SimulationConfig::default();
104
105 Config {
106 hardware,
107 model,
108 scheduler,
109 workload,
110 simulation,
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_hardware_compute_bound_threshold() {
121 let mut hw = HardwareConfig {
122 name: "Test".to_string(),
123 compute_flops: 1.513e15,
124 memory_bandwidth: 3.35e12,
125 memory_capacity: 80_000_000_000,
126 kv_cache_capacity: 60_000_000_000,
127 gpu_memory_utilization: 0.9,
128 bytes_per_param: 2,
129 compute_bound_threshold: 0,
130 };
131 hw.compute_threshold();
132
133 assert!(hw.compute_bound_threshold > 900);
135 assert!(hw.compute_bound_threshold < 910);
136 }
137
138 #[test]
139 fn test_model_kv_cache_calculation() {
140 let mut model = ModelConfig {
141 name: "Test".to_string(),
142 num_parameters: 7_000_000_000,
143 num_active_parameters: None,
144 num_layers: 32,
145 hidden_dim: 4096,
146 num_heads: 32,
147 num_kv_heads: None,
148 max_seq_len: 2048,
149 sliding_window: None,
150 num_sliding_layers: None,
151 kv_cache_bytes_per_token: 0,
152 };
153 model.compute_kv_cache_size(2); assert_eq!(model.kv_cache_bytes_per_token, 524_288);
157
158 let size = model.kv_cache_size_for_sequence(100);
160 assert_eq!(size, 52_428_800); }
162
163 #[test]
164 fn test_config_creation() {
165 let config = Config::test_default();
166 assert!(config.hardware.compute_bound_threshold > 0);
167 assert!(config.model.kv_cache_bytes_per_token > 0);
168 }
169}