Skip to main content

punch_runtime/
model_router.rs

1//! Smart model routing based on query complexity.
2//!
3//! Classifies user messages into tiers using keyword heuristics (no LLM call
4//! required) and selects the appropriate model configuration for each tier.
5//!
6//! - **Cheap**: Simple greetings, yes/no, short acknowledgements. Nano models.
7//! - **Mid**: Tool-calling messages (search, email, calendar, etc.).
8//! - **Expensive**: Complex reasoning (analysis, comparison, code review, etc.).
9
10use std::fmt;
11use std::sync::Arc;
12
13use tracing::debug;
14
15use punch_types::config::{ModelConfig, ModelRoutingConfig};
16use punch_types::{ContentPart, Message, PunchResult};
17
18use crate::driver::{LlmDriver, create_driver};
19
20/// The complexity tier for a user message.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ModelTier {
23    /// Simple responses, no tools needed. Suitable for nano models.
24    Cheap,
25    /// Tool calling required. Needs a model that reliably generates tool calls.
26    Mid,
27    /// Complex multi-step reasoning. Benefits from a frontier model.
28    Expensive,
29}
30
31impl fmt::Display for ModelTier {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::Cheap => write!(f, "cheap"),
35            Self::Mid => write!(f, "mid"),
36            Self::Expensive => write!(f, "expensive"),
37        }
38    }
39}
40
41/// Keyword patterns that indicate expensive (complex reasoning) queries.
42const EXPENSIVE_PATTERNS: &[&str] = &[
43    "analyze",
44    "compare",
45    "summarize",
46    "explain why",
47    "write a",
48    "create a plan",
49    "review",
50    "debug",
51    "what are the pros and cons",
52    "design",
53    "refactor",
54    "architect",
55    "evaluate",
56    "assess",
57    "critique",
58    "optimize",
59    "trade-off",
60    "tradeoff",
61    "strategy",
62    "deep dive",
63];
64
65/// Keyword patterns that indicate mid-tier (tool-calling) queries.
66const TOOL_PATTERNS: &[&str] = &[
67    "check", "calendar", "email", "send", "search", "find", "file", "download", "read", "schedule",
68    "meeting", "remind", "weather", "stock", "price", "open", "run", "execute", "install", "list",
69    "my ", "show me", "look up", "fetch", "get the", "delete", "update", "upload",
70];
71
72/// Smart model router that picks cheap / mid / expensive models based on
73/// message complexity.
74pub struct ModelRouter {
75    config: ModelRoutingConfig,
76}
77
78impl ModelRouter {
79    /// Create a new router from the routing configuration.
80    pub fn new(config: ModelRoutingConfig) -> Self {
81        Self { config }
82    }
83
84    /// Returns `true` if model routing is enabled.
85    pub fn is_enabled(&self) -> bool {
86        self.config.enabled
87    }
88
89    /// Classify a user message into a complexity tier using keyword heuristics.
90    ///
91    /// The classification is intentionally simple: pattern matching on the
92    /// lowercased message. No LLM call is made.
93    pub fn classify(message: &str) -> ModelTier {
94        let lower = message.to_lowercase();
95
96        // Check expensive patterns first (they take priority).
97        if EXPENSIVE_PATTERNS.iter().any(|p| lower.contains(p)) {
98            return ModelTier::Expensive;
99        }
100
101        // Check tool-calling patterns.
102        if TOOL_PATTERNS.iter().any(|p| lower.contains(p)) {
103            return ModelTier::Mid;
104        }
105
106        // Default: simple message, use cheap model.
107        ModelTier::Cheap
108    }
109
110    /// Classify a user message with context awareness: if the conversation
111    /// contains images (from screenshots, Telegram photos, etc.), force the
112    /// expensive tier so a vision-capable model handles them.
113    pub fn classify_with_context(message: &str, messages: &[Message]) -> ModelTier {
114        // If any message in the conversation has an image, force expensive tier.
115        let has_images = messages.iter().any(|m| {
116            m.has_images()
117                || m.content_parts
118                    .iter()
119                    .any(|p| matches!(p, ContentPart::Image { .. }))
120                || m.tool_results.iter().any(|tr| tr.image.is_some())
121        });
122        if has_images {
123            return ModelTier::Expensive;
124        }
125
126        // Also check tool results for png_base64 field (screenshot output).
127        let has_screenshot_output = messages.iter().any(|m| {
128            m.tool_results
129                .iter()
130                .any(|tr| tr.content.contains("png_base64"))
131        });
132        if has_screenshot_output {
133            return ModelTier::Expensive;
134        }
135
136        Self::classify(message)
137    }
138
139    /// Select the model config for a given tier. Returns `None` if the tier
140    /// has no model configured (caller should fall back to the default model).
141    pub fn select_model(&self, tier: ModelTier) -> Option<&ModelConfig> {
142        match tier {
143            ModelTier::Cheap => self.config.cheap.as_ref(),
144            ModelTier::Mid => self.config.mid.as_ref(),
145            ModelTier::Expensive => self.config.expensive.as_ref(),
146        }
147    }
148
149    /// Classify a message and create a driver for the selected tier.
150    ///
151    /// Returns `Some((tier, driver))` if routing is enabled and a tier-specific
152    /// model is configured. Returns `None` if routing is disabled or the tier
153    /// model is not configured (the caller should use the default driver).
154    pub fn route_message(&self, message: &str) -> Option<(ModelTier, ModelConfig)> {
155        self.route_message_with_context(message, &[])
156    }
157
158    /// Classify a message with conversation context and return the tier-specific
159    /// model config. This enables image-aware routing where conversations containing
160    /// screenshots or photos automatically escalate to vision-capable models.
161    pub fn route_message_with_context(
162        &self,
163        message: &str,
164        messages: &[Message],
165    ) -> Option<(ModelTier, ModelConfig)> {
166        if !self.config.enabled {
167            return None;
168        }
169
170        let tier = Self::classify_with_context(message, messages);
171        let model_config = self.select_model(tier)?;
172
173        debug!(
174            tier = %tier,
175            model = %model_config.model,
176            provider = %model_config.provider,
177            "model router selected"
178        );
179
180        Some((tier, model_config.clone()))
181    }
182
183    /// Create an LLM driver for a routed model config.
184    pub fn create_tier_driver(config: &ModelConfig) -> PunchResult<Arc<dyn LlmDriver>> {
185        create_driver(config)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use punch_types::config::Provider;
193
194    fn make_model_config(model: &str) -> ModelConfig {
195        ModelConfig {
196            provider: Provider::OpenAI,
197            model: model.to_string(),
198            api_key_env: Some("OPENAI_API_KEY".to_string()),
199            base_url: None,
200            max_tokens: Some(4096),
201            temperature: Some(0.7),
202        }
203    }
204
205    fn make_routing_config(enabled: bool) -> ModelRoutingConfig {
206        ModelRoutingConfig {
207            enabled,
208            cheap: Some(make_model_config("gpt-4.1-nano")),
209            mid: Some(make_model_config("gpt-4.1-mini")),
210            expensive: Some(make_model_config("gpt-4.1")),
211        }
212    }
213
214    // -----------------------------------------------------------------------
215    // Classification tests
216    // -----------------------------------------------------------------------
217
218    #[test]
219    fn test_classify_greeting_is_cheap() {
220        assert_eq!(ModelRouter::classify("hello"), ModelTier::Cheap);
221        assert_eq!(ModelRouter::classify("hi there!"), ModelTier::Cheap);
222        assert_eq!(ModelRouter::classify("thanks"), ModelTier::Cheap);
223        assert_eq!(ModelRouter::classify("yes"), ModelTier::Cheap);
224        assert_eq!(ModelRouter::classify("no"), ModelTier::Cheap);
225        assert_eq!(ModelRouter::classify("ok"), ModelTier::Cheap);
226        assert_eq!(ModelRouter::classify("good morning"), ModelTier::Cheap);
227    }
228
229    #[test]
230    fn test_classify_tool_patterns_are_mid() {
231        assert_eq!(ModelRouter::classify("check my email"), ModelTier::Mid);
232        assert_eq!(
233            ModelRouter::classify("search for rust tutorials"),
234            ModelTier::Mid
235        );
236        assert_eq!(ModelRouter::classify("schedule a meeting"), ModelTier::Mid);
237        assert_eq!(ModelRouter::classify("what's the weather"), ModelTier::Mid);
238        assert_eq!(ModelRouter::classify("find the file"), ModelTier::Mid);
239        assert_eq!(ModelRouter::classify("show me my calendar"), ModelTier::Mid);
240        assert_eq!(
241            ModelRouter::classify("send an email to Bob"),
242            ModelTier::Mid
243        );
244        assert_eq!(ModelRouter::classify("download the report"), ModelTier::Mid);
245        assert_eq!(ModelRouter::classify("list all files"), ModelTier::Mid);
246        assert_eq!(ModelRouter::classify("run the tests"), ModelTier::Mid);
247    }
248
249    #[test]
250    fn test_classify_complex_patterns_are_expensive() {
251        assert_eq!(
252            ModelRouter::classify("analyze this data"),
253            ModelTier::Expensive
254        );
255        assert_eq!(
256            ModelRouter::classify("compare React vs Vue"),
257            ModelTier::Expensive
258        );
259        assert_eq!(
260            ModelRouter::classify("summarize the article"),
261            ModelTier::Expensive
262        );
263        assert_eq!(
264            ModelRouter::classify("explain why this fails"),
265            ModelTier::Expensive
266        );
267        assert_eq!(
268            ModelRouter::classify("write a blog post"),
269            ModelTier::Expensive
270        );
271        assert_eq!(
272            ModelRouter::classify("create a plan for migration"),
273            ModelTier::Expensive
274        );
275        assert_eq!(
276            ModelRouter::classify("review this code"),
277            ModelTier::Expensive
278        );
279        assert_eq!(
280            ModelRouter::classify("debug this issue"),
281            ModelTier::Expensive
282        );
283        assert_eq!(
284            ModelRouter::classify("what are the pros and cons of microservices"),
285            ModelTier::Expensive
286        );
287        assert_eq!(
288            ModelRouter::classify("design a REST API"),
289            ModelTier::Expensive
290        );
291    }
292
293    #[test]
294    fn test_classify_is_case_insensitive() {
295        assert_eq!(ModelRouter::classify("ANALYZE this"), ModelTier::Expensive);
296        assert_eq!(ModelRouter::classify("Check My Email"), ModelTier::Mid);
297        assert_eq!(ModelRouter::classify("HELLO"), ModelTier::Cheap);
298    }
299
300    #[test]
301    fn test_expensive_takes_priority_over_mid() {
302        // "review" is expensive, "search" is mid — expensive should win.
303        assert_eq!(
304            ModelRouter::classify("review and search the codebase"),
305            ModelTier::Expensive
306        );
307        // "analyze" is expensive, "find" is mid — expensive should win.
308        assert_eq!(
309            ModelRouter::classify("find and analyze the logs"),
310            ModelTier::Expensive
311        );
312    }
313
314    // -----------------------------------------------------------------------
315    // Router selection tests
316    // -----------------------------------------------------------------------
317
318    #[test]
319    fn test_select_model_returns_correct_tier() {
320        let router = ModelRouter::new(make_routing_config(true));
321
322        let cheap = router.select_model(ModelTier::Cheap).unwrap();
323        assert_eq!(cheap.model, "gpt-4.1-nano");
324
325        let mid = router.select_model(ModelTier::Mid).unwrap();
326        assert_eq!(mid.model, "gpt-4.1-mini");
327
328        let expensive = router.select_model(ModelTier::Expensive).unwrap();
329        assert_eq!(expensive.model, "gpt-4.1");
330    }
331
332    #[test]
333    fn test_select_model_returns_none_when_not_configured() {
334        let config = ModelRoutingConfig {
335            enabled: true,
336            cheap: Some(make_model_config("gpt-4.1-nano")),
337            mid: None,
338            expensive: None,
339        };
340        let router = ModelRouter::new(config);
341
342        assert!(router.select_model(ModelTier::Cheap).is_some());
343        assert!(router.select_model(ModelTier::Mid).is_none());
344        assert!(router.select_model(ModelTier::Expensive).is_none());
345    }
346
347    #[test]
348    fn test_route_message_disabled() {
349        let router = ModelRouter::new(make_routing_config(false));
350        assert!(router.route_message("analyze this").is_none());
351    }
352
353    #[test]
354    fn test_route_message_enabled() {
355        let router = ModelRouter::new(make_routing_config(true));
356
357        let (tier, config) = router.route_message("hello").unwrap();
358        assert_eq!(tier, ModelTier::Cheap);
359        assert_eq!(config.model, "gpt-4.1-nano");
360
361        let (tier, config) = router.route_message("check my email").unwrap();
362        assert_eq!(tier, ModelTier::Mid);
363        assert_eq!(config.model, "gpt-4.1-mini");
364
365        let (tier, config) = router.route_message("analyze the data").unwrap();
366        assert_eq!(tier, ModelTier::Expensive);
367        assert_eq!(config.model, "gpt-4.1");
368    }
369
370    #[test]
371    fn test_route_message_falls_back_when_tier_missing() {
372        let config = ModelRoutingConfig {
373            enabled: true,
374            cheap: None,
375            mid: Some(make_model_config("gpt-4.1-mini")),
376            expensive: None,
377        };
378        let router = ModelRouter::new(config);
379
380        // Cheap tier not configured — returns None (caller uses default).
381        assert!(router.route_message("hello").is_none());
382
383        // Mid tier configured — returns Some.
384        let result = router.route_message("search for files");
385        assert!(result.is_some());
386
387        // Expensive tier not configured — returns None.
388        assert!(router.route_message("analyze this").is_none());
389    }
390
391    #[test]
392    fn test_model_tier_display() {
393        assert_eq!(ModelTier::Cheap.to_string(), "cheap");
394        assert_eq!(ModelTier::Mid.to_string(), "mid");
395        assert_eq!(ModelTier::Expensive.to_string(), "expensive");
396    }
397
398    #[test]
399    fn test_default_routing_config_is_disabled() {
400        let config = ModelRoutingConfig::default();
401        assert!(!config.enabled);
402        assert!(config.cheap.is_none());
403        assert!(config.mid.is_none());
404        assert!(config.expensive.is_none());
405    }
406
407    // -----------------------------------------------------------------------
408    // Image detection tests (classify_with_context)
409    // -----------------------------------------------------------------------
410
411    #[test]
412    fn test_classify_with_context_no_images_is_normal() {
413        let messages = vec![Message::new(punch_types::Role::User, "hello")];
414        assert_eq!(
415            ModelRouter::classify_with_context("hello", &messages),
416            ModelTier::Cheap
417        );
418    }
419
420    #[test]
421    fn test_classify_with_context_image_forces_expensive() {
422        let msg = Message::with_parts(
423            punch_types::Role::User,
424            "What's in this image?",
425            vec![ContentPart::Image {
426                media_type: "image/png".to_string(),
427                data: "base64data".to_string(),
428            }],
429        );
430        let messages = vec![msg];
431        // Even though "hello" would be Cheap, image presence forces Expensive.
432        assert_eq!(
433            ModelRouter::classify_with_context("hello", &messages),
434            ModelTier::Expensive
435        );
436    }
437
438    #[test]
439    fn test_classify_with_context_tool_result_image_forces_expensive() {
440        let mut msg = Message::new(punch_types::Role::Tool, "");
441        msg.tool_results = vec![punch_types::ToolCallResult {
442            id: "tc1".to_string(),
443            content: "screenshot taken".to_string(),
444            is_error: false,
445            image: Some(ContentPart::Image {
446                media_type: "image/png".to_string(),
447                data: "base64data".to_string(),
448            }),
449        }];
450        let messages = vec![msg];
451        assert_eq!(
452            ModelRouter::classify_with_context("ok", &messages),
453            ModelTier::Expensive
454        );
455    }
456
457    #[test]
458    fn test_classify_with_context_png_base64_in_content() {
459        let mut msg = Message::new(punch_types::Role::Tool, "");
460        msg.tool_results = vec![punch_types::ToolCallResult {
461            id: "tc1".to_string(),
462            content: r#"{"png_base64": "iVBORw0KGgo=", "width": 1920}"#.to_string(),
463            is_error: false,
464            image: None,
465        }];
466        let messages = vec![msg];
467        assert_eq!(
468            ModelRouter::classify_with_context("ok", &messages),
469            ModelTier::Expensive
470        );
471    }
472}