Skip to main content

agentzero_core/
routing.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use tracing::debug;
4
5/// Privacy level for model routes — controls which routes are eligible
6/// based on the active privacy mode.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
8#[serde(rename_all = "lowercase")]
9pub enum PrivacyLevel {
10    /// Only eligible when running locally (local_only mode).
11    Local,
12    /// Only eligible when cloud access is allowed.
13    Cloud,
14    /// Eligible in any privacy mode (default).
15    #[default]
16    Either,
17}
18
19/// A model route entry mapping a hint to a specific provider+model.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ModelRoute {
22    pub hint: String,
23    pub provider: String,
24    pub model: String,
25    pub max_tokens: Option<usize>,
26    pub api_key: Option<String>,
27    pub transport: Option<String>,
28    /// Privacy level controlling route eligibility by privacy mode.
29    #[serde(default)]
30    pub privacy_level: PrivacyLevel,
31}
32
33/// An embedding route entry mapping a hint to an embedding provider+model.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct EmbeddingRoute {
36    pub hint: String,
37    pub provider: String,
38    pub model: String,
39    pub dimensions: Option<usize>,
40    pub api_key: Option<String>,
41}
42
43/// A query classification rule for automatic model routing.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ClassificationRule {
46    pub hint: String,
47    #[serde(default)]
48    pub keywords: Vec<String>,
49    #[serde(default)]
50    pub patterns: Vec<String>,
51    pub min_length: Option<usize>,
52    pub max_length: Option<usize>,
53    #[serde(default)]
54    pub priority: i32,
55}
56
57/// Resolved route for a model request.
58#[derive(Debug, Clone)]
59pub struct ResolvedRoute {
60    pub provider: String,
61    pub model: String,
62    pub max_tokens: Option<usize>,
63    pub api_key: Option<String>,
64    pub transport: Option<String>,
65    pub matched_hint: String,
66}
67
68/// Resolved route for an embedding request.
69#[derive(Debug, Clone)]
70pub struct ResolvedEmbeddingRoute {
71    pub provider: String,
72    pub model: String,
73    pub dimensions: Option<usize>,
74    pub api_key: Option<String>,
75    pub matched_hint: String,
76}
77
78/// The model router resolves hints and classifies queries to select routes.
79#[derive(Debug, Clone, Default)]
80pub struct ModelRouter {
81    pub model_routes: Vec<ModelRoute>,
82    pub embedding_routes: Vec<EmbeddingRoute>,
83    pub classification_rules: Vec<ClassificationRule>,
84    pub classification_enabled: bool,
85}
86
87impl ModelRouter {
88    /// Resolve a model route by explicit hint name.
89    pub fn resolve_hint(&self, hint: &str) -> Option<ResolvedRoute> {
90        self.model_routes
91            .iter()
92            .find(|r| r.hint.eq_ignore_ascii_case(hint))
93            .map(|r| ResolvedRoute {
94                provider: r.provider.clone(),
95                model: r.model.clone(),
96                max_tokens: r.max_tokens,
97                api_key: r.api_key.clone(),
98                transport: r.transport.clone(),
99                matched_hint: r.hint.clone(),
100            })
101    }
102
103    /// Resolve an embedding route by explicit hint name.
104    pub fn resolve_embedding_hint(&self, hint: &str) -> Option<ResolvedEmbeddingRoute> {
105        self.embedding_routes
106            .iter()
107            .find(|r| r.hint.eq_ignore_ascii_case(hint))
108            .map(|r| ResolvedEmbeddingRoute {
109                provider: r.provider.clone(),
110                model: r.model.clone(),
111                dimensions: r.dimensions,
112                api_key: r.api_key.clone(),
113                matched_hint: r.hint.clone(),
114            })
115    }
116
117    /// Classify a query and return the best matching hint.
118    pub fn classify_query(&self, query: &str) -> Option<String> {
119        if !self.classification_enabled || self.classification_rules.is_empty() {
120            return None;
121        }
122
123        let query_lower = query.to_lowercase();
124        let query_len = query.len();
125
126        let mut best_hint: Option<&str> = None;
127        let mut best_priority = i32::MIN;
128
129        for rule in &self.classification_rules {
130            // Length filters.
131            if let Some(min) = rule.min_length {
132                if query_len < min {
133                    continue;
134                }
135            }
136            if let Some(max) = rule.max_length {
137                if query_len > max {
138                    continue;
139                }
140            }
141
142            // Keyword match (any keyword present).
143            let keyword_match = rule.keywords.is_empty()
144                || rule
145                    .keywords
146                    .iter()
147                    .any(|kw| query_lower.contains(&kw.to_lowercase()));
148
149            // Pattern match (any regex matches).
150            let pattern_match = rule.patterns.is_empty()
151                || rule
152                    .patterns
153                    .iter()
154                    .any(|p| Regex::new(p).map(|re| re.is_match(query)).unwrap_or(false));
155
156            if keyword_match && pattern_match && rule.priority > best_priority {
157                best_priority = rule.priority;
158                best_hint = Some(&rule.hint);
159            }
160        }
161
162        if let Some(hint) = best_hint {
163            debug!(hint, "query classified");
164        }
165
166        best_hint.map(String::from)
167    }
168
169    /// Classify a query and resolve to a model route in one step.
170    pub fn route_query(&self, query: &str) -> Option<ResolvedRoute> {
171        self.classify_query(query)
172            .and_then(|hint| self.resolve_hint(&hint))
173    }
174
175    /// Resolve a hint with privacy filtering.
176    ///
177    /// - `"local_only"`: only `Local` routes
178    /// - `"private"`: prefer `Local`, fall through to `Cloud`
179    /// - `"off"` / other: all routes (current behavior)
180    pub fn resolve_hint_with_privacy(
181        &self,
182        hint: &str,
183        privacy_mode: &str,
184    ) -> Option<ResolvedRoute> {
185        let candidates: Vec<&ModelRoute> = self
186            .model_routes
187            .iter()
188            .filter(|r| r.hint.eq_ignore_ascii_case(hint))
189            .collect();
190
191        match privacy_mode {
192            "local_only" => candidates
193                .iter()
194                .find(|r| r.privacy_level == PrivacyLevel::Local)
195                .map(|r| self.route_to_resolved(r)),
196            "private" => {
197                // Prefer local, fall through to cloud/either.
198                candidates
199                    .iter()
200                    .find(|r| r.privacy_level == PrivacyLevel::Local)
201                    .or_else(|| {
202                        candidates
203                            .iter()
204                            .find(|r| r.privacy_level != PrivacyLevel::Local)
205                    })
206                    .map(|r| self.route_to_resolved(r))
207            }
208            _ => candidates.first().map(|r| self.route_to_resolved(r)),
209        }
210    }
211
212    /// Classify a query and resolve with privacy filtering.
213    pub fn route_query_with_privacy(
214        &self,
215        query: &str,
216        privacy_mode: &str,
217    ) -> Option<ResolvedRoute> {
218        self.classify_query(query)
219            .and_then(|hint| self.resolve_hint_with_privacy(&hint, privacy_mode))
220    }
221
222    fn route_to_resolved(&self, r: &ModelRoute) -> ResolvedRoute {
223        ResolvedRoute {
224            provider: r.provider.clone(),
225            model: r.model.clone(),
226            max_tokens: r.max_tokens,
227            api_key: r.api_key.clone(),
228            transport: r.transport.clone(),
229            matched_hint: r.hint.clone(),
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn router() -> ModelRouter {
239        ModelRouter {
240            model_routes: vec![
241                ModelRoute {
242                    hint: "reasoning".into(),
243                    provider: "openrouter".into(),
244                    model: "anthropic/claude-opus-4-6".into(),
245                    max_tokens: Some(8192),
246                    api_key: None,
247                    transport: None,
248                    privacy_level: PrivacyLevel::Either,
249                },
250                ModelRoute {
251                    hint: "fast".into(),
252                    provider: "openrouter".into(),
253                    model: "anthropic/claude-haiku-4-5".into(),
254                    max_tokens: None,
255                    api_key: None,
256                    transport: None,
257                    privacy_level: PrivacyLevel::Either,
258                },
259                ModelRoute {
260                    hint: "code".into(),
261                    provider: "openrouter".into(),
262                    model: "anthropic/claude-sonnet-4-6".into(),
263                    max_tokens: Some(16384),
264                    api_key: None,
265                    transport: None,
266                    privacy_level: PrivacyLevel::Either,
267                },
268            ],
269            embedding_routes: vec![EmbeddingRoute {
270                hint: "default".into(),
271                provider: "openai".into(),
272                model: "text-embedding-3-small".into(),
273                dimensions: Some(1536),
274                api_key: None,
275            }],
276            classification_rules: vec![
277                ClassificationRule {
278                    hint: "reasoning".into(),
279                    keywords: vec!["explain".into(), "why".into(), "analyze".into()],
280                    patterns: Vec::new(),
281                    min_length: Some(50),
282                    max_length: None,
283                    priority: 10,
284                },
285                ClassificationRule {
286                    hint: "fast".into(),
287                    keywords: vec![],
288                    patterns: Vec::new(),
289                    min_length: None,
290                    max_length: Some(20),
291                    priority: 5,
292                },
293                ClassificationRule {
294                    hint: "code".into(),
295                    keywords: vec![
296                        "implement".into(),
297                        "function".into(),
298                        "code".into(),
299                        "fix".into(),
300                    ],
301                    patterns: Vec::new(),
302                    min_length: None,
303                    max_length: None,
304                    priority: 8,
305                },
306            ],
307            classification_enabled: true,
308        }
309    }
310
311    #[test]
312    fn resolve_hint_finds_matching_route() {
313        let r = router();
314        let route = r.resolve_hint("fast").unwrap();
315        assert_eq!(route.model, "anthropic/claude-haiku-4-5");
316    }
317
318    #[test]
319    fn resolve_hint_returns_none_for_unknown() {
320        let r = router();
321        assert!(r.resolve_hint("nonexistent").is_none());
322    }
323
324    #[test]
325    fn resolve_embedding_hint() {
326        let r = router();
327        let route = r.resolve_embedding_hint("default").unwrap();
328        assert_eq!(route.model, "text-embedding-3-small");
329        assert_eq!(route.dimensions, Some(1536));
330    }
331
332    #[test]
333    fn classify_query_by_keywords() {
334        let r = router();
335        let hint = r
336            .classify_query("please implement a function to parse JSON")
337            .unwrap();
338        assert_eq!(hint, "code");
339    }
340
341    #[test]
342    fn classify_query_by_length() {
343        let r = router();
344        let hint = r.classify_query("hello").unwrap();
345        assert_eq!(hint, "fast");
346    }
347
348    #[test]
349    fn classify_query_priority_wins() {
350        let r = router();
351        // Long query with "explain" and "implement" → reasoning wins (priority 10 > code 8).
352        let hint = r
353            .classify_query(
354                "explain why this function fails and implement a fix for the memory leak issue",
355            )
356            .unwrap();
357        assert_eq!(hint, "reasoning");
358    }
359
360    #[test]
361    fn classify_disabled_returns_none() {
362        let mut r = router();
363        r.classification_enabled = false;
364        assert!(r.classify_query("explain this").is_none());
365    }
366
367    #[test]
368    fn route_query_resolves_end_to_end() {
369        let r = router();
370        let route = r
371            .route_query("implement a function to sort arrays")
372            .unwrap();
373        assert_eq!(route.model, "anthropic/claude-sonnet-4-6");
374        assert_eq!(route.matched_hint, "code");
375    }
376
377    // --- Privacy-aware routing tests ---
378
379    fn privacy_router() -> ModelRouter {
380        ModelRouter {
381            model_routes: vec![
382                ModelRoute {
383                    hint: "fast".into(),
384                    provider: "ollama".into(),
385                    model: "llama3.2".into(),
386                    max_tokens: None,
387                    api_key: None,
388                    transport: None,
389                    privacy_level: PrivacyLevel::Local,
390                },
391                ModelRoute {
392                    hint: "fast".into(),
393                    provider: "anthropic".into(),
394                    model: "claude-haiku-4-5".into(),
395                    max_tokens: None,
396                    api_key: None,
397                    transport: None,
398                    privacy_level: PrivacyLevel::Cloud,
399                },
400                ModelRoute {
401                    hint: "reasoning".into(),
402                    provider: "openai".into(),
403                    model: "o1".into(),
404                    max_tokens: Some(8192),
405                    api_key: None,
406                    transport: None,
407                    privacy_level: PrivacyLevel::Either,
408                },
409            ],
410            embedding_routes: vec![],
411            classification_rules: vec![],
412            classification_enabled: false,
413        }
414    }
415
416    #[test]
417    fn private_mode_prefers_local_route() {
418        let r = privacy_router();
419        let route = r
420            .resolve_hint_with_privacy("fast", "private")
421            .expect("should resolve");
422        assert_eq!(route.provider, "ollama", "private should prefer local");
423    }
424
425    #[test]
426    fn private_mode_falls_through_to_cloud() {
427        let r = privacy_router();
428        // "reasoning" has no Local route, only Either — should still resolve.
429        let route = r
430            .resolve_hint_with_privacy("reasoning", "private")
431            .expect("should fall through");
432        assert_eq!(route.provider, "openai");
433    }
434
435    #[test]
436    fn local_only_blocks_cloud_routes() {
437        let r = privacy_router();
438        let route = r.resolve_hint_with_privacy("fast", "local_only");
439        assert_eq!(
440            route.as_ref().map(|r| r.provider.as_str()),
441            Some("ollama"),
442            "local_only should only return local routes"
443        );
444    }
445
446    #[test]
447    fn local_only_returns_none_for_cloud_only() {
448        let r = ModelRouter {
449            model_routes: vec![ModelRoute {
450                hint: "cloud-only".into(),
451                provider: "anthropic".into(),
452                model: "claude-sonnet-4-6".into(),
453                max_tokens: None,
454                api_key: None,
455                transport: None,
456                privacy_level: PrivacyLevel::Cloud,
457            }],
458            ..Default::default()
459        };
460        assert!(
461            r.resolve_hint_with_privacy("cloud-only", "local_only")
462                .is_none(),
463            "local_only should not resolve cloud-only routes"
464        );
465    }
466
467    #[test]
468    fn off_mode_allows_all_routes() {
469        let r = privacy_router();
470        let route = r
471            .resolve_hint_with_privacy("fast", "off")
472            .expect("should resolve");
473        // "off" mode returns the first matching route (local in this case).
474        assert!(!route.provider.is_empty());
475    }
476
477    #[test]
478    fn privacy_level_defaults_to_either() {
479        let r = router(); // Uses the existing test router (privacy_level = Either)
480                          // "off" mode accepts all privacy levels including Either.
481        let route = r
482            .resolve_hint_with_privacy("fast", "off")
483            .expect("Either routes should be available in off mode");
484        assert_eq!(route.provider, "openrouter");
485        // "private" also accepts Either (falls through from Local).
486        let route2 = r
487            .resolve_hint_with_privacy("fast", "private")
488            .expect("Either routes should be available in private mode");
489        assert_eq!(route2.provider, "openrouter");
490    }
491}