1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, LlmProvider, Message, Role};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelTier {
9 Fast,
10 Capable,
11 Advanced,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum TaskComplexity {
16 Simple,
17 Moderate,
18 Complex,
19}
20
21impl TaskComplexity {
22 #[must_use]
23 pub const fn recommended_tier(self) -> ModelTier {
24 match self {
25 Self::Simple => ModelTier::Fast,
26 Self::Moderate => ModelTier::Capable,
27 Self::Complex => ModelTier::Advanced,
28 }
29 }
30}
31
32pub struct ModelRouter<C, S, A> {
33 classifier: C,
34 fast: S,
35 capable: S,
36 advanced: A,
37}
38
39impl<C, S, A> ModelRouter<C, S, A>
40where
41 C: LlmProvider,
42 S: LlmProvider,
43 A: LlmProvider,
44{
45 pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
46 Self {
47 classifier,
48 fast,
49 capable,
50 advanced,
51 }
52 }
53
54 pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
57 let classification_prompt = build_classification_prompt(request);
58
59 let classification_request = ChatRequest {
60 system: CLASSIFICATION_SYSTEM.to_owned(),
61 messages: vec![Message::user(classification_prompt)],
62 tools: None,
63 max_tokens: 50,
64 thinking: None,
65 };
66
67 match self.classifier.chat(classification_request).await? {
68 ChatOutcome::Success(response) => {
69 let complexity = parse_complexity(&response);
70 log::debug!(
71 "Model router classified request as {:?} using {}",
72 complexity,
73 self.classifier.model()
74 );
75 Ok(complexity)
76 }
77 ChatOutcome::RateLimited => {
78 log::warn!("Classifier rate limited, defaulting to Complex");
79 Ok(TaskComplexity::Complex)
80 }
81 ChatOutcome::InvalidRequest(e) => {
82 log::error!("Classifier invalid request: {e}, defaulting to Complex");
83 Ok(TaskComplexity::Complex)
84 }
85 ChatOutcome::ServerError(e) => {
86 log::error!("Classifier server error: {e}, defaulting to Complex");
87 Ok(TaskComplexity::Complex)
88 }
89 }
90 }
91
92 pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
95 let complexity = self.classify(&request).await?;
96 let tier = complexity.recommended_tier();
97
98 log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
99
100 match tier {
101 ModelTier::Fast => self.fast.chat(request).await,
102 ModelTier::Capable => self.capable.chat(request).await,
103 ModelTier::Advanced => self.advanced.chat(request).await,
104 }
105 }
106
107 pub async fn route_with_tier(
110 &self,
111 request: ChatRequest,
112 tier: ModelTier,
113 ) -> Result<ChatOutcome> {
114 match tier {
115 ModelTier::Fast => self.fast.chat(request).await,
116 ModelTier::Capable => self.capable.chat(request).await,
117 ModelTier::Advanced => self.advanced.chat(request).await,
118 }
119 }
120
121 #[must_use]
122 pub const fn fast_provider(&self) -> &S {
123 &self.fast
124 }
125
126 #[must_use]
127 pub const fn capable_provider(&self) -> &S {
128 &self.capable
129 }
130
131 #[must_use]
132 pub const fn advanced_provider(&self) -> &A {
133 &self.advanced
134 }
135}
136
137const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
138
139SIMPLE tasks:
140- Basic questions with factual answers
141- Simple calculations
142- Direct lookups or retrievals
143- Yes/no questions
144- Single-step operations
145
146MODERATE tasks:
147- Multi-step reasoning
148- Summarization
149- Basic analysis
150- Comparisons
151- Standard tool usage
152
153COMPLEX tasks:
154- Creative writing or content generation
155- Multi-step planning
156- Complex analysis or synthesis
157- Nuanced decisions
158- Tasks requiring deep domain knowledge
159- Financial advice or calculations
160- Multi-tool orchestration
161
162Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
163
164fn build_classification_prompt(request: &ChatRequest) -> String {
165 let mut prompt = String::new();
166
167 prompt.push_str("Classify this task:\n\n");
168
169 if !request.system.is_empty() {
170 prompt.push_str("System context: ");
171 prompt.push_str(&request.system[..request.system.len().min(200)]);
172 if request.system.len() > 200 {
173 prompt.push_str("...");
174 }
175 prompt.push_str("\n\n");
176 }
177
178 if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
179 && let Some(text) = last_user_message.content.first_text()
180 {
181 prompt.push_str("User request: ");
182 prompt.push_str(&text[..text.len().min(500)]);
183 if text.len() > 500 {
184 prompt.push_str("...");
185 }
186 }
187
188 if let Some(tools) = &request.tools {
189 let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
190 }
191
192 prompt
193}
194
195fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
196 let text = response.first_text().unwrap_or("").to_uppercase();
197
198 if text.contains("SIMPLE") {
199 TaskComplexity::Simple
200 } else if text.contains("MODERATE") {
201 TaskComplexity::Moderate
202 } else {
203 TaskComplexity::Complex
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn complexity_to_tier() {
213 assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
214 assert_eq!(
215 TaskComplexity::Moderate.recommended_tier(),
216 ModelTier::Capable
217 );
218 assert_eq!(
219 TaskComplexity::Complex.recommended_tier(),
220 ModelTier::Advanced
221 );
222 }
223}