agent_sdk_providers/provider.rs
1//! LLM provider trait and streaming helpers.
2//!
3//! This module defines the [`LlmProvider`] trait that all LLM backends implement,
4//! as well as the [`collect_stream`] helper for consuming a streaming response.
5
6use agent_sdk_foundation::llm::{
7 ChatOutcome, ChatRequest, ChatResponse, ContentBlock, ThinkingConfig, ThinkingMode, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12use serde::{Deserialize, Serialize};
13
14use crate::model_capabilities::{
15 ModelCapabilities, default_max_output_tokens, get_model_capabilities,
16};
17use crate::streaming::{StreamAccumulator, StreamBox, StreamDelta, StreamErrorKind};
18
19/// How a provider satisfies a [`ResponseFormat`](agent_sdk_foundation::llm::ResponseFormat)
20/// structured-output request.
21///
22/// The structured-output runner consults this to decide how to shape the
23/// request and where to read the final structured value from the response.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum StructuredOutputSupport {
26 /// The provider applies the schema natively (JSON-mode /
27 /// structured-outputs) when it sees `request.response_format`. The final
28 /// structured value is the JSON in the assistant's text output.
29 Native,
30 /// The provider has no native JSON-schema mode. The runner injects a
31 /// single forced "respond" tool whose `input_schema` is the output schema,
32 /// and reads the structured value from that tool call's input.
33 ToolForcing,
34}
35
36/// A single model entry returned by a provider's live model-listing endpoint.
37///
38/// This is the *dynamic* counterpart to the static
39/// [`ModelCapabilities`] table: it
40/// is populated from the provider's own `/models` API at runtime, so newly
41/// shipped models appear without an SDK code change. Fields beyond `id` are
42/// optional because not every provider's listing endpoint reports them.
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct ModelInfo {
45 /// The model identifier as the provider's chat endpoint expects it.
46 pub id: String,
47 /// Human-friendly display name, when the listing endpoint provides one.
48 pub display_name: Option<String>,
49 /// Maximum total context window in tokens, when reported.
50 pub context_window: Option<u32>,
51 /// Maximum output tokens per response, when reported.
52 pub max_output_tokens: Option<u32>,
53}
54
55#[async_trait]
56pub trait LlmProvider: Send + Sync {
57 /// Non-streaming chat completion.
58 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
59
60 /// List the models the provider currently exposes, queried live from the
61 /// provider's own model-listing endpoint.
62 ///
63 /// The default implementation returns an error: model discovery is an
64 /// additive capability, so a provider that has not implemented it stays
65 /// source-compatible while reporting that the operation is unsupported.
66 ///
67 /// # Errors
68 ///
69 /// Returns an error when the provider does not support live model listing,
70 /// or when the underlying HTTP request or response parsing fails.
71 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
72 Err(anyhow::anyhow!(
73 "list_models is not supported for provider {}",
74 self.provider()
75 ))
76 }
77
78 /// Streaming chat completion.
79 ///
80 /// Returns a stream of [`StreamDelta`] events. The default implementation
81 /// calls [`chat()`](Self::chat) and converts the result to a single-chunk stream.
82 ///
83 /// Providers should override this method to provide true streaming support.
84 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
85 Box::pin(async_stream::stream! {
86 match self.chat(request).await {
87 Ok(outcome) => match outcome {
88 ChatOutcome::Success(response) => {
89 // Emit content as deltas
90 for (idx, block) in response.content.iter().enumerate() {
91 match block {
92 ContentBlock::Text { text } => {
93 yield Ok(StreamDelta::TextDelta {
94 delta: text.clone(),
95 block_index: idx,
96 });
97 }
98 ContentBlock::Thinking { thinking, .. } => {
99 yield Ok(StreamDelta::ThinkingDelta {
100 delta: thinking.clone(),
101 block_index: idx,
102 });
103 }
104 ContentBlock::RedactedThinking { .. }
105 | ContentBlock::ToolResult { .. }
106 | ContentBlock::Image { .. }
107 | ContentBlock::Document { .. } => {
108 // Not streamed in the default implementation
109 }
110 ContentBlock::ToolUse { id, name, input, thought_signature } => {
111 yield Ok(StreamDelta::ToolUseStart {
112 id: id.clone(),
113 name: name.clone(),
114 block_index: idx,
115 thought_signature: thought_signature.clone(),
116 });
117 yield Ok(StreamDelta::ToolInputDelta {
118 id: id.clone(),
119 delta: serde_json::to_string(input).unwrap_or_default(),
120 block_index: idx,
121 });
122 }
123 // `ContentBlock` is `#[non_exhaustive]`; a future
124 // block kind we cannot stream is skipped rather than
125 // panicking the default fallback.
126 _ => {
127 log::warn!(
128 "chat_stream fallback skipping unrecognized content block at index {idx}"
129 );
130 }
131 }
132 }
133 yield Ok(StreamDelta::Usage(response.usage));
134 yield Ok(StreamDelta::Done {
135 stop_reason: response.stop_reason,
136 });
137 }
138 ChatOutcome::RateLimited(_) => {
139 yield Ok(StreamDelta::Error {
140 message: "Rate limited".to_string(),
141 kind: StreamErrorKind::RateLimited,
142 });
143 }
144 ChatOutcome::InvalidRequest(msg) => {
145 yield Ok(StreamDelta::Error {
146 message: msg,
147 kind: StreamErrorKind::InvalidRequest,
148 });
149 }
150 ChatOutcome::ServerError(msg) => {
151 yield Ok(StreamDelta::Error {
152 message: msg,
153 kind: StreamErrorKind::ServerError,
154 });
155 }
156 // `ChatOutcome` is `#[non_exhaustive]`; an outcome this SDK
157 // version does not model is surfaced as an unclassified
158 // (non-recoverable) stream error rather than dropped.
159 _ => {
160 yield Ok(StreamDelta::Error {
161 message: "Unrecognized chat outcome".to_string(),
162 kind: StreamErrorKind::Unknown,
163 });
164 }
165 },
166 Err(e) => yield Err(e),
167 }
168 })
169 }
170
171 fn model(&self) -> &str;
172 fn provider(&self) -> &'static str;
173
174 /// Provider-owned thinking configuration, if any.
175 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
176 None
177 }
178
179 /// Canonical capability metadata for this provider/model, if known.
180 fn capabilities(&self) -> Option<&'static ModelCapabilities> {
181 get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
182 "openai-responses" | "openai-codex" => get_model_capabilities("openai", self.model()),
183 "vertex" if self.model().starts_with("claude-") => {
184 get_model_capabilities("anthropic", self.model())
185 }
186 "vertex" => get_model_capabilities("gemini", self.model()),
187 _ => None,
188 })
189 }
190
191 /// Validate a thinking configuration against the provider/model capabilities.
192 ///
193 /// # Errors
194 ///
195 /// Returns an error when the requested thinking mode is not supported by
196 /// the active provider/model capability set.
197 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
198 let Some(thinking) = thinking else {
199 return Ok(());
200 };
201
202 if self
203 .capabilities()
204 .is_some_and(|caps| !caps.supports_thinking)
205 {
206 return Err(anyhow::anyhow!(
207 "thinking is not supported for provider={} model={}",
208 self.provider(),
209 self.model()
210 ));
211 }
212
213 if matches!(thinking.mode, ThinkingMode::Adaptive)
214 && !self
215 .capabilities()
216 .is_some_and(|caps| caps.supports_adaptive_thinking)
217 {
218 return Err(anyhow::anyhow!(
219 "adaptive thinking is not supported for provider={} model={}",
220 self.provider(),
221 self.model()
222 ));
223 }
224
225 Ok(())
226 }
227
228 /// Resolve the effective thinking configuration for a request.
229 ///
230 /// Request-level thinking overrides provider-owned defaults when present.
231 ///
232 /// # Errors
233 ///
234 /// Returns an error when the resolved thinking configuration is not
235 /// supported by the active provider/model capability set.
236 fn resolve_thinking_config(
237 &self,
238 request_thinking: Option<&ThinkingConfig>,
239 ) -> Result<Option<ThinkingConfig>> {
240 let thinking = request_thinking.or_else(|| self.configured_thinking());
241 self.validate_thinking_config(thinking)?;
242 Ok(thinking.cloned())
243 }
244
245 /// Default maximum output tokens for this provider/model when the caller
246 /// does not explicitly override `AgentConfig.max_tokens`.
247 fn default_max_tokens(&self) -> u32 {
248 self.capabilities()
249 .and_then(|caps| caps.max_output_tokens)
250 .or_else(|| default_max_output_tokens(self.provider(), self.model()))
251 .unwrap_or(4096)
252 }
253
254 /// How this provider satisfies a structured-output
255 /// ([`ResponseFormat`](agent_sdk_foundation::llm::ResponseFormat)) request.
256 ///
257 /// Providers with a native JSON-schema / JSON-mode wire field
258 /// (OpenAI-family, Gemini, Vertex) report
259 /// [`StructuredOutputSupport::Native`] and consume
260 /// `request.response_format` directly. Providers without one (Anthropic)
261 /// report [`StructuredOutputSupport::ToolForcing`] so the runner forces a
262 /// single "respond" tool whose schema is the output schema. The default
263 /// is the conservative [`StructuredOutputSupport::ToolForcing`], which
264 /// works for any tool-capable provider.
265 fn structured_output_support(&self) -> StructuredOutputSupport {
266 match self.provider() {
267 "openai" | "openai-responses" | "openai-codex" | "gemini" => {
268 StructuredOutputSupport::Native
269 }
270 // Vertex multiplexes Anthropic and Gemini models. Only the Gemini
271 // side has a native structured-output field; Claude-on-Vertex uses
272 // the Messages API shape, which has no `response_format`.
273 "vertex" if !self.model().starts_with("claude-") => StructuredOutputSupport::Native,
274 _ => StructuredOutputSupport::ToolForcing,
275 }
276 }
277}
278
279/// Helper function to consume a stream and collect it into a `ChatResponse`.
280///
281/// This is useful for providers that want to test their streaming implementation
282/// or for cases where you need the full response after streaming.
283///
284/// # Errors
285///
286/// Returns an error if the stream yields an error result.
287pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
288 let mut accumulator = StreamAccumulator::new();
289 let mut last_error: Option<(String, StreamErrorKind)> = None;
290
291 while let Some(result) = stream.next().await {
292 match result {
293 Ok(delta) => {
294 if let StreamDelta::Error { message, kind } = &delta {
295 last_error = Some((message.clone(), *kind));
296 }
297 accumulator.apply(&delta);
298 }
299 Err(e) => return Err(e),
300 }
301 }
302
303 // If we encountered an error during streaming, map kind directly
304 // to the corresponding `ChatOutcome` variant. No string-matching
305 // heuristic is needed because the kind already records the
306 // category at the construction site.
307 if let Some((message, kind)) = last_error {
308 return Ok(match kind {
309 // The streaming error channel does not carry a `Retry-After`, so
310 // the reconstructed outcome reports no server-supplied delay.
311 StreamErrorKind::RateLimited => ChatOutcome::RateLimited(None),
312 StreamErrorKind::InvalidRequest => ChatOutcome::InvalidRequest(message),
313 // `StreamErrorKind::ServerError`, plus the `#[non_exhaustive]`
314 // catch-all (`Unknown` / future kinds): an unclassified error is
315 // treated as a (non-recoverable) server error so the caller still
316 // surfaces the failure rather than silently succeeding.
317 _ => ChatOutcome::ServerError(message),
318 });
319 }
320
321 // Extract usage and stop_reason before consuming the accumulator
322 let usage = accumulator.take_usage().unwrap_or(Usage {
323 input_tokens: 0,
324 output_tokens: 0,
325 cached_input_tokens: 0,
326 cache_creation_input_tokens: 0,
327 });
328 let stop_reason = accumulator.take_stop_reason();
329 let content = accumulator.into_content_blocks();
330
331 // Log accumulated response for debugging
332 log::debug!(
333 "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
334 model,
335 stop_reason,
336 usage.input_tokens,
337 usage.output_tokens,
338 content.len()
339 );
340 for (i, block) in content.iter().enumerate() {
341 match block {
342 ContentBlock::Text { text } => {
343 log::debug!(" content_block[{}]: Text (len={})", i, text.len());
344 }
345 ContentBlock::Thinking { thinking, .. } => {
346 log::debug!(" content_block[{}]: Thinking (len={})", i, thinking.len());
347 }
348 ContentBlock::RedactedThinking { .. } => {
349 log::debug!(" content_block[{i}]: RedactedThinking");
350 }
351 ContentBlock::ToolUse {
352 id, name, input, ..
353 } => {
354 log::debug!(" content_block[{i}]: ToolUse id={id} name={name} input={input}");
355 }
356 ContentBlock::ToolResult {
357 tool_use_id,
358 content: result_content,
359 is_error,
360 } => {
361 log::debug!(
362 " content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
363 i,
364 tool_use_id,
365 is_error,
366 result_content.len()
367 );
368 }
369 ContentBlock::Image { source } => {
370 log::debug!(
371 " content_block[{i}]: Image media_type={}",
372 source.media_type
373 );
374 }
375 ContentBlock::Document { source } => {
376 log::debug!(
377 " content_block[{i}]: Document media_type={}",
378 source.media_type
379 );
380 }
381 // `ContentBlock` is `#[non_exhaustive]`; log unknown future block
382 // kinds generically so the debug dump stays exhaustive.
383 _ => {
384 log::debug!(" content_block[{i}]: <unrecognized block kind>");
385 }
386 }
387 }
388
389 Ok(ChatOutcome::Success(ChatResponse {
390 id: String::new(),
391 content,
392 model,
393 stop_reason,
394 usage,
395 }))
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use anyhow::Result;
402 use async_trait::async_trait;
403
404 struct Stub {
405 provider: &'static str,
406 model: &'static str,
407 }
408
409 #[async_trait]
410 impl LlmProvider for Stub {
411 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
412 Ok(ChatOutcome::ServerError("unused".to_owned()))
413 }
414
415 fn model(&self) -> &str {
416 self.model
417 }
418
419 fn provider(&self) -> &'static str {
420 self.provider
421 }
422 }
423
424 fn support_for(provider: &'static str, model: &'static str) -> StructuredOutputSupport {
425 Stub { provider, model }.structured_output_support()
426 }
427
428 #[test]
429 fn native_providers_report_native_support() {
430 for provider in ["openai", "openai-responses", "openai-codex", "gemini"] {
431 assert_eq!(
432 support_for(provider, "any-model"),
433 StructuredOutputSupport::Native,
434 "{provider} should be native"
435 );
436 }
437 }
438
439 #[test]
440 fn anthropic_reports_tool_forcing() {
441 assert_eq!(
442 support_for("anthropic", "claude-sonnet-4-5"),
443 StructuredOutputSupport::ToolForcing
444 );
445 }
446
447 #[test]
448 fn vertex_is_native_for_gemini_models_and_tool_forcing_for_claude() {
449 assert_eq!(
450 support_for("vertex", "gemini-3-flash-preview"),
451 StructuredOutputSupport::Native
452 );
453 assert_eq!(
454 support_for("vertex", "claude-sonnet-4-5"),
455 StructuredOutputSupport::ToolForcing
456 );
457 }
458
459 #[test]
460 fn unknown_provider_defaults_to_tool_forcing() {
461 assert_eq!(
462 support_for("some-new-provider", "x"),
463 StructuredOutputSupport::ToolForcing
464 );
465 }
466}