1#![allow(missing_docs)]
7
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug, Clone, serde::Deserialize, Default)]
13pub struct RouterConfigFile {
14 pub enabled: Option<bool>,
15 pub default_profile: Option<String>,
16 pub classifier_model: Option<String>,
17 pub context_upgrade_threshold: Option<usize>,
18 pub max_session_budget: Option<f64>,
19 pub profiles: Option<toml::Value>,
20 pub weights: Option<toml::Value>,
21 pub pin_tier: Option<String>,
22 pub phase_bias: Option<f64>,
23}
24
25#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
27pub struct RouterConfig {
28 default_profile: String,
29 classifier_model: Option<String>,
30 context_upgrade_threshold: Option<usize>,
31 max_session_budget: Option<f64>,
32 profiles: HashMap<String, RouterProfile>,
33 weights: ScoringWeights,
34 pin_tier: Option<String>,
35 phase_bias: Option<f64>,
36}
37
38impl RouterConfig {
39 pub fn enabled(&self) -> Option<bool> {
41 Some(!self.profiles.is_empty())
42 }
43
44 pub fn default_profile(&self) -> &str {
46 &self.default_profile
47 }
48
49 pub fn get_profile(&self, name: &str) -> Option<&RouterProfile> {
51 self.profiles.get(name)
52 }
53
54 pub fn profiles(&self) -> &HashMap<String, RouterProfile> {
56 &self.profiles
57 }
58
59 pub fn weights(&self) -> &ScoringWeights {
61 &self.weights
62 }
63
64 pub fn classifier_model(&self) -> Option<&str> {
66 self.classifier_model.as_deref()
67 }
68
69 pub fn context_upgrade_threshold(&self) -> Option<usize> {
71 self.context_upgrade_threshold
72 }
73
74 pub fn max_session_budget(&self) -> Option<f64> {
76 self.max_session_budget
77 }
78
79 pub fn pin_tier(&self) -> Option<&str> {
81 self.pin_tier.as_deref()
82 }
83
84 pub fn phase_bias(&self) -> Option<f64> {
86 self.phase_bias
87 }
88}
89
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
91pub struct RouterProfile {
92 pub high: RoutedTierConfig,
93 pub medium: RoutedTierConfig,
94 pub low: RoutedTierConfig,
95}
96
97impl RouterProfile {
98 pub fn tier_config(&self, tier: &str) -> Option<&RoutedTierConfig> {
100 match tier {
101 "high" => Some(&self.high),
102 "medium" => Some(&self.medium),
103 "low" => Some(&self.low),
104 _ => None,
105 }
106 }
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
110pub struct RoutedTierConfig {
111 pub model: String,
112 #[serde(default)]
113 pub thinking: Option<String>,
114 #[serde(default)]
115 pub fallbacks: Vec<String>,
116}
117
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct ScoringWeights {
120 #[serde(default = "default_structural")]
121 pub structural: f64,
122 #[serde(default = "default_behavioral")]
123 pub behavioral: f64,
124 #[serde(default = "default_context")]
125 pub context_budget: f64,
126 #[serde(default = "default_vision")]
127 pub vision: f64,
128 #[serde(default = "default_message")]
129 pub message: f64,
130}
131
132fn default_structural() -> f64 {
133 0.25
134}
135fn default_behavioral() -> f64 {
136 0.20
137}
138fn default_context() -> f64 {
139 0.15
140}
141fn default_vision() -> f64 {
142 0.10
143}
144fn default_message() -> f64 {
145 0.30
146}
147
148impl Default for ScoringWeights {
149 fn default() -> Self {
150 Self {
151 structural: default_structural(),
152 behavioral: default_behavioral(),
153 context_budget: default_context(),
154 vision: default_vision(),
155 message: default_message(),
156 }
157 }
158}
159
160pub fn load_router_config(global_dir: &Path, project_dir: &Path) -> Option<RouterConfig> {
163 let global_path = global_dir.join("settings.toml");
164 let project_path = project_dir.join(".oxi/settings.toml");
165
166 let global_cfg = read_toml_router(&global_path);
167 let project_cfg = read_toml_router(&project_path);
168
169 let base_enabled = global_cfg.as_ref().and_then(|c| c.enabled);
170 let override_enabled = project_cfg.as_ref().and_then(|c| c.enabled);
171
172 if base_enabled != Some(true) && override_enabled != Some(true) {
174 return None;
175 }
176
177 let default_name = project_cfg
178 .as_ref()
179 .and_then(|c| c.default_profile.clone())
180 .or_else(|| global_cfg.as_ref().and_then(|c| c.default_profile.clone()))
181 .unwrap_or_else(|| "auto".to_string());
182
183 let mut profiles: HashMap<String, RouterProfile> = HashMap::new();
184
185 if let Some(ref g) = global_cfg
187 && let Some(ref tbl) = g.profiles
188 && let Some(inner) = tbl.as_table()
189 {
190 for (name, value) in inner {
191 if let Some(profile) = parse_profile(value) {
192 profiles.insert(name.clone(), profile);
193 }
194 }
195 }
196
197 if let Some(ref p) = project_cfg
199 && let Some(ref tbl) = p.profiles
200 && let Some(inner) = tbl.as_table()
201 {
202 for (name, value) in inner {
203 if let Some(profile) = parse_profile(value) {
204 profiles.insert(name.clone(), profile);
205 }
206 }
207 }
208
209 if profiles.is_empty() {
210 return None;
211 }
212
213 let weights = project_cfg
214 .as_ref()
215 .and_then(|c| c.weights.as_ref())
216 .and_then(parse_weights)
217 .or_else(|| {
218 global_cfg
219 .as_ref()
220 .and_then(|c| c.weights.as_ref())
221 .and_then(parse_weights)
222 })
223 .unwrap_or_default();
224
225 let pin_tier = project_cfg
226 .as_ref()
227 .and_then(|c| c.pin_tier.as_ref())
228 .or_else(|| global_cfg.as_ref().and_then(|c| c.pin_tier.as_ref()))
229 .and_then(|s| parse_tier_str(s));
230
231 let phase_bias = project_cfg
232 .as_ref()
233 .and_then(|c| c.phase_bias)
234 .or_else(|| global_cfg.as_ref().and_then(|c| c.phase_bias));
235
236 Some(RouterConfig {
237 default_profile: default_name,
238 classifier_model: project_cfg
239 .as_ref()
240 .and_then(|c| c.classifier_model.clone())
241 .or_else(|| global_cfg.as_ref().and_then(|c| c.classifier_model.clone())),
242 context_upgrade_threshold: project_cfg
243 .as_ref()
244 .and_then(|c| c.context_upgrade_threshold)
245 .or_else(|| {
246 global_cfg
247 .as_ref()
248 .and_then(|c| c.context_upgrade_threshold)
249 }),
250 max_session_budget: project_cfg
251 .as_ref()
252 .and_then(|c| c.max_session_budget)
253 .or_else(|| global_cfg.as_ref().and_then(|c| c.max_session_budget)),
254 profiles,
255 weights,
256 pin_tier,
257 phase_bias,
258 })
259}
260
261fn read_toml_router(path: &Path) -> Option<RouterConfigFile> {
262 let content = std::fs::read_to_string(path).ok()?;
263 let toml: toml::Value = toml::from_str(&content).ok()?;
264 toml.get("router")?.clone().try_into().ok()
265}
266
267fn parse_profile(value: &toml::Value) -> Option<RouterProfile> {
268 let table = value.as_table()?;
269 Some(RouterProfile {
270 high: parse_tier(table.get("high"))?,
271 medium: parse_tier(table.get("medium"))?,
272 low: parse_tier(table.get("low"))?,
273 })
274}
275
276fn parse_tier(value: Option<&toml::Value>) -> Option<RoutedTierConfig> {
277 let table = value?.as_table()?;
278 Some(RoutedTierConfig {
279 model: table.get("model")?.as_str()?.to_string(),
280 thinking: table
281 .get("thinking")
282 .and_then(|v| v.as_str().map(String::from)),
283 fallbacks: table
284 .get("fallbacks")
285 .and_then(|v| v.as_array())
286 .map(|arr| {
287 arr.iter()
288 .filter_map(|v| v.as_str().map(String::from))
289 .collect()
290 })
291 .unwrap_or_default(),
292 })
293}
294
295fn parse_weights(value: &toml::Value) -> Option<ScoringWeights> {
296 let table = value.as_table()?;
297 Some(ScoringWeights {
298 structural: table.get("structural")?.as_float().unwrap_or(0.25),
299 behavioral: table.get("behavioral")?.as_float().unwrap_or(0.20),
300 context_budget: table.get("context_budget")?.as_float().unwrap_or(0.15),
301 vision: table
302 .get("vision")
303 .and_then(|v| v.as_float())
304 .unwrap_or(0.10),
305 message: table
306 .get("message")
307 .and_then(|v| v.as_float())
308 .unwrap_or(0.30),
309 })
310}
311
312fn parse_tier_str(s: &str) -> Option<String> {
313 match s.to_lowercase().as_str() {
314 "high" | "medium" | "low" => Some(s.to_lowercase()),
315 _ => None,
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn parse_minimal_config() {
325 let toml_str = r#"
326[router]
327enabled = true
328default_profile = "auto"
329
330[router.profiles.auto]
331high.model = "anthropic/claude-sonnet-4"
332medium.model = "anthropic/claude-sonnet-4"
333low.model = "google/gemini-2.0-flash"
334"#;
335 let value: toml::Value = toml::from_str(toml_str).unwrap();
336 let cfg: RouterConfigFile = value.get("router").unwrap().clone().try_into().unwrap();
337 assert!(cfg.enabled.is_some());
338 assert_eq!(cfg.default_profile.as_ref().unwrap(), "auto");
339 }
340
341 #[test]
342 fn returns_none_when_not_enabled() {
343 let toml_str = r#"
344[other]
345value = 1
346"#;
347 let value: toml::Value = toml::from_str(toml_str).unwrap();
348 let cfg: Option<RouterConfigFile> =
349 value.get("router").and_then(|v| v.clone().try_into().ok());
350 assert!(cfg.is_none());
351 }
352
353 #[test]
354 fn load_router_config_merges_profiles() {
355 let global_dir = tempfile::tempdir().unwrap();
356 let project_dir = tempfile::tempdir().unwrap();
357 let oxi_dir = project_dir.path().join(".oxi");
358 std::fs::create_dir_all(&oxi_dir).unwrap();
359
360 std::fs::write(
361 global_dir.path().join("settings.toml"),
362 r#"
363[router]
364enabled = true
365default_profile = "auto"
366
367[router.profiles.auto]
368high.model = "anthropic/claude-sonnet-4"
369medium.model = "anthropic/claude-haiku-4"
370low.model = "google/gemini-2.0-flash"
371"#,
372 )
373 .unwrap();
374 std::fs::write(
375 oxi_dir.join("settings.toml"),
376 r#"
377[router]
378enabled = true
379"#,
380 )
381 .unwrap();
382
383 let config = load_router_config(global_dir.path(), project_dir.path());
384 assert!(config.is_some());
385 let config = config.unwrap();
386 assert_eq!(config.default_profile, "auto");
387 assert!(config.profiles.contains_key("auto"));
388
389 }
391}