agent_sdk_providers/
router.rs1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::provider::LlmProvider;
6use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, Message, Role};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ModelTier {
10 Fast,
11 Capable,
12 Advanced,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TaskComplexity {
17 Simple,
18 Moderate,
19 Complex,
20}
21
22impl TaskComplexity {
23 #[must_use]
24 pub const fn recommended_tier(self) -> ModelTier {
25 match self {
26 Self::Simple => ModelTier::Fast,
27 Self::Moderate => ModelTier::Capable,
28 Self::Complex => ModelTier::Advanced,
29 }
30 }
31}
32
33pub struct ModelRouter<C, S, A> {
34 classifier: C,
35 fast: S,
36 capable: S,
37 advanced: A,
38}
39
40impl<C, S, A> ModelRouter<C, S, A>
41where
42 C: LlmProvider,
43 S: LlmProvider,
44 A: LlmProvider,
45{
46 pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
47 Self {
48 classifier,
49 fast,
50 capable,
51 advanced,
52 }
53 }
54
55 pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
58 let classification_prompt = build_classification_prompt(request);
59
60 let classification_request = ChatRequest {
61 system: CLASSIFICATION_SYSTEM.to_owned(),
62 messages: vec![Message::user(classification_prompt)],
63 tools: None,
64 max_tokens: 50,
65 max_tokens_explicit: true,
66 session_id: None,
67 cached_content: None,
68 thinking: None,
69 tool_choice: None,
70 response_format: None,
71 };
72
73 match self.classifier.chat(classification_request).await? {
74 ChatOutcome::Success(response) => {
75 let complexity = parse_complexity(&response);
76 log::debug!(
77 "Model router classified request as {:?} using {}",
78 complexity,
79 self.classifier.model()
80 );
81 Ok(complexity)
82 }
83 ChatOutcome::RateLimited => {
84 log::warn!("Classifier rate limited, defaulting to Complex");
85 Ok(TaskComplexity::Complex)
86 }
87 ChatOutcome::InvalidRequest(e) => {
88 log::error!("Classifier invalid request: {e}, defaulting to Complex");
89 Ok(TaskComplexity::Complex)
90 }
91 ChatOutcome::ServerError(e) => {
92 log::error!("Classifier server error: {e}, defaulting to Complex");
93 Ok(TaskComplexity::Complex)
94 }
95 _ => {
98 log::error!("Classifier returned unrecognized outcome, defaulting to Complex");
99 Ok(TaskComplexity::Complex)
100 }
101 }
102 }
103
104 pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
107 let complexity = self.classify(&request).await?;
108 let tier = complexity.recommended_tier();
109
110 log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
111
112 match tier {
113 ModelTier::Fast => self.fast.chat(request).await,
114 ModelTier::Capable => self.capable.chat(request).await,
115 ModelTier::Advanced => self.advanced.chat(request).await,
116 }
117 }
118
119 pub async fn route_with_tier(
122 &self,
123 request: ChatRequest,
124 tier: ModelTier,
125 ) -> Result<ChatOutcome> {
126 match tier {
127 ModelTier::Fast => self.fast.chat(request).await,
128 ModelTier::Capable => self.capable.chat(request).await,
129 ModelTier::Advanced => self.advanced.chat(request).await,
130 }
131 }
132
133 #[must_use]
134 pub const fn fast_provider(&self) -> &S {
135 &self.fast
136 }
137
138 #[must_use]
139 pub const fn capable_provider(&self) -> &S {
140 &self.capable
141 }
142
143 #[must_use]
144 pub const fn advanced_provider(&self) -> &A {
145 &self.advanced
146 }
147}
148
149const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
150
151SIMPLE tasks:
152- Basic questions with factual answers
153- Simple calculations
154- Direct lookups or retrievals
155- Yes/no questions
156- Single-step operations
157
158MODERATE tasks:
159- Multi-step reasoning
160- Summarization
161- Basic analysis
162- Comparisons
163- Standard tool usage
164
165COMPLEX tasks:
166- Creative writing or content generation
167- Multi-step planning
168- Complex analysis or synthesis
169- Nuanced decisions
170- Tasks requiring deep domain knowledge
171- Financial advice or calculations
172- Multi-tool orchestration
173
174Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
175
176fn build_classification_prompt(request: &ChatRequest) -> String {
177 let mut prompt = String::new();
178
179 prompt.push_str("Classify this task:\n\n");
180
181 if !request.system.is_empty() {
182 prompt.push_str("System context: ");
183 prompt.push_str(&request.system[..request.system.len().min(200)]);
184 if request.system.len() > 200 {
185 prompt.push_str("...");
186 }
187 prompt.push_str("\n\n");
188 }
189
190 if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
191 && let Some(text) = last_user_message.content.first_text()
192 {
193 prompt.push_str("User request: ");
194 prompt.push_str(&text[..text.len().min(500)]);
195 if text.len() > 500 {
196 prompt.push_str("...");
197 }
198 }
199
200 if let Some(tools) = &request.tools {
201 let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
202 }
203
204 prompt
205}
206
207fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
208 let text = response.first_text().unwrap_or("").to_uppercase();
209
210 if text.contains("SIMPLE") {
211 TaskComplexity::Simple
212 } else if text.contains("MODERATE") {
213 TaskComplexity::Moderate
214 } else {
215 TaskComplexity::Complex
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn complexity_to_tier() {
225 assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
226 assert_eq!(
227 TaskComplexity::Moderate.recommended_tier(),
228 ModelTier::Capable
229 );
230 assert_eq!(
231 TaskComplexity::Complex.recommended_tier(),
232 ModelTier::Advanced
233 );
234 }
235}