#[cfg(feature = "anthropic")]
pub mod anthropic;
#[cfg(feature = "bedrock")]
pub mod bedrock;
pub mod retry;
use crate::events::TokenUsage;
use crate::types::{Message, StopReason, ToolDefinition, ToolUseBlock};
use futures::stream::BoxStream;
use std::error::Error;
#[cfg(feature = "anthropic")]
pub use anthropic::AnthropicProvider;
#[cfg(feature = "bedrock")]
pub use bedrock::{BedrockProvider, InferenceProfile};
pub use retry::{RetryCallback, RetryConfig, RetryInfo};
pub use crate::model::ModelResponse;
#[derive(Debug, Clone)]
pub enum StreamEvent {
TextDelta(String),
ToolUse(ToolUseBlock),
ThinkingDelta(String),
Stop {
stop_reason: StopReason,
usage: Option<TokenUsage>,
},
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("Authentication failed: {0}")]
Authentication(String),
#[error("Rate limited: {0}")]
RateLimited(String),
#[error("Network error: {0}")]
Network(String),
#[error("Model error: {0}")]
Model(String),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
#[error("Invalid configuration: {0}")]
Configuration(String),
#[error("{0}")]
Other(String),
#[error("Communication error: {0}")]
Communication(#[from] Box<dyn Error + Send + Sync>),
}
#[async_trait::async_trait]
pub trait ModelProvider: Send + Sync {
fn name(&self) -> &str;
fn max_context_tokens(&self) -> usize;
fn max_output_tokens(&self) -> usize;
fn estimate_token_count(&self, text: &str) -> usize {
text.len().div_ceil(4)
}
fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
let mut total = 0;
for message in messages {
total += 4; for block in &message.content {
total += self.estimate_token_count(&format!("{:?}", block));
}
}
total
}
async fn generate(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
system_prompt: Option<String>,
) -> Result<ModelResponse, ProviderError>;
async fn generate_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
system_prompt: Option<String>,
) -> Result<BoxStream<'static, Result<StreamEvent, ProviderError>>, ProviderError> {
let response = self.generate(messages, tools, system_prompt).await?;
let mut text_content = String::new();
let mut tool_uses = Vec::new();
for content in &response.message.content {
match content {
crate::types::ContentBlock::Text(text) => {
text_content.push_str(text);
}
crate::types::ContentBlock::ToolUse(tool_use) => {
tool_uses.push(tool_use.clone());
}
_ => {}
}
}
let mut events = Vec::new();
if !text_content.is_empty() {
events.push(Ok(StreamEvent::TextDelta(text_content)));
}
for tool_use in tool_uses {
events.push(Ok(StreamEvent::ToolUse(tool_use)));
}
events.push(Ok(StreamEvent::Stop {
stop_reason: response.stop_reason,
usage: response.usage,
}));
Ok(Box::pin(futures::stream::iter(events)))
}
}
#[async_trait::async_trait]
impl ModelProvider for std::sync::Arc<dyn ModelProvider> {
fn name(&self) -> &str {
(**self).name()
}
fn max_context_tokens(&self) -> usize {
(**self).max_context_tokens()
}
fn max_output_tokens(&self) -> usize {
(**self).max_output_tokens()
}
fn estimate_token_count(&self, text: &str) -> usize {
(**self).estimate_token_count(text)
}
fn estimate_message_tokens(&self, messages: &[Message]) -> usize {
(**self).estimate_message_tokens(messages)
}
async fn generate(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
system_prompt: Option<String>,
) -> Result<ModelResponse, ProviderError> {
(**self).generate(messages, tools, system_prompt).await
}
async fn generate_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
system_prompt: Option<String>,
) -> Result<BoxStream<'static, Result<StreamEvent, ProviderError>>, ProviderError> {
(**self)
.generate_stream(messages, tools, system_prompt)
.await
}
}