Skip to main content

agent_sdk/llm/
router.rs

1use std::fmt::Write;
2
3use anyhow::Result;
4
5use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, LlmProvider, Message, Role};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ModelTier {
9    Fast,
10    Capable,
11    Advanced,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum TaskComplexity {
16    Simple,
17    Moderate,
18    Complex,
19}
20
21impl TaskComplexity {
22    #[must_use]
23    pub const fn recommended_tier(self) -> ModelTier {
24        match self {
25            Self::Simple => ModelTier::Fast,
26            Self::Moderate => ModelTier::Capable,
27            Self::Complex => ModelTier::Advanced,
28        }
29    }
30}
31
32pub struct ModelRouter<C, S, A> {
33    classifier: C,
34    fast: S,
35    capable: S,
36    advanced: A,
37}
38
39impl<C, S, A> ModelRouter<C, S, A>
40where
41    C: LlmProvider,
42    S: LlmProvider,
43    A: LlmProvider,
44{
45    pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
46        Self {
47            classifier,
48            fast,
49            capable,
50            advanced,
51        }
52    }
53
54    /// # Errors
55    /// Returns an error if the LLM provider fails.
56    pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
57        let classification_prompt = build_classification_prompt(request);
58
59        let classification_request = ChatRequest {
60            system: CLASSIFICATION_SYSTEM.to_owned(),
61            messages: vec![Message::user(classification_prompt)],
62            tools: None,
63            max_tokens: 50,
64            max_tokens_explicit: true,
65            session_id: None,
66            cached_content: None,
67            thinking: None,
68        };
69
70        match self.classifier.chat(classification_request).await? {
71            ChatOutcome::Success(response) => {
72                let complexity = parse_complexity(&response);
73                log::debug!(
74                    "Model router classified request as {:?} using {}",
75                    complexity,
76                    self.classifier.model()
77                );
78                Ok(complexity)
79            }
80            ChatOutcome::RateLimited => {
81                log::warn!("Classifier rate limited, defaulting to Complex");
82                Ok(TaskComplexity::Complex)
83            }
84            ChatOutcome::InvalidRequest(e) => {
85                log::error!("Classifier invalid request: {e}, defaulting to Complex");
86                Ok(TaskComplexity::Complex)
87            }
88            ChatOutcome::ServerError(e) => {
89                log::error!("Classifier server error: {e}, defaulting to Complex");
90                Ok(TaskComplexity::Complex)
91            }
92        }
93    }
94
95    /// # Errors
96    /// Returns an error if the LLM provider fails.
97    pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
98        let complexity = self.classify(&request).await?;
99        let tier = complexity.recommended_tier();
100
101        log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
102
103        match tier {
104            ModelTier::Fast => self.fast.chat(request).await,
105            ModelTier::Capable => self.capable.chat(request).await,
106            ModelTier::Advanced => self.advanced.chat(request).await,
107        }
108    }
109
110    /// # Errors
111    /// Returns an error if the LLM provider fails.
112    pub async fn route_with_tier(
113        &self,
114        request: ChatRequest,
115        tier: ModelTier,
116    ) -> Result<ChatOutcome> {
117        match tier {
118            ModelTier::Fast => self.fast.chat(request).await,
119            ModelTier::Capable => self.capable.chat(request).await,
120            ModelTier::Advanced => self.advanced.chat(request).await,
121        }
122    }
123
124    #[must_use]
125    pub const fn fast_provider(&self) -> &S {
126        &self.fast
127    }
128
129    #[must_use]
130    pub const fn capable_provider(&self) -> &S {
131        &self.capable
132    }
133
134    #[must_use]
135    pub const fn advanced_provider(&self) -> &A {
136        &self.advanced
137    }
138}
139
140const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
141
142SIMPLE tasks:
143- Basic questions with factual answers
144- Simple calculations
145- Direct lookups or retrievals
146- Yes/no questions
147- Single-step operations
148
149MODERATE tasks:
150- Multi-step reasoning
151- Summarization
152- Basic analysis
153- Comparisons
154- Standard tool usage
155
156COMPLEX tasks:
157- Creative writing or content generation
158- Multi-step planning
159- Complex analysis or synthesis
160- Nuanced decisions
161- Tasks requiring deep domain knowledge
162- Financial advice or calculations
163- Multi-tool orchestration
164
165Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
166
167fn build_classification_prompt(request: &ChatRequest) -> String {
168    let mut prompt = String::new();
169
170    prompt.push_str("Classify this task:\n\n");
171
172    if !request.system.is_empty() {
173        prompt.push_str("System context: ");
174        prompt.push_str(&request.system[..request.system.len().min(200)]);
175        if request.system.len() > 200 {
176            prompt.push_str("...");
177        }
178        prompt.push_str("\n\n");
179    }
180
181    if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
182        && let Some(text) = last_user_message.content.first_text()
183    {
184        prompt.push_str("User request: ");
185        prompt.push_str(&text[..text.len().min(500)]);
186        if text.len() > 500 {
187            prompt.push_str("...");
188        }
189    }
190
191    if let Some(tools) = &request.tools {
192        let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
193    }
194
195    prompt
196}
197
198fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
199    let text = response.first_text().unwrap_or("").to_uppercase();
200
201    if text.contains("SIMPLE") {
202        TaskComplexity::Simple
203    } else if text.contains("MODERATE") {
204        TaskComplexity::Moderate
205    } else {
206        TaskComplexity::Complex
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn complexity_to_tier() {
216        assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
217        assert_eq!(
218            TaskComplexity::Moderate.recommended_tier(),
219            ModelTier::Capable
220        );
221        assert_eq!(
222            TaskComplexity::Complex.recommended_tier(),
223            ModelTier::Advanced
224        );
225    }
226}