Skip to main content

agent_sdk_providers/
router.rs

1use std::fmt::Write;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6
7use crate::provider::LlmProvider;
8use crate::streaming::StreamBox;
9use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, Message, Role};
10
11/// A capability/cost tier a request can be routed to.
12///
13/// Tiers are ordered cheapest-and-fastest (`Fast`) to most-capable-and-costly
14/// (`Advanced`); [`TaskComplexity::recommended_tier`] maps a classified
15/// complexity onto one of these.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ModelTier {
18    /// Cheapest, fastest tier for trivial single-step work.
19    Fast,
20    /// Mid tier for multi-step reasoning, summarization, standard tool use.
21    Capable,
22    /// Most capable tier for creative, multi-step, or domain-heavy work.
23    Advanced,
24}
25
26/// The complexity a classifier assigns to an incoming request.
27///
28/// Produced by [`ModelRouter::classify`]. When the classifier itself errors or
29/// is rate limited the router falls back to the conservative
30/// [`TaskComplexity::Complex`] so a misclassification never silently downgrades
31/// a hard request to a weak model.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TaskComplexity {
34    /// Basic factual questions, lookups, single-step operations.
35    Simple,
36    /// Multi-step reasoning, summarization, standard tool usage.
37    Moderate,
38    /// Creative generation, planning, synthesis, deep domain knowledge.
39    Complex,
40}
41
42impl TaskComplexity {
43    #[must_use]
44    pub const fn recommended_tier(self) -> ModelTier {
45        match self {
46            Self::Simple => ModelTier::Fast,
47            Self::Moderate => ModelTier::Capable,
48            Self::Complex => ModelTier::Advanced,
49        }
50    }
51}
52
53/// Routes each request to a model tier chosen by an LLM classifier.
54///
55/// A `ModelRouter` wraps a `classifier` provider plus three tier providers and,
56/// for every request, makes one extra classifier call to decide whether the work
57/// is [`Simple`](TaskComplexity::Simple), [`Moderate`](TaskComplexity::Moderate),
58/// or [`Complex`](TaskComplexity::Complex), then dispatches to the `fast`,
59/// `capable`, or `advanced` provider respectively. If the classifier call fails
60/// (error or rate limit) the router conservatively treats the request as
61/// `Complex` rather than risk under-serving it (see [`classify`](Self::classify)).
62///
63/// `ModelRouter` itself implements [`LlmProvider`], so it can be passed anywhere
64/// a `&dyn LlmProvider` is expected (`run_structured`, `RefreshingProvider`,
65/// etc.). [`chat`](LlmProvider::chat) classifies then routes; the streaming
66/// [`chat_stream`](LlmProvider::chat_stream) classifies first, then streams the
67/// chosen tier.
68///
69/// Note: the `fast` and `capable` tiers currently share one provider type `S`
70/// (only `advanced` has its own type `A`), so mixing e.g. a Gemini fast tier with
71/// an `OpenAI` capable tier requires both behind the same concrete type. Use
72/// `Arc<dyn LlmProvider>` for all three tiers to mix providers freely.
73pub struct ModelRouter<C, S, A> {
74    classifier: C,
75    fast: S,
76    capable: S,
77    advanced: A,
78}
79
80impl<C, S, A> ModelRouter<C, S, A>
81where
82    C: LlmProvider,
83    S: LlmProvider,
84    A: LlmProvider,
85{
86    pub const fn new(classifier: C, fast: S, capable: S, advanced: A) -> Self {
87        Self {
88            classifier,
89            fast,
90            capable,
91            advanced,
92        }
93    }
94
95    /// # Errors
96    /// Returns an error if the LLM provider fails.
97    pub async fn classify(&self, request: &ChatRequest) -> Result<TaskComplexity> {
98        let classification_prompt = build_classification_prompt(request);
99
100        let classification_request = ChatRequest {
101            system: CLASSIFICATION_SYSTEM.to_owned(),
102            messages: vec![Message::user(classification_prompt)],
103            tools: None,
104            max_tokens: 50,
105            max_tokens_explicit: true,
106            session_id: None,
107            cached_content: None,
108            thinking: None,
109            tool_choice: None,
110            response_format: None,
111            cache: None,
112        };
113
114        match self.classifier.chat(classification_request).await? {
115            ChatOutcome::Success(response) => {
116                let complexity = parse_complexity(&response);
117                log::debug!(
118                    "Model router classified request as {:?} using {}",
119                    complexity,
120                    self.classifier.model()
121                );
122                Ok(complexity)
123            }
124            ChatOutcome::RateLimited(_) => {
125                log::warn!("Classifier rate limited, defaulting to Complex");
126                Ok(TaskComplexity::Complex)
127            }
128            ChatOutcome::InvalidRequest(e) => {
129                log::error!("Classifier invalid request: {e}, defaulting to Complex");
130                Ok(TaskComplexity::Complex)
131            }
132            ChatOutcome::ServerError(e) => {
133                log::error!("Classifier server error: {e}, defaulting to Complex");
134                Ok(TaskComplexity::Complex)
135            }
136            // `ChatOutcome` is `#[non_exhaustive]`; an unrecognized outcome
137            // takes the same conservative fallback as the error variants.
138            _ => {
139                log::error!("Classifier returned unrecognized outcome, defaulting to Complex");
140                Ok(TaskComplexity::Complex)
141            }
142        }
143    }
144
145    /// # Errors
146    /// Returns an error if the LLM provider fails.
147    pub async fn route(&self, request: ChatRequest) -> Result<ChatOutcome> {
148        let complexity = self.classify(&request).await?;
149        let tier = complexity.recommended_tier();
150
151        log::info!("Routing request to {tier:?} tier (complexity: {complexity:?})");
152
153        match tier {
154            ModelTier::Fast => self.fast.chat(request).await,
155            ModelTier::Capable => self.capable.chat(request).await,
156            ModelTier::Advanced => self.advanced.chat(request).await,
157        }
158    }
159
160    /// # Errors
161    /// Returns an error if the LLM provider fails.
162    pub async fn route_with_tier(
163        &self,
164        request: ChatRequest,
165        tier: ModelTier,
166    ) -> Result<ChatOutcome> {
167        match tier {
168            ModelTier::Fast => self.fast.chat(request).await,
169            ModelTier::Capable => self.capable.chat(request).await,
170            ModelTier::Advanced => self.advanced.chat(request).await,
171        }
172    }
173
174    #[must_use]
175    pub const fn fast_provider(&self) -> &S {
176        &self.fast
177    }
178
179    #[must_use]
180    pub const fn capable_provider(&self) -> &S {
181        &self.capable
182    }
183
184    #[must_use]
185    pub const fn advanced_provider(&self) -> &A {
186        &self.advanced
187    }
188}
189
190#[async_trait]
191impl<C, S, A> LlmProvider for ModelRouter<C, S, A>
192where
193    C: LlmProvider,
194    S: LlmProvider,
195    A: LlmProvider,
196{
197    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
198        self.route(request).await
199    }
200
201    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
202        Box::pin(async_stream::stream! {
203            let tier = match self.classify(&request).await {
204                Ok(complexity) => complexity.recommended_tier(),
205                Err(error) => {
206                    yield Err(error);
207                    return;
208                }
209            };
210            log::info!("Streaming request to {tier:?} tier");
211            let mut stream = match tier {
212                ModelTier::Fast => self.fast.chat_stream(request),
213                ModelTier::Capable => self.capable.chat_stream(request),
214                ModelTier::Advanced => self.advanced.chat_stream(request),
215            };
216            while let Some(item) = stream.next().await {
217                yield item;
218            }
219        })
220    }
221
222    /// Reports the `capable` (mid) tier's model as the router's representative
223    /// model identifier.
224    fn model(&self) -> &str {
225        self.capable.model()
226    }
227
228    /// Reports the `capable` (mid) tier's provider as the router's representative
229    /// provider identifier.
230    fn provider(&self) -> &'static str {
231        self.capable.provider()
232    }
233}
234
235const CLASSIFICATION_SYSTEM: &str = r"You are a task complexity classifier. Analyze the user's request and classify it as one of: SIMPLE, MODERATE, or COMPLEX.
236
237SIMPLE tasks:
238- Basic questions with factual answers
239- Simple calculations
240- Direct lookups or retrievals
241- Yes/no questions
242- Single-step operations
243
244MODERATE tasks:
245- Multi-step reasoning
246- Summarization
247- Basic analysis
248- Comparisons
249- Standard tool usage
250
251COMPLEX tasks:
252- Creative writing or content generation
253- Multi-step planning
254- Complex analysis or synthesis
255- Nuanced decisions
256- Tasks requiring deep domain knowledge
257- Financial advice or calculations
258- Multi-tool orchestration
259
260Respond with ONLY one word: SIMPLE, MODERATE, or COMPLEX.";
261
262fn build_classification_prompt(request: &ChatRequest) -> String {
263    let mut prompt = String::new();
264
265    prompt.push_str("Classify this task:\n\n");
266
267    if !request.system.is_empty() {
268        prompt.push_str("System context: ");
269        let truncated = truncate_on_char_boundary(&request.system, 200);
270        prompt.push_str(truncated);
271        if truncated.len() < request.system.len() {
272            prompt.push_str("...");
273        }
274        prompt.push_str("\n\n");
275    }
276
277    if let Some(last_user_message) = request.messages.iter().rev().find(|m| m.role == Role::User)
278        && let Some(text) = last_user_message.content.first_text()
279    {
280        prompt.push_str("User request: ");
281        let truncated = truncate_on_char_boundary(text, 500);
282        prompt.push_str(truncated);
283        if truncated.len() < text.len() {
284            prompt.push_str("...");
285        }
286    }
287
288    if let Some(tools) = &request.tools {
289        let _ = write!(prompt, "\n\nAvailable tools: {}", tools.len());
290    }
291
292    prompt
293}
294
295/// Truncate `s` to at most `max_bytes`, backing off to the nearest UTF-8
296/// character boundary so the byte slice never panics on a multi-byte character
297/// (emoji, CJK, accented text) that straddles the limit.
298fn truncate_on_char_boundary(s: &str, max_bytes: usize) -> &str {
299    if s.len() <= max_bytes {
300        return s;
301    }
302    let mut end = max_bytes;
303    while end > 0 && !s.is_char_boundary(end) {
304        end -= 1;
305    }
306    &s[..end]
307}
308
309fn parse_complexity(response: &ChatResponse) -> TaskComplexity {
310    let text = response.first_text().unwrap_or("").to_uppercase();
311
312    if text.contains("SIMPLE") {
313        TaskComplexity::Simple
314    } else if text.contains("MODERATE") {
315        TaskComplexity::Moderate
316    } else {
317        TaskComplexity::Complex
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn complexity_to_tier() {
327        assert_eq!(TaskComplexity::Simple.recommended_tier(), ModelTier::Fast);
328        assert_eq!(
329            TaskComplexity::Moderate.recommended_tier(),
330            ModelTier::Capable
331        );
332        assert_eq!(
333            TaskComplexity::Complex.recommended_tier(),
334            ModelTier::Advanced
335        );
336    }
337
338    #[test]
339    fn truncate_on_char_boundary_never_splits_multibyte_char() {
340        // "😀" is a 4-byte character. Truncating at byte 1, 2, or 3 would land
341        // inside it and panic with naive `&s[..n]`; the helper must back off to a
342        // valid boundary instead.
343        let s = "😀😀😀";
344        for max in 0..=s.len() {
345            let truncated = truncate_on_char_boundary(s, max);
346            // Must be a valid prefix of the original (never panics, always UTF-8).
347            assert!(s.starts_with(truncated));
348            assert!(truncated.len() <= max);
349        }
350        assert_eq!(truncate_on_char_boundary(s, 4), "😀");
351        assert_eq!(truncate_on_char_boundary(s, 5), "😀");
352        assert_eq!(truncate_on_char_boundary(s, 100), s);
353    }
354
355    #[test]
356    fn build_classification_prompt_handles_multibyte_at_limit() {
357        // A system prompt longer than 200 bytes whose 200th byte falls inside a
358        // multi-byte char must not panic when building the classification prompt.
359        let system = "é".repeat(150); // 300 bytes; byte 200 is mid-character
360        let request = ChatRequest::new(system, vec![Message::user("日本語".repeat(300))]);
361        // The bug manifested as a panic; reaching this assertion means no panic.
362        let prompt = build_classification_prompt(&request);
363        assert!(prompt.contains("System context:"));
364        assert!(prompt.ends_with("..."));
365    }
366}