mixtape_core/provider/
mod.rs1#[cfg(feature = "anthropic")]
7pub mod anthropic;
8#[cfg(feature = "bedrock")]
9pub mod bedrock;
10pub mod retry;
11
12use crate::events::TokenUsage;
13use crate::types::{Message, StopReason, ToolDefinition, ToolUseBlock};
14use futures::stream::BoxStream;
15use std::error::Error;
16
17#[cfg(feature = "anthropic")]
19pub use anthropic::AnthropicProvider;
20#[cfg(feature = "bedrock")]
21pub use bedrock::{BedrockProvider, InferenceProfile};
22pub use retry::{RetryCallback, RetryConfig, RetryInfo};
23
24pub use crate::model::ModelResponse;
26
27#[derive(Debug, Clone)]
29pub enum StreamEvent {
30 TextDelta(String),
32 ToolUse(ToolUseBlock),
34 ThinkingDelta(String),
36 Stop {
38 stop_reason: StopReason,
40 usage: Option<TokenUsage>,
42 },
43}
44
45#[derive(Debug, thiserror::Error)]
47pub enum ProviderError {
48 #[error("Authentication failed: {0}")]
50 Authentication(String),
51
52 #[error("Rate limited: {0}")]
54 RateLimited(String),
55
56 #[error("Network error: {0}")]
58 Network(String),
59
60 #[error("Model error: {0}")]
62 Model(String),
63
64 #[error("Service unavailable: {0}")]
66 ServiceUnavailable(String),
67
68 #[error("Invalid configuration: {0}")]
70 Configuration(String),
71
72 #[error("{0}")]
74 Other(String),
75
76 #[error("Communication error: {0}")]
78 Communication(#[from] Box<dyn Error + Send + Sync>),
79}
80
81#[async_trait::async_trait]
96pub trait ModelProvider: Send + Sync {
97 fn name(&self) -> &str;
99
100 fn max_context_tokens(&self) -> usize;
102
103 fn max_output_tokens(&self) -> usize;
105
106 fn estimate_token_count(&self, text: &str) -> usize {
111 text.len().div_ceil(4)
112 }
113
114 fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
116 let mut total = 0;
117 for message in messages {
118 total += 4; for block in &message.content {
120 total += self.estimate_token_count(&format!("{:?}", block));
121 }
122 }
123 total
124 }
125
126 async fn generate(
133 &self,
134 messages: Vec<Message>,
135 tools: Vec<ToolDefinition>,
136 system_prompt: Option<String>,
137 ) -> Result<ModelResponse, ProviderError>;
138
139 async fn generate_stream(
146 &self,
147 messages: Vec<Message>,
148 tools: Vec<ToolDefinition>,
149 system_prompt: Option<String>,
150 ) -> Result<BoxStream<'static, Result<StreamEvent, ProviderError>>, ProviderError> {
151 let response = self.generate(messages, tools, system_prompt).await?;
153
154 let mut text_content = String::new();
156 let mut tool_uses = Vec::new();
157
158 for content in &response.message.content {
159 match content {
160 crate::types::ContentBlock::Text(text) => {
161 text_content.push_str(text);
162 }
163 crate::types::ContentBlock::ToolUse(tool_use) => {
164 tool_uses.push(tool_use.clone());
165 }
166 _ => {}
167 }
168 }
169
170 let mut events = Vec::new();
172 if !text_content.is_empty() {
173 events.push(Ok(StreamEvent::TextDelta(text_content)));
174 }
175 for tool_use in tool_uses {
176 events.push(Ok(StreamEvent::ToolUse(tool_use)));
177 }
178 events.push(Ok(StreamEvent::Stop {
179 stop_reason: response.stop_reason,
180 usage: response.usage,
181 }));
182
183 Ok(Box::pin(futures::stream::iter(events)))
184 }
185}
186
187#[async_trait::async_trait]
189impl ModelProvider for std::sync::Arc<dyn ModelProvider> {
190 fn name(&self) -> &str {
191 (**self).name()
192 }
193
194 fn max_context_tokens(&self) -> usize {
195 (**self).max_context_tokens()
196 }
197
198 fn max_output_tokens(&self) -> usize {
199 (**self).max_output_tokens()
200 }
201
202 fn estimate_token_count(&self, text: &str) -> usize {
203 (**self).estimate_token_count(text)
204 }
205
206 fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
207 (**self).estimate_message_tokens(messages)
208 }
209
210 async fn generate(
211 &self,
212 messages: Vec<Message>,
213 tools: Vec<ToolDefinition>,
214 system_prompt: Option<String>,
215 ) -> Result<ModelResponse, ProviderError> {
216 (**self).generate(messages, tools, system_prompt).await
217 }
218
219 async fn generate_stream(
220 &self,
221 messages: Vec<Message>,
222 tools: Vec<ToolDefinition>,
223 system_prompt: Option<String>,
224 ) -> Result<BoxStream<'static, Result<StreamEvent, ProviderError>>, ProviderError> {
225 (**self)
226 .generate_stream(messages, tools, system_prompt)
227 .await
228 }
229}