1use agent_sdk_foundation::llm::{
7 ChatOutcome, ChatRequest, ChatResponse, ContentBlock, ThinkingConfig, ThinkingMode, Usage,
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13use crate::model_capabilities::{
14 ModelCapabilities, default_max_output_tokens, get_model_capabilities,
15};
16use crate::streaming::{StreamAccumulator, StreamBox, StreamDelta, StreamErrorKind};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum StructuredOutputSupport {
25 Native,
29 ToolForcing,
33}
34
35#[async_trait]
36pub trait LlmProvider: Send + Sync {
37 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
39
40 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
47 Box::pin(async_stream::stream! {
48 match self.chat(request).await {
49 Ok(outcome) => match outcome {
50 ChatOutcome::Success(response) => {
51 for (idx, block) in response.content.iter().enumerate() {
53 match block {
54 ContentBlock::Text { text } => {
55 yield Ok(StreamDelta::TextDelta {
56 delta: text.clone(),
57 block_index: idx,
58 });
59 }
60 ContentBlock::Thinking { thinking, .. } => {
61 yield Ok(StreamDelta::ThinkingDelta {
62 delta: thinking.clone(),
63 block_index: idx,
64 });
65 }
66 ContentBlock::RedactedThinking { .. }
67 | ContentBlock::ToolResult { .. }
68 | ContentBlock::Image { .. }
69 | ContentBlock::Document { .. } => {
70 }
72 ContentBlock::ToolUse { id, name, input, thought_signature } => {
73 yield Ok(StreamDelta::ToolUseStart {
74 id: id.clone(),
75 name: name.clone(),
76 block_index: idx,
77 thought_signature: thought_signature.clone(),
78 });
79 yield Ok(StreamDelta::ToolInputDelta {
80 id: id.clone(),
81 delta: serde_json::to_string(input).unwrap_or_default(),
82 block_index: idx,
83 });
84 }
85 _ => {
89 log::warn!(
90 "chat_stream fallback skipping unrecognized content block at index {idx}"
91 );
92 }
93 }
94 }
95 yield Ok(StreamDelta::Usage(response.usage));
96 yield Ok(StreamDelta::Done {
97 stop_reason: response.stop_reason,
98 });
99 }
100 ChatOutcome::RateLimited => {
101 yield Ok(StreamDelta::Error {
102 message: "Rate limited".to_string(),
103 kind: StreamErrorKind::RateLimited,
104 });
105 }
106 ChatOutcome::InvalidRequest(msg) => {
107 yield Ok(StreamDelta::Error {
108 message: msg,
109 kind: StreamErrorKind::InvalidRequest,
110 });
111 }
112 ChatOutcome::ServerError(msg) => {
113 yield Ok(StreamDelta::Error {
114 message: msg,
115 kind: StreamErrorKind::ServerError,
116 });
117 }
118 _ => {
122 yield Ok(StreamDelta::Error {
123 message: "Unrecognized chat outcome".to_string(),
124 kind: StreamErrorKind::Unknown,
125 });
126 }
127 },
128 Err(e) => yield Err(e),
129 }
130 })
131 }
132
133 fn model(&self) -> &str;
134 fn provider(&self) -> &'static str;
135
136 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
138 None
139 }
140
141 fn capabilities(&self) -> Option<&'static ModelCapabilities> {
143 get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
144 "openai-responses" | "openai-codex" => get_model_capabilities("openai", self.model()),
145 "vertex" if self.model().starts_with("claude-") => {
146 get_model_capabilities("anthropic", self.model())
147 }
148 "vertex" => get_model_capabilities("gemini", self.model()),
149 _ => None,
150 })
151 }
152
153 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
160 let Some(thinking) = thinking else {
161 return Ok(());
162 };
163
164 if self
165 .capabilities()
166 .is_some_and(|caps| !caps.supports_thinking)
167 {
168 return Err(anyhow::anyhow!(
169 "thinking is not supported for provider={} model={}",
170 self.provider(),
171 self.model()
172 ));
173 }
174
175 if matches!(thinking.mode, ThinkingMode::Adaptive)
176 && !self
177 .capabilities()
178 .is_some_and(|caps| caps.supports_adaptive_thinking)
179 {
180 return Err(anyhow::anyhow!(
181 "adaptive thinking is not supported for provider={} model={}",
182 self.provider(),
183 self.model()
184 ));
185 }
186
187 Ok(())
188 }
189
190 fn resolve_thinking_config(
199 &self,
200 request_thinking: Option<&ThinkingConfig>,
201 ) -> Result<Option<ThinkingConfig>> {
202 let thinking = request_thinking.or_else(|| self.configured_thinking());
203 self.validate_thinking_config(thinking)?;
204 Ok(thinking.cloned())
205 }
206
207 fn default_max_tokens(&self) -> u32 {
210 self.capabilities()
211 .and_then(|caps| caps.max_output_tokens)
212 .or_else(|| default_max_output_tokens(self.provider(), self.model()))
213 .unwrap_or(4096)
214 }
215
216 fn structured_output_support(&self) -> StructuredOutputSupport {
228 match self.provider() {
229 "openai" | "openai-responses" | "openai-codex" | "gemini" => {
230 StructuredOutputSupport::Native
231 }
232 "vertex" if !self.model().starts_with("claude-") => StructuredOutputSupport::Native,
236 _ => StructuredOutputSupport::ToolForcing,
237 }
238 }
239}
240
241pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
250 let mut accumulator = StreamAccumulator::new();
251 let mut last_error: Option<(String, StreamErrorKind)> = None;
252
253 while let Some(result) = stream.next().await {
254 match result {
255 Ok(delta) => {
256 if let StreamDelta::Error { message, kind } = &delta {
257 last_error = Some((message.clone(), *kind));
258 }
259 accumulator.apply(&delta);
260 }
261 Err(e) => return Err(e),
262 }
263 }
264
265 if let Some((message, kind)) = last_error {
270 return Ok(match kind {
271 StreamErrorKind::RateLimited => ChatOutcome::RateLimited,
272 StreamErrorKind::InvalidRequest => ChatOutcome::InvalidRequest(message),
273 _ => ChatOutcome::ServerError(message),
278 });
279 }
280
281 let usage = accumulator.take_usage().unwrap_or(Usage {
283 input_tokens: 0,
284 output_tokens: 0,
285 cached_input_tokens: 0,
286 cache_creation_input_tokens: 0,
287 });
288 let stop_reason = accumulator.take_stop_reason();
289 let content = accumulator.into_content_blocks();
290
291 log::debug!(
293 "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
294 model,
295 stop_reason,
296 usage.input_tokens,
297 usage.output_tokens,
298 content.len()
299 );
300 for (i, block) in content.iter().enumerate() {
301 match block {
302 ContentBlock::Text { text } => {
303 log::debug!(" content_block[{}]: Text (len={})", i, text.len());
304 }
305 ContentBlock::Thinking { thinking, .. } => {
306 log::debug!(" content_block[{}]: Thinking (len={})", i, thinking.len());
307 }
308 ContentBlock::RedactedThinking { .. } => {
309 log::debug!(" content_block[{i}]: RedactedThinking");
310 }
311 ContentBlock::ToolUse {
312 id, name, input, ..
313 } => {
314 log::debug!(" content_block[{i}]: ToolUse id={id} name={name} input={input}");
315 }
316 ContentBlock::ToolResult {
317 tool_use_id,
318 content: result_content,
319 is_error,
320 } => {
321 log::debug!(
322 " content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
323 i,
324 tool_use_id,
325 is_error,
326 result_content.len()
327 );
328 }
329 ContentBlock::Image { source } => {
330 log::debug!(
331 " content_block[{i}]: Image media_type={}",
332 source.media_type
333 );
334 }
335 ContentBlock::Document { source } => {
336 log::debug!(
337 " content_block[{i}]: Document media_type={}",
338 source.media_type
339 );
340 }
341 _ => {
344 log::debug!(" content_block[{i}]: <unrecognized block kind>");
345 }
346 }
347 }
348
349 Ok(ChatOutcome::Success(ChatResponse {
350 id: String::new(),
351 content,
352 model,
353 stop_reason,
354 usage,
355 }))
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use anyhow::Result;
362 use async_trait::async_trait;
363
364 struct Stub {
365 provider: &'static str,
366 model: &'static str,
367 }
368
369 #[async_trait]
370 impl LlmProvider for Stub {
371 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
372 Ok(ChatOutcome::ServerError("unused".to_owned()))
373 }
374
375 fn model(&self) -> &str {
376 self.model
377 }
378
379 fn provider(&self) -> &'static str {
380 self.provider
381 }
382 }
383
384 fn support_for(provider: &'static str, model: &'static str) -> StructuredOutputSupport {
385 Stub { provider, model }.structured_output_support()
386 }
387
388 #[test]
389 fn native_providers_report_native_support() {
390 for provider in ["openai", "openai-responses", "openai-codex", "gemini"] {
391 assert_eq!(
392 support_for(provider, "any-model"),
393 StructuredOutputSupport::Native,
394 "{provider} should be native"
395 );
396 }
397 }
398
399 #[test]
400 fn anthropic_reports_tool_forcing() {
401 assert_eq!(
402 support_for("anthropic", "claude-sonnet-4-5"),
403 StructuredOutputSupport::ToolForcing
404 );
405 }
406
407 #[test]
408 fn vertex_is_native_for_gemini_models_and_tool_forcing_for_claude() {
409 assert_eq!(
410 support_for("vertex", "gemini-3-flash-preview"),
411 StructuredOutputSupport::Native
412 );
413 assert_eq!(
414 support_for("vertex", "claude-sonnet-4-5"),
415 StructuredOutputSupport::ToolForcing
416 );
417 }
418
419 #[test]
420 fn unknown_provider_defaults_to_tool_forcing() {
421 assert_eq!(
422 support_for("some-new-provider", "x"),
423 StructuredOutputSupport::ToolForcing
424 );
425 }
426}