1use std::fmt;
11use std::sync::Arc;
12
13use tracing::debug;
14
15use punch_types::config::{ModelConfig, ModelRoutingConfig};
16use punch_types::{ContentPart, Message, PunchResult};
17
18use crate::driver::{LlmDriver, create_driver};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ModelTier {
23 Cheap,
25 Mid,
27 Expensive,
29}
30
31impl fmt::Display for ModelTier {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 match self {
34 Self::Cheap => write!(f, "cheap"),
35 Self::Mid => write!(f, "mid"),
36 Self::Expensive => write!(f, "expensive"),
37 }
38 }
39}
40
41const EXPENSIVE_PATTERNS: &[&str] = &[
43 "analyze",
44 "compare",
45 "summarize",
46 "explain why",
47 "write a",
48 "create a plan",
49 "review",
50 "debug",
51 "what are the pros and cons",
52 "design",
53 "refactor",
54 "architect",
55 "evaluate",
56 "assess",
57 "critique",
58 "optimize",
59 "trade-off",
60 "tradeoff",
61 "strategy",
62 "deep dive",
63];
64
65const TOOL_PATTERNS: &[&str] = &[
67 "check", "calendar", "email", "send", "search", "find", "file", "download", "read", "schedule",
68 "meeting", "remind", "weather", "stock", "price", "open", "run", "execute", "install", "list",
69 "my ", "show me", "look up", "fetch", "get the", "delete", "update", "upload",
70];
71
72pub struct ModelRouter {
75 config: ModelRoutingConfig,
76}
77
78impl ModelRouter {
79 pub fn new(config: ModelRoutingConfig) -> Self {
81 Self { config }
82 }
83
84 pub fn is_enabled(&self) -> bool {
86 self.config.enabled
87 }
88
89 pub fn classify(message: &str) -> ModelTier {
94 let lower = message.to_lowercase();
95
96 if EXPENSIVE_PATTERNS.iter().any(|p| lower.contains(p)) {
98 return ModelTier::Expensive;
99 }
100
101 if TOOL_PATTERNS.iter().any(|p| lower.contains(p)) {
103 return ModelTier::Mid;
104 }
105
106 ModelTier::Cheap
108 }
109
110 pub fn classify_with_context(message: &str, messages: &[Message]) -> ModelTier {
114 let has_images = messages.iter().any(|m| {
116 m.has_images()
117 || m.content_parts
118 .iter()
119 .any(|p| matches!(p, ContentPart::Image { .. }))
120 || m.tool_results.iter().any(|tr| tr.image.is_some())
121 });
122 if has_images {
123 return ModelTier::Expensive;
124 }
125
126 let has_screenshot_output = messages.iter().any(|m| {
128 m.tool_results
129 .iter()
130 .any(|tr| tr.content.contains("png_base64"))
131 });
132 if has_screenshot_output {
133 return ModelTier::Expensive;
134 }
135
136 Self::classify(message)
137 }
138
139 pub fn select_model(&self, tier: ModelTier) -> Option<&ModelConfig> {
142 match tier {
143 ModelTier::Cheap => self.config.cheap.as_ref(),
144 ModelTier::Mid => self.config.mid.as_ref(),
145 ModelTier::Expensive => self.config.expensive.as_ref(),
146 }
147 }
148
149 pub fn route_message(&self, message: &str) -> Option<(ModelTier, ModelConfig)> {
155 self.route_message_with_context(message, &[])
156 }
157
158 pub fn route_message_with_context(
162 &self,
163 message: &str,
164 messages: &[Message],
165 ) -> Option<(ModelTier, ModelConfig)> {
166 if !self.config.enabled {
167 return None;
168 }
169
170 let tier = Self::classify_with_context(message, messages);
171 let model_config = self.select_model(tier)?;
172
173 debug!(
174 tier = %tier,
175 model = %model_config.model,
176 provider = %model_config.provider,
177 "model router selected"
178 );
179
180 Some((tier, model_config.clone()))
181 }
182
183 pub fn create_tier_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
185 create_driver(config)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use punch_types::config::Provider;
193
194 fn make_model_config(model: &str) -> ModelConfig {
195 ModelConfig {
196 provider: Provider::OpenAI,
197 model: model.to_string(),
198 api_key_env: Some("OPENAI_API_KEY".to_string()),
199 base_url: None,
200 max_tokens: Some(4096),
201 temperature: Some(0.7),
202 }
203 }
204
205 fn make_routing_config(enabled: bool) -> ModelRoutingConfig {
206 ModelRoutingConfig {
207 enabled,
208 cheap: Some(make_model_config("gpt-4.1-nano")),
209 mid: Some(make_model_config("gpt-4.1-mini")),
210 expensive: Some(make_model_config("gpt-4.1")),
211 }
212 }
213
214 #[test]
219 fn test_classify_greeting_is_cheap() {
220 assert_eq!(ModelRouter::classify("hello"), ModelTier::Cheap);
221 assert_eq!(ModelRouter::classify("hi there!"), ModelTier::Cheap);
222 assert_eq!(ModelRouter::classify("thanks"), ModelTier::Cheap);
223 assert_eq!(ModelRouter::classify("yes"), ModelTier::Cheap);
224 assert_eq!(ModelRouter::classify("no"), ModelTier::Cheap);
225 assert_eq!(ModelRouter::classify("ok"), ModelTier::Cheap);
226 assert_eq!(ModelRouter::classify("good morning"), ModelTier::Cheap);
227 }
228
229 #[test]
230 fn test_classify_tool_patterns_are_mid() {
231 assert_eq!(ModelRouter::classify("check my email"), ModelTier::Mid);
232 assert_eq!(
233 ModelRouter::classify("search for rust tutorials"),
234 ModelTier::Mid
235 );
236 assert_eq!(ModelRouter::classify("schedule a meeting"), ModelTier::Mid);
237 assert_eq!(ModelRouter::classify("what's the weather"), ModelTier::Mid);
238 assert_eq!(ModelRouter::classify("find the file"), ModelTier::Mid);
239 assert_eq!(ModelRouter::classify("show me my calendar"), ModelTier::Mid);
240 assert_eq!(
241 ModelRouter::classify("send an email to Bob"),
242 ModelTier::Mid
243 );
244 assert_eq!(ModelRouter::classify("download the report"), ModelTier::Mid);
245 assert_eq!(ModelRouter::classify("list all files"), ModelTier::Mid);
246 assert_eq!(ModelRouter::classify("run the tests"), ModelTier::Mid);
247 }
248
249 #[test]
250 fn test_classify_complex_patterns_are_expensive() {
251 assert_eq!(
252 ModelRouter::classify("analyze this data"),
253 ModelTier::Expensive
254 );
255 assert_eq!(
256 ModelRouter::classify("compare React vs Vue"),
257 ModelTier::Expensive
258 );
259 assert_eq!(
260 ModelRouter::classify("summarize the article"),
261 ModelTier::Expensive
262 );
263 assert_eq!(
264 ModelRouter::classify("explain why this fails"),
265 ModelTier::Expensive
266 );
267 assert_eq!(
268 ModelRouter::classify("write a blog post"),
269 ModelTier::Expensive
270 );
271 assert_eq!(
272 ModelRouter::classify("create a plan for migration"),
273 ModelTier::Expensive
274 );
275 assert_eq!(
276 ModelRouter::classify("review this code"),
277 ModelTier::Expensive
278 );
279 assert_eq!(
280 ModelRouter::classify("debug this issue"),
281 ModelTier::Expensive
282 );
283 assert_eq!(
284 ModelRouter::classify("what are the pros and cons of microservices"),
285 ModelTier::Expensive
286 );
287 assert_eq!(
288 ModelRouter::classify("design a REST API"),
289 ModelTier::Expensive
290 );
291 }
292
293 #[test]
294 fn test_classify_is_case_insensitive() {
295 assert_eq!(ModelRouter::classify("ANALYZE this"), ModelTier::Expensive);
296 assert_eq!(ModelRouter::classify("Check My Email"), ModelTier::Mid);
297 assert_eq!(ModelRouter::classify("HELLO"), ModelTier::Cheap);
298 }
299
300 #[test]
301 fn test_expensive_takes_priority_over_mid() {
302 assert_eq!(
304 ModelRouter::classify("review and search the codebase"),
305 ModelTier::Expensive
306 );
307 assert_eq!(
309 ModelRouter::classify("find and analyze the logs"),
310 ModelTier::Expensive
311 );
312 }
313
314 #[test]
319 fn test_select_model_returns_correct_tier() {
320 let router = ModelRouter::new(make_routing_config(true));
321
322 let cheap = router.select_model(ModelTier::Cheap).unwrap();
323 assert_eq!(cheap.model, "gpt-4.1-nano");
324
325 let mid = router.select_model(ModelTier::Mid).unwrap();
326 assert_eq!(mid.model, "gpt-4.1-mini");
327
328 let expensive = router.select_model(ModelTier::Expensive).unwrap();
329 assert_eq!(expensive.model, "gpt-4.1");
330 }
331
332 #[test]
333 fn test_select_model_returns_none_when_not_configured() {
334 let config = ModelRoutingConfig {
335 enabled: true,
336 cheap: Some(make_model_config("gpt-4.1-nano")),
337 mid: None,
338 expensive: None,
339 };
340 let router = ModelRouter::new(config);
341
342 assert!(router.select_model(ModelTier::Cheap).is_some());
343 assert!(router.select_model(ModelTier::Mid).is_none());
344 assert!(router.select_model(ModelTier::Expensive).is_none());
345 }
346
347 #[test]
348 fn test_route_message_disabled() {
349 let router = ModelRouter::new(make_routing_config(false));
350 assert!(router.route_message("analyze this").is_none());
351 }
352
353 #[test]
354 fn test_route_message_enabled() {
355 let router = ModelRouter::new(make_routing_config(true));
356
357 let (tier, config) = router.route_message("hello").unwrap();
358 assert_eq!(tier, ModelTier::Cheap);
359 assert_eq!(config.model, "gpt-4.1-nano");
360
361 let (tier, config) = router.route_message("check my email").unwrap();
362 assert_eq!(tier, ModelTier::Mid);
363 assert_eq!(config.model, "gpt-4.1-mini");
364
365 let (tier, config) = router.route_message("analyze the data").unwrap();
366 assert_eq!(tier, ModelTier::Expensive);
367 assert_eq!(config.model, "gpt-4.1");
368 }
369
370 #[test]
371 fn test_route_message_falls_back_when_tier_missing() {
372 let config = ModelRoutingConfig {
373 enabled: true,
374 cheap: None,
375 mid: Some(make_model_config("gpt-4.1-mini")),
376 expensive: None,
377 };
378 let router = ModelRouter::new(config);
379
380 assert!(router.route_message("hello").is_none());
382
383 let result = router.route_message("search for files");
385 assert!(result.is_some());
386
387 assert!(router.route_message("analyze this").is_none());
389 }
390
391 #[test]
392 fn test_model_tier_display() {
393 assert_eq!(ModelTier::Cheap.to_string(), "cheap");
394 assert_eq!(ModelTier::Mid.to_string(), "mid");
395 assert_eq!(ModelTier::Expensive.to_string(), "expensive");
396 }
397
398 #[test]
399 fn test_default_routing_config_is_disabled() {
400 let config = ModelRoutingConfig::default();
401 assert!(!config.enabled);
402 assert!(config.cheap.is_none());
403 assert!(config.mid.is_none());
404 assert!(config.expensive.is_none());
405 }
406
407 #[test]
412 fn test_classify_with_context_no_images_is_normal() {
413 let messages = vec![Message::new(punch_types::Role::User, "hello")];
414 assert_eq!(
415 ModelRouter::classify_with_context("hello", &messages),
416 ModelTier::Cheap
417 );
418 }
419
420 #[test]
421 fn test_classify_with_context_image_forces_expensive() {
422 let msg = Message::with_parts(
423 punch_types::Role::User,
424 "What's in this image?",
425 vec![ContentPart::Image {
426 media_type: "image/png".to_string(),
427 data: "base64data".to_string(),
428 }],
429 );
430 let messages = vec![msg];
431 assert_eq!(
433 ModelRouter::classify_with_context("hello", &messages),
434 ModelTier::Expensive
435 );
436 }
437
438 #[test]
439 fn test_classify_with_context_tool_result_image_forces_expensive() {
440 let mut msg = Message::new(punch_types::Role::Tool, "");
441 msg.tool_results = vec![punch_types::ToolCallResult {
442 id: "tc1".to_string(),
443 content: "screenshot taken".to_string(),
444 is_error: false,
445 image: Some(ContentPart::Image {
446 media_type: "image/png".to_string(),
447 data: "base64data".to_string(),
448 }),
449 }];
450 let messages = vec![msg];
451 assert_eq!(
452 ModelRouter::classify_with_context("ok", &messages),
453 ModelTier::Expensive
454 );
455 }
456
457 #[test]
458 fn test_classify_with_context_png_base64_in_content() {
459 let mut msg = Message::new(punch_types::Role::Tool, "");
460 msg.tool_results = vec![punch_types::ToolCallResult {
461 id: "tc1".to_string(),
462 content: r#"{"png_base64": "iVBORw0KGgo=", "width": 1920}"#.to_string(),
463 is_error: false,
464 image: None,
465 }];
466 let messages = vec![msg];
467 assert_eq!(
468 ModelRouter::classify_with_context("ok", &messages),
469 ModelTier::Expensive
470 );
471 }
472}