1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use tracing::debug;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelRoute {
8 pub hint: String,
9 pub provider: String,
10 pub model: String,
11 pub max_tokens: Option<usize>,
12 pub api_key: Option<String>,
13 pub transport: Option<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EmbeddingRoute {
19 pub hint: String,
20 pub provider: String,
21 pub model: String,
22 pub dimensions: Option<usize>,
23 pub api_key: Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ClassificationRule {
29 pub hint: String,
30 #[serde(default)]
31 pub keywords: Vec<String>,
32 #[serde(default)]
33 pub patterns: Vec<String>,
34 pub min_length: Option<usize>,
35 pub max_length: Option<usize>,
36 #[serde(default)]
37 pub priority: i32,
38}
39
40#[derive(Debug, Clone)]
42pub struct ResolvedRoute {
43 pub provider: String,
44 pub model: String,
45 pub max_tokens: Option<usize>,
46 pub api_key: Option<String>,
47 pub transport: Option<String>,
48 pub matched_hint: String,
49}
50
51#[derive(Debug, Clone)]
53pub struct ResolvedEmbeddingRoute {
54 pub provider: String,
55 pub model: String,
56 pub dimensions: Option<usize>,
57 pub api_key: Option<String>,
58 pub matched_hint: String,
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct ModelRouter {
64 pub model_routes: Vec<ModelRoute>,
65 pub embedding_routes: Vec<EmbeddingRoute>,
66 pub classification_rules: Vec<ClassificationRule>,
67 pub classification_enabled: bool,
68}
69
70impl ModelRouter {
71 pub fn resolve_hint(&self, hint: &str) -> Option<ResolvedRoute> {
73 self.model_routes
74 .iter()
75 .find(|r| r.hint.eq_ignore_ascii_case(hint))
76 .map(|r| ResolvedRoute {
77 provider: r.provider.clone(),
78 model: r.model.clone(),
79 max_tokens: r.max_tokens,
80 api_key: r.api_key.clone(),
81 transport: r.transport.clone(),
82 matched_hint: r.hint.clone(),
83 })
84 }
85
86 pub fn resolve_embedding_hint(&self, hint: &str) -> Option<ResolvedEmbeddingRoute> {
88 self.embedding_routes
89 .iter()
90 .find(|r| r.hint.eq_ignore_ascii_case(hint))
91 .map(|r| ResolvedEmbeddingRoute {
92 provider: r.provider.clone(),
93 model: r.model.clone(),
94 dimensions: r.dimensions,
95 api_key: r.api_key.clone(),
96 matched_hint: r.hint.clone(),
97 })
98 }
99
100 pub fn classify_query(&self, query: &str) -> Option<String> {
102 if !self.classification_enabled || self.classification_rules.is_empty() {
103 return None;
104 }
105
106 let query_lower = query.to_lowercase();
107 let query_len = query.len();
108
109 let mut best_hint: Option<&str> = None;
110 let mut best_priority = i32::MIN;
111
112 for rule in &self.classification_rules {
113 if let Some(min) = rule.min_length {
115 if query_len < min {
116 continue;
117 }
118 }
119 if let Some(max) = rule.max_length {
120 if query_len > max {
121 continue;
122 }
123 }
124
125 let keyword_match = rule.keywords.is_empty()
127 || rule
128 .keywords
129 .iter()
130 .any(|kw| query_lower.contains(&kw.to_lowercase()));
131
132 let pattern_match = rule.patterns.is_empty()
134 || rule
135 .patterns
136 .iter()
137 .any(|p| Regex::new(p).map(|re| re.is_match(query)).unwrap_or(false));
138
139 if keyword_match && pattern_match && rule.priority > best_priority {
140 best_priority = rule.priority;
141 best_hint = Some(&rule.hint);
142 }
143 }
144
145 if let Some(hint) = best_hint {
146 debug!(hint, "query classified");
147 }
148
149 best_hint.map(String::from)
150 }
151
152 pub fn route_query(&self, query: &str) -> Option<ResolvedRoute> {
154 self.classify_query(query)
155 .and_then(|hint| self.resolve_hint(&hint))
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 fn router() -> ModelRouter {
164 ModelRouter {
165 model_routes: vec![
166 ModelRoute {
167 hint: "reasoning".into(),
168 provider: "openrouter".into(),
169 model: "anthropic/claude-opus-4-6".into(),
170 max_tokens: Some(8192),
171 api_key: None,
172 transport: None,
173 },
174 ModelRoute {
175 hint: "fast".into(),
176 provider: "openrouter".into(),
177 model: "anthropic/claude-haiku-4-5".into(),
178 max_tokens: None,
179 api_key: None,
180 transport: None,
181 },
182 ModelRoute {
183 hint: "code".into(),
184 provider: "openrouter".into(),
185 model: "anthropic/claude-sonnet-4-6".into(),
186 max_tokens: Some(16384),
187 api_key: None,
188 transport: None,
189 },
190 ],
191 embedding_routes: vec![EmbeddingRoute {
192 hint: "default".into(),
193 provider: "openai".into(),
194 model: "text-embedding-3-small".into(),
195 dimensions: Some(1536),
196 api_key: None,
197 }],
198 classification_rules: vec![
199 ClassificationRule {
200 hint: "reasoning".into(),
201 keywords: vec!["explain".into(), "why".into(), "analyze".into()],
202 patterns: Vec::new(),
203 min_length: Some(50),
204 max_length: None,
205 priority: 10,
206 },
207 ClassificationRule {
208 hint: "fast".into(),
209 keywords: vec![],
210 patterns: Vec::new(),
211 min_length: None,
212 max_length: Some(20),
213 priority: 5,
214 },
215 ClassificationRule {
216 hint: "code".into(),
217 keywords: vec![
218 "implement".into(),
219 "function".into(),
220 "code".into(),
221 "fix".into(),
222 ],
223 patterns: Vec::new(),
224 min_length: None,
225 max_length: None,
226 priority: 8,
227 },
228 ],
229 classification_enabled: true,
230 }
231 }
232
233 #[test]
234 fn resolve_hint_finds_matching_route() {
235 let r = router();
236 let route = r.resolve_hint("fast").unwrap();
237 assert_eq!(route.model, "anthropic/claude-haiku-4-5");
238 }
239
240 #[test]
241 fn resolve_hint_returns_none_for_unknown() {
242 let r = router();
243 assert!(r.resolve_hint("nonexistent").is_none());
244 }
245
246 #[test]
247 fn resolve_embedding_hint() {
248 let r = router();
249 let route = r.resolve_embedding_hint("default").unwrap();
250 assert_eq!(route.model, "text-embedding-3-small");
251 assert_eq!(route.dimensions, Some(1536));
252 }
253
254 #[test]
255 fn classify_query_by_keywords() {
256 let r = router();
257 let hint = r
258 .classify_query("please implement a function to parse JSON")
259 .unwrap();
260 assert_eq!(hint, "code");
261 }
262
263 #[test]
264 fn classify_query_by_length() {
265 let r = router();
266 let hint = r.classify_query("hello").unwrap();
267 assert_eq!(hint, "fast");
268 }
269
270 #[test]
271 fn classify_query_priority_wins() {
272 let r = router();
273 let hint = r
275 .classify_query(
276 "explain why this function fails and implement a fix for the memory leak issue",
277 )
278 .unwrap();
279 assert_eq!(hint, "reasoning");
280 }
281
282 #[test]
283 fn classify_disabled_returns_none() {
284 let mut r = router();
285 r.classification_enabled = false;
286 assert!(r.classify_query("explain this").is_none());
287 }
288
289 #[test]
290 fn route_query_resolves_end_to_end() {
291 let r = router();
292 let route = r
293 .route_query("implement a function to sort arrays")
294 .unwrap();
295 assert_eq!(route.model, "anthropic/claude-sonnet-4-6");
296 assert_eq!(route.matched_hint, "code");
297 }
298}