Skip to main content

agent_sdk_providers/
provider.rs

1//! LLM provider trait and streaming helpers.
2//!
3//! This module defines the [`LlmProvider`] trait that all LLM backends implement,
4//! as well as the [`collect_stream`] helper for consuming a streaming response.
5
6use agent_sdk_foundation::llm::{
7    ChatOutcome, ChatRequest, ChatResponse, ContentBlock, ThinkingConfig, ThinkingMode, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13use crate::model_capabilities::{
14    ModelCapabilities, default_max_output_tokens, get_model_capabilities,
15};
16use crate::streaming::{StreamAccumulator, StreamBox, StreamDelta, StreamErrorKind};
17
18/// How a provider satisfies a [`ResponseFormat`](agent_sdk_foundation::llm::ResponseFormat)
19/// structured-output request.
20///
21/// The structured-output runner consults this to decide how to shape the
22/// request and where to read the final structured value from the response.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum StructuredOutputSupport {
25    /// The provider applies the schema natively (JSON-mode /
26    /// structured-outputs) when it sees `request.response_format`. The final
27    /// structured value is the JSON in the assistant's text output.
28    Native,
29    /// The provider has no native JSON-schema mode. The runner injects a
30    /// single forced "respond" tool whose `input_schema` is the output schema,
31    /// and reads the structured value from that tool call's input.
32    ToolForcing,
33}
34
35#[async_trait]
36pub trait LlmProvider: Send + Sync {
37    /// Non-streaming chat completion.
38    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
39
40    /// Streaming chat completion.
41    ///
42    /// Returns a stream of [`StreamDelta`] events. The default implementation
43    /// calls [`chat()`](Self::chat) and converts the result to a single-chunk stream.
44    ///
45    /// Providers should override this method to provide true streaming support.
46    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
47        Box::pin(async_stream::stream! {
48            match self.chat(request).await {
49                Ok(outcome) => match outcome {
50                    ChatOutcome::Success(response) => {
51                        // Emit content as deltas
52                        for (idx, block) in response.content.iter().enumerate() {
53                            match block {
54                                ContentBlock::Text { text } => {
55                                    yield Ok(StreamDelta::TextDelta {
56                                        delta: text.clone(),
57                                        block_index: idx,
58                                    });
59                                }
60                                ContentBlock::Thinking { thinking, .. } => {
61                                    yield Ok(StreamDelta::ThinkingDelta {
62                                        delta: thinking.clone(),
63                                        block_index: idx,
64                                    });
65                                }
66                                ContentBlock::RedactedThinking { .. }
67                                | ContentBlock::ToolResult { .. }
68                                | ContentBlock::Image { .. }
69                                | ContentBlock::Document { .. } => {
70                                    // Not streamed in the default implementation
71                                }
72                                ContentBlock::ToolUse { id, name, input, thought_signature } => {
73                                    yield Ok(StreamDelta::ToolUseStart {
74                                        id: id.clone(),
75                                        name: name.clone(),
76                                        block_index: idx,
77                                        thought_signature: thought_signature.clone(),
78                                    });
79                                    yield Ok(StreamDelta::ToolInputDelta {
80                                        id: id.clone(),
81                                        delta: serde_json::to_string(input).unwrap_or_default(),
82                                        block_index: idx,
83                                    });
84                                }
85                                // `ContentBlock` is `#[non_exhaustive]`; a future
86                                // block kind we cannot stream is skipped rather than
87                                // panicking the default fallback.
88                                _ => {
89                                    log::warn!(
90                                        "chat_stream fallback skipping unrecognized content block at index {idx}"
91                                    );
92                                }
93                            }
94                        }
95                        yield Ok(StreamDelta::Usage(response.usage));
96                        yield Ok(StreamDelta::Done {
97                            stop_reason: response.stop_reason,
98                        });
99                    }
100                    ChatOutcome::RateLimited => {
101                        yield Ok(StreamDelta::Error {
102                            message: "Rate limited".to_string(),
103                            kind: StreamErrorKind::RateLimited,
104                        });
105                    }
106                    ChatOutcome::InvalidRequest(msg) => {
107                        yield Ok(StreamDelta::Error {
108                            message: msg,
109                            kind: StreamErrorKind::InvalidRequest,
110                        });
111                    }
112                    ChatOutcome::ServerError(msg) => {
113                        yield Ok(StreamDelta::Error {
114                            message: msg,
115                            kind: StreamErrorKind::ServerError,
116                        });
117                    }
118                    // `ChatOutcome` is `#[non_exhaustive]`; an outcome this SDK
119                    // version does not model is surfaced as an unclassified
120                    // (non-recoverable) stream error rather than dropped.
121                    _ => {
122                        yield Ok(StreamDelta::Error {
123                            message: "Unrecognized chat outcome".to_string(),
124                            kind: StreamErrorKind::Unknown,
125                        });
126                    }
127                },
128                Err(e) => yield Err(e),
129            }
130        })
131    }
132
133    fn model(&self) -> &str;
134    fn provider(&self) -> &'static str;
135
136    /// Provider-owned thinking configuration, if any.
137    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
138        None
139    }
140
141    /// Canonical capability metadata for this provider/model, if known.
142    fn capabilities(&self) -> Option<&'static ModelCapabilities> {
143        get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
144            "openai-responses" | "openai-codex" => get_model_capabilities("openai", self.model()),
145            "vertex" if self.model().starts_with("claude-") => {
146                get_model_capabilities("anthropic", self.model())
147            }
148            "vertex" => get_model_capabilities("gemini", self.model()),
149            _ => None,
150        })
151    }
152
153    /// Validate a thinking configuration against the provider/model capabilities.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error when the requested thinking mode is not supported by
158    /// the active provider/model capability set.
159    fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
160        let Some(thinking) = thinking else {
161            return Ok(());
162        };
163
164        if self
165            .capabilities()
166            .is_some_and(|caps| !caps.supports_thinking)
167        {
168            return Err(anyhow::anyhow!(
169                "thinking is not supported for provider={} model={}",
170                self.provider(),
171                self.model()
172            ));
173        }
174
175        if matches!(thinking.mode, ThinkingMode::Adaptive)
176            && !self
177                .capabilities()
178                .is_some_and(|caps| caps.supports_adaptive_thinking)
179        {
180            return Err(anyhow::anyhow!(
181                "adaptive thinking is not supported for provider={} model={}",
182                self.provider(),
183                self.model()
184            ));
185        }
186
187        Ok(())
188    }
189
190    /// Resolve the effective thinking configuration for a request.
191    ///
192    /// Request-level thinking overrides provider-owned defaults when present.
193    ///
194    /// # Errors
195    ///
196    /// Returns an error when the resolved thinking configuration is not
197    /// supported by the active provider/model capability set.
198    fn resolve_thinking_config(
199        &self,
200        request_thinking: Option<&ThinkingConfig>,
201    ) -> Result<Option<ThinkingConfig>> {
202        let thinking = request_thinking.or_else(|| self.configured_thinking());
203        self.validate_thinking_config(thinking)?;
204        Ok(thinking.cloned())
205    }
206
207    /// Default maximum output tokens for this provider/model when the caller
208    /// does not explicitly override `AgentConfig.max_tokens`.
209    fn default_max_tokens(&self) -> u32 {
210        self.capabilities()
211            .and_then(|caps| caps.max_output_tokens)
212            .or_else(|| default_max_output_tokens(self.provider(), self.model()))
213            .unwrap_or(4096)
214    }
215
216    /// How this provider satisfies a structured-output
217    /// ([`ResponseFormat`](agent_sdk_foundation::llm::ResponseFormat)) request.
218    ///
219    /// Providers with a native JSON-schema / JSON-mode wire field
220    /// (OpenAI-family, Gemini, Vertex) report
221    /// [`StructuredOutputSupport::Native`] and consume
222    /// `request.response_format` directly. Providers without one (Anthropic)
223    /// report [`StructuredOutputSupport::ToolForcing`] so the runner forces a
224    /// single "respond" tool whose schema is the output schema. The default
225    /// is the conservative [`StructuredOutputSupport::ToolForcing`], which
226    /// works for any tool-capable provider.
227    fn structured_output_support(&self) -> StructuredOutputSupport {
228        match self.provider() {
229            "openai" | "openai-responses" | "openai-codex" | "gemini" => {
230                StructuredOutputSupport::Native
231            }
232            // Vertex multiplexes Anthropic and Gemini models. Only the Gemini
233            // side has a native structured-output field; Claude-on-Vertex uses
234            // the Messages API shape, which has no `response_format`.
235            "vertex" if !self.model().starts_with("claude-") => StructuredOutputSupport::Native,
236            _ => StructuredOutputSupport::ToolForcing,
237        }
238    }
239}
240
241/// Helper function to consume a stream and collect it into a `ChatResponse`.
242///
243/// This is useful for providers that want to test their streaming implementation
244/// or for cases where you need the full response after streaming.
245///
246/// # Errors
247///
248/// Returns an error if the stream yields an error result.
249pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
250    let mut accumulator = StreamAccumulator::new();
251    let mut last_error: Option<(String, StreamErrorKind)> = None;
252
253    while let Some(result) = stream.next().await {
254        match result {
255            Ok(delta) => {
256                if let StreamDelta::Error { message, kind } = &delta {
257                    last_error = Some((message.clone(), *kind));
258                }
259                accumulator.apply(&delta);
260            }
261            Err(e) => return Err(e),
262        }
263    }
264
265    // If we encountered an error during streaming, map kind directly
266    // to the corresponding `ChatOutcome` variant.  No string-matching
267    // heuristic is needed because the kind already records the
268    // category at the construction site.
269    if let Some((message, kind)) = last_error {
270        return Ok(match kind {
271            StreamErrorKind::RateLimited => ChatOutcome::RateLimited,
272            StreamErrorKind::InvalidRequest => ChatOutcome::InvalidRequest(message),
273            // `StreamErrorKind::ServerError`, plus the `#[non_exhaustive]`
274            // catch-all (`Unknown` / future kinds): an unclassified error is
275            // treated as a (non-recoverable) server error so the caller still
276            // surfaces the failure rather than silently succeeding.
277            _ => ChatOutcome::ServerError(message),
278        });
279    }
280
281    // Extract usage and stop_reason before consuming the accumulator
282    let usage = accumulator.take_usage().unwrap_or(Usage {
283        input_tokens: 0,
284        output_tokens: 0,
285        cached_input_tokens: 0,
286        cache_creation_input_tokens: 0,
287    });
288    let stop_reason = accumulator.take_stop_reason();
289    let content = accumulator.into_content_blocks();
290
291    // Log accumulated response for debugging
292    log::debug!(
293        "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
294        model,
295        stop_reason,
296        usage.input_tokens,
297        usage.output_tokens,
298        content.len()
299    );
300    for (i, block) in content.iter().enumerate() {
301        match block {
302            ContentBlock::Text { text } => {
303                log::debug!("  content_block[{}]: Text (len={})", i, text.len());
304            }
305            ContentBlock::Thinking { thinking, .. } => {
306                log::debug!("  content_block[{}]: Thinking (len={})", i, thinking.len());
307            }
308            ContentBlock::RedactedThinking { .. } => {
309                log::debug!("  content_block[{i}]: RedactedThinking");
310            }
311            ContentBlock::ToolUse {
312                id, name, input, ..
313            } => {
314                log::debug!("  content_block[{i}]: ToolUse id={id} name={name} input={input}");
315            }
316            ContentBlock::ToolResult {
317                tool_use_id,
318                content: result_content,
319                is_error,
320            } => {
321                log::debug!(
322                    "  content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
323                    i,
324                    tool_use_id,
325                    is_error,
326                    result_content.len()
327                );
328            }
329            ContentBlock::Image { source } => {
330                log::debug!(
331                    "  content_block[{i}]: Image media_type={}",
332                    source.media_type
333                );
334            }
335            ContentBlock::Document { source } => {
336                log::debug!(
337                    "  content_block[{i}]: Document media_type={}",
338                    source.media_type
339                );
340            }
341            // `ContentBlock` is `#[non_exhaustive]`; log unknown future block
342            // kinds generically so the debug dump stays exhaustive.
343            _ => {
344                log::debug!("  content_block[{i}]: <unrecognized block kind>");
345            }
346        }
347    }
348
349    Ok(ChatOutcome::Success(ChatResponse {
350        id: String::new(),
351        content,
352        model,
353        stop_reason,
354        usage,
355    }))
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use anyhow::Result;
362    use async_trait::async_trait;
363
364    struct Stub {
365        provider: &'static str,
366        model: &'static str,
367    }
368
369    #[async_trait]
370    impl LlmProvider for Stub {
371        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
372            Ok(ChatOutcome::ServerError("unused".to_owned()))
373        }
374
375        fn model(&self) -> &str {
376            self.model
377        }
378
379        fn provider(&self) -> &'static str {
380            self.provider
381        }
382    }
383
384    fn support_for(provider: &'static str, model: &'static str) -> StructuredOutputSupport {
385        Stub { provider, model }.structured_output_support()
386    }
387
388    #[test]
389    fn native_providers_report_native_support() {
390        for provider in ["openai", "openai-responses", "openai-codex", "gemini"] {
391            assert_eq!(
392                support_for(provider, "any-model"),
393                StructuredOutputSupport::Native,
394                "{provider} should be native"
395            );
396        }
397    }
398
399    #[test]
400    fn anthropic_reports_tool_forcing() {
401        assert_eq!(
402            support_for("anthropic", "claude-sonnet-4-5"),
403            StructuredOutputSupport::ToolForcing
404        );
405    }
406
407    #[test]
408    fn vertex_is_native_for_gemini_models_and_tool_forcing_for_claude() {
409        assert_eq!(
410            support_for("vertex", "gemini-3-flash-preview"),
411            StructuredOutputSupport::Native
412        );
413        assert_eq!(
414            support_for("vertex", "claude-sonnet-4-5"),
415            StructuredOutputSupport::ToolForcing
416        );
417    }
418
419    #[test]
420    fn unknown_provider_defaults_to_tool_forcing() {
421        assert_eq!(
422            support_for("some-new-provider", "x"),
423            StructuredOutputSupport::ToolForcing
424        );
425    }
426}