1pub mod router;
2pub mod streaming;
3pub mod types;
4
5pub use router::{ModelRouter, ModelTier, TaskComplexity};
6pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
7pub use types::*;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13use crate::model_capabilities::{
14 ModelCapabilities, default_max_output_tokens, get_model_capabilities,
15};
16
17#[async_trait]
18pub trait LlmProvider: Send + Sync {
19 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
21
22 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
29 Box::pin(async_stream::stream! {
30 match self.chat(request).await {
31 Ok(outcome) => match outcome {
32 ChatOutcome::Success(response) => {
33 for (idx, block) in response.content.iter().enumerate() {
35 match block {
36 ContentBlock::Text { text } => {
37 yield Ok(StreamDelta::TextDelta {
38 delta: text.clone(),
39 block_index: idx,
40 });
41 }
42 ContentBlock::Thinking { thinking, .. } => {
43 yield Ok(StreamDelta::ThinkingDelta {
44 delta: thinking.clone(),
45 block_index: idx,
46 });
47 }
48 ContentBlock::RedactedThinking { .. }
49 | ContentBlock::ToolResult { .. }
50 | ContentBlock::Image { .. }
51 | ContentBlock::Document { .. } => {
52 }
54 ContentBlock::ToolUse { id, name, input, thought_signature } => {
55 yield Ok(StreamDelta::ToolUseStart {
56 id: id.clone(),
57 name: name.clone(),
58 block_index: idx,
59 thought_signature: thought_signature.clone(),
60 });
61 yield Ok(StreamDelta::ToolInputDelta {
62 id: id.clone(),
63 delta: serde_json::to_string(input).unwrap_or_default(),
64 block_index: idx,
65 });
66 }
67 }
68 }
69 yield Ok(StreamDelta::Usage(response.usage));
70 yield Ok(StreamDelta::Done {
71 stop_reason: response.stop_reason,
72 });
73 }
74 ChatOutcome::RateLimited => {
75 yield Ok(StreamDelta::Error {
76 message: "Rate limited".to_string(),
77 recoverable: true,
78 });
79 }
80 ChatOutcome::InvalidRequest(msg) => {
81 yield Ok(StreamDelta::Error {
82 message: msg,
83 recoverable: false,
84 });
85 }
86 ChatOutcome::ServerError(msg) => {
87 yield Ok(StreamDelta::Error {
88 message: msg,
89 recoverable: true,
90 });
91 }
92 },
93 Err(e) => yield Err(e),
94 }
95 })
96 }
97
98 fn model(&self) -> &str;
99 fn provider(&self) -> &'static str;
100
101 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
103 None
104 }
105
106 fn capabilities(&self) -> Option<&'static ModelCapabilities> {
108 get_model_capabilities(self.provider(), self.model()).or_else(|| match self.provider() {
109 "openai-responses" => get_model_capabilities("openai", self.model()),
110 "vertex" if self.model().starts_with("claude-") => {
111 get_model_capabilities("anthropic", self.model())
112 }
113 "vertex" => get_model_capabilities("gemini", self.model()),
114 _ => None,
115 })
116 }
117
118 fn validate_thinking_config(&self, thinking: Option<&ThinkingConfig>) -> Result<()> {
125 let Some(thinking) = thinking else {
126 return Ok(());
127 };
128
129 if matches!(thinking.mode, ThinkingMode::Adaptive)
130 && !self
131 .capabilities()
132 .is_some_and(|caps| caps.supports_adaptive_thinking)
133 {
134 return Err(anyhow::anyhow!(
135 "adaptive thinking is not supported for provider={} model={}",
136 self.provider(),
137 self.model()
138 ));
139 }
140
141 Ok(())
142 }
143
144 fn resolve_thinking_config(
153 &self,
154 request_thinking: Option<&ThinkingConfig>,
155 ) -> Result<Option<ThinkingConfig>> {
156 let thinking = request_thinking.or_else(|| self.configured_thinking());
157 self.validate_thinking_config(thinking)?;
158 Ok(thinking.cloned())
159 }
160
161 fn default_max_tokens(&self) -> u32 {
164 self.capabilities()
165 .and_then(|caps| caps.max_output_tokens)
166 .or_else(|| default_max_output_tokens(self.provider(), self.model()))
167 .unwrap_or(4096)
168 }
169}
170
171pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
180 let mut accumulator = StreamAccumulator::new();
181 let mut last_error: Option<(String, bool)> = None;
182
183 while let Some(result) = stream.next().await {
184 match result {
185 Ok(delta) => {
186 if let StreamDelta::Error {
187 message,
188 recoverable,
189 } = &delta
190 {
191 last_error = Some((message.clone(), *recoverable));
192 }
193 accumulator.apply(&delta);
194 }
195 Err(e) => return Err(e),
196 }
197 }
198
199 if let Some((message, recoverable)) = last_error {
201 if !recoverable {
202 return Ok(ChatOutcome::InvalidRequest(message));
203 }
204 if message.contains("Rate limited") || message.contains("rate limit") {
206 return Ok(ChatOutcome::RateLimited);
207 }
208 return Ok(ChatOutcome::ServerError(message));
209 }
210
211 let usage = accumulator.take_usage().unwrap_or(Usage {
213 input_tokens: 0,
214 output_tokens: 0,
215 });
216 let stop_reason = accumulator.take_stop_reason();
217 let content = accumulator.into_content_blocks();
218
219 log::debug!(
221 "Collected stream response: model={} stop_reason={:?} usage={{input_tokens={}, output_tokens={}}} content_blocks={}",
222 model,
223 stop_reason,
224 usage.input_tokens,
225 usage.output_tokens,
226 content.len()
227 );
228 for (i, block) in content.iter().enumerate() {
229 match block {
230 ContentBlock::Text { text } => {
231 log::debug!(" content_block[{}]: Text (len={})", i, text.len());
232 }
233 ContentBlock::Thinking { thinking, .. } => {
234 log::debug!(" content_block[{}]: Thinking (len={})", i, thinking.len());
235 }
236 ContentBlock::RedactedThinking { .. } => {
237 log::debug!(" content_block[{i}]: RedactedThinking");
238 }
239 ContentBlock::ToolUse {
240 id, name, input, ..
241 } => {
242 log::debug!(" content_block[{i}]: ToolUse id={id} name={name} input={input}");
243 }
244 ContentBlock::ToolResult {
245 tool_use_id,
246 content: result_content,
247 is_error,
248 } => {
249 log::debug!(
250 " content_block[{}]: ToolResult tool_use_id={} is_error={:?} content_len={}",
251 i,
252 tool_use_id,
253 is_error,
254 result_content.len()
255 );
256 }
257 ContentBlock::Image { source } => {
258 log::debug!(
259 " content_block[{i}]: Image media_type={}",
260 source.media_type
261 );
262 }
263 ContentBlock::Document { source } => {
264 log::debug!(
265 " content_block[{i}]: Document media_type={}",
266 source.media_type
267 );
268 }
269 }
270 }
271
272 Ok(ChatOutcome::Success(ChatResponse {
273 id: String::new(),
274 content,
275 model,
276 stop_reason,
277 usage,
278 }))
279}