Skip to main content

agent_sdk/llm/
router.rs

1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, LlmProvider, Message, Role};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelTier {
9    Fast,
10    Capable,
11    Advanced,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum TaskComplexity {
16    Simple,
17    Moderate,
18    Complex,
19}
20
21impl TaskComplexity {
22    #[must_use]
23    pub const fn recommended_tier(self) -> ModelTier {
24        match self {
25            Self::Simple => ModelTier::Fast,
26            Self::Moderate => ModelTier::Capable,
27            Self::Complex => ModelTier::Advanced,
28        }
29    }
30}
31
32pub struct ModelRouter<C, S, A> {
33    classifier: C,
34    fast: S,
35    capable: S,
36    advanced: A,
37}
38
39impl<C, S, A> ModelRouter<C, S, A>
40where
41    C: LlmProvider,
42    S: LlmProvider,
43    A: LlmProvider,
44{
45    pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
46        Self {
47            classifier,
48            fast,
49            capable,
50            advanced,
51        }
52    }
53
54    /// # Errors
55    /// Returns an error if the LLM provider fails.
56    pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
57        let classification_prompt = build_classification_prompt(request);
58
59        let classification_request = ChatRequest {
60            system: CLASSIFICATION_SYSTEM.to_owned(),
61            messages: vec![Message::user(classification_prompt)],
62            tools: None,
63            max_tokens: 50,
64            thinking: None,
65        };
66
67        match self.classifier.chat(classification_request).await? {
68            ChatOutcome::Success(response) => {
69                let complexity = parse_complexity(&response);
70                log::debug!(
71                    "Model router classified request as {:?} using {}",
72                    complexity,
73                    self.classifier.model()
74                );
75                Ok(complexity)
76            }
77            ChatOutcome::RateLimited => {
78                log::warn!("Classifier rate limited, defaulting to Complex");
79                Ok(TaskComplexity::Complex)
80            }
81            ChatOutcome::InvalidRequest(e) => {
82                log::error!("Classifier invalid request: {e}, defaulting to Complex");
83                Ok(TaskComplexity::Complex)
84            }
85            ChatOutcome::ServerError(e) => {
86                log::error!("Classifier server error: {e}, defaulting to Complex");
87                Ok(TaskComplexity::Complex)
88            }
89        }
90    }
91
92    /// # Errors
93    /// Returns an error if the LLM provider fails.
94    pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
95        let complexity = self.classify(&request).await?;
96        let tier = complexity.recommended_tier();
97
98        log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
99
100        match tier {
101            ModelTier::Fast => self.fast.chat(request).await,
102            ModelTier::Capable => self.capable.chat(request).await,
103            ModelTier::Advanced => self.advanced.chat(request).await,
104        }
105    }
106
107    /// # Errors
108    /// Returns an error if the LLM provider fails.
109    pub async fn route_with_tier(
110        &self,
111        request: ChatRequest,
112        tier: ModelTier,
113    ) -> Result<ChatOutcome> {
114        match tier {
115            ModelTier::Fast => self.fast.chat(request).await,
116            ModelTier::Capable => self.capable.chat(request).await,
117            ModelTier::Advanced => self.advanced.chat(request).await,
118        }
119    }
120
121    #[must_use]
122    pub const fn fast_provider(&self) -> &S {
123        &self.fast
124    }
125
126    #[must_use]
127    pub const fn capable_provider(&self) -> &S {
128        &self.capable
129    }
130
131    #[must_use]
132    pub const fn advanced_provider(&self) -> &A {
133        &self.advanced
134    }
135}
136
137const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
138
139SIMPLE tasks:
140- Basic questions with factual answers
141- Simple calculations
142- Direct lookups or retrievals
143- Yes/no questions
144- Single-step operations
145
146MODERATE tasks:
147- Multi-step reasoning
148- Summarization
149- Basic analysis
150- Comparisons
151- Standard tool usage
152
153COMPLEX tasks:
154- Creative writing or content generation
155- Multi-step planning
156- Complex analysis or synthesis
157- Nuanced decisions
158- Tasks requiring deep domain knowledge
159- Financial advice or calculations
160- Multi-tool orchestration
161
162Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
163
164fn build_classification_prompt(request: &ChatRequest) -> String {
165    let mut prompt = String::new();
166
167    prompt.push_str("Classify this task:\n\n");
168
169    if !request.system.is_empty() {
170        prompt.push_str("System context: ");
171        prompt.push_str(&request.system[..request.system.len().min(200)]);
172        if request.system.len() > 200 {
173            prompt.push_str("...");
174        }
175        prompt.push_str("\n\n");
176    }
177
178    if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
179        && let Some(text) = last_user_message.content.first_text()
180    {
181        prompt.push_str("User request: ");
182        prompt.push_str(&text[..text.len().min(500)]);
183        if text.len() > 500 {
184            prompt.push_str("...");
185        }
186    }
187
188    if let Some(tools) = &request.tools {
189        let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
190    }
191
192    prompt
193}
194
195fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
196    let text = response.first_text().unwrap_or("").to_uppercase();
197
198    if text.contains("SIMPLE") {
199        TaskComplexity::Simple
200    } else if text.contains("MODERATE") {
201        TaskComplexity::Moderate
202    } else {
203        TaskComplexity::Complex
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn complexity_to_tier() {
213        assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
214        assert_eq!(
215            TaskComplexity::Moderate.recommended_tier(),
216            ModelTier::Capable
217        );
218        assert_eq!(
219            TaskComplexity::Complex.recommended_tier(),
220            ModelTier::Advanced
221        );
222    }
223}