1use std::fmt::Write;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6
7use crate::provider::LlmProvider;
8use crate::streaming::StreamBox;
9use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, Message, Role};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ModelTier {
18 Fast,
20 Capable,
22 Advanced,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TaskComplexity {
34 Simple,
36 Moderate,
38 Complex,
40}
41
42impl TaskComplexity {
43 #[must_use]
44 pub const fn recommended_tier(self) -> ModelTier {
45 match self {
46 Self::Simple => ModelTier::Fast,
47 Self::Moderate => ModelTier::Capable,
48 Self::Complex => ModelTier::Advanced,
49 }
50 }
51}
52
53pub struct ModelRouter<C, S, A> {
74 classifier: C,
75 fast: S,
76 capable: S,
77 advanced: A,
78}
79
80impl<C, S, A> ModelRouter<C, S, A>
81where
82 C: LlmProvider,
83 S: LlmProvider,
84 A: LlmProvider,
85{
86 pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
87 Self {
88 classifier,
89 fast,
90 capable,
91 advanced,
92 }
93 }
94
95 pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
98 let classification_prompt = build_classification_prompt(request);
99
100 let classification_request = ChatRequest {
101 system: CLASSIFICATION_SYSTEM.to_owned(),
102 messages: vec![Message::user(classification_prompt)],
103 tools: None,
104 max_tokens: 50,
105 max_tokens_explicit: true,
106 session_id: None,
107 cached_content: None,
108 thinking: None,
109 tool_choice: None,
110 response_format: None,
111 };
112
113 match self.classifier.chat(classification_request).await? {
114 ChatOutcome::Success(response) => {
115 let complexity = parse_complexity(&response);
116 log::debug!(
117 "Model router classified request as {:?} using {}",
118 complexity,
119 self.classifier.model()
120 );
121 Ok(complexity)
122 }
123 ChatOutcome::RateLimited => {
124 log::warn!("Classifier rate limited, defaulting to Complex");
125 Ok(TaskComplexity::Complex)
126 }
127 ChatOutcome::InvalidRequest(e) => {
128 log::error!("Classifier invalid request: {e}, defaulting to Complex");
129 Ok(TaskComplexity::Complex)
130 }
131 ChatOutcome::ServerError(e) => {
132 log::error!("Classifier server error: {e}, defaulting to Complex");
133 Ok(TaskComplexity::Complex)
134 }
135 _ => {
138 log::error!("Classifier returned unrecognized outcome, defaulting to Complex");
139 Ok(TaskComplexity::Complex)
140 }
141 }
142 }
143
144 pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
147 let complexity = self.classify(&request).await?;
148 let tier = complexity.recommended_tier();
149
150 log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
151
152 match tier {
153 ModelTier::Fast => self.fast.chat(request).await,
154 ModelTier::Capable => self.capable.chat(request).await,
155 ModelTier::Advanced => self.advanced.chat(request).await,
156 }
157 }
158
159 pub async fn route_with_tier(
162 &self,
163 request: ChatRequest,
164 tier: ModelTier,
165 ) -> Result<ChatOutcome> {
166 match tier {
167 ModelTier::Fast => self.fast.chat(request).await,
168 ModelTier::Capable => self.capable.chat(request).await,
169 ModelTier::Advanced => self.advanced.chat(request).await,
170 }
171 }
172
173 #[must_use]
174 pub const fn fast_provider(&self) -> &S {
175 &self.fast
176 }
177
178 #[must_use]
179 pub const fn capable_provider(&self) -> &S {
180 &self.capable
181 }
182
183 #[must_use]
184 pub const fn advanced_provider(&self) -> &A {
185 &self.advanced
186 }
187}
188
189#[async_trait]
190impl<C, S, A> LlmProvider for ModelRouter<C, S, A>
191where
192 C: LlmProvider,
193 S: LlmProvider,
194 A: LlmProvider,
195{
196 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
197 self.route(request).await
198 }
199
200 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
201 Box::pin(async_stream::stream! {
202 let tier = match self.classify(&request).await {
203 Ok(complexity) => complexity.recommended_tier(),
204 Err(error) => {
205 yield Err(error);
206 return;
207 }
208 };
209 log::info!("Streaming request to {tier:?} tier");
210 let mut stream = match tier {
211 ModelTier::Fast => self.fast.chat_stream(request),
212 ModelTier::Capable => self.capable.chat_stream(request),
213 ModelTier::Advanced => self.advanced.chat_stream(request),
214 };
215 while let Some(item) = stream.next().await {
216 yield item;
217 }
218 })
219 }
220
221 fn model(&self) -> &str {
224 self.capable.model()
225 }
226
227 fn provider(&self) -> &'static str {
230 self.capable.provider()
231 }
232}
233
234const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
235
236SIMPLE tasks:
237- Basic questions with factual answers
238- Simple calculations
239- Direct lookups or retrievals
240- Yes/no questions
241- Single-step operations
242
243MODERATE tasks:
244- Multi-step reasoning
245- Summarization
246- Basic analysis
247- Comparisons
248- Standard tool usage
249
250COMPLEX tasks:
251- Creative writing or content generation
252- Multi-step planning
253- Complex analysis or synthesis
254- Nuanced decisions
255- Tasks requiring deep domain knowledge
256- Financial advice or calculations
257- Multi-tool orchestration
258
259Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
260
261fn build_classification_prompt(request: &ChatRequest) -> String {
262 let mut prompt = String::new();
263
264 prompt.push_str("Classify this task:\n\n");
265
266 if !request.system.is_empty() {
267 prompt.push_str("System context: ");
268 let truncated = truncate_on_char_boundary(&request.system, 200);
269 prompt.push_str(truncated);
270 if truncated.len() < request.system.len() {
271 prompt.push_str("...");
272 }
273 prompt.push_str("\n\n");
274 }
275
276 if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
277 && let Some(text) = last_user_message.content.first_text()
278 {
279 prompt.push_str("User request: ");
280 let truncated = truncate_on_char_boundary(text, 500);
281 prompt.push_str(truncated);
282 if truncated.len() < text.len() {
283 prompt.push_str("...");
284 }
285 }
286
287 if let Some(tools) = &request.tools {
288 let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
289 }
290
291 prompt
292}
293
294fn truncate_on_char_boundary(s: &str, max_bytes: usize) -> &str {
298 if s.len() <= max_bytes {
299 return s;
300 }
301 let mut end = max_bytes;
302 while end > 0 && !s.is_char_boundary(end) {
303 end -= 1;
304 }
305 &s[..end]
306}
307
308fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
309 let text = response.first_text().unwrap_or("").to_uppercase();
310
311 if text.contains("SIMPLE") {
312 TaskComplexity::Simple
313 } else if text.contains("MODERATE") {
314 TaskComplexity::Moderate
315 } else {
316 TaskComplexity::Complex
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn complexity_to_tier() {
326 assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
327 assert_eq!(
328 TaskComplexity::Moderate.recommended_tier(),
329 ModelTier::Capable
330 );
331 assert_eq!(
332 TaskComplexity::Complex.recommended_tier(),
333 ModelTier::Advanced
334 );
335 }
336
337 #[test]
338 fn truncate_on_char_boundary_never_splits_multibyte_char() {
339 let s = "😀😀😀";
343 for max in 0..=s.len() {
344 let truncated = truncate_on_char_boundary(s, max);
345 assert!(s.starts_with(truncated));
347 assert!(truncated.len() <= max);
348 }
349 assert_eq!(truncate_on_char_boundary(s, 4), "😀");
350 assert_eq!(truncate_on_char_boundary(s, 5), "😀");
351 assert_eq!(truncate_on_char_boundary(s, 100), s);
352 }
353
354 #[test]
355 fn build_classification_prompt_handles_multibyte_at_limit() {
356 let system = "é".repeat(150); let request = ChatRequest::new(system, vec![Message::user("日本語".repeat(300))]);
360 let prompt = build_classification_prompt(&request);
362 assert!(prompt.contains("System context:"));
363 assert!(prompt.ends_with("..."));
364 }
365}