1#![allow(missing_docs)]
7
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug, Clone, serde::Deserialize, Default)]
13pub struct RouterConfigFile {
14 pub enabled: Option<bool>,
15 pub default_profile: Option<String>,
16 pub classifier_model: Option<String>,
17 pub context_upgrade_threshold: Option<usize>,
18 pub max_session_budget: Option<f64>,
19 pub profiles: Option<toml::Value>,
20 pub weights: Option<toml::Value>,
21 pub pin_tier: Option<String>,
22 pub phase_bias: Option<f64>,
23}
24
25#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
27pub struct RouterConfig {
28 default_profile: String,
29 classifier_model: Option<String>,
30 context_upgrade_threshold: Option<usize>,
31 max_session_budget: Option<f64>,
32 profiles: HashMap<String, RouterProfile>,
33 weights: ScoringWeights,
34 pin_tier: Option<String>,
35 phase_bias: Option<f64>,
36}
37
38impl RouterConfig {
39 pub fn enabled(&self) -> Option<bool> {
41 Some(!self.profiles.is_empty())
42 }
43
44 pub fn default_profile(&self) -> &str {
46 &self.default_profile
47 }
48
49 pub fn get_profile(&self, name: &str) -> Option<&RouterProfile> {
51 self.profiles.get(name)
52 }
53
54 pub fn profiles(&self) -> &HashMap<String, RouterProfile> {
56 &self.profiles
57 }
58
59 pub fn weights(&self) -> &ScoringWeights {
61 &self.weights
62 }
63
64 pub fn classifier_model(&self) -> Option<&str> {
66 self.classifier_model.as_deref()
67 }
68
69 pub fn context_upgrade_threshold(&self) -> Option<usize> {
71 self.context_upgrade_threshold
72 }
73
74 pub fn max_session_budget(&self) -> Option<f64> {
76 self.max_session_budget
77 }
78
79 pub fn pin_tier(&self) -> Option<&str> {
81 self.pin_tier.as_deref()
82 }
83
84 pub fn phase_bias(&self) -> Option<f64> {
86 self.phase_bias
87 }
88}
89
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
91pub struct RouterProfile {
92 pub high: RoutedTierConfig,
93 pub medium: RoutedTierConfig,
94 pub low: RoutedTierConfig,
95}
96
97impl RouterProfile {
98 pub fn tier_config(&self, tier: &str) -> Option<&RoutedTierConfig> {
100 match tier {
101 "high" => Some(&self.high),
102 "medium" => Some(&self.medium),
103 "low" => Some(&self.low),
104 _ => None,
105 }
106 }
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
110pub struct RoutedTierConfig {
111 pub model: String,
112 #[serde(default)]
113 pub thinking: Option<String>,
114 #[serde(default)]
115 pub fallbacks: Vec<String>,
116}
117
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct ScoringWeights {
120 #[serde(default = "default_structural")]
121 pub structural: f64,
122 #[serde(default = "default_behavioral")]
123 pub behavioral: f64,
124 #[serde(default = "default_context")]
125 pub context_budget: f64,
126 #[serde(default = "default_vision")]
127 pub vision: f64,
128 #[serde(default = "default_message")]
129 pub message: f64,
130}
131
132fn default_structural() -> f64 {
133 0.25
134}
135fn default_behavioral() -> f64 {
136 0.20
137}
138fn default_context() -> f64 {
139 0.15
140}
141fn default_vision() -> f64 {
142 0.10
143}
144fn default_message() -> f64 {
145 0.30
146}
147
148impl Default for ScoringWeights {
149 fn default() -> Self {
150 Self {
151 structural: default_structural(),
152 behavioral: default_behavioral(),
153 context_budget: default_context(),
154 vision: default_vision(),
155 message: default_message(),
156 }
157 }
158}
159
160pub fn load_router_config(global_dir: &Path, project_dir: &Path) -> Option<RouterConfig> {
163 let global_path = global_dir.join("settings.toml");
164 let project_path = project_dir.join(".oxi/settings.toml");
165
166 let global_cfg = read_toml_router(&global_path);
167 let project_cfg = read_toml_router(&project_path);
168
169 let base_enabled = global_cfg.as_ref().and_then(|c| c.enabled);
170 let override_enabled = project_cfg.as_ref().and_then(|c| c.enabled);
171
172 if base_enabled != Some(true) && override_enabled != Some(true) {
174 return None;
175 }
176
177 let default_name = project_cfg
178 .as_ref()
179 .and_then(|c| c.default_profile.clone())
180 .or_else(|| global_cfg.as_ref().and_then(|c| c.default_profile.clone()))
181 .unwrap_or_else(|| "auto".to_string());
182
183 let mut profiles: HashMap<String, RouterProfile> = HashMap::new();
184
185 if let Some(ref g) = global_cfg {
187 if let Some(ref tbl) = g.profiles {
188 if let Some(inner) = tbl.as_table() {
189 for (name, value) in inner {
190 if let Some(profile) = parse_profile(value) {
191 profiles.insert(name.clone(), profile);
192 }
193 }
194 }
195 }
196 }
197
198 if let Some(ref p) = project_cfg {
200 if let Some(ref tbl) = p.profiles {
201 if let Some(inner) = tbl.as_table() {
202 for (name, value) in inner {
203 if let Some(profile) = parse_profile(value) {
204 profiles.insert(name.clone(), profile);
205 }
206 }
207 }
208 }
209 }
210
211 if profiles.is_empty() {
212 return None;
213 }
214
215 let weights = project_cfg
216 .as_ref()
217 .and_then(|c| c.weights.as_ref())
218 .and_then(parse_weights)
219 .or_else(|| {
220 global_cfg
221 .as_ref()
222 .and_then(|c| c.weights.as_ref())
223 .and_then(parse_weights)
224 })
225 .unwrap_or_default();
226
227 let pin_tier = project_cfg
228 .as_ref()
229 .and_then(|c| c.pin_tier.as_ref())
230 .or_else(|| global_cfg.as_ref().and_then(|c| c.pin_tier.as_ref()))
231 .and_then(|s| parse_tier_str(s));
232
233 let phase_bias = project_cfg
234 .as_ref()
235 .and_then(|c| c.phase_bias)
236 .or_else(|| global_cfg.as_ref().and_then(|c| c.phase_bias));
237
238 Some(RouterConfig {
239 default_profile: default_name,
240 classifier_model: project_cfg
241 .as_ref()
242 .and_then(|c| c.classifier_model.clone())
243 .or_else(|| global_cfg.as_ref().and_then(|c| c.classifier_model.clone())),
244 context_upgrade_threshold: project_cfg
245 .as_ref()
246 .and_then(|c| c.context_upgrade_threshold)
247 .or_else(|| {
248 global_cfg
249 .as_ref()
250 .and_then(|c| c.context_upgrade_threshold)
251 }),
252 max_session_budget: project_cfg
253 .as_ref()
254 .and_then(|c| c.max_session_budget)
255 .or_else(|| global_cfg.as_ref().and_then(|c| c.max_session_budget)),
256 profiles,
257 weights,
258 pin_tier,
259 phase_bias,
260 })
261}
262
263fn read_toml_router(path: &Path) -> Option<RouterConfigFile> {
264 let content = std::fs::read_to_string(path).ok()?;
265 let toml: toml::Value = toml::from_str(&content).ok()?;
266 toml.get("router")?.clone().try_into().ok()
267}
268
269fn parse_profile(value: &toml::Value) -> Option<RouterProfile> {
270 let table = value.as_table()?;
271 Some(RouterProfile {
272 high: parse_tier(table.get("high"))?,
273 medium: parse_tier(table.get("medium"))?,
274 low: parse_tier(table.get("low"))?,
275 })
276}
277
278fn parse_tier(value: Option<&toml::Value>) -> Option<RoutedTierConfig> {
279 let table = value?.as_table()?;
280 Some(RoutedTierConfig {
281 model: table.get("model")?.as_str()?.to_string(),
282 thinking: table
283 .get("thinking")
284 .and_then(|v| v.as_str().map(String::from)),
285 fallbacks: table
286 .get("fallbacks")
287 .and_then(|v| v.as_array())
288 .map(|arr| {
289 arr.iter()
290 .filter_map(|v| v.as_str().map(String::from))
291 .collect()
292 })
293 .unwrap_or_default(),
294 })
295}
296
297fn parse_weights(value: &toml::Value) -> Option<ScoringWeights> {
298 let table = value.as_table()?;
299 Some(ScoringWeights {
300 structural: table.get("structural")?.as_float().unwrap_or(0.25),
301 behavioral: table.get("behavioral")?.as_float().unwrap_or(0.20),
302 context_budget: table.get("context_budget")?.as_float().unwrap_or(0.15),
303 vision: table
304 .get("vision")
305 .and_then(|v| v.as_float())
306 .unwrap_or(0.10),
307 message: table
308 .get("message")
309 .and_then(|v| v.as_float())
310 .unwrap_or(0.30),
311 })
312}
313
314fn parse_tier_str(s: &str) -> Option<String> {
315 match s.to_lowercase().as_str() {
316 "high" | "medium" | "low" => Some(s.to_lowercase()),
317 _ => None,
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn parse_minimal_config() {
327 let toml_str = r#"
328[router]
329enabled = true
330default_profile = "auto"
331
332[router.profiles.auto]
333high.model = "anthropic/claude-sonnet-4"
334medium.model = "anthropic/claude-sonnet-4"
335low.model = "google/gemini-2.0-flash"
336"#;
337 let value: toml::Value = toml::from_str(toml_str).unwrap();
338 let cfg: RouterConfigFile = value.get("router").unwrap().clone().try_into().unwrap();
339 assert!(cfg.enabled.is_some());
340 assert_eq!(cfg.default_profile.as_ref().unwrap(), "auto");
341 }
342
343 #[test]
344 fn returns_none_when_not_enabled() {
345 let toml_str = r#"
346[other]
347value = 1
348"#;
349 let value: toml::Value = toml::from_str(toml_str).unwrap();
350 let cfg: Option<RouterConfigFile> =
351 value.get("router").and_then(|v| v.clone().try_into().ok());
352 assert!(cfg.is_none());
353 }
354
355 #[test]
356 fn load_router_config_merges_profiles() {
357 let global_dir = tempfile::tempdir().unwrap();
358 let project_dir = tempfile::tempdir().unwrap();
359 let oxi_dir = project_dir.path().join(".oxi");
360 std::fs::create_dir_all(&oxi_dir).unwrap();
361
362 std::fs::write(
363 global_dir.path().join("settings.toml"),
364 r#"
365[router]
366enabled = true
367default_profile = "auto"
368
369[router.profiles.auto]
370high.model = "anthropic/claude-sonnet-4"
371medium.model = "anthropic/claude-haiku-4"
372low.model = "google/gemini-2.0-flash"
373"#,
374 )
375 .unwrap();
376 std::fs::write(
377 oxi_dir.join("settings.toml"),
378 r#"
379[router]
380enabled = true
381"#,
382 )
383 .unwrap();
384
385 let config = load_router_config(global_dir.path(), project_dir.path());
386 assert!(config.is_some());
387 let config = config.unwrap();
388 assert_eq!(config.default_profile, "auto");
389 assert!(config.profiles.contains_key("auto"));
390
391 }
393}