Skip to main content

agent_sdk_providers/
router.rs

1use std::fmt::Write;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6
7use crate::provider::LlmProvider;
8use crate::streaming::StreamBox;
9use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, Message, Role};
10
11/// A capability/cost tier a request can be routed to.
12///
13/// Tiers are ordered cheapest-and-fastest (`Fast`) to most-capable-and-costly
14/// (`Advanced`); [`TaskComplexity::recommended_tier`] maps a classified
15/// complexity onto one of these.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ModelTier {
18    /// Cheapest, fastest tier for trivial single-step work.
19    Fast,
20    /// Mid tier for multi-step reasoning, summarization, standard tool use.
21    Capable,
22    /// Most capable tier for creative, multi-step, or domain-heavy work.
23    Advanced,
24}
25
26/// The complexity a classifier assigns to an incoming request.
27///
28/// Produced by [`ModelRouter::classify`]. When the classifier itself errors or
29/// is rate limited the router falls back to the conservative
30/// [`TaskComplexity::Complex`] so a misclassification never silently downgrades
31/// a hard request to a weak model.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TaskComplexity {
34    /// Basic factual questions, lookups, single-step operations.
35    Simple,
36    /// Multi-step reasoning, summarization, standard tool usage.
37    Moderate,
38    /// Creative generation, planning, synthesis, deep domain knowledge.
39    Complex,
40}
41
42impl TaskComplexity {
43    #[must_use]
44    pub const fn recommended_tier(self) -> ModelTier {
45        match self {
46            Self::Simple => ModelTier::Fast,
47            Self::Moderate => ModelTier::Capable,
48            Self::Complex => ModelTier::Advanced,
49        }
50    }
51}
52
53/// Routes each request to a model tier chosen by an LLM classifier.
54///
55/// A `ModelRouter` wraps a `classifier` provider plus three tier providers and,
56/// for every request, makes one extra classifier call to decide whether the work
57/// is [`Simple`](TaskComplexity::Simple), [`Moderate`](TaskComplexity::Moderate),
58/// or [`Complex`](TaskComplexity::Complex), then dispatches to the `fast`,
59/// `capable`, or `advanced` provider respectively. If the classifier call fails
60/// (error or rate limit) the router conservatively treats the request as
61/// `Complex` rather than risk under-serving it (see [`classify`](Self::classify)).
62///
63/// `ModelRouter` itself implements [`LlmProvider`], so it can be passed anywhere
64/// a `&dyn LlmProvider` is expected (`run_structured`, `RefreshingProvider`,
65/// etc.). [`chat`](LlmProvider::chat) classifies then routes; the streaming
66/// [`chat_stream`](LlmProvider::chat_stream) classifies first, then streams the
67/// chosen tier.
68///
69/// Note: the `fast` and `capable` tiers currently share one provider type `S`
70/// (only `advanced` has its own type `A`), so mixing e.g. a Gemini fast tier with
71/// an `OpenAI` capable tier requires both behind the same concrete type. Use
72/// `Arc<dyn LlmProvider>` for all three tiers to mix providers freely.
73pub struct ModelRouter<C, S, A> {
74    classifier: C,
75    fast: S,
76    capable: S,
77    advanced: A,
78}
79
80impl<C, S, A> ModelRouter<C, S, A>
81where
82    C: LlmProvider,
83    S: LlmProvider,
84    A: LlmProvider,
85{
86    pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
87        Self {
88            classifier,
89            fast,
90            capable,
91            advanced,
92        }
93    }
94
95    /// # Errors
96    /// Returns an error if the LLM provider fails.
97    pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
98        let classification_prompt = build_classification_prompt(request);
99
100        let classification_request = ChatRequest {
101            system: CLASSIFICATION_SYSTEM.to_owned(),
102            messages: vec![Message::user(classification_prompt)],
103            tools: None,
104            max_tokens: 50,
105            max_tokens_explicit: true,
106            session_id: None,
107            cached_content: None,
108            thinking: None,
109            tool_choice: None,
110            response_format: None,
111        };
112
113        match self.classifier.chat(classification_request).await? {
114            ChatOutcome::Success(response) => {
115                let complexity = parse_complexity(&response);
116                log::debug!(
117                    "Model router classified request as {:?} using {}",
118                    complexity,
119                    self.classifier.model()
120                );
121                Ok(complexity)
122            }
123            ChatOutcome::RateLimited => {
124                log::warn!("Classifier rate limited, defaulting to Complex");
125                Ok(TaskComplexity::Complex)
126            }
127            ChatOutcome::InvalidRequest(e) => {
128                log::error!("Classifier invalid request: {e}, defaulting to Complex");
129                Ok(TaskComplexity::Complex)
130            }
131            ChatOutcome::ServerError(e) => {
132                log::error!("Classifier server error: {e}, defaulting to Complex");
133                Ok(TaskComplexity::Complex)
134            }
135            // `ChatOutcome` is `#[non_exhaustive]`; an unrecognized outcome
136            // takes the same conservative fallback as the error variants.
137            _ => {
138                log::error!("Classifier returned unrecognized outcome, defaulting to Complex");
139                Ok(TaskComplexity::Complex)
140            }
141        }
142    }
143
144    /// # Errors
145    /// Returns an error if the LLM provider fails.
146    pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
147        let complexity = self.classify(&request).await?;
148        let tier = complexity.recommended_tier();
149
150        log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
151
152        match tier {
153            ModelTier::Fast => self.fast.chat(request).await,
154            ModelTier::Capable => self.capable.chat(request).await,
155            ModelTier::Advanced => self.advanced.chat(request).await,
156        }
157    }
158
159    /// # Errors
160    /// Returns an error if the LLM provider fails.
161    pub async fn route_with_tier(
162        &self,
163        request: ChatRequest,
164        tier: ModelTier,
165    ) -> Result<ChatOutcome> {
166        match tier {
167            ModelTier::Fast => self.fast.chat(request).await,
168            ModelTier::Capable => self.capable.chat(request).await,
169            ModelTier::Advanced => self.advanced.chat(request).await,
170        }
171    }
172
173    #[must_use]
174    pub const fn fast_provider(&self) -> &S {
175        &self.fast
176    }
177
178    #[must_use]
179    pub const fn capable_provider(&self) -> &S {
180        &self.capable
181    }
182
183    #[must_use]
184    pub const fn advanced_provider(&self) -> &A {
185        &self.advanced
186    }
187}
188
189#[async_trait]
190impl<C, S, A> LlmProvider for ModelRouter<C, S, A>
191where
192    C: LlmProvider,
193    S: LlmProvider,
194    A: LlmProvider,
195{
196    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
197        self.route(request).await
198    }
199
200    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
201        Box::pin(async_stream::stream! {
202            let tier = match self.classify(&request).await {
203                Ok(complexity) => complexity.recommended_tier(),
204                Err(error) => {
205                    yield Err(error);
206                    return;
207                }
208            };
209            log::info!("Streaming request to {tier:?} tier");
210            let mut stream = match tier {
211                ModelTier::Fast => self.fast.chat_stream(request),
212                ModelTier::Capable => self.capable.chat_stream(request),
213                ModelTier::Advanced => self.advanced.chat_stream(request),
214            };
215            while let Some(item) = stream.next().await {
216                yield item;
217            }
218        })
219    }
220
221    /// Reports the `capable` (mid) tier's model as the router's representative
222    /// model identifier.
223    fn model(&self) -> &str {
224        self.capable.model()
225    }
226
227    /// Reports the `capable` (mid) tier's provider as the router's representative
228    /// provider identifier.
229    fn provider(&self) -> &'static str {
230        self.capable.provider()
231    }
232}
233
234const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
235
236SIMPLE tasks:
237- Basic questions with factual answers
238- Simple calculations
239- Direct lookups or retrievals
240- Yes/no questions
241- Single-step operations
242
243MODERATE tasks:
244- Multi-step reasoning
245- Summarization
246- Basic analysis
247- Comparisons
248- Standard tool usage
249
250COMPLEX tasks:
251- Creative writing or content generation
252- Multi-step planning
253- Complex analysis or synthesis
254- Nuanced decisions
255- Tasks requiring deep domain knowledge
256- Financial advice or calculations
257- Multi-tool orchestration
258
259Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
260
261fn build_classification_prompt(request: &ChatRequest) -> String {
262    let mut prompt = String::new();
263
264    prompt.push_str("Classify this task:\n\n");
265
266    if !request.system.is_empty() {
267        prompt.push_str("System context: ");
268        let truncated = truncate_on_char_boundary(&request.system, 200);
269        prompt.push_str(truncated);
270        if truncated.len() < request.system.len() {
271            prompt.push_str("...");
272        }
273        prompt.push_str("\n\n");
274    }
275
276    if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
277        && let Some(text) = last_user_message.content.first_text()
278    {
279        prompt.push_str("User request: ");
280        let truncated = truncate_on_char_boundary(text, 500);
281        prompt.push_str(truncated);
282        if truncated.len() < text.len() {
283            prompt.push_str("...");
284        }
285    }
286
287    if let Some(tools) = &request.tools {
288        let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
289    }
290
291    prompt
292}
293
294/// Truncate `s` to at most `max_bytes`, backing off to the nearest UTF-8
295/// character boundary so the byte slice never panics on a multi-byte character
296/// (emoji, CJK, accented text) that straddles the limit.
297fn truncate_on_char_boundary(s: &str, max_bytes: usize) -> &str {
298    if s.len() <= max_bytes {
299        return s;
300    }
301    let mut end = max_bytes;
302    while end > 0 && !s.is_char_boundary(end) {
303        end -= 1;
304    }
305    &s[..end]
306}
307
308fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
309    let text = response.first_text().unwrap_or("").to_uppercase();
310
311    if text.contains("SIMPLE") {
312        TaskComplexity::Simple
313    } else if text.contains("MODERATE") {
314        TaskComplexity::Moderate
315    } else {
316        TaskComplexity::Complex
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn complexity_to_tier() {
326        assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
327        assert_eq!(
328            TaskComplexity::Moderate.recommended_tier(),
329            ModelTier::Capable
330        );
331        assert_eq!(
332            TaskComplexity::Complex.recommended_tier(),
333            ModelTier::Advanced
334        );
335    }
336
337    #[test]
338    fn truncate_on_char_boundary_never_splits_multibyte_char() {
339        // "😀" is a 4-byte character. Truncating at byte 1, 2, or 3 would land
340        // inside it and panic with naive `&s[..n]`; the helper must back off to a
341        // valid boundary instead.
342        let s = "😀😀😀";
343        for max in 0..=s.len() {
344            let truncated = truncate_on_char_boundary(s, max);
345            // Must be a valid prefix of the original (never panics, always UTF-8).
346            assert!(s.starts_with(truncated));
347            assert!(truncated.len() <= max);
348        }
349        assert_eq!(truncate_on_char_boundary(s, 4), "😀");
350        assert_eq!(truncate_on_char_boundary(s, 5), "😀");
351        assert_eq!(truncate_on_char_boundary(s, 100), s);
352    }
353
354    #[test]
355    fn build_classification_prompt_handles_multibyte_at_limit() {
356        // A system prompt longer than 200 bytes whose 200th byte falls inside a
357        // multi-byte char must not panic when building the classification prompt.
358        let system = "é".repeat(150); // 300 bytes; byte 200 is mid-character
359        let request = ChatRequest::new(system, vec![Message::user("日本語".repeat(300))]);
360        // The bug manifested as a panic; reaching this assertion means no panic.
361        let prompt = build_classification_prompt(&request);
362        assert!(prompt.contains("System context:"));
363        assert!(prompt.ends_with("..."));
364    }
365}