Skip to main content

agent_sdk/
llm.rs

1pub mod attachments;
2pub mod router;
3pub mod streaming;
4pub mod types;
5
6pub use router::{ModelRouter, ModelTier, TaskComplexity};
7pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
8pub use types::*;
9
10use anyhow::Result;
11use async_trait::async_trait;
12use futures::StreamExt;
13
14use crate::model_capabilities::{
15    ModelCapabilities, default_max_output_tokens, get_model_capabilities,
16};
17
18#[async_trait]
19pub trait LlmProvider: Send + Sync {
20    /// Non-streaming chat completion.
21    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
22
23    /// Streaming chat completion.
24    ///
25    /// Returns a stream of [`StreamDelta`] events. The default implementation
26    /// calls [`chat()`](Self::chat) and converts the result to a single-chunk stream.
27    ///
28    /// Providers should override this method to provide true streaming support.
29    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
30        Box::pin(async_stream::stream! {
31            match self.chat(request).await {
32                Ok(outcome) => match outcome {
33                    ChatOutcome::Success(response) => {
34                        // Emit content as deltas
35                        for (idx, block) in response.content.iter().enumerate() {
36                            match block {
37                                ContentBlock::Text { text } => {
38                                    yield Ok(StreamDelta::TextDelta {
39                                        delta: text.clone(),
40                                        block_index: idx,
41                                    });
42                                }
43                                ContentBlock::Thinking { thinking, .. } => {
44                                    yield Ok(StreamDelta::ThinkingDelta {
45                                        delta: thinking.clone(),
46                                        block_index: idx,
47                                    });
48                                }
49                                ContentBlock::RedactedThinking { .. }
50                                | ContentBlock::ToolResult { .. }
51                                | ContentBlock::Image { .. }
52                                | ContentBlock::Document { .. } => {
53                                    // Not streamed in the default implementation
54                                }
55                                ContentBlock::ToolUse { id, name, input, thought_signature } => {
56                                    yield Ok(StreamDelta::ToolUseStart {
57                                        id: id.clone(),
58                                        name: name.clone(),
59                                        block_index: idx,
60                                        thought_signature: thought_signature.clone(),
61                                    });
62                                    yield Ok(StreamDelta::ToolInputDelta {
63                                        id: id.clone(),
64                                        delta: serde_json::to_string(input).unwrap_or_default(),
65                                        block_index: idx,
66                                    });
67                                }
68                            }
69                        }
70                        yield Ok(StreamDelta::Usage(response.usage));
71                        yield Ok(StreamDelta::Done {
72                            stop_reason: response.stop_reason,
73                        });
74                    }
75                    ChatOutcome::RateLimited => {
76                        yield Ok(StreamDelta::Error {
77                            message: "Rate limited".to_string(),
78                            recoverable: true,
79                        });
80                    }
81                    ChatOutcome::InvalidRequest(msg) => {
82                        yield Ok(StreamDelta::Error {
83                            message: msg,
84                            recoverable: false,
85                        });
86                    }
87                    ChatOutcome::ServerError(msg) => {
88                        yield Ok(StreamDelta::Error {
89                            message: msg,
90                            recoverable: true,
91                        });
92                    }
93                },
94                Err(e) => yield Err(e),
95            }
96        })
97    }
98
99    fn model(&self) -> &str;
100    fn provider(&self) -> &'static str;
101
102    /// Provider-owned thinking configuration, if any.
103    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
104        None
105    }
106
107    /// Canonical capability metadata for this provider/model, if known.
108    fn capabilities(&self) -> Option<&'static ModelCapabilities> {
109        get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
110            "openai-responses" => get_model_capabilities("openai", self.model()),
111            "vertex" if self.model().starts_with("claude-") => {
112                get_model_capabilities("anthropic", self.model())
113            }
114            "vertex" => get_model_capabilities("gemini", self.model()),
115            _ => None,
116        })
117    }
118
119    /// Validate a thinking configuration against the provider/model capabilities.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error when the requested thinking mode is not supported by
124    /// the active provider/model capability set.
125    fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
126        let Some(thinking) = thinking else {
127            return Ok(());
128        };
129
130        if matches!(thinking.mode, ThinkingMode::Adaptive)
131            && !self
132                .capabilities()
133                .is_some_and(|caps| caps.supports_adaptive_thinking)
134        {
135            return Err(anyhow::anyhow!(
136                "adaptive thinking is not supported for provider={} model={}",
137                self.provider(),
138                self.model()
139            ));
140        }
141
142        Ok(())
143    }
144
145    /// Resolve the effective thinking configuration for a request.
146    ///
147    /// Request-level thinking overrides provider-owned defaults when present.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error when the resolved thinking configuration is not
152    /// supported by the active provider/model capability set.
153    fn resolve_thinking_config(
154        &self,
155        request_thinking: Option<&ThinkingConfig>,
156    ) -> Result<Option<ThinkingConfig>> {
157        let thinking = request_thinking.or_else(|| self.configured_thinking());
158        self.validate_thinking_config(thinking)?;
159        Ok(thinking.cloned())
160    }
161
162    /// Default maximum output tokens for this provider/model when the caller
163    /// does not explicitly override `AgentConfig.max_tokens`.
164    fn default_max_tokens(&self) -> u32 {
165        self.capabilities()
166            .and_then(|caps| caps.max_output_tokens)
167            .or_else(|| default_max_output_tokens(self.provider(), self.model()))
168            .unwrap_or(4096)
169    }
170}
171
172/// Helper function to consume a stream and collect it into a `ChatResponse`.
173///
174/// This is useful for providers that want to test their streaming implementation
175/// or for cases where you need the full response after streaming.
176///
177/// # Errors
178///
179/// Returns an error if the stream yields an error result.
180pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
181    let mut accumulator = StreamAccumulator::new();
182    let mut last_error: Option<(String, bool)> = None;
183
184    while let Some(result) = stream.next().await {
185        match result {
186            Ok(delta) => {
187                if let StreamDelta::Error {
188                    message,
189                    recoverable,
190                } = &delta
191                {
192                    last_error = Some((message.clone(), *recoverable));
193                }
194                accumulator.apply(&delta);
195            }
196            Err(e) => return Err(e),
197        }
198    }
199
200    // If we encountered an error during streaming, return it
201    if let Some((message, recoverable)) = last_error {
202        if !recoverable {
203            return Ok(ChatOutcome::InvalidRequest(message));
204        }
205        // Check if it was a rate limit
206        if message.contains("Rate limited") || message.contains("rate limit") {
207            return Ok(ChatOutcome::RateLimited);
208        }
209        return Ok(ChatOutcome::ServerError(message));
210    }
211
212    // Extract usage and stop_reason before consuming the accumulator
213    let usage = accumulator.take_usage().unwrap_or(Usage {
214        input_tokens: 0,
215        output_tokens: 0,
216    });
217    let stop_reason = accumulator.take_stop_reason();
218    let content = accumulator.into_content_blocks();
219
220    // Log accumulated response for debugging
221    log::debug!(
222        "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
223        model,
224        stop_reason,
225        usage.input_tokens,
226        usage.output_tokens,
227        content.len()
228    );
229    for (i, block) in content.iter().enumerate() {
230        match block {
231            ContentBlock::Text { text } => {
232                log::debug!("  content_block[{}]: Text (len={})", i, text.len());
233            }
234            ContentBlock::Thinking { thinking, .. } => {
235                log::debug!("  content_block[{}]: Thinking (len={})", i, thinking.len());
236            }
237            ContentBlock::RedactedThinking { .. } => {
238                log::debug!("  content_block[{i}]: RedactedThinking");
239            }
240            ContentBlock::ToolUse {
241                id, name, input, ..
242            } => {
243                log::debug!("  content_block[{i}]: ToolUse id={id} name={name} input={input}");
244            }
245            ContentBlock::ToolResult {
246                tool_use_id,
247                content: result_content,
248                is_error,
249            } => {
250                log::debug!(
251                    "  content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
252                    i,
253                    tool_use_id,
254                    is_error,
255                    result_content.len()
256                );
257            }
258            ContentBlock::Image { source } => {
259                log::debug!(
260                    "  content_block[{i}]: Image media_type={}",
261                    source.media_type
262                );
263            }
264            ContentBlock::Document { source } => {
265                log::debug!(
266                    "  content_block[{i}]: Document media_type={}",
267                    source.media_type
268                );
269            }
270        }
271    }
272
273    Ok(ChatOutcome::Success(ChatResponse {
274        id: String::new(),
275        content,
276        model,
277        stop_reason,
278        usage,
279    }))
280}