1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use tracing::debug;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
8#[serde(rename_all = "lowercase")]
9pub enum PrivacyLevel {
10 Local,
12 Cloud,
14 #[default]
16 Either,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ModelRoute {
22 pub hint: String,
23 pub provider: String,
24 pub model: String,
25 pub max_tokens: Option<usize>,
26 pub api_key: Option<String>,
27 pub transport: Option<String>,
28 #[serde(default)]
30 pub privacy_level: PrivacyLevel,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct EmbeddingRoute {
36 pub hint: String,
37 pub provider: String,
38 pub model: String,
39 pub dimensions: Option<usize>,
40 pub api_key: Option<String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ClassificationRule {
46 pub hint: String,
47 #[serde(default)]
48 pub keywords: Vec<String>,
49 #[serde(default)]
50 pub patterns: Vec<String>,
51 pub min_length: Option<usize>,
52 pub max_length: Option<usize>,
53 #[serde(default)]
54 pub priority: i32,
55}
56
57#[derive(Debug, Clone)]
59pub struct ResolvedRoute {
60 pub provider: String,
61 pub model: String,
62 pub max_tokens: Option<usize>,
63 pub api_key: Option<String>,
64 pub transport: Option<String>,
65 pub matched_hint: String,
66}
67
68#[derive(Debug, Clone)]
70pub struct ResolvedEmbeddingRoute {
71 pub provider: String,
72 pub model: String,
73 pub dimensions: Option<usize>,
74 pub api_key: Option<String>,
75 pub matched_hint: String,
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct ModelRouter {
81 pub model_routes: Vec<ModelRoute>,
82 pub embedding_routes: Vec<EmbeddingRoute>,
83 pub classification_rules: Vec<ClassificationRule>,
84 pub classification_enabled: bool,
85}
86
87impl ModelRouter {
88 pub fn resolve_hint(&self, hint: &str) -> Option<ResolvedRoute> {
90 self.model_routes
91 .iter()
92 .find(|r| r.hint.eq_ignore_ascii_case(hint))
93 .map(|r| ResolvedRoute {
94 provider: r.provider.clone(),
95 model: r.model.clone(),
96 max_tokens: r.max_tokens,
97 api_key: r.api_key.clone(),
98 transport: r.transport.clone(),
99 matched_hint: r.hint.clone(),
100 })
101 }
102
103 pub fn resolve_embedding_hint(&self, hint: &str) -> Option<ResolvedEmbeddingRoute> {
105 self.embedding_routes
106 .iter()
107 .find(|r| r.hint.eq_ignore_ascii_case(hint))
108 .map(|r| ResolvedEmbeddingRoute {
109 provider: r.provider.clone(),
110 model: r.model.clone(),
111 dimensions: r.dimensions,
112 api_key: r.api_key.clone(),
113 matched_hint: r.hint.clone(),
114 })
115 }
116
117 pub fn classify_query(&self, query: &str) -> Option<String> {
119 if !self.classification_enabled || self.classification_rules.is_empty() {
120 return None;
121 }
122
123 let query_lower = query.to_lowercase();
124 let query_len = query.len();
125
126 let mut best_hint: Option<&str> = None;
127 let mut best_priority = i32::MIN;
128
129 for rule in &self.classification_rules {
130 if let Some(min) = rule.min_length {
132 if query_len < min {
133 continue;
134 }
135 }
136 if let Some(max) = rule.max_length {
137 if query_len > max {
138 continue;
139 }
140 }
141
142 let keyword_match = rule.keywords.is_empty()
144 || rule
145 .keywords
146 .iter()
147 .any(|kw| query_lower.contains(&kw.to_lowercase()));
148
149 let pattern_match = rule.patterns.is_empty()
151 || rule
152 .patterns
153 .iter()
154 .any(|p| Regex::new(p).map(|re| re.is_match(query)).unwrap_or(false));
155
156 if keyword_match && pattern_match && rule.priority > best_priority {
157 best_priority = rule.priority;
158 best_hint = Some(&rule.hint);
159 }
160 }
161
162 if let Some(hint) = best_hint {
163 debug!(hint, "query classified");
164 }
165
166 best_hint.map(String::from)
167 }
168
169 pub fn route_query(&self, query: &str) -> Option<ResolvedRoute> {
171 self.classify_query(query)
172 .and_then(|hint| self.resolve_hint(&hint))
173 }
174
175 pub fn resolve_hint_with_privacy(
181 &self,
182 hint: &str,
183 privacy_mode: &str,
184 ) -> Option<ResolvedRoute> {
185 let candidates: Vec<&ModelRoute> = self
186 .model_routes
187 .iter()
188 .filter(|r| r.hint.eq_ignore_ascii_case(hint))
189 .collect();
190
191 match privacy_mode {
192 "local_only" => candidates
193 .iter()
194 .find(|r| r.privacy_level == PrivacyLevel::Local)
195 .map(|r| self.route_to_resolved(r)),
196 "private" => {
197 candidates
199 .iter()
200 .find(|r| r.privacy_level == PrivacyLevel::Local)
201 .or_else(|| {
202 candidates
203 .iter()
204 .find(|r| r.privacy_level != PrivacyLevel::Local)
205 })
206 .map(|r| self.route_to_resolved(r))
207 }
208 _ => candidates.first().map(|r| self.route_to_resolved(r)),
209 }
210 }
211
212 pub fn route_query_with_privacy(
214 &self,
215 query: &str,
216 privacy_mode: &str,
217 ) -> Option<ResolvedRoute> {
218 self.classify_query(query)
219 .and_then(|hint| self.resolve_hint_with_privacy(&hint, privacy_mode))
220 }
221
222 fn route_to_resolved(&self, r: &ModelRoute) -> ResolvedRoute {
223 ResolvedRoute {
224 provider: r.provider.clone(),
225 model: r.model.clone(),
226 max_tokens: r.max_tokens,
227 api_key: r.api_key.clone(),
228 transport: r.transport.clone(),
229 matched_hint: r.hint.clone(),
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn router() -> ModelRouter {
239 ModelRouter {
240 model_routes: vec![
241 ModelRoute {
242 hint: "reasoning".into(),
243 provider: "openrouter".into(),
244 model: "anthropic/claude-opus-4-6".into(),
245 max_tokens: Some(8192),
246 api_key: None,
247 transport: None,
248 privacy_level: PrivacyLevel::Either,
249 },
250 ModelRoute {
251 hint: "fast".into(),
252 provider: "openrouter".into(),
253 model: "anthropic/claude-haiku-4-5".into(),
254 max_tokens: None,
255 api_key: None,
256 transport: None,
257 privacy_level: PrivacyLevel::Either,
258 },
259 ModelRoute {
260 hint: "code".into(),
261 provider: "openrouter".into(),
262 model: "anthropic/claude-sonnet-4-6".into(),
263 max_tokens: Some(16384),
264 api_key: None,
265 transport: None,
266 privacy_level: PrivacyLevel::Either,
267 },
268 ],
269 embedding_routes: vec![EmbeddingRoute {
270 hint: "default".into(),
271 provider: "openai".into(),
272 model: "text-embedding-3-small".into(),
273 dimensions: Some(1536),
274 api_key: None,
275 }],
276 classification_rules: vec![
277 ClassificationRule {
278 hint: "reasoning".into(),
279 keywords: vec!["explain".into(), "why".into(), "analyze".into()],
280 patterns: Vec::new(),
281 min_length: Some(50),
282 max_length: None,
283 priority: 10,
284 },
285 ClassificationRule {
286 hint: "fast".into(),
287 keywords: vec![],
288 patterns: Vec::new(),
289 min_length: None,
290 max_length: Some(20),
291 priority: 5,
292 },
293 ClassificationRule {
294 hint: "code".into(),
295 keywords: vec![
296 "implement".into(),
297 "function".into(),
298 "code".into(),
299 "fix".into(),
300 ],
301 patterns: Vec::new(),
302 min_length: None,
303 max_length: None,
304 priority: 8,
305 },
306 ],
307 classification_enabled: true,
308 }
309 }
310
311 #[test]
312 fn resolve_hint_finds_matching_route() {
313 let r = router();
314 let route = r.resolve_hint("fast").unwrap();
315 assert_eq!(route.model, "anthropic/claude-haiku-4-5");
316 }
317
318 #[test]
319 fn resolve_hint_returns_none_for_unknown() {
320 let r = router();
321 assert!(r.resolve_hint("nonexistent").is_none());
322 }
323
324 #[test]
325 fn resolve_embedding_hint() {
326 let r = router();
327 let route = r.resolve_embedding_hint("default").unwrap();
328 assert_eq!(route.model, "text-embedding-3-small");
329 assert_eq!(route.dimensions, Some(1536));
330 }
331
332 #[test]
333 fn classify_query_by_keywords() {
334 let r = router();
335 let hint = r
336 .classify_query("please implement a function to parse JSON")
337 .unwrap();
338 assert_eq!(hint, "code");
339 }
340
341 #[test]
342 fn classify_query_by_length() {
343 let r = router();
344 let hint = r.classify_query("hello").unwrap();
345 assert_eq!(hint, "fast");
346 }
347
348 #[test]
349 fn classify_query_priority_wins() {
350 let r = router();
351 let hint = r
353 .classify_query(
354 "explain why this function fails and implement a fix for the memory leak issue",
355 )
356 .unwrap();
357 assert_eq!(hint, "reasoning");
358 }
359
360 #[test]
361 fn classify_disabled_returns_none() {
362 let mut r = router();
363 r.classification_enabled = false;
364 assert!(r.classify_query("explain this").is_none());
365 }
366
367 #[test]
368 fn route_query_resolves_end_to_end() {
369 let r = router();
370 let route = r
371 .route_query("implement a function to sort arrays")
372 .unwrap();
373 assert_eq!(route.model, "anthropic/claude-sonnet-4-6");
374 assert_eq!(route.matched_hint, "code");
375 }
376
377 fn privacy_router() -> ModelRouter {
380 ModelRouter {
381 model_routes: vec![
382 ModelRoute {
383 hint: "fast".into(),
384 provider: "ollama".into(),
385 model: "llama3.2".into(),
386 max_tokens: None,
387 api_key: None,
388 transport: None,
389 privacy_level: PrivacyLevel::Local,
390 },
391 ModelRoute {
392 hint: "fast".into(),
393 provider: "anthropic".into(),
394 model: "claude-haiku-4-5".into(),
395 max_tokens: None,
396 api_key: None,
397 transport: None,
398 privacy_level: PrivacyLevel::Cloud,
399 },
400 ModelRoute {
401 hint: "reasoning".into(),
402 provider: "openai".into(),
403 model: "o1".into(),
404 max_tokens: Some(8192),
405 api_key: None,
406 transport: None,
407 privacy_level: PrivacyLevel::Either,
408 },
409 ],
410 embedding_routes: vec![],
411 classification_rules: vec![],
412 classification_enabled: false,
413 }
414 }
415
416 #[test]
417 fn private_mode_prefers_local_route() {
418 let r = privacy_router();
419 let route = r
420 .resolve_hint_with_privacy("fast", "private")
421 .expect("should resolve");
422 assert_eq!(route.provider, "ollama", "private should prefer local");
423 }
424
425 #[test]
426 fn private_mode_falls_through_to_cloud() {
427 let r = privacy_router();
428 let route = r
430 .resolve_hint_with_privacy("reasoning", "private")
431 .expect("should fall through");
432 assert_eq!(route.provider, "openai");
433 }
434
435 #[test]
436 fn local_only_blocks_cloud_routes() {
437 let r = privacy_router();
438 let route = r.resolve_hint_with_privacy("fast", "local_only");
439 assert_eq!(
440 route.as_ref().map(|r| r.provider.as_str()),
441 Some("ollama"),
442 "local_only should only return local routes"
443 );
444 }
445
446 #[test]
447 fn local_only_returns_none_for_cloud_only() {
448 let r = ModelRouter {
449 model_routes: vec![ModelRoute {
450 hint: "cloud-only".into(),
451 provider: "anthropic".into(),
452 model: "claude-sonnet-4-6".into(),
453 max_tokens: None,
454 api_key: None,
455 transport: None,
456 privacy_level: PrivacyLevel::Cloud,
457 }],
458 ..Default::default()
459 };
460 assert!(
461 r.resolve_hint_with_privacy("cloud-only", "local_only")
462 .is_none(),
463 "local_only should not resolve cloud-only routes"
464 );
465 }
466
467 #[test]
468 fn off_mode_allows_all_routes() {
469 let r = privacy_router();
470 let route = r
471 .resolve_hint_with_privacy("fast", "off")
472 .expect("should resolve");
473 assert!(!route.provider.is_empty());
475 }
476
477 #[test]
478 fn privacy_level_defaults_to_either() {
479 let r = router(); let route = r
482 .resolve_hint_with_privacy("fast", "off")
483 .expect("Either routes should be available in off mode");
484 assert_eq!(route.provider, "openrouter");
485 let route2 = r
487 .resolve_hint_with_privacy("fast", "private")
488 .expect("Either routes should be available in private mode");
489 assert_eq!(route2.provider, "openrouter");
490 }
491}