1pub mod attachments;
2pub mod router;
3pub mod streaming;
4pub mod types;
5
6pub use router::{ModelRouter, ModelTier, TaskComplexity};
7pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
8pub use types::*;
9
10use anyhow::Result;
11use async_trait::async_trait;
12use futures::StreamExt;
13
14use crate::model_capabilities::{
15 ModelCapabilities, default_max_output_tokens, get_model_capabilities,
16};
17
18#[async_trait]
19pub trait LlmProvider: Send + Sync {
20 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
22
23 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
30 Box::pin(async_stream::stream! {
31 match self.chat(request).await {
32 Ok(outcome) => match outcome {
33 ChatOutcome::Success(response) => {
34 for (idx, block) in response.content.iter().enumerate() {
36 match block {
37 ContentBlock::Text { text } => {
38 yield Ok(StreamDelta::TextDelta {
39 delta: text.clone(),
40 block_index: idx,
41 });
42 }
43 ContentBlock::Thinking { thinking, .. } => {
44 yield Ok(StreamDelta::ThinkingDelta {
45 delta: thinking.clone(),
46 block_index: idx,
47 });
48 }
49 ContentBlock::RedactedThinking { .. }
50 | ContentBlock::ToolResult { .. }
51 | ContentBlock::Image { .. }
52 | ContentBlock::Document { .. } => {
53 }
55 ContentBlock::ToolUse { id, name, input, thought_signature } => {
56 yield Ok(StreamDelta::ToolUseStart {
57 id: id.clone(),
58 name: name.clone(),
59 block_index: idx,
60 thought_signature: thought_signature.clone(),
61 });
62 yield Ok(StreamDelta::ToolInputDelta {
63 id: id.clone(),
64 delta: serde_json::to_string(input).unwrap_or_default(),
65 block_index: idx,
66 });
67 }
68 }
69 }
70 yield Ok(StreamDelta::Usage(response.usage));
71 yield Ok(StreamDelta::Done {
72 stop_reason: response.stop_reason,
73 });
74 }
75 ChatOutcome::RateLimited => {
76 yield Ok(StreamDelta::Error {
77 message: "Rate limited".to_string(),
78 recoverable: true,
79 });
80 }
81 ChatOutcome::InvalidRequest(msg) => {
82 yield Ok(StreamDelta::Error {
83 message: msg,
84 recoverable: false,
85 });
86 }
87 ChatOutcome::ServerError(msg) => {
88 yield Ok(StreamDelta::Error {
89 message: msg,
90 recoverable: true,
91 });
92 }
93 },
94 Err(e) => yield Err(e),
95 }
96 })
97 }
98
99 fn model(&self) -> &str;
100 fn provider(&self) -> &'static str;
101
102 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
104 None
105 }
106
107 fn capabilities(&self) -> Option<&'static ModelCapabilities> {
109 get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
110 "openai-responses" | "openai-codex" => get_model_capabilities("openai", self.model()),
111 "vertex" if self.model().starts_with("claude-") => {
112 get_model_capabilities("anthropic", self.model())
113 }
114 "vertex" => get_model_capabilities("gemini", self.model()),
115 _ => None,
116 })
117 }
118
119 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
126 let Some(thinking) = thinking else {
127 return Ok(());
128 };
129
130 if self
131 .capabilities()
132 .is_some_and(|caps| !caps.supports_thinking)
133 {
134 return Err(anyhow::anyhow!(
135 "thinking is not supported for provider={} model={}",
136 self.provider(),
137 self.model()
138 ));
139 }
140
141 if matches!(thinking.mode, ThinkingMode::Adaptive)
142 && !self
143 .capabilities()
144 .is_some_and(|caps| caps.supports_adaptive_thinking)
145 {
146 return Err(anyhow::anyhow!(
147 "adaptive thinking is not supported for provider={} model={}",
148 self.provider(),
149 self.model()
150 ));
151 }
152
153 Ok(())
154 }
155
156 fn resolve_thinking_config(
165 &self,
166 request_thinking: Option<&ThinkingConfig>,
167 ) -> Result<Option<ThinkingConfig>> {
168 let thinking = request_thinking.or_else(|| self.configured_thinking());
169 self.validate_thinking_config(thinking)?;
170 Ok(thinking.cloned())
171 }
172
173 fn default_max_tokens(&self) -> u32 {
176 self.capabilities()
177 .and_then(|caps| caps.max_output_tokens)
178 .or_else(|| default_max_output_tokens(self.provider(), self.model()))
179 .unwrap_or(4096)
180 }
181}
182
183pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
192 let mut accumulator = StreamAccumulator::new();
193 let mut last_error: Option<(String, bool)> = None;
194
195 while let Some(result) = stream.next().await {
196 match result {
197 Ok(delta) => {
198 if let StreamDelta::Error {
199 message,
200 recoverable,
201 } = &delta
202 {
203 last_error = Some((message.clone(), *recoverable));
204 }
205 accumulator.apply(&delta);
206 }
207 Err(e) => return Err(e),
208 }
209 }
210
211 if let Some((message, recoverable)) = last_error {
213 if !recoverable {
214 return Ok(ChatOutcome::InvalidRequest(message));
215 }
216 if message.contains("Rate limited") || message.contains("rate limit") {
218 return Ok(ChatOutcome::RateLimited);
219 }
220 return Ok(ChatOutcome::ServerError(message));
221 }
222
223 let usage = accumulator.take_usage().unwrap_or(Usage {
225 input_tokens: 0,
226 output_tokens: 0,
227 cached_input_tokens: 0,
228 });
229 let stop_reason = accumulator.take_stop_reason();
230 let content = accumulator.into_content_blocks();
231
232 log::debug!(
234 "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
235 model,
236 stop_reason,
237 usage.input_tokens,
238 usage.output_tokens,
239 content.len()
240 );
241 for (i, block) in content.iter().enumerate() {
242 match block {
243 ContentBlock::Text { text } => {
244 log::debug!(" content_block[{}]: Text (len={})", i, text.len());
245 }
246 ContentBlock::Thinking { thinking, .. } => {
247 log::debug!(" content_block[{}]: Thinking (len={})", i, thinking.len());
248 }
249 ContentBlock::RedactedThinking { .. } => {
250 log::debug!(" content_block[{i}]: RedactedThinking");
251 }
252 ContentBlock::ToolUse {
253 id, name, input, ..
254 } => {
255 log::debug!(" content_block[{i}]: ToolUse id={id} name={name} input={input}");
256 }
257 ContentBlock::ToolResult {
258 tool_use_id,
259 content: result_content,
260 is_error,
261 } => {
262 log::debug!(
263 " content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
264 i,
265 tool_use_id,
266 is_error,
267 result_content.len()
268 );
269 }
270 ContentBlock::Image { source } => {
271 log::debug!(
272 " content_block[{i}]: Image media_type={}",
273 source.media_type
274 );
275 }
276 ContentBlock::Document { source } => {
277 log::debug!(
278 " content_block[{i}]: Document media_type={}",
279 source.media_type
280 );
281 }
282 }
283 }
284
285 Ok(ChatOutcome::Success(ChatResponse {
286 id: String::new(),
287 content,
288 model,
289 stop_reason,
290 usage,
291 }))
292}