1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, LlmProvider, Message, Role};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelTier {
9 Fast,
10 Capable,
11 Advanced,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum TaskComplexity {
16 Simple,
17 Moderate,
18 Complex,
19}
20
21impl TaskComplexity {
22 #[must_use]
23 pub const fn recommended_tier(self) -> ModelTier {
24 match self {
25 Self::Simple => ModelTier::Fast,
26 Self::Moderate => ModelTier::Capable,
27 Self::Complex => ModelTier::Advanced,
28 }
29 }
30}
31
32pub struct ModelRouter<C, S, A> {
33 classifier: C,
34 fast: S,
35 capable: S,
36 advanced: A,
37}
38
39impl<C, S, A> ModelRouter<C, S, A>
40where
41 C: LlmProvider,
42 S: LlmProvider,
43 A: LlmProvider,
44{
45 pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
46 Self {
47 classifier,
48 fast,
49 capable,
50 advanced,
51 }
52 }
53
54 pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
57 let classification_prompt = build_classification_prompt(request);
58
59 let classification_request = ChatRequest {
60 system: CLASSIFICATION_SYSTEM.to_owned(),
61 messages: vec![Message::user(classification_prompt)],
62 tools: None,
63 max_tokens: 50,
64 };
65
66 match self.classifier.chat(classification_request).await? {
67 ChatOutcome::Success(response) => {
68 let complexity = parse_complexity(&response);
69 log::debug!(
70 "Model router classified request as {:?} using {}",
71 complexity,
72 self.classifier.model()
73 );
74 Ok(complexity)
75 }
76 ChatOutcome::RateLimited => {
77 log::warn!("Classifier rate limited, defaulting to Complex");
78 Ok(TaskComplexity::Complex)
79 }
80 ChatOutcome::InvalidRequest(e) => {
81 log::error!("Classifier invalid request: {e}, defaulting to Complex");
82 Ok(TaskComplexity::Complex)
83 }
84 ChatOutcome::ServerError(e) => {
85 log::error!("Classifier server error: {e}, defaulting to Complex");
86 Ok(TaskComplexity::Complex)
87 }
88 }
89 }
90
91 pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
94 let complexity = self.classify(&request).await?;
95 let tier = complexity.recommended_tier();
96
97 log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
98
99 match tier {
100 ModelTier::Fast => self.fast.chat(request).await,
101 ModelTier::Capable => self.capable.chat(request).await,
102 ModelTier::Advanced => self.advanced.chat(request).await,
103 }
104 }
105
106 pub async fn route_with_tier(
109 &self,
110 request: ChatRequest,
111 tier: ModelTier,
112 ) -> Result<ChatOutcome> {
113 match tier {
114 ModelTier::Fast => self.fast.chat(request).await,
115 ModelTier::Capable => self.capable.chat(request).await,
116 ModelTier::Advanced => self.advanced.chat(request).await,
117 }
118 }
119
120 #[must_use]
121 pub const fn fast_provider(&self) -> &S {
122 &self.fast
123 }
124
125 #[must_use]
126 pub const fn capable_provider(&self) -> &S {
127 &self.capable
128 }
129
130 #[must_use]
131 pub const fn advanced_provider(&self) -> &A {
132 &self.advanced
133 }
134}
135
136const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
137
138SIMPLE tasks:
139- Basic questions with factual answers
140- Simple calculations
141- Direct lookups or retrievals
142- Yes/no questions
143- Single-step operations
144
145MODERATE tasks:
146- Multi-step reasoning
147- Summarization
148- Basic analysis
149- Comparisons
150- Standard tool usage
151
152COMPLEX tasks:
153- Creative writing or content generation
154- Multi-step planning
155- Complex analysis or synthesis
156- Nuanced decisions
157- Tasks requiring deep domain knowledge
158- Financial advice or calculations
159- Multi-tool orchestration
160
161Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
162
163fn build_classification_prompt(request: &ChatRequest) -> String {
164 let mut prompt = String::new();
165
166 prompt.push_str("Classify this task:\n\n");
167
168 if !request.system.is_empty() {
169 prompt.push_str("System context: ");
170 prompt.push_str(&request.system[..request.system.len().min(200)]);
171 if request.system.len() > 200 {
172 prompt.push_str("...");
173 }
174 prompt.push_str("\n\n");
175 }
176
177 if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
178 && let Some(text) = last_user_message.content.first_text()
179 {
180 prompt.push_str("User request: ");
181 prompt.push_str(&text[..text.len().min(500)]);
182 if text.len() > 500 {
183 prompt.push_str("...");
184 }
185 }
186
187 if let Some(tools) = &request.tools {
188 let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
189 }
190
191 prompt
192}
193
194fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
195 let text = response.first_text().unwrap_or("").to_uppercase();
196
197 if text.contains("SIMPLE") {
198 TaskComplexity::Simple
199 } else if text.contains("MODERATE") {
200 TaskComplexity::Moderate
201 } else {
202 TaskComplexity::Complex
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn complexity_to_tier() {
212 assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
213 assert_eq!(
214 TaskComplexity::Moderate.recommended_tier(),
215 ModelTier::Capable
216 );
217 assert_eq!(
218 TaskComplexity::Complex.recommended_tier(),
219 ModelTier::Advanced
220 );
221 }
222}