Skip to main content

agent_sdk_providers/
router.rs

1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::provider::LlmProvider;
6use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, Message, Role};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ModelTier {
10    Fast,
11    Capable,
12    Advanced,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TaskComplexity {
17    Simple,
18    Moderate,
19    Complex,
20}
21
22impl TaskComplexity {
23    #[must_use]
24    pub const fn recommended_tier(self) -> ModelTier {
25        match self {
26            Self::Simple => ModelTier::Fast,
27            Self::Moderate => ModelTier::Capable,
28            Self::Complex => ModelTier::Advanced,
29        }
30    }
31}
32
33pub struct ModelRouter<C, S, A> {
34    classifier: C,
35    fast: S,
36    capable: S,
37    advanced: A,
38}
39
40impl<C, S, A> ModelRouter<C, S, A>
41where
42    C: LlmProvider,
43    S: LlmProvider,
44    A: LlmProvider,
45{
46    pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
47        Self {
48            classifier,
49            fast,
50            capable,
51            advanced,
52        }
53    }
54
55    /// # Errors
56    /// Returns an error if the LLM provider fails.
57    pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
58        let classification_prompt = build_classification_prompt(request);
59
60        let classification_request = ChatRequest {
61            system: CLASSIFICATION_SYSTEM.to_owned(),
62            messages: vec![Message::user(classification_prompt)],
63            tools: None,
64            max_tokens: 50,
65            max_tokens_explicit: true,
66            session_id: None,
67            cached_content: None,
68            thinking: None,
69            tool_choice: None,
70            response_format: None,
71        };
72
73        match self.classifier.chat(classification_request).await? {
74            ChatOutcome::Success(response) => {
75                let complexity = parse_complexity(&response);
76                log::debug!(
77                    "Model router classified request as {:?} using {}",
78                    complexity,
79                    self.classifier.model()
80                );
81                Ok(complexity)
82            }
83            ChatOutcome::RateLimited => {
84                log::warn!("Classifier rate limited, defaulting to Complex");
85                Ok(TaskComplexity::Complex)
86            }
87            ChatOutcome::InvalidRequest(e) => {
88                log::error!("Classifier invalid request: {e}, defaulting to Complex");
89                Ok(TaskComplexity::Complex)
90            }
91            ChatOutcome::ServerError(e) => {
92                log::error!("Classifier server error: {e}, defaulting to Complex");
93                Ok(TaskComplexity::Complex)
94            }
95            // `ChatOutcome` is `#[non_exhaustive]`; an unrecognized outcome
96            // takes the same conservative fallback as the error variants.
97            _ => {
98                log::error!("Classifier returned unrecognized outcome, defaulting to Complex");
99                Ok(TaskComplexity::Complex)
100            }
101        }
102    }
103
104    /// # Errors
105    /// Returns an error if the LLM provider fails.
106    pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
107        let complexity = self.classify(&request).await?;
108        let tier = complexity.recommended_tier();
109
110        log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
111
112        match tier {
113            ModelTier::Fast => self.fast.chat(request).await,
114            ModelTier::Capable => self.capable.chat(request).await,
115            ModelTier::Advanced => self.advanced.chat(request).await,
116        }
117    }
118
119    /// # Errors
120    /// Returns an error if the LLM provider fails.
121    pub async fn route_with_tier(
122        &self,
123        request: ChatRequest,
124        tier: ModelTier,
125    ) -> Result<ChatOutcome> {
126        match tier {
127            ModelTier::Fast => self.fast.chat(request).await,
128            ModelTier::Capable => self.capable.chat(request).await,
129            ModelTier::Advanced => self.advanced.chat(request).await,
130        }
131    }
132
133    #[must_use]
134    pub const fn fast_provider(&self) -> &S {
135        &self.fast
136    }
137
138    #[must_use]
139    pub const fn capable_provider(&self) -> &S {
140        &self.capable
141    }
142
143    #[must_use]
144    pub const fn advanced_provider(&self) -> &A {
145        &self.advanced
146    }
147}
148
149const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
150
151SIMPLE tasks:
152- Basic questions with factual answers
153- Simple calculations
154- Direct lookups or retrievals
155- Yes/no questions
156- Single-step operations
157
158MODERATE tasks:
159- Multi-step reasoning
160- Summarization
161- Basic analysis
162- Comparisons
163- Standard tool usage
164
165COMPLEX tasks:
166- Creative writing or content generation
167- Multi-step planning
168- Complex analysis or synthesis
169- Nuanced decisions
170- Tasks requiring deep domain knowledge
171- Financial advice or calculations
172- Multi-tool orchestration
173
174Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
175
176fn build_classification_prompt(request: &ChatRequest) -> String {
177    let mut prompt = String::new();
178
179    prompt.push_str("Classify this task:\n\n");
180
181    if !request.system.is_empty() {
182        prompt.push_str("System context: ");
183        prompt.push_str(&request.system[..request.system.len().min(200)]);
184        if request.system.len() > 200 {
185            prompt.push_str("...");
186        }
187        prompt.push_str("\n\n");
188    }
189
190    if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
191        && let Some(text) = last_user_message.content.first_text()
192    {
193        prompt.push_str("User request: ");
194        prompt.push_str(&text[..text.len().min(500)]);
195        if text.len() > 500 {
196            prompt.push_str("...");
197        }
198    }
199
200    if let Some(tools) = &request.tools {
201        let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
202    }
203
204    prompt
205}
206
207fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
208    let text = response.first_text().unwrap_or("").to_uppercase();
209
210    if text.contains("SIMPLE") {
211        TaskComplexity::Simple
212    } else if text.contains("MODERATE") {
213        TaskComplexity::Moderate
214    } else {
215        TaskComplexity::Complex
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn complexity_to_tier() {
225        assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
226        assert_eq!(
227            TaskComplexity::Moderate.recommended_tier(),
228            ModelTier::Capable
229        );
230        assert_eq!(
231            TaskComplexity::Complex.recommended_tier(),
232            ModelTier::Advanced
233        );
234    }
235}