Skip to main content

agent_sdk/
llm.rs

1pub mod router;
2pub mod streaming;
3pub mod types;
4
5pub use router::{ModelRouter, ModelTier, TaskComplexity};
6pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
7pub use types::*;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13use crate::model_capabilities::{
14    ModelCapabilities, default_max_output_tokens, get_model_capabilities,
15};
16
17#[async_trait]
18pub trait LlmProvider: Send + Sync {
19    /// Non-streaming chat completion.
20    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
21
22    /// Streaming chat completion.
23    ///
24    /// Returns a stream of [`StreamDelta`] events. The default implementation
25    /// calls [`chat()`](Self::chat) and converts the result to a single-chunk stream.
26    ///
27    /// Providers should override this method to provide true streaming support.
28    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
29        Box::pin(async_stream::stream! {
30            match self.chat(request).await {
31                Ok(outcome) => match outcome {
32                    ChatOutcome::Success(response) => {
33                        // Emit content as deltas
34                        for (idx, block) in response.content.iter().enumerate() {
35                            match block {
36                                ContentBlock::Text { text } => {
37                                    yield Ok(StreamDelta::TextDelta {
38                                        delta: text.clone(),
39                                        block_index: idx,
40                                    });
41                                }
42                                ContentBlock::Thinking { thinking, .. } => {
43                                    yield Ok(StreamDelta::ThinkingDelta {
44                                        delta: thinking.clone(),
45                                        block_index: idx,
46                                    });
47                                }
48                                ContentBlock::RedactedThinking { .. }
49                                | ContentBlock::ToolResult { .. }
50                                | ContentBlock::Image { .. }
51                                | ContentBlock::Document { .. } => {
52                                    // Not streamed in the default implementation
53                                }
54                                ContentBlock::ToolUse { id, name, input, thought_signature } => {
55                                    yield Ok(StreamDelta::ToolUseStart {
56                                        id: id.clone(),
57                                        name: name.clone(),
58                                        block_index: idx,
59                                        thought_signature: thought_signature.clone(),
60                                    });
61                                    yield Ok(StreamDelta::ToolInputDelta {
62                                        id: id.clone(),
63                                        delta: serde_json::to_string(input).unwrap_or_default(),
64                                        block_index: idx,
65                                    });
66                                }
67                            }
68                        }
69                        yield Ok(StreamDelta::Usage(response.usage));
70                        yield Ok(StreamDelta::Done {
71                            stop_reason: response.stop_reason,
72                        });
73                    }
74                    ChatOutcome::RateLimited => {
75                        yield Ok(StreamDelta::Error {
76                            message: "Rate limited".to_string(),
77                            recoverable: true,
78                        });
79                    }
80                    ChatOutcome::InvalidRequest(msg) => {
81                        yield Ok(StreamDelta::Error {
82                            message: msg,
83                            recoverable: false,
84                        });
85                    }
86                    ChatOutcome::ServerError(msg) => {
87                        yield Ok(StreamDelta::Error {
88                            message: msg,
89                            recoverable: true,
90                        });
91                    }
92                },
93                Err(e) => yield Err(e),
94            }
95        })
96    }
97
98    fn model(&self) -> &str;
99    fn provider(&self) -> &'static str;
100
101    /// Provider-owned thinking configuration, if any.
102    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
103        None
104    }
105
106    /// Canonical capability metadata for this provider/model, if known.
107    fn capabilities(&self) -> Option<&'static ModelCapabilities> {
108        get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
109            "openai-responses" => get_model_capabilities("openai", self.model()),
110            "vertex" if self.model().starts_with("claude-") => {
111                get_model_capabilities("anthropic", self.model())
112            }
113            "vertex" => get_model_capabilities("gemini", self.model()),
114            _ => None,
115        })
116    }
117
118    /// Validate a thinking configuration against the provider/model capabilities.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error when the requested thinking mode is not supported by
123    /// the active provider/model capability set.
124    fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
125        let Some(thinking) = thinking else {
126            return Ok(());
127        };
128
129        if matches!(thinking.mode, ThinkingMode::Adaptive)
130            && !self
131                .capabilities()
132                .is_some_and(|caps| caps.supports_adaptive_thinking)
133        {
134            return Err(anyhow::anyhow!(
135                "adaptive thinking is not supported for provider={} model={}",
136                self.provider(),
137                self.model()
138            ));
139        }
140
141        Ok(())
142    }
143
144    /// Resolve the effective thinking configuration for a request.
145    ///
146    /// Request-level thinking overrides provider-owned defaults when present.
147    ///
148    /// # Errors
149    ///
150    /// Returns an error when the resolved thinking configuration is not
151    /// supported by the active provider/model capability set.
152    fn resolve_thinking_config(
153        &self,
154        request_thinking: Option<&ThinkingConfig>,
155    ) -> Result<Option<ThinkingConfig>> {
156        let thinking = request_thinking.or_else(|| self.configured_thinking());
157        self.validate_thinking_config(thinking)?;
158        Ok(thinking.cloned())
159    }
160
161    /// Default maximum output tokens for this provider/model when the caller
162    /// does not explicitly override `AgentConfig.max_tokens`.
163    fn default_max_tokens(&self) -> u32 {
164        self.capabilities()
165            .and_then(|caps| caps.max_output_tokens)
166            .or_else(|| default_max_output_tokens(self.provider(), self.model()))
167            .unwrap_or(4096)
168    }
169}
170
171/// Helper function to consume a stream and collect it into a `ChatResponse`.
172///
173/// This is useful for providers that want to test their streaming implementation
174/// or for cases where you need the full response after streaming.
175///
176/// # Errors
177///
178/// Returns an error if the stream yields an error result.
179pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
180    let mut accumulator = StreamAccumulator::new();
181    let mut last_error: Option<(String, bool)> = None;
182
183    while let Some(result) = stream.next().await {
184        match result {
185            Ok(delta) => {
186                if let StreamDelta::Error {
187                    message,
188                    recoverable,
189                } = &delta
190                {
191                    last_error = Some((message.clone(), *recoverable));
192                }
193                accumulator.apply(&delta);
194            }
195            Err(e) => return Err(e),
196        }
197    }
198
199    // If we encountered an error during streaming, return it
200    if let Some((message, recoverable)) = last_error {
201        if !recoverable {
202            return Ok(ChatOutcome::InvalidRequest(message));
203        }
204        // Check if it was a rate limit
205        if message.contains("Rate limited") || message.contains("rate limit") {
206            return Ok(ChatOutcome::RateLimited);
207        }
208        return Ok(ChatOutcome::ServerError(message));
209    }
210
211    // Extract usage and stop_reason before consuming the accumulator
212    let usage = accumulator.take_usage().unwrap_or(Usage {
213        input_tokens: 0,
214        output_tokens: 0,
215    });
216    let stop_reason = accumulator.take_stop_reason();
217    let content = accumulator.into_content_blocks();
218
219    // Log accumulated response for debugging
220    log::debug!(
221        "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
222        model,
223        stop_reason,
224        usage.input_tokens,
225        usage.output_tokens,
226        content.len()
227    );
228    for (i, block) in content.iter().enumerate() {
229        match block {
230            ContentBlock::Text { text } => {
231                log::debug!("  content_block[{}]: Text (len={})", i, text.len());
232            }
233            ContentBlock::Thinking { thinking, .. } => {
234                log::debug!("  content_block[{}]: Thinking (len={})", i, thinking.len());
235            }
236            ContentBlock::RedactedThinking { .. } => {
237                log::debug!("  content_block[{i}]: RedactedThinking");
238            }
239            ContentBlock::ToolUse {
240                id, name, input, ..
241            } => {
242                log::debug!("  content_block[{i}]: ToolUse id={id} name={name} input={input}");
243            }
244            ContentBlock::ToolResult {
245                tool_use_id,
246                content: result_content,
247                is_error,
248            } => {
249                log::debug!(
250                    "  content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
251                    i,
252                    tool_use_id,
253                    is_error,
254                    result_content.len()
255                );
256            }
257            ContentBlock::Image { source } => {
258                log::debug!(
259                    "  content_block[{i}]: Image media_type={}",
260                    source.media_type
261                );
262            }
263            ContentBlock::Document { source } => {
264                log::debug!(
265                    "  content_block[{i}]: Document media_type={}",
266                    source.media_type
267                );
268            }
269        }
270    }
271
272    Ok(ChatOutcome::Success(ChatResponse {
273        id: String::new(),
274        content,
275        model,
276        stop_reason,
277        usage,
278    }))
279}