1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, LlmProvider, Message, Role};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelTier {
9 Fast,
10 Capable,
11 Advanced,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum TaskComplexity {
16 Simple,
17 Moderate,
18 Complex,
19}
20
21impl TaskComplexity {
22 #[must_use]
23 pub const fn recommended_tier(self) -> ModelTier {
24 match self {
25 Self::Simple => ModelTier::Fast,
26 Self::Moderate => ModelTier::Capable,
27 Self::Complex => ModelTier::Advanced,
28 }
29 }
30}
31
32pub struct ModelRouter<C, S, A> {
33 classifier: C,
34 fast: S,
35 capable: S,
36 advanced: A,
37}
38
39impl<C, S, A> ModelRouter<C, S, A>
40where
41 C: LlmProvider,
42 S: LlmProvider,
43 A: LlmProvider,
44{
45 pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
46 Self {
47 classifier,
48 fast,
49 capable,
50 advanced,
51 }
52 }
53
54 pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
57 let classification_prompt = build_classification_prompt(request);
58
59 let classification_request = ChatRequest {
60 system: CLASSIFICATION_SYSTEM.to_owned(),
61 messages: vec![Message::user(classification_prompt)],
62 tools: None,
63 max_tokens: 50,
64 max_tokens_explicit: true,
65 session_id: None,
66 cached_content: None,
67 thinking: None,
68 };
69
70 match self.classifier.chat(classification_request).await? {
71 ChatOutcome::Success(response) => {
72 let complexity = parse_complexity(&response);
73 log::debug!(
74 "Model router classified request as {:?} using {}",
75 complexity,
76 self.classifier.model()
77 );
78 Ok(complexity)
79 }
80 ChatOutcome::RateLimited => {
81 log::warn!("Classifier rate limited, defaulting to Complex");
82 Ok(TaskComplexity::Complex)
83 }
84 ChatOutcome::InvalidRequest(e) => {
85 log::error!("Classifier invalid request: {e}, defaulting to Complex");
86 Ok(TaskComplexity::Complex)
87 }
88 ChatOutcome::ServerError(e) => {
89 log::error!("Classifier server error: {e}, defaulting to Complex");
90 Ok(TaskComplexity::Complex)
91 }
92 }
93 }
94
95 pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
98 let complexity = self.classify(&request).await?;
99 let tier = complexity.recommended_tier();
100
101 log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
102
103 match tier {
104 ModelTier::Fast => self.fast.chat(request).await,
105 ModelTier::Capable => self.capable.chat(request).await,
106 ModelTier::Advanced => self.advanced.chat(request).await,
107 }
108 }
109
110 pub async fn route_with_tier(
113 &self,
114 request: ChatRequest,
115 tier: ModelTier,
116 ) -> Result<ChatOutcome> {
117 match tier {
118 ModelTier::Fast => self.fast.chat(request).await,
119 ModelTier::Capable => self.capable.chat(request).await,
120 ModelTier::Advanced => self.advanced.chat(request).await,
121 }
122 }
123
124 #[must_use]
125 pub const fn fast_provider(&self) -> &S {
126 &self.fast
127 }
128
129 #[must_use]
130 pub const fn capable_provider(&self) -> &S {
131 &self.capable
132 }
133
134 #[must_use]
135 pub const fn advanced_provider(&self) -> &A {
136 &self.advanced
137 }
138}
139
140const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
141
142SIMPLE tasks:
143- Basic questions with factual answers
144- Simple calculations
145- Direct lookups or retrievals
146- Yes/no questions
147- Single-step operations
148
149MODERATE tasks:
150- Multi-step reasoning
151- Summarization
152- Basic analysis
153- Comparisons
154- Standard tool usage
155
156COMPLEX tasks:
157- Creative writing or content generation
158- Multi-step planning
159- Complex analysis or synthesis
160- Nuanced decisions
161- Tasks requiring deep domain knowledge
162- Financial advice or calculations
163- Multi-tool orchestration
164
165Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
166
167fn build_classification_prompt(request: &ChatRequest) -> String {
168 let mut prompt = String::new();
169
170 prompt.push_str("Classify this task:\n\n");
171
172 if !request.system.is_empty() {
173 prompt.push_str("System context: ");
174 prompt.push_str(&request.system[..request.system.len().min(200)]);
175 if request.system.len() > 200 {
176 prompt.push_str("...");
177 }
178 prompt.push_str("\n\n");
179 }
180
181 if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
182 && let Some(text) = last_user_message.content.first_text()
183 {
184 prompt.push_str("User request: ");
185 prompt.push_str(&text[..text.len().min(500)]);
186 if text.len() > 500 {
187 prompt.push_str("...");
188 }
189 }
190
191 if let Some(tools) = &request.tools {
192 let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
193 }
194
195 prompt
196}
197
198fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
199 let text = response.first_text().unwrap_or("").to_uppercase();
200
201 if text.contains("SIMPLE") {
202 TaskComplexity::Simple
203 } else if text.contains("MODERATE") {
204 TaskComplexity::Moderate
205 } else {
206 TaskComplexity::Complex
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn complexity_to_tier() {
216 assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
217 assert_eq!(
218 TaskComplexity::Moderate.recommended_tier(),
219 ModelTier::Capable
220 );
221 assert_eq!(
222 TaskComplexity::Complex.recommended_tier(),
223 ModelTier::Advanced
224 );
225 }
226}