use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use futures::Stream;
use crate::messages::Message;
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("HTTP request failed: {0}")]
Http(String),
#[error("API error {status}: {body}")]
Api { status: u16, body: String },
#[error("Stream error: {0}")]
Stream(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("{0}")]
Other(String),
}
#[async_trait]
pub trait Provider: Send + Sync {
async fn chat_completion(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Message, ProviderError>;
async fn chat_completion_stream(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError>;
fn last_stream_message(&self) -> Option<Message>;
fn last_usage(&self) -> Option<crate::messages::Usage>;
async fn embed(&self, _text: &str) -> Option<Vec<f32>> {
None
}
}
#[async_trait]
impl<T: Provider + ?Sized> Provider for Box<T> {
async fn chat_completion(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Message, ProviderError> {
(**self).chat_completion(messages, tools, tool_choice, max_tokens, temperature).await
}
async fn chat_completion_stream(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
(**self).chat_completion_stream(messages, tools, tool_choice, max_tokens, temperature).await
}
fn last_stream_message(&self) -> Option<Message> {
(**self).last_stream_message()
}
fn last_usage(&self) -> Option<crate::messages::Usage> {
(**self).last_usage()
}
}
#[async_trait]
impl<T: Provider + ?Sized> Provider for Arc<T> {
async fn chat_completion(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Message, ProviderError> {
(**self).chat_completion(messages, tools, tool_choice, max_tokens, temperature).await
}
async fn chat_completion_stream(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
(**self).chat_completion_stream(messages, tools, tool_choice, max_tokens, temperature).await
}
fn last_stream_message(&self) -> Option<Message> {
(**self).last_stream_message()
}
fn last_usage(&self) -> Option<crate::messages::Usage> {
(**self).last_usage()
}
}