Skip to main content

agent_sdk_providers/
provider.rs

1//! LLM provider trait and streaming helpers.
2//!
3//! This module defines the [`LlmProvider`] trait that all LLM backends implement,
4//! as well as the [`collect_stream`] helper for consuming a streaming response.
5
6use agent_sdk_foundation::llm::{
7    ChatOutcome, ChatRequest, ChatResponse, ContentBlock, ThinkingConfig, ThinkingMode, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14use crate::model_capabilities::{
15    ModelCapabilities, default_max_output_tokens, get_model_capabilities,
16};
17use crate::streaming::{StreamAccumulator, StreamBox, StreamDelta, StreamErrorKind};
18
19/// How a provider satisfies a [`ResponseFormat`](agent_sdk_foundation::llm::ResponseFormat)
20/// structured-output request.
21///
22/// The structured-output runner consults this to decide how to shape the
23/// request and where to read the final structured value from the response.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum StructuredOutputSupport {
26    /// The provider applies the schema natively (JSON-mode /
27    /// structured-outputs) when it sees `request.response_format`. The final
28    /// structured value is the JSON in the assistant's text output.
29    Native,
30    /// The provider has no native JSON-schema mode. The runner injects a
31    /// single forced "respond" tool whose `input_schema` is the output schema,
32    /// and reads the structured value from that tool call's input.
33    ToolForcing,
34}
35
36/// A single model entry returned by a provider's live model-listing endpoint.
37///
38/// This is the *dynamic* counterpart to the static
39/// [`ModelCapabilities`] table: it
40/// is populated from the provider's own `/models` API at runtime, so newly
41/// shipped models appear without an SDK code change. Fields beyond `id` are
42/// optional because not every provider's listing endpoint reports them.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct ModelInfo {
45    /// The model identifier as the provider's chat endpoint expects it.
46    pub id: String,
47    /// Human-friendly display name, when the listing endpoint provides one.
48    pub display_name: Option<String>,
49    /// Maximum total context window in tokens, when reported.
50    pub context_window: Option<u32>,
51    /// Maximum output tokens per response, when reported.
52    pub max_output_tokens: Option<u32>,
53}
54
55#[async_trait]
56pub trait LlmProvider: Send + Sync {
57    /// Non-streaming chat completion.
58    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
59
60    /// List the models the provider currently exposes, queried live from the
61    /// provider's own model-listing endpoint.
62    ///
63    /// The default implementation returns an error: model discovery is an
64    /// additive capability, so a provider that has not implemented it stays
65    /// source-compatible while reporting that the operation is unsupported.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error when the provider does not support live model listing,
70    /// or when the underlying HTTP request or response parsing fails.
71    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
72        Err(anyhow::anyhow!(
73            "list_models is not supported for provider {}",
74            self.provider()
75        ))
76    }
77
78    /// Streaming chat completion.
79    ///
80    /// Returns a stream of [`StreamDelta`] events. The default implementation
81    /// calls [`chat()`](Self::chat) and converts the result to a single-chunk stream.
82    ///
83    /// Providers should override this method to provide true streaming support.
84    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
85        Box::pin(async_stream::stream! {
86            match self.chat(request).await {
87                Ok(outcome) => match outcome {
88                    ChatOutcome::Success(response) => {
89                        // Emit content as deltas
90                        for (idx, block) in response.content.iter().enumerate() {
91                            match block {
92                                ContentBlock::Text { text } => {
93                                    yield Ok(StreamDelta::TextDelta {
94                                        delta: text.clone(),
95                                        block_index: idx,
96                                    });
97                                }
98                                ContentBlock::Thinking { thinking, .. } => {
99                                    yield Ok(StreamDelta::ThinkingDelta {
100                                        delta: thinking.clone(),
101                                        block_index: idx,
102                                    });
103                                }
104                                ContentBlock::RedactedThinking { .. }
105                                | ContentBlock::ToolResult { .. }
106                                | ContentBlock::Image { .. }
107                                | ContentBlock::Document { .. } => {
108                                    // Not streamed in the default implementation
109                                }
110                                ContentBlock::ToolUse { id, name, input, thought_signature } => {
111                                    yield Ok(StreamDelta::ToolUseStart {
112                                        id: id.clone(),
113                                        name: name.clone(),
114                                        block_index: idx,
115                                        thought_signature: thought_signature.clone(),
116                                    });
117                                    yield Ok(StreamDelta::ToolInputDelta {
118                                        id: id.clone(),
119                                        delta: serde_json::to_string(input).unwrap_or_default(),
120                                        block_index: idx,
121                                    });
122                                }
123                                // `ContentBlock` is `#[non_exhaustive]`; a future
124                                // block kind we cannot stream is skipped rather than
125                                // panicking the default fallback.
126                                _ => {
127                                    log::warn!(
128                                        "chat_stream fallback skipping unrecognized content block at index {idx}"
129                                    );
130                                }
131                            }
132                        }
133                        yield Ok(StreamDelta::Usage(response.usage));
134                        yield Ok(StreamDelta::Done {
135                            stop_reason: response.stop_reason,
136                        });
137                    }
138                    ChatOutcome::RateLimited(_) => {
139                        yield Ok(StreamDelta::Error {
140                            message: "Rate limited".to_string(),
141                            kind: StreamErrorKind::RateLimited,
142                        });
143                    }
144                    ChatOutcome::InvalidRequest(msg) => {
145                        yield Ok(StreamDelta::Error {
146                            message: msg,
147                            kind: StreamErrorKind::InvalidRequest,
148                        });
149                    }
150                    ChatOutcome::ServerError(msg) => {
151                        yield Ok(StreamDelta::Error {
152                            message: msg,
153                            kind: StreamErrorKind::ServerError,
154                        });
155                    }
156                    // `ChatOutcome` is `#[non_exhaustive]`; an outcome this SDK
157                    // version does not model is surfaced as an unclassified
158                    // (non-recoverable) stream error rather than dropped.
159                    _ => {
160                        yield Ok(StreamDelta::Error {
161                            message: "Unrecognized chat outcome".to_string(),
162                            kind: StreamErrorKind::Unknown,
163                        });
164                    }
165                },
166                Err(e) => yield Err(e),
167            }
168        })
169    }
170
171    fn model(&self) -> &str;
172    fn provider(&self) -> &'static str;
173
174    /// Provider-owned thinking configuration, if any.
175    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
176        None
177    }
178
179    /// Canonical capability metadata for this provider/model, if known.
180    fn capabilities(&self) -> Option<&'static ModelCapabilities> {
181        get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
182            "openai-responses" | "openai-codex" => get_model_capabilities("openai", self.model()),
183            "vertex" if self.model().starts_with("claude-") => {
184                get_model_capabilities("anthropic", self.model())
185            }
186            "vertex" => get_model_capabilities("gemini", self.model()),
187            _ => None,
188        })
189    }
190
191    /// Validate a thinking configuration against the provider/model capabilities.
192    ///
193    /// # Errors
194    ///
195    /// Returns an error when the requested thinking mode is not supported by
196    /// the active provider/model capability set.
197    fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
198        let Some(thinking) = thinking else {
199            return Ok(());
200        };
201
202        if self
203            .capabilities()
204            .is_some_and(|caps| !caps.supports_thinking)
205        {
206            return Err(anyhow::anyhow!(
207                "thinking is not supported for provider={} model={}",
208                self.provider(),
209                self.model()
210            ));
211        }
212
213        if matches!(thinking.mode, ThinkingMode::Adaptive)
214            && !self
215                .capabilities()
216                .is_some_and(|caps| caps.supports_adaptive_thinking)
217        {
218            return Err(anyhow::anyhow!(
219                "adaptive thinking is not supported for provider={} model={}",
220                self.provider(),
221                self.model()
222            ));
223        }
224
225        Ok(())
226    }
227
228    /// Resolve the effective thinking configuration for a request.
229    ///
230    /// Request-level thinking overrides provider-owned defaults when present.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error when the resolved thinking configuration is not
235    /// supported by the active provider/model capability set.
236    fn resolve_thinking_config(
237        &self,
238        request_thinking: Option<&ThinkingConfig>,
239    ) -> Result<Option<ThinkingConfig>> {
240        let thinking = request_thinking.or_else(|| self.configured_thinking());
241        self.validate_thinking_config(thinking)?;
242        Ok(thinking.cloned())
243    }
244
245    /// Default maximum output tokens for this provider/model when the caller
246    /// does not explicitly override `AgentConfig.max_tokens`.
247    fn default_max_tokens(&self) -> u32 {
248        self.capabilities()
249            .and_then(|caps| caps.max_output_tokens)
250            .or_else(|| default_max_output_tokens(self.provider(), self.model()))
251            .unwrap_or(4096)
252    }
253
254    /// How this provider satisfies a structured-output
255    /// ([`ResponseFormat`](agent_sdk_foundation::llm::ResponseFormat)) request.
256    ///
257    /// Providers with a native JSON-schema / JSON-mode wire field
258    /// (OpenAI-family, Gemini, Vertex) report
259    /// [`StructuredOutputSupport::Native`] and consume
260    /// `request.response_format` directly. Providers without one (Anthropic)
261    /// report [`StructuredOutputSupport::ToolForcing`] so the runner forces a
262    /// single "respond" tool whose schema is the output schema. The default
263    /// is the conservative [`StructuredOutputSupport::ToolForcing`], which
264    /// works for any tool-capable provider.
265    fn structured_output_support(&self) -> StructuredOutputSupport {
266        match self.provider() {
267            "openai" | "openai-responses" | "openai-codex" | "gemini" => {
268                StructuredOutputSupport::Native
269            }
270            // Vertex multiplexes Anthropic and Gemini models. Only the Gemini
271            // side has a native structured-output field; Claude-on-Vertex uses
272            // the Messages API shape, which has no `response_format`.
273            "vertex" if !self.model().starts_with("claude-") => StructuredOutputSupport::Native,
274            _ => StructuredOutputSupport::ToolForcing,
275        }
276    }
277}
278
279/// Helper function to consume a stream and collect it into a `ChatResponse`.
280///
281/// This is useful for providers that want to test their streaming implementation
282/// or for cases where you need the full response after streaming.
283///
284/// # Errors
285///
286/// Returns an error if the stream yields an error result.
287pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
288    let mut accumulator = StreamAccumulator::new();
289    let mut last_error: Option<(String, StreamErrorKind)> = None;
290
291    while let Some(result) = stream.next().await {
292        match result {
293            Ok(delta) => {
294                if let StreamDelta::Error { message, kind } = &delta {
295                    last_error = Some((message.clone(), *kind));
296                }
297                accumulator.apply(&delta);
298            }
299            Err(e) => return Err(e),
300        }
301    }
302
303    // If we encountered an error during streaming, map kind directly
304    // to the corresponding `ChatOutcome` variant.  No string-matching
305    // heuristic is needed because the kind already records the
306    // category at the construction site.
307    if let Some((message, kind)) = last_error {
308        return Ok(match kind {
309            // The streaming error channel does not carry a `Retry-After`, so
310            // the reconstructed outcome reports no server-supplied delay.
311            StreamErrorKind::RateLimited => ChatOutcome::RateLimited(None),
312            StreamErrorKind::InvalidRequest => ChatOutcome::InvalidRequest(message),
313            // `StreamErrorKind::ServerError`, plus the `#[non_exhaustive]`
314            // catch-all (`Unknown` / future kinds): an unclassified error is
315            // treated as a (non-recoverable) server error so the caller still
316            // surfaces the failure rather than silently succeeding.
317            _ => ChatOutcome::ServerError(message),
318        });
319    }
320
321    // Extract usage and stop_reason before consuming the accumulator
322    let usage = accumulator.take_usage().unwrap_or(Usage {
323        input_tokens: 0,
324        output_tokens: 0,
325        cached_input_tokens: 0,
326        cache_creation_input_tokens: 0,
327    });
328    let stop_reason = accumulator.take_stop_reason();
329    let content = accumulator.into_content_blocks();
330
331    // Log accumulated response for debugging
332    log::debug!(
333        "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
334        model,
335        stop_reason,
336        usage.input_tokens,
337        usage.output_tokens,
338        content.len()
339    );
340    for (i, block) in content.iter().enumerate() {
341        match block {
342            ContentBlock::Text { text } => {
343                log::debug!("  content_block[{}]: Text (len={})", i, text.len());
344            }
345            ContentBlock::Thinking { thinking, .. } => {
346                log::debug!("  content_block[{}]: Thinking (len={})", i, thinking.len());
347            }
348            ContentBlock::RedactedThinking { .. } => {
349                log::debug!("  content_block[{i}]: RedactedThinking");
350            }
351            ContentBlock::ToolUse {
352                id, name, input, ..
353            } => {
354                log::debug!("  content_block[{i}]: ToolUse id={id} name={name} input={input}");
355            }
356            ContentBlock::ToolResult {
357                tool_use_id,
358                content: result_content,
359                is_error,
360            } => {
361                log::debug!(
362                    "  content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
363                    i,
364                    tool_use_id,
365                    is_error,
366                    result_content.len()
367                );
368            }
369            ContentBlock::Image { source } => {
370                log::debug!(
371                    "  content_block[{i}]: Image media_type={}",
372                    source.media_type
373                );
374            }
375            ContentBlock::Document { source } => {
376                log::debug!(
377                    "  content_block[{i}]: Document media_type={}",
378                    source.media_type
379                );
380            }
381            // `ContentBlock` is `#[non_exhaustive]`; log unknown future block
382            // kinds generically so the debug dump stays exhaustive.
383            _ => {
384                log::debug!("  content_block[{i}]: <unrecognized block kind>");
385            }
386        }
387    }
388
389    Ok(ChatOutcome::Success(ChatResponse {
390        id: String::new(),
391        content,
392        model,
393        stop_reason,
394        usage,
395    }))
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use anyhow::Result;
402    use async_trait::async_trait;
403
404    struct Stub {
405        provider: &'static str,
406        model: &'static str,
407    }
408
409    #[async_trait]
410    impl LlmProvider for Stub {
411        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
412            Ok(ChatOutcome::ServerError("unused".to_owned()))
413        }
414
415        fn model(&self) -> &str {
416            self.model
417        }
418
419        fn provider(&self) -> &'static str {
420            self.provider
421        }
422    }
423
424    fn support_for(provider: &'static str, model: &'static str) -> StructuredOutputSupport {
425        Stub { provider, model }.structured_output_support()
426    }
427
428    #[test]
429    fn native_providers_report_native_support() {
430        for provider in ["openai", "openai-responses", "openai-codex", "gemini"] {
431            assert_eq!(
432                support_for(provider, "any-model"),
433                StructuredOutputSupport::Native,
434                "{provider} should be native"
435            );
436        }
437    }
438
439    #[test]
440    fn anthropic_reports_tool_forcing() {
441        assert_eq!(
442            support_for("anthropic", "claude-sonnet-4-5"),
443            StructuredOutputSupport::ToolForcing
444        );
445    }
446
447    #[test]
448    fn vertex_is_native_for_gemini_models_and_tool_forcing_for_claude() {
449        assert_eq!(
450            support_for("vertex", "gemini-3-flash-preview"),
451            StructuredOutputSupport::Native
452        );
453        assert_eq!(
454            support_for("vertex", "claude-sonnet-4-5"),
455            StructuredOutputSupport::ToolForcing
456        );
457    }
458
459    #[test]
460    fn unknown_provider_defaults_to_tool_forcing() {
461        assert_eq!(
462            support_for("some-new-provider", "x"),
463            StructuredOutputSupport::ToolForcing
464        );
465    }
466}