Skip to main content

agentzero_core/
routing.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use tracing::debug;
4
5/// A model route entry mapping a hint to a specific provider+model.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelRoute {
8    pub hint: String,
9    pub provider: String,
10    pub model: String,
11    pub max_tokens: Option<usize>,
12    pub api_key: Option<String>,
13    pub transport: Option<String>,
14}
15
16/// An embedding route entry mapping a hint to an embedding provider+model.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingRoute {
19    pub hint: String,
20    pub provider: String,
21    pub model: String,
22    pub dimensions: Option<usize>,
23    pub api_key: Option<String>,
24}
25
26/// A query classification rule for automatic model routing.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ClassificationRule {
29    pub hint: String,
30    #[serde(default)]
31    pub keywords: Vec<String>,
32    #[serde(default)]
33    pub patterns: Vec<String>,
34    pub min_length: Option<usize>,
35    pub max_length: Option<usize>,
36    #[serde(default)]
37    pub priority: i32,
38}
39
40/// Resolved route for a model request.
41#[derive(Debug, Clone)]
42pub struct ResolvedRoute {
43    pub provider: String,
44    pub model: String,
45    pub max_tokens: Option<usize>,
46    pub api_key: Option<String>,
47    pub transport: Option<String>,
48    pub matched_hint: String,
49}
50
51/// Resolved route for an embedding request.
52#[derive(Debug, Clone)]
53pub struct ResolvedEmbeddingRoute {
54    pub provider: String,
55    pub model: String,
56    pub dimensions: Option<usize>,
57    pub api_key: Option<String>,
58    pub matched_hint: String,
59}
60
61/// The model router resolves hints and classifies queries to select routes.
62#[derive(Debug, Clone, Default)]
63pub struct ModelRouter {
64    pub model_routes: Vec<ModelRoute>,
65    pub embedding_routes: Vec<EmbeddingRoute>,
66    pub classification_rules: Vec<ClassificationRule>,
67    pub classification_enabled: bool,
68}
69
70impl ModelRouter {
71    /// Resolve a model route by explicit hint name.
72    pub fn resolve_hint(&self, hint: &str) -> Option<ResolvedRoute> {
73        self.model_routes
74            .iter()
75            .find(|r| r.hint.eq_ignore_ascii_case(hint))
76            .map(|r| ResolvedRoute {
77                provider: r.provider.clone(),
78                model: r.model.clone(),
79                max_tokens: r.max_tokens,
80                api_key: r.api_key.clone(),
81                transport: r.transport.clone(),
82                matched_hint: r.hint.clone(),
83            })
84    }
85
86    /// Resolve an embedding route by explicit hint name.
87    pub fn resolve_embedding_hint(&self, hint: &str) -> Option<ResolvedEmbeddingRoute> {
88        self.embedding_routes
89            .iter()
90            .find(|r| r.hint.eq_ignore_ascii_case(hint))
91            .map(|r| ResolvedEmbeddingRoute {
92                provider: r.provider.clone(),
93                model: r.model.clone(),
94                dimensions: r.dimensions,
95                api_key: r.api_key.clone(),
96                matched_hint: r.hint.clone(),
97            })
98    }
99
100    /// Classify a query and return the best matching hint.
101    pub fn classify_query(&self, query: &str) -> Option<String> {
102        if !self.classification_enabled || self.classification_rules.is_empty() {
103            return None;
104        }
105
106        let query_lower = query.to_lowercase();
107        let query_len = query.len();
108
109        let mut best_hint: Option<&str> = None;
110        let mut best_priority = i32::MIN;
111
112        for rule in &self.classification_rules {
113            // Length filters.
114            if let Some(min) = rule.min_length {
115                if query_len < min {
116                    continue;
117                }
118            }
119            if let Some(max) = rule.max_length {
120                if query_len > max {
121                    continue;
122                }
123            }
124
125            // Keyword match (any keyword present).
126            let keyword_match = rule.keywords.is_empty()
127                || rule
128                    .keywords
129                    .iter()
130                    .any(|kw| query_lower.contains(&kw.to_lowercase()));
131
132            // Pattern match (any regex matches).
133            let pattern_match = rule.patterns.is_empty()
134                || rule
135                    .patterns
136                    .iter()
137                    .any(|p| Regex::new(p).map(|re| re.is_match(query)).unwrap_or(false));
138
139            if keyword_match && pattern_match && rule.priority > best_priority {
140                best_priority = rule.priority;
141                best_hint = Some(&rule.hint);
142            }
143        }
144
145        if let Some(hint) = best_hint {
146            debug!(hint, "query classified");
147        }
148
149        best_hint.map(String::from)
150    }
151
152    /// Classify a query and resolve to a model route in one step.
153    pub fn route_query(&self, query: &str) -> Option<ResolvedRoute> {
154        self.classify_query(query)
155            .and_then(|hint| self.resolve_hint(&hint))
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    fn router() -> ModelRouter {
164        ModelRouter {
165            model_routes: vec![
166                ModelRoute {
167                    hint: "reasoning".into(),
168                    provider: "openrouter".into(),
169                    model: "anthropic/claude-opus-4-6".into(),
170                    max_tokens: Some(8192),
171                    api_key: None,
172                    transport: None,
173                },
174                ModelRoute {
175                    hint: "fast".into(),
176                    provider: "openrouter".into(),
177                    model: "anthropic/claude-haiku-4-5".into(),
178                    max_tokens: None,
179                    api_key: None,
180                    transport: None,
181                },
182                ModelRoute {
183                    hint: "code".into(),
184                    provider: "openrouter".into(),
185                    model: "anthropic/claude-sonnet-4-6".into(),
186                    max_tokens: Some(16384),
187                    api_key: None,
188                    transport: None,
189                },
190            ],
191            embedding_routes: vec![EmbeddingRoute {
192                hint: "default".into(),
193                provider: "openai".into(),
194                model: "text-embedding-3-small".into(),
195                dimensions: Some(1536),
196                api_key: None,
197            }],
198            classification_rules: vec![
199                ClassificationRule {
200                    hint: "reasoning".into(),
201                    keywords: vec!["explain".into(), "why".into(), "analyze".into()],
202                    patterns: Vec::new(),
203                    min_length: Some(50),
204                    max_length: None,
205                    priority: 10,
206                },
207                ClassificationRule {
208                    hint: "fast".into(),
209                    keywords: vec![],
210                    patterns: Vec::new(),
211                    min_length: None,
212                    max_length: Some(20),
213                    priority: 5,
214                },
215                ClassificationRule {
216                    hint: "code".into(),
217                    keywords: vec![
218                        "implement".into(),
219                        "function".into(),
220                        "code".into(),
221                        "fix".into(),
222                    ],
223                    patterns: Vec::new(),
224                    min_length: None,
225                    max_length: None,
226                    priority: 8,
227                },
228            ],
229            classification_enabled: true,
230        }
231    }
232
233    #[test]
234    fn resolve_hint_finds_matching_route() {
235        let r = router();
236        let route = r.resolve_hint("fast").unwrap();
237        assert_eq!(route.model, "anthropic/claude-haiku-4-5");
238    }
239
240    #[test]
241    fn resolve_hint_returns_none_for_unknown() {
242        let r = router();
243        assert!(r.resolve_hint("nonexistent").is_none());
244    }
245
246    #[test]
247    fn resolve_embedding_hint() {
248        let r = router();
249        let route = r.resolve_embedding_hint("default").unwrap();
250        assert_eq!(route.model, "text-embedding-3-small");
251        assert_eq!(route.dimensions, Some(1536));
252    }
253
254    #[test]
255    fn classify_query_by_keywords() {
256        let r = router();
257        let hint = r
258            .classify_query("please implement a function to parse JSON")
259            .unwrap();
260        assert_eq!(hint, "code");
261    }
262
263    #[test]
264    fn classify_query_by_length() {
265        let r = router();
266        let hint = r.classify_query("hello").unwrap();
267        assert_eq!(hint, "fast");
268    }
269
270    #[test]
271    fn classify_query_priority_wins() {
272        let r = router();
273        // Long query with "explain" and "implement" → reasoning wins (priority 10 > code 8).
274        let hint = r
275            .classify_query(
276                "explain why this function fails and implement a fix for the memory leak issue",
277            )
278            .unwrap();
279        assert_eq!(hint, "reasoning");
280    }
281
282    #[test]
283    fn classify_disabled_returns_none() {
284        let mut r = router();
285        r.classification_enabled = false;
286        assert!(r.classify_query("explain this").is_none());
287    }
288
289    #[test]
290    fn route_query_resolves_end_to_end() {
291        let r = router();
292        let route = r
293            .route_query("implement a function to sort arrays")
294            .unwrap();
295        assert_eq!(route.model, "anthropic/claude-sonnet-4-6");
296        assert_eq!(route.matched_hint, "code");
297    }
298}