1use std::fmt::Write;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6
7use crate::provider::LlmProvider;
8use crate::streaming::StreamBox;
9use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, Message, Role};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ModelTier {
18 Fast,
20 Capable,
22 Advanced,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TaskComplexity {
34 Simple,
36 Moderate,
38 Complex,
40}
41
42impl TaskComplexity {
43 #[must_use]
44 pub const fn recommended_tier(self) -> ModelTier {
45 match self {
46 Self::Simple => ModelTier::Fast,
47 Self::Moderate => ModelTier::Capable,
48 Self::Complex => ModelTier::Advanced,
49 }
50 }
51}
52
53pub struct ModelRouter<C, S, A> {
74 classifier: C,
75 fast: S,
76 capable: S,
77 advanced: A,
78}
79
80impl<C, S, A> ModelRouter<C, S, A>
81where
82 C: LlmProvider,
83 S: LlmProvider,
84 A: LlmProvider,
85{
86 pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
87 Self {
88 classifier,
89 fast,
90 capable,
91 advanced,
92 }
93 }
94
95 pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
98 let classification_prompt = build_classification_prompt(request);
99
100 let classification_request = ChatRequest {
101 system: CLASSIFICATION_SYSTEM.to_owned(),
102 messages: vec![Message::user(classification_prompt)],
103 tools: None,
104 max_tokens: 50,
105 max_tokens_explicit: true,
106 session_id: None,
107 cached_content: None,
108 thinking: None,
109 tool_choice: None,
110 response_format: None,
111 cache: None,
112 };
113
114 match self.classifier.chat(classification_request).await? {
115 ChatOutcome::Success(response) => {
116 let complexity = parse_complexity(&response);
117 log::debug!(
118 "Model router classified request as {:?} using {}",
119 complexity,
120 self.classifier.model()
121 );
122 Ok(complexity)
123 }
124 ChatOutcome::RateLimited(_) => {
125 log::warn!("Classifier rate limited, defaulting to Complex");
126 Ok(TaskComplexity::Complex)
127 }
128 ChatOutcome::InvalidRequest(e) => {
129 log::error!("Classifier invalid request: {e}, defaulting to Complex");
130 Ok(TaskComplexity::Complex)
131 }
132 ChatOutcome::ServerError(e) => {
133 log::error!("Classifier server error: {e}, defaulting to Complex");
134 Ok(TaskComplexity::Complex)
135 }
136 _ => {
139 log::error!("Classifier returned unrecognized outcome, defaulting to Complex");
140 Ok(TaskComplexity::Complex)
141 }
142 }
143 }
144
145 pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
148 let complexity = self.classify(&request).await?;
149 let tier = complexity.recommended_tier();
150
151 log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
152
153 match tier {
154 ModelTier::Fast => self.fast.chat(request).await,
155 ModelTier::Capable => self.capable.chat(request).await,
156 ModelTier::Advanced => self.advanced.chat(request).await,
157 }
158 }
159
160 pub async fn route_with_tier(
163 &self,
164 request: ChatRequest,
165 tier: ModelTier,
166 ) -> Result<ChatOutcome> {
167 match tier {
168 ModelTier::Fast => self.fast.chat(request).await,
169 ModelTier::Capable => self.capable.chat(request).await,
170 ModelTier::Advanced => self.advanced.chat(request).await,
171 }
172 }
173
174 #[must_use]
175 pub const fn fast_provider(&self) -> &S {
176 &self.fast
177 }
178
179 #[must_use]
180 pub const fn capable_provider(&self) -> &S {
181 &self.capable
182 }
183
184 #[must_use]
185 pub const fn advanced_provider(&self) -> &A {
186 &self.advanced
187 }
188}
189
190#[async_trait]
191impl<C, S, A> LlmProvider for ModelRouter<C, S, A>
192where
193 C: LlmProvider,
194 S: LlmProvider,
195 A: LlmProvider,
196{
197 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
198 self.route(request).await
199 }
200
201 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
202 Box::pin(async_stream::stream! {
203 let tier = match self.classify(&request).await {
204 Ok(complexity) => complexity.recommended_tier(),
205 Err(error) => {
206 yield Err(error);
207 return;
208 }
209 };
210 log::info!("Streaming request to {tier:?} tier");
211 let mut stream = match tier {
212 ModelTier::Fast => self.fast.chat_stream(request),
213 ModelTier::Capable => self.capable.chat_stream(request),
214 ModelTier::Advanced => self.advanced.chat_stream(request),
215 };
216 while let Some(item) = stream.next().await {
217 yield item;
218 }
219 })
220 }
221
222 fn model(&self) -> &str {
225 self.capable.model()
226 }
227
228 fn provider(&self) -> &'static str {
231 self.capable.provider()
232 }
233}
234
235const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
236
237SIMPLE tasks:
238- Basic questions with factual answers
239- Simple calculations
240- Direct lookups or retrievals
241- Yes/no questions
242- Single-step operations
243
244MODERATE tasks:
245- Multi-step reasoning
246- Summarization
247- Basic analysis
248- Comparisons
249- Standard tool usage
250
251COMPLEX tasks:
252- Creative writing or content generation
253- Multi-step planning
254- Complex analysis or synthesis
255- Nuanced decisions
256- Tasks requiring deep domain knowledge
257- Financial advice or calculations
258- Multi-tool orchestration
259
260Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
261
262fn build_classification_prompt(request: &ChatRequest) -> String {
263 let mut prompt = String::new();
264
265 prompt.push_str("Classify this task:\n\n");
266
267 if !request.system.is_empty() {
268 prompt.push_str("System context: ");
269 let truncated = truncate_on_char_boundary(&request.system, 200);
270 prompt.push_str(truncated);
271 if truncated.len() < request.system.len() {
272 prompt.push_str("...");
273 }
274 prompt.push_str("\n\n");
275 }
276
277 if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
278 && let Some(text) = last_user_message.content.first_text()
279 {
280 prompt.push_str("User request: ");
281 let truncated = truncate_on_char_boundary(text, 500);
282 prompt.push_str(truncated);
283 if truncated.len() < text.len() {
284 prompt.push_str("...");
285 }
286 }
287
288 if let Some(tools) = &request.tools {
289 let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
290 }
291
292 prompt
293}
294
295fn truncate_on_char_boundary(s: &str, max_bytes: usize) -> &str {
299 if s.len() <= max_bytes {
300 return s;
301 }
302 let mut end = max_bytes;
303 while end > 0 && !s.is_char_boundary(end) {
304 end -= 1;
305 }
306 &s[..end]
307}
308
309fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
310 let text = response.first_text().unwrap_or("").to_uppercase();
311
312 if text.contains("SIMPLE") {
313 TaskComplexity::Simple
314 } else if text.contains("MODERATE") {
315 TaskComplexity::Moderate
316 } else {
317 TaskComplexity::Complex
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn complexity_to_tier() {
327 assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
328 assert_eq!(
329 TaskComplexity::Moderate.recommended_tier(),
330 ModelTier::Capable
331 );
332 assert_eq!(
333 TaskComplexity::Complex.recommended_tier(),
334 ModelTier::Advanced
335 );
336 }
337
338 #[test]
339 fn truncate_on_char_boundary_never_splits_multibyte_char() {
340 let s = "😀😀😀";
344 for max in 0..=s.len() {
345 let truncated = truncate_on_char_boundary(s, max);
346 assert!(s.starts_with(truncated));
348 assert!(truncated.len() <= max);
349 }
350 assert_eq!(truncate_on_char_boundary(s, 4), "😀");
351 assert_eq!(truncate_on_char_boundary(s, 5), "😀");
352 assert_eq!(truncate_on_char_boundary(s, 100), s);
353 }
354
355 #[test]
356 fn build_classification_prompt_handles_multibyte_at_limit() {
357 let system = "é".repeat(150); let request = ChatRequest::new(system, vec![Message::user("日本語".repeat(300))]);
361 let prompt = build_classification_prompt(&request);
363 assert!(prompt.contains("System context:"));
364 assert!(prompt.ends_with("..."));
365 }
366}