use std::pin::Pin;
use futures_core::Stream;
use tokio_util::sync::CancellationToken;
use crate::message::{InputContent, Message, ToolDef};
use crate::stream::AssistantStreamEvent;
pub trait Provider: Send + Sync {
fn id(&self) -> &str;
fn models(&self) -> &[ModelInfo];
fn stream(&self, request: Request) -> EventStream;
}
pub type EventStream =
Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, ProviderError>> + Send>>;
pub struct Request {
pub model: String,
pub system: Option<String>,
pub messages: Vec<Message>,
pub tools: Vec<ToolDef>,
pub max_tokens: Option<u64>,
pub temperature: Option<f64>,
pub thinking: ThinkingConfig,
pub stop_sequences: Vec<String>,
pub metadata: Option<serde_json::Value>,
pub cancel: CancellationToken,
}
impl Request {
pub fn contains_image_input(&self) -> bool {
self.messages.iter().any(|message| match message {
Message::User(user) => user
.content
.iter()
.any(|content| matches!(content, InputContent::Image { .. })),
_ => false,
})
}
}
#[derive(Debug, Clone, Default)]
pub struct ThinkingConfig {
pub enabled: bool,
pub budget_tokens: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub id: String,
pub display_name: String,
pub context_window: u64,
pub max_output_tokens: u64,
pub supports_images: bool,
pub supports_streaming: bool,
pub supports_thinking: bool,
}
pub fn validate_request_capabilities(
provider: &dyn Provider,
request: &Request,
) -> Result<(), ProviderError> {
if !request.contains_image_input() {
return Ok(());
}
let model_id = request
.model
.split_once(':')
.map(|(provider_id, model_id)| {
if provider_id == provider.id() {
model_id
} else {
request.model.as_str()
}
})
.unwrap_or(request.model.as_str());
let Some(model) = provider.models().iter().find(|m| m.id == model_id) else {
return Ok(());
};
if model.supports_images {
return Ok(());
}
Err(ProviderError::RequestFailed(format!(
"model '{}' for provider '{}' does not support image input",
model.id,
provider.id()
)))
}
#[derive(Debug, thiserror::Error)]
pub enum ProviderError {
#[error("rate limited")]
RateLimited { retry_after_ms: Option<u64> },
#[error("request timed out")]
Timeout,
#[error("request failed: {0}")]
RequestFailed(String),
#[error("stream error: {0}")]
StreamError(String),
#[error("authentication failed: {0}")]
AuthFailed(String),
}
impl ProviderError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
ProviderError::RateLimited { .. } | ProviderError::Timeout
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProviderKind {
OpenAI,
Anthropic,
Google,
Mistral,
Bedrock,
Azure,
}