mixtape_core/provider/
mod.rs

1//! Model providers for LLM interactions
2//!
3//! This module contains the `ModelProvider` trait and implementations for
4//! different LLM backends (Bedrock, Anthropic, etc.)
5
6#[cfg(feature = "anthropic")]
7pub mod anthropic;
8#[cfg(feature = "bedrock")]
9pub mod bedrock;
10pub mod retry;
11
12use crate::events::TokenUsage;
13use crate::types::{Message, StopReason, ToolDefinition, ToolUseBlock};
14use futures::stream::BoxStream;
15use std::error::Error;
16
17// Re-export provider types at provider level
18#[cfg(feature = "anthropic")]
19pub use anthropic::AnthropicProvider;
20#[cfg(feature = "bedrock")]
21pub use bedrock::{BedrockProvider, InferenceProfile};
22pub use retry::{RetryCallback, RetryConfig, RetryInfo};
23
24// Re-export ModelResponse from model module
25pub use crate::model::ModelResponse;
26
27/// Events from streaming model responses
28#[derive(Debug, Clone)]
29pub enum StreamEvent {
30    /// Incremental text delta
31    TextDelta(String),
32    /// Tool use detected
33    ToolUse(ToolUseBlock),
34    /// Incremental thinking delta (extended thinking)
35    ThinkingDelta(String),
36    /// Streaming stopped
37    Stop {
38        /// Why the model stopped
39        stop_reason: StopReason,
40        /// Token usage for this response (if available)
41        usage: Option<TokenUsage>,
42    },
43}
44
45/// Error types for model providers
46#[derive(Debug, thiserror::Error)]
47pub enum ProviderError {
48    /// Authentication or authorization failed (expired tokens, invalid credentials, etc.)
49    #[error("Authentication failed: {0}")]
50    Authentication(String),
51
52    /// Rate limiting or throttling
53    #[error("Rate limited: {0}")]
54    RateLimited(String),
55
56    /// Network or connectivity issues
57    #[error("Network error: {0}")]
58    Network(String),
59
60    /// Model-specific errors (content filtered, context too long, etc.)
61    #[error("Model error: {0}")]
62    Model(String),
63
64    /// Service unavailable or temporary issues
65    #[error("Service unavailable: {0}")]
66    ServiceUnavailable(String),
67
68    /// Invalid configuration (bad model ID, missing parameters, etc.)
69    #[error("Invalid configuration: {0}")]
70    Configuration(String),
71
72    /// Other provider-specific errors that don't fit above categories
73    #[error("{0}")]
74    Other(String),
75
76    /// Communication error (legacy, kept for compatibility)
77    #[error("Communication error: {0}")]
78    Communication(#[from] Box<dyn Error + Send + Sync>),
79}
80
81/// Trait for model providers
82///
83/// This trait abstracts over different LLM providers (Bedrock, Anthropic, etc.)
84/// allowing the Agent to work with any provider implementation.
85///
86/// A provider combines model metadata (name, token limits) with API interaction
87/// (generate, streaming). Use the builder to create agents:
88///
89/// ```ignore
90/// let agent = Agent::builder()
91///     .bedrock(ClaudeSonnet4_5)
92///     .build()
93///     .await?;
94/// ```
95#[async_trait::async_trait]
96pub trait ModelProvider: Send + Sync {
97    /// Get the model name for display (e.g., "Claude Sonnet 4.5")
98    fn name(&self) -> &str;
99
100    /// Maximum input context tokens for this model
101    fn max_context_tokens(&self) -> usize;
102
103    /// Maximum output tokens this model can generate
104    fn max_output_tokens(&self) -> usize;
105
106    /// Estimate token count for text
107    ///
108    /// Providers should implement this to match their model's tokenization.
109    /// Default implementation uses ~4 characters per token heuristic.
110    fn estimate_token_count(&self, text: &str) -> usize {
111        text.len().div_ceil(4)
112    }
113
114    /// Estimate token count for a conversation
115    fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
116        let mut total = 0;
117        for message in messages {
118            total += 4; // Role overhead
119            for block in &message.content {
120                total += self.estimate_token_count(&format!("{:?}", block));
121            }
122        }
123        total
124    }
125
126    /// Send a request to the model and get a response
127    ///
128    /// # Arguments
129    /// * `messages` - The conversation history
130    /// * `tools` - Available tools for the model to use
131    /// * `system_prompt` - Optional system prompt
132    async fn generate(
133        &self,
134        messages: Vec<Message>,
135        tools: Vec<ToolDefinition>,
136        system_prompt: Option<String>,
137    ) -> Result<ModelResponse, ProviderError>;
138
139    /// Send a request and stream the response token-by-token (optional)
140    ///
141    /// # Arguments
142    /// * `messages` - The conversation history
143    /// * `tools` - Available tools for the model to use
144    /// * `system_prompt` - Optional system prompt
145    async fn generate_stream(
146        &self,
147        messages: Vec<Message>,
148        tools: Vec<ToolDefinition>,
149        system_prompt: Option<String>,
150    ) -> Result<BoxStream<'static, Result<StreamEvent, ProviderError>>, ProviderError> {
151        // Default implementation: call generate and return complete response
152        let response = self.generate(messages, tools, system_prompt).await?;
153
154        // Extract text content and tool uses from response message
155        let mut text_content = String::new();
156        let mut tool_uses = Vec::new();
157
158        for content in &response.message.content {
159            match content {
160                crate::types::ContentBlock::Text(text) => {
161                    text_content.push_str(text);
162                }
163                crate::types::ContentBlock::ToolUse(tool_use) => {
164                    tool_uses.push(tool_use.clone());
165                }
166                _ => {}
167            }
168        }
169
170        // Create a stream with the complete response
171        let mut events = Vec::new();
172        if !text_content.is_empty() {
173            events.push(Ok(StreamEvent::TextDelta(text_content)));
174        }
175        for tool_use in tool_uses {
176            events.push(Ok(StreamEvent::ToolUse(tool_use)));
177        }
178        events.push(Ok(StreamEvent::Stop {
179            stop_reason: response.stop_reason,
180            usage: response.usage,
181        }));
182
183        Ok(Box::pin(futures::stream::iter(events)))
184    }
185}
186
187// Implement ModelProvider for Arc<dyn ModelProvider> to support dynamic dispatch
188#[async_trait::async_trait]
189impl ModelProvider for std::sync::Arc<dyn ModelProvider> {
190    fn name(&self) -> &str {
191        (**self).name()
192    }
193
194    fn max_context_tokens(&self) -> usize {
195        (**self).max_context_tokens()
196    }
197
198    fn max_output_tokens(&self) -> usize {
199        (**self).max_output_tokens()
200    }
201
202    fn estimate_token_count(&self, text: &str) -> usize {
203        (**self).estimate_token_count(text)
204    }
205
206    fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
207        (**self).estimate_message_tokens(messages)
208    }
209
210    async fn generate(
211        &self,
212        messages: Vec<Message>,
213        tools: Vec<ToolDefinition>,
214        system_prompt: Option<String>,
215    ) -> Result<ModelResponse, ProviderError> {
216        (**self).generate(messages, tools, system_prompt).await
217    }
218
219    async fn generate_stream(
220        &self,
221        messages: Vec<Message>,
222        tools: Vec<ToolDefinition>,
223        system_prompt: Option<String>,
224    ) -> Result<BoxStream<'static, Result<StreamEvent, ProviderError>>, ProviderError> {
225        (**self)
226            .generate_stream(messages, tools, system_prompt)
227            .await
228    }
229}