agent_sdk/llm/
router.rs

1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, LlmProvider, Message, Role};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelTier {
9    Fast,
10    Capable,
11    Advanced,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum TaskComplexity {
16    Simple,
17    Moderate,
18    Complex,
19}
20
21impl TaskComplexity {
22    #[must_use]
23    pub const fn recommended_tier(self) -> ModelTier {
24        match self {
25            Self::Simple => ModelTier::Fast,
26            Self::Moderate => ModelTier::Capable,
27            Self::Complex => ModelTier::Advanced,
28        }
29    }
30}
31
32pub struct ModelRouter<C, S, A> {
33    classifier: C,
34    fast: S,
35    capable: S,
36    advanced: A,
37}
38
39impl<C, S, A> ModelRouter<C, S, A>
40where
41    C: LlmProvider,
42    S: LlmProvider,
43    A: LlmProvider,
44{
45    pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
46        Self {
47            classifier,
48            fast,
49            capable,
50            advanced,
51        }
52    }
53
54    /// # Errors
55    /// Returns an error if the LLM provider fails.
56    pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
57        let classification_prompt = build_classification_prompt(request);
58
59        let classification_request = ChatRequest {
60            system: CLASSIFICATION_SYSTEM.to_owned(),
61            messages: vec![Message::user(classification_prompt)],
62            tools: None,
63            max_tokens: 50,
64        };
65
66        match self.classifier.chat(classification_request).await? {
67            ChatOutcome::Success(response) => {
68                let complexity = parse_complexity(&response);
69                log::debug!(
70                    "Model router classified request as {:?} using {}",
71                    complexity,
72                    self.classifier.model()
73                );
74                Ok(complexity)
75            }
76            ChatOutcome::RateLimited => {
77                log::warn!("Classifier rate limited, defaulting to Complex");
78                Ok(TaskComplexity::Complex)
79            }
80            ChatOutcome::InvalidRequest(e) => {
81                log::error!("Classifier invalid request: {e}, defaulting to Complex");
82                Ok(TaskComplexity::Complex)
83            }
84            ChatOutcome::ServerError(e) => {
85                log::error!("Classifier server error: {e}, defaulting to Complex");
86                Ok(TaskComplexity::Complex)
87            }
88        }
89    }
90
91    /// # Errors
92    /// Returns an error if the LLM provider fails.
93    pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
94        let complexity = self.classify(&request).await?;
95        let tier = complexity.recommended_tier();
96
97        log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
98
99        match tier {
100            ModelTier::Fast => self.fast.chat(request).await,
101            ModelTier::Capable => self.capable.chat(request).await,
102            ModelTier::Advanced => self.advanced.chat(request).await,
103        }
104    }
105
106    /// # Errors
107    /// Returns an error if the LLM provider fails.
108    pub async fn route_with_tier(
109        &self,
110        request: ChatRequest,
111        tier: ModelTier,
112    ) -> Result<ChatOutcome> {
113        match tier {
114            ModelTier::Fast => self.fast.chat(request).await,
115            ModelTier::Capable => self.capable.chat(request).await,
116            ModelTier::Advanced => self.advanced.chat(request).await,
117        }
118    }
119
120    #[must_use]
121    pub const fn fast_provider(&self) -> &S {
122        &self.fast
123    }
124
125    #[must_use]
126    pub const fn capable_provider(&self) -> &S {
127        &self.capable
128    }
129
130    #[must_use]
131    pub const fn advanced_provider(&self) -> &A {
132        &self.advanced
133    }
134}
135
136const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
137
138SIMPLE tasks:
139- Basic questions with factual answers
140- Simple calculations
141- Direct lookups or retrievals
142- Yes/no questions
143- Single-step operations
144
145MODERATE tasks:
146- Multi-step reasoning
147- Summarization
148- Basic analysis
149- Comparisons
150- Standard tool usage
151
152COMPLEX tasks:
153- Creative writing or content generation
154- Multi-step planning
155- Complex analysis or synthesis
156- Nuanced decisions
157- Tasks requiring deep domain knowledge
158- Financial advice or calculations
159- Multi-tool orchestration
160
161Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
162
163fn build_classification_prompt(request: &ChatRequest) -> String {
164    let mut prompt = String::new();
165
166    prompt.push_str("Classify this task:\n\n");
167
168    if !request.system.is_empty() {
169        prompt.push_str("System context: ");
170        prompt.push_str(&request.system[..request.system.len().min(200)]);
171        if request.system.len() > 200 {
172            prompt.push_str("...");
173        }
174        prompt.push_str("\n\n");
175    }
176
177    if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
178        && let Some(text) = last_user_message.content.first_text()
179    {
180        prompt.push_str("User request: ");
181        prompt.push_str(&text[..text.len().min(500)]);
182        if text.len() > 500 {
183            prompt.push_str("...");
184        }
185    }
186
187    if let Some(tools) = &request.tools {
188        let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
189    }
190
191    prompt
192}
193
194fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
195    let text = response.first_text().unwrap_or("").to_uppercase();
196
197    if text.contains("SIMPLE") {
198        TaskComplexity::Simple
199    } else if text.contains("MODERATE") {
200        TaskComplexity::Moderate
201    } else {
202        TaskComplexity::Complex
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn complexity_to_tier() {
212        assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
213        assert_eq!(
214            TaskComplexity::Moderate.recommended_tier(),
215            ModelTier::Capable
216        );
217        assert_eq!(
218            TaskComplexity::Complex.recommended_tier(),
219            ModelTier::Advanced
220        );
221    }
222}