1pub mod attachments;
2pub mod router;
3pub mod streaming;
4pub mod types;
5
6pub use router::{ModelRouter, ModelTier, TaskComplexity};
7pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
8pub use types::*;
9
10use anyhow::Result;
11use async_trait::async_trait;
12use futures::StreamExt;
13
14use crate::model_capabilities::{
15 ModelCapabilities, default_max_output_tokens, get_model_capabilities,
16};
17
18#[async_trait]
19pub trait LlmProvider: Send + Sync {
20 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
22
23 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
30 Box::pin(async_stream::stream! {
31 match self.chat(request).await {
32 Ok(outcome) => match outcome {
33 ChatOutcome::Success(response) => {
34 for (idx, block) in response.content.iter().enumerate() {
36 match block {
37 ContentBlock::Text { text } => {
38 yield Ok(StreamDelta::TextDelta {
39 delta: text.clone(),
40 block_index: idx,
41 });
42 }
43 ContentBlock::Thinking { thinking, .. } => {
44 yield Ok(StreamDelta::ThinkingDelta {
45 delta: thinking.clone(),
46 block_index: idx,
47 });
48 }
49 ContentBlock::RedactedThinking { .. }
50 | ContentBlock::ToolResult { .. }
51 | ContentBlock::Image { .. }
52 | ContentBlock::Document { .. } => {
53 }
55 ContentBlock::ToolUse { id, name, input, thought_signature } => {
56 yield Ok(StreamDelta::ToolUseStart {
57 id: id.clone(),
58 name: name.clone(),
59 block_index: idx,
60 thought_signature: thought_signature.clone(),
61 });
62 yield Ok(StreamDelta::ToolInputDelta {
63 id: id.clone(),
64 delta: serde_json::to_string(input).unwrap_or_default(),
65 block_index: idx,
66 });
67 }
68 }
69 }
70 yield Ok(StreamDelta::Usage(response.usage));
71 yield Ok(StreamDelta::Done {
72 stop_reason: response.stop_reason,
73 });
74 }
75 ChatOutcome::RateLimited => {
76 yield Ok(StreamDelta::Error {
77 message: "Rate limited".to_string(),
78 recoverable: true,
79 });
80 }
81 ChatOutcome::InvalidRequest(msg) => {
82 yield Ok(StreamDelta::Error {
83 message: msg,
84 recoverable: false,
85 });
86 }
87 ChatOutcome::ServerError(msg) => {
88 yield Ok(StreamDelta::Error {
89 message: msg,
90 recoverable: true,
91 });
92 }
93 },
94 Err(e) => yield Err(e),
95 }
96 })
97 }
98
99 fn model(&self) -> &str;
100 fn provider(&self) -> &'static str;
101
102 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
104 None
105 }
106
107 fn capabilities(&self) -> Option<&'static ModelCapabilities> {
109 get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
110 "openai-responses" | "openai-codex" => get_model_capabilities("openai", self.model()),
111 "vertex" if self.model().starts_with("claude-") => {
112 get_model_capabilities("anthropic", self.model())
113 }
114 "vertex" => get_model_capabilities("gemini", self.model()),
115 _ => None,
116 })
117 }
118
119 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
126 let Some(thinking) = thinking else {
127 return Ok(());
128 };
129
130 if matches!(thinking.mode, ThinkingMode::Adaptive)
131 && !self
132 .capabilities()
133 .is_some_and(|caps| caps.supports_adaptive_thinking)
134 {
135 return Err(anyhow::anyhow!(
136 "adaptive thinking is not supported for provider={} model={}",
137 self.provider(),
138 self.model()
139 ));
140 }
141
142 Ok(())
143 }
144
145 fn resolve_thinking_config(
154 &self,
155 request_thinking: Option<&ThinkingConfig>,
156 ) -> Result<Option<ThinkingConfig>> {
157 let thinking = request_thinking.or_else(|| self.configured_thinking());
158 self.validate_thinking_config(thinking)?;
159 Ok(thinking.cloned())
160 }
161
162 fn default_max_tokens(&self) -> u32 {
165 self.capabilities()
166 .and_then(|caps| caps.max_output_tokens)
167 .or_else(|| default_max_output_tokens(self.provider(), self.model()))
168 .unwrap_or(4096)
169 }
170}
171
172pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
181 let mut accumulator = StreamAccumulator::new();
182 let mut last_error: Option<(String, bool)> = None;
183
184 while let Some(result) = stream.next().await {
185 match result {
186 Ok(delta) => {
187 if let StreamDelta::Error {
188 message,
189 recoverable,
190 } = &delta
191 {
192 last_error = Some((message.clone(), *recoverable));
193 }
194 accumulator.apply(&delta);
195 }
196 Err(e) => return Err(e),
197 }
198 }
199
200 if let Some((message, recoverable)) = last_error {
202 if !recoverable {
203 return Ok(ChatOutcome::InvalidRequest(message));
204 }
205 if message.contains("Rate limited") || message.contains("rate limit") {
207 return Ok(ChatOutcome::RateLimited);
208 }
209 return Ok(ChatOutcome::ServerError(message));
210 }
211
212 let usage = accumulator.take_usage().unwrap_or(Usage {
214 input_tokens: 0,
215 output_tokens: 0,
216 });
217 let stop_reason = accumulator.take_stop_reason();
218 let content = accumulator.into_content_blocks();
219
220 log::debug!(
222 "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
223 model,
224 stop_reason,
225 usage.input_tokens,
226 usage.output_tokens,
227 content.len()
228 );
229 for (i, block) in content.iter().enumerate() {
230 match block {
231 ContentBlock::Text { text } => {
232 log::debug!(" content_block[{}]: Text (len={})", i, text.len());
233 }
234 ContentBlock::Thinking { thinking, .. } => {
235 log::debug!(" content_block[{}]: Thinking (len={})", i, thinking.len());
236 }
237 ContentBlock::RedactedThinking { .. } => {
238 log::debug!(" content_block[{i}]: RedactedThinking");
239 }
240 ContentBlock::ToolUse {
241 id, name, input, ..
242 } => {
243 log::debug!(" content_block[{i}]: ToolUse id={id} name={name} input={input}");
244 }
245 ContentBlock::ToolResult {
246 tool_use_id,
247 content: result_content,
248 is_error,
249 } => {
250 log::debug!(
251 " content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
252 i,
253 tool_use_id,
254 is_error,
255 result_content.len()
256 );
257 }
258 ContentBlock::Image { source } => {
259 log::debug!(
260 " content_block[{i}]: Image media_type={}",
261 source.media_type
262 );
263 }
264 ContentBlock::Document { source } => {
265 log::debug!(
266 " content_block[{i}]: Document media_type={}",
267 source.media_type
268 );
269 }
270 }
271 }
272
273 Ok(ChatOutcome::Success(ChatResponse {
274 id: String::new(),
275 content,
276 model,
277 stop_reason,
278 usage,
279 }))
280}