Skip to main content

oxi/store/
router_config.rs

1//! Router configuration loading for oxi-store.
2//!
3//! Reads the `[router]` section from global and project settings.toml
4//! and merges them (project overrides global).
5
6#![allow(missing_docs)]
7
8use std::collections::HashMap;
9use std::path::Path;
10
11/// TOML representation of the router config section.
12#[derive(Debug, Clone, serde::Deserialize, Default)]
13pub struct RouterConfigFile {
14    pub enabled: Option<bool>,
15    pub default_profile: Option<String>,
16    pub classifier_model: Option<String>,
17    pub context_upgrade_threshold: Option<usize>,
18    pub max_session_budget: Option<f64>,
19    pub profiles: Option<toml::Value>,
20    pub weights: Option<toml::Value>,
21    pub pin_tier: Option<String>,
22    pub phase_bias: Option<f64>,
23}
24
25/// Fully resolved router config with all required fields.
26#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
27pub struct RouterConfig {
28    default_profile: String,
29    classifier_model: Option<String>,
30    context_upgrade_threshold: Option<usize>,
31    max_session_budget: Option<f64>,
32    profiles: HashMap<String, RouterProfile>,
33    weights: ScoringWeights,
34    pin_tier: Option<String>,
35    phase_bias: Option<f64>,
36}
37
38impl RouterConfig {
39    /// Returns whether router is enabled (default: true).
40    pub fn enabled(&self) -> Option<bool> {
41        Some(!self.profiles.is_empty())
42    }
43
44    /// Get the default profile name.
45    pub fn default_profile(&self) -> &str {
46        &self.default_profile
47    }
48
49    /// Get a profile by name.
50    pub fn get_profile(&self, name: &str) -> Option<&RouterProfile> {
51        self.profiles.get(name)
52    }
53
54    /// Get all profiles map.
55    pub fn profiles(&self) -> &HashMap<String, RouterProfile> {
56        &self.profiles
57    }
58
59    /// Get scoring weights.
60    pub fn weights(&self) -> &ScoringWeights {
61        &self.weights
62    }
63
64    /// Get the classifier model.
65    pub fn classifier_model(&self) -> Option<&str> {
66        self.classifier_model.as_deref()
67    }
68
69    /// Get context upgrade threshold.
70    pub fn context_upgrade_threshold(&self) -> Option<usize> {
71        self.context_upgrade_threshold
72    }
73
74    /// Get max session budget.
75    pub fn max_session_budget(&self) -> Option<f64> {
76        self.max_session_budget
77    }
78
79    /// Get pinned tier as string ("high", "medium", "low").
80    pub fn pin_tier(&self) -> Option<&str> {
81        self.pin_tier.as_deref()
82    }
83
84    /// Get phase bias.
85    pub fn phase_bias(&self) -> Option<f64> {
86        self.phase_bias
87    }
88}
89
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
91pub struct RouterProfile {
92    pub high: RoutedTierConfig,
93    pub medium: RoutedTierConfig,
94    pub low: RoutedTierConfig,
95}
96
97impl RouterProfile {
98    /// Get tier config by tier name.
99    pub fn tier_config(&self, tier: &str) -> Option<&RoutedTierConfig> {
100        match tier {
101            "high" => Some(&self.high),
102            "medium" => Some(&self.medium),
103            "low" => Some(&self.low),
104            _ => None,
105        }
106    }
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
110pub struct RoutedTierConfig {
111    pub model: String,
112    #[serde(default)]
113    pub thinking: Option<String>,
114    #[serde(default)]
115    pub fallbacks: Vec<String>,
116}
117
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct ScoringWeights {
120    #[serde(default = "default_structural")]
121    pub structural: f64,
122    #[serde(default = "default_behavioral")]
123    pub behavioral: f64,
124    #[serde(default = "default_context")]
125    pub context_budget: f64,
126    #[serde(default = "default_vision")]
127    pub vision: f64,
128    #[serde(default = "default_message")]
129    pub message: f64,
130}
131
132fn default_structural() -> f64 {
133    0.25
134}
135fn default_behavioral() -> f64 {
136    0.20
137}
138fn default_context() -> f64 {
139    0.15
140}
141fn default_vision() -> f64 {
142    0.10
143}
144fn default_message() -> f64 {
145    0.30
146}
147
148impl Default for ScoringWeights {
149    fn default() -> Self {
150        Self {
151            structural: default_structural(),
152            behavioral: default_behavioral(),
153            context_budget: default_context(),
154            vision: default_vision(),
155            message: default_message(),
156        }
157    }
158}
159
160/// Load router config from global and project settings directories.
161/// Returns `None` if no router config is found in either file.
162pub fn load_router_config(global_dir: &Path, project_dir: &Path) -> Option<RouterConfig> {
163    let global_path = global_dir.join("settings.toml");
164    let project_path = project_dir.join(".oxi/settings.toml");
165
166    let global_cfg = read_toml_router(&global_path);
167    let project_cfg = read_toml_router(&project_path);
168
169    let base_enabled = global_cfg.as_ref().and_then(|c| c.enabled);
170    let override_enabled = project_cfg.as_ref().and_then(|c| c.enabled);
171
172    // If neither file has `enabled = true`, skip router.
173    if base_enabled != Some(true) && override_enabled != Some(true) {
174        return None;
175    }
176
177    let default_name = project_cfg
178        .as_ref()
179        .and_then(|c| c.default_profile.clone())
180        .or_else(|| global_cfg.as_ref().and_then(|c| c.default_profile.clone()))
181        .unwrap_or_else(|| "auto".to_string());
182
183    let mut profiles: HashMap<String, RouterProfile> = HashMap::new();
184
185    // Merge global profiles.
186    if let Some(ref g) = global_cfg {
187        if let Some(ref tbl) = g.profiles {
188            if let Some(inner) = tbl.as_table() {
189                for (name, value) in inner {
190                    if let Some(profile) = parse_profile(value) {
191                        profiles.insert(name.clone(), profile);
192                    }
193                }
194            }
195        }
196    }
197
198    // Merge/override project profiles.
199    if let Some(ref p) = project_cfg {
200        if let Some(ref tbl) = p.profiles {
201            if let Some(inner) = tbl.as_table() {
202                for (name, value) in inner {
203                    if let Some(profile) = parse_profile(value) {
204                        profiles.insert(name.clone(), profile);
205                    }
206                }
207            }
208        }
209    }
210
211    if profiles.is_empty() {
212        return None;
213    }
214
215    let weights = project_cfg
216        .as_ref()
217        .and_then(|c| c.weights.as_ref())
218        .and_then(parse_weights)
219        .or_else(|| {
220            global_cfg
221                .as_ref()
222                .and_then(|c| c.weights.as_ref())
223                .and_then(parse_weights)
224        })
225        .unwrap_or_default();
226
227    let pin_tier = project_cfg
228        .as_ref()
229        .and_then(|c| c.pin_tier.as_ref())
230        .or_else(|| global_cfg.as_ref().and_then(|c| c.pin_tier.as_ref()))
231        .and_then(|s| parse_tier_str(s));
232
233    let phase_bias = project_cfg
234        .as_ref()
235        .and_then(|c| c.phase_bias)
236        .or_else(|| global_cfg.as_ref().and_then(|c| c.phase_bias));
237
238    Some(RouterConfig {
239        default_profile: default_name,
240        classifier_model: project_cfg
241            .as_ref()
242            .and_then(|c| c.classifier_model.clone())
243            .or_else(|| global_cfg.as_ref().and_then(|c| c.classifier_model.clone())),
244        context_upgrade_threshold: project_cfg
245            .as_ref()
246            .and_then(|c| c.context_upgrade_threshold)
247            .or_else(|| {
248                global_cfg
249                    .as_ref()
250                    .and_then(|c| c.context_upgrade_threshold)
251            }),
252        max_session_budget: project_cfg
253            .as_ref()
254            .and_then(|c| c.max_session_budget)
255            .or_else(|| global_cfg.as_ref().and_then(|c| c.max_session_budget)),
256        profiles,
257        weights,
258        pin_tier,
259        phase_bias,
260    })
261}
262
263fn read_toml_router(path: &Path) -> Option<RouterConfigFile> {
264    let content = std::fs::read_to_string(path).ok()?;
265    let toml: toml::Value = toml::from_str(&content).ok()?;
266    toml.get("router")?.clone().try_into().ok()
267}
268
269fn parse_profile(value: &toml::Value) -> Option<RouterProfile> {
270    let table = value.as_table()?;
271    Some(RouterProfile {
272        high: parse_tier(table.get("high"))?,
273        medium: parse_tier(table.get("medium"))?,
274        low: parse_tier(table.get("low"))?,
275    })
276}
277
278fn parse_tier(value: Option<&toml::Value>) -> Option<RoutedTierConfig> {
279    let table = value?.as_table()?;
280    Some(RoutedTierConfig {
281        model: table.get("model")?.as_str()?.to_string(),
282        thinking: table
283            .get("thinking")
284            .and_then(|v| v.as_str().map(String::from)),
285        fallbacks: table
286            .get("fallbacks")
287            .and_then(|v| v.as_array())
288            .map(|arr| {
289                arr.iter()
290                    .filter_map(|v| v.as_str().map(String::from))
291                    .collect()
292            })
293            .unwrap_or_default(),
294    })
295}
296
297fn parse_weights(value: &toml::Value) -> Option<ScoringWeights> {
298    let table = value.as_table()?;
299    Some(ScoringWeights {
300        structural: table.get("structural")?.as_float().unwrap_or(0.25),
301        behavioral: table.get("behavioral")?.as_float().unwrap_or(0.20),
302        context_budget: table.get("context_budget")?.as_float().unwrap_or(0.15),
303        vision: table
304            .get("vision")
305            .and_then(|v| v.as_float())
306            .unwrap_or(0.10),
307        message: table
308            .get("message")
309            .and_then(|v| v.as_float())
310            .unwrap_or(0.30),
311    })
312}
313
314fn parse_tier_str(s: &str) -> Option<String> {
315    match s.to_lowercase().as_str() {
316        "high" | "medium" | "low" => Some(s.to_lowercase()),
317        _ => None,
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn parse_minimal_config() {
327        let toml_str = r#"
328[router]
329enabled = true
330default_profile = "auto"
331
332[router.profiles.auto]
333high.model = "anthropic/claude-sonnet-4"
334medium.model = "anthropic/claude-sonnet-4"
335low.model = "google/gemini-2.0-flash"
336"#;
337        let value: toml::Value = toml::from_str(toml_str).unwrap();
338        let cfg: RouterConfigFile = value.get("router").unwrap().clone().try_into().unwrap();
339        assert!(cfg.enabled.is_some());
340        assert_eq!(cfg.default_profile.as_ref().unwrap(), "auto");
341    }
342
343    #[test]
344    fn returns_none_when_not_enabled() {
345        let toml_str = r#"
346[other]
347value = 1
348"#;
349        let value: toml::Value = toml::from_str(toml_str).unwrap();
350        let cfg: Option<RouterConfigFile> =
351            value.get("router").and_then(|v| v.clone().try_into().ok());
352        assert!(cfg.is_none());
353    }
354
355    #[test]
356    fn load_router_config_merges_profiles() {
357        let global_dir = tempfile::tempdir().unwrap();
358        let project_dir = tempfile::tempdir().unwrap();
359        let oxi_dir = project_dir.path().join(".oxi");
360        std::fs::create_dir_all(&oxi_dir).unwrap();
361
362        std::fs::write(
363            global_dir.path().join("settings.toml"),
364            r#"
365[router]
366enabled = true
367default_profile = "auto"
368
369[router.profiles.auto]
370high.model = "anthropic/claude-sonnet-4"
371medium.model = "anthropic/claude-haiku-4"
372low.model = "google/gemini-2.0-flash"
373"#,
374        )
375        .unwrap();
376        std::fs::write(
377            oxi_dir.join("settings.toml"),
378            r#"
379[router]
380enabled = true
381"#,
382        )
383        .unwrap();
384
385        let config = load_router_config(global_dir.path(), project_dir.path());
386        assert!(config.is_some());
387        let config = config.unwrap();
388        assert_eq!(config.default_profile, "auto");
389        assert!(config.profiles.contains_key("auto"));
390
391        // tempdir::TempDir drops clean up automatically
392    }
393}