Skip to main content

oxi/store/
router_config.rs

1//! Router configuration loading for oxi-store.
2//!
3//! Reads the `[router]` section from global and project settings.toml
4//! and merges them (project overrides global).
5
6#![allow(missing_docs)]
7
8use std::collections::HashMap;
9use std::path::Path;
10
11/// TOML representation of the router config section.
12#[derive(Debug, Clone, serde::Deserialize, Default)]
13pub struct RouterConfigFile {
14    pub enabled: Option<bool>,
15    pub default_profile: Option<String>,
16    pub classifier_model: Option<String>,
17    pub context_upgrade_threshold: Option<usize>,
18    pub max_session_budget: Option<f64>,
19    pub profiles: Option<toml::Value>,
20    pub weights: Option<toml::Value>,
21    pub pin_tier: Option<String>,
22    pub phase_bias: Option<f64>,
23}
24
25/// Fully resolved router config with all required fields.
26#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
27pub struct RouterConfig {
28    default_profile: String,
29    classifier_model: Option<String>,
30    context_upgrade_threshold: Option<usize>,
31    max_session_budget: Option<f64>,
32    profiles: HashMap<String, RouterProfile>,
33    weights: ScoringWeights,
34    pin_tier: Option<String>,
35    phase_bias: Option<f64>,
36}
37
38impl RouterConfig {
39    /// Returns whether router is enabled (default: true).
40    pub fn enabled(&self) -> Option<bool> {
41        Some(!self.profiles.is_empty())
42    }
43
44    /// Get the default profile name.
45    pub fn default_profile(&self) -> &str {
46        &self.default_profile
47    }
48
49    /// Get a profile by name.
50    pub fn get_profile(&self, name: &str) -> Option<&RouterProfile> {
51        self.profiles.get(name)
52    }
53
54    /// Get all profiles map.
55    pub fn profiles(&self) -> &HashMap<String, RouterProfile> {
56        &self.profiles
57    }
58
59    /// Get scoring weights.
60    pub fn weights(&self) -> &ScoringWeights {
61        &self.weights
62    }
63
64    /// Get the classifier model.
65    pub fn classifier_model(&self) -> Option<&str> {
66        self.classifier_model.as_deref()
67    }
68
69    /// Get context upgrade threshold.
70    pub fn context_upgrade_threshold(&self) -> Option<usize> {
71        self.context_upgrade_threshold
72    }
73
74    /// Get max session budget.
75    pub fn max_session_budget(&self) -> Option<f64> {
76        self.max_session_budget
77    }
78
79    /// Get pinned tier as string ("high", "medium", "low").
80    pub fn pin_tier(&self) -> Option<&str> {
81        self.pin_tier.as_deref()
82    }
83
84    /// Get phase bias.
85    pub fn phase_bias(&self) -> Option<f64> {
86        self.phase_bias
87    }
88}
89
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
91pub struct RouterProfile {
92    pub high: RoutedTierConfig,
93    pub medium: RoutedTierConfig,
94    pub low: RoutedTierConfig,
95}
96
97impl RouterProfile {
98    /// Get tier config by tier name.
99    pub fn tier_config(&self, tier: &str) -> Option<&RoutedTierConfig> {
100        match tier {
101            "high" => Some(&self.high),
102            "medium" => Some(&self.medium),
103            "low" => Some(&self.low),
104            _ => None,
105        }
106    }
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
110pub struct RoutedTierConfig {
111    pub model: String,
112    #[serde(default)]
113    pub thinking: Option<String>,
114    #[serde(default)]
115    pub fallbacks: Vec<String>,
116}
117
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct ScoringWeights {
120    #[serde(default = "default_structural")]
121    pub structural: f64,
122    #[serde(default = "default_behavioral")]
123    pub behavioral: f64,
124    #[serde(default = "default_context")]
125    pub context_budget: f64,
126    #[serde(default = "default_vision")]
127    pub vision: f64,
128    #[serde(default = "default_message")]
129    pub message: f64,
130}
131
132fn default_structural() -> f64 {
133    0.25
134}
135fn default_behavioral() -> f64 {
136    0.20
137}
138fn default_context() -> f64 {
139    0.15
140}
141fn default_vision() -> f64 {
142    0.10
143}
144fn default_message() -> f64 {
145    0.30
146}
147
148impl Default for ScoringWeights {
149    fn default() -> Self {
150        Self {
151            structural: default_structural(),
152            behavioral: default_behavioral(),
153            context_budget: default_context(),
154            vision: default_vision(),
155            message: default_message(),
156        }
157    }
158}
159
160/// Load router config from global and project settings directories.
161/// Returns `None` if no router config is found in either file.
162pub fn load_router_config(global_dir: &Path, project_dir: &Path) -> Option<RouterConfig> {
163    let global_path = global_dir.join("settings.toml");
164    let project_path = project_dir.join(".oxi/settings.toml");
165
166    let global_cfg = read_toml_router(&global_path);
167    let project_cfg = read_toml_router(&project_path);
168
169    let base_enabled = global_cfg.as_ref().and_then(|c| c.enabled);
170    let override_enabled = project_cfg.as_ref().and_then(|c| c.enabled);
171
172    // If neither file has `enabled = true`, skip router.
173    if base_enabled != Some(true) && override_enabled != Some(true) {
174        return None;
175    }
176
177    let default_name = project_cfg
178        .as_ref()
179        .and_then(|c| c.default_profile.clone())
180        .or_else(|| global_cfg.as_ref().and_then(|c| c.default_profile.clone()))
181        .unwrap_or_else(|| "auto".to_string());
182
183    let mut profiles: HashMap<String, RouterProfile> = HashMap::new();
184
185    // Merge global profiles.
186    if let Some(ref g) = global_cfg
187        && let Some(ref tbl) = g.profiles
188        && let Some(inner) = tbl.as_table()
189    {
190        for (name, value) in inner {
191            if let Some(profile) = parse_profile(value) {
192                profiles.insert(name.clone(), profile);
193            }
194        }
195    }
196
197    // Merge/override project profiles.
198    if let Some(ref p) = project_cfg
199        && let Some(ref tbl) = p.profiles
200        && let Some(inner) = tbl.as_table()
201    {
202        for (name, value) in inner {
203            if let Some(profile) = parse_profile(value) {
204                profiles.insert(name.clone(), profile);
205            }
206        }
207    }
208
209    if profiles.is_empty() {
210        return None;
211    }
212
213    let weights = project_cfg
214        .as_ref()
215        .and_then(|c| c.weights.as_ref())
216        .and_then(parse_weights)
217        .or_else(|| {
218            global_cfg
219                .as_ref()
220                .and_then(|c| c.weights.as_ref())
221                .and_then(parse_weights)
222        })
223        .unwrap_or_default();
224
225    let pin_tier = project_cfg
226        .as_ref()
227        .and_then(|c| c.pin_tier.as_ref())
228        .or_else(|| global_cfg.as_ref().and_then(|c| c.pin_tier.as_ref()))
229        .and_then(|s| parse_tier_str(s));
230
231    let phase_bias = project_cfg
232        .as_ref()
233        .and_then(|c| c.phase_bias)
234        .or_else(|| global_cfg.as_ref().and_then(|c| c.phase_bias));
235
236    Some(RouterConfig {
237        default_profile: default_name,
238        classifier_model: project_cfg
239            .as_ref()
240            .and_then(|c| c.classifier_model.clone())
241            .or_else(|| global_cfg.as_ref().and_then(|c| c.classifier_model.clone())),
242        context_upgrade_threshold: project_cfg
243            .as_ref()
244            .and_then(|c| c.context_upgrade_threshold)
245            .or_else(|| {
246                global_cfg
247                    .as_ref()
248                    .and_then(|c| c.context_upgrade_threshold)
249            }),
250        max_session_budget: project_cfg
251            .as_ref()
252            .and_then(|c| c.max_session_budget)
253            .or_else(|| global_cfg.as_ref().and_then(|c| c.max_session_budget)),
254        profiles,
255        weights,
256        pin_tier,
257        phase_bias,
258    })
259}
260
261fn read_toml_router(path: &Path) -> Option<RouterConfigFile> {
262    let content = std::fs::read_to_string(path).ok()?;
263    let toml: toml::Value = toml::from_str(&content).ok()?;
264    toml.get("router")?.clone().try_into().ok()
265}
266
267fn parse_profile(value: &toml::Value) -> Option<RouterProfile> {
268    let table = value.as_table()?;
269    Some(RouterProfile {
270        high: parse_tier(table.get("high"))?,
271        medium: parse_tier(table.get("medium"))?,
272        low: parse_tier(table.get("low"))?,
273    })
274}
275
276fn parse_tier(value: Option<&toml::Value>) -> Option<RoutedTierConfig> {
277    let table = value?.as_table()?;
278    Some(RoutedTierConfig {
279        model: table.get("model")?.as_str()?.to_string(),
280        thinking: table
281            .get("thinking")
282            .and_then(|v| v.as_str().map(String::from)),
283        fallbacks: table
284            .get("fallbacks")
285            .and_then(|v| v.as_array())
286            .map(|arr| {
287                arr.iter()
288                    .filter_map(|v| v.as_str().map(String::from))
289                    .collect()
290            })
291            .unwrap_or_default(),
292    })
293}
294
295fn parse_weights(value: &toml::Value) -> Option<ScoringWeights> {
296    let table = value.as_table()?;
297    Some(ScoringWeights {
298        structural: table.get("structural")?.as_float().unwrap_or(0.25),
299        behavioral: table.get("behavioral")?.as_float().unwrap_or(0.20),
300        context_budget: table.get("context_budget")?.as_float().unwrap_or(0.15),
301        vision: table
302            .get("vision")
303            .and_then(|v| v.as_float())
304            .unwrap_or(0.10),
305        message: table
306            .get("message")
307            .and_then(|v| v.as_float())
308            .unwrap_or(0.30),
309    })
310}
311
312fn parse_tier_str(s: &str) -> Option<String> {
313    match s.to_lowercase().as_str() {
314        "high" | "medium" | "low" => Some(s.to_lowercase()),
315        _ => None,
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn parse_minimal_config() {
325        let toml_str = r#"
326[router]
327enabled = true
328default_profile = "auto"
329
330[router.profiles.auto]
331high.model = "anthropic/claude-sonnet-4"
332medium.model = "anthropic/claude-sonnet-4"
333low.model = "google/gemini-2.0-flash"
334"#;
335        let value: toml::Value = toml::from_str(toml_str).unwrap();
336        let cfg: RouterConfigFile = value.get("router").unwrap().clone().try_into().unwrap();
337        assert!(cfg.enabled.is_some());
338        assert_eq!(cfg.default_profile.as_ref().unwrap(), "auto");
339    }
340
341    #[test]
342    fn returns_none_when_not_enabled() {
343        let toml_str = r#"
344[other]
345value = 1
346"#;
347        let value: toml::Value = toml::from_str(toml_str).unwrap();
348        let cfg: Option<RouterConfigFile> =
349            value.get("router").and_then(|v| v.clone().try_into().ok());
350        assert!(cfg.is_none());
351    }
352
353    #[test]
354    fn load_router_config_merges_profiles() {
355        let global_dir = tempfile::tempdir().unwrap();
356        let project_dir = tempfile::tempdir().unwrap();
357        let oxi_dir = project_dir.path().join(".oxi");
358        std::fs::create_dir_all(&oxi_dir).unwrap();
359
360        std::fs::write(
361            global_dir.path().join("settings.toml"),
362            r#"
363[router]
364enabled = true
365default_profile = "auto"
366
367[router.profiles.auto]
368high.model = "anthropic/claude-sonnet-4"
369medium.model = "anthropic/claude-haiku-4"
370low.model = "google/gemini-2.0-flash"
371"#,
372        )
373        .unwrap();
374        std::fs::write(
375            oxi_dir.join("settings.toml"),
376            r#"
377[router]
378enabled = true
379"#,
380        )
381        .unwrap();
382
383        let config = load_router_config(global_dir.path(), project_dir.path());
384        assert!(config.is_some());
385        let config = config.unwrap();
386        assert_eq!(config.default_profile, "auto");
387        assert!(config.profiles.contains_key("auto"));
388
389        // tempdir::TempDir drops clean up automatically
390    }
391}