use std::pin::Pin;
use std::time::Duration;
use async_trait::async_trait;
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use crate::auth::{ApiKey, AuthStore};
use crate::error::Result;
use crate::message::Message;
use crate::model::{Model, ModelMeta};
use crate::stream::StreamEvent;
#[async_trait]
pub trait Provider: Send + Sync {
fn stream(
&self,
model: &Model,
context: Context,
options: RequestOptions,
api_key: &str,
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>;
async fn resolve_auth(&self, auth: &AuthStore) -> Result<ApiKey>;
fn id(&self) -> &str;
fn models(&self) -> &[ModelMeta];
fn transport_capabilities(&self) -> TransportCapabilities {
TransportCapabilities::default()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TransportCapabilities {
pub request_response: bool,
pub streaming: bool,
pub continuation: ContinuationMode,
pub persistent_session: PersistentSessionMode,
pub cancellation: CancellationMode,
pub resumability: ResumabilityMode,
}
impl TransportCapabilities {
pub const fn stateless_streaming_http() -> Self {
Self {
request_response: true,
streaming: true,
continuation: ContinuationMode::None,
persistent_session: PersistentSessionMode::None,
cancellation: CancellationMode::DropLocalStream,
resumability: ResumabilityMode::RestartRequest,
}
}
}
impl Default for TransportCapabilities {
fn default() -> Self {
Self::stateless_streaming_http()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContinuationMode {
None,
ProviderManagedId,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PersistentSessionMode {
None,
WebSocket,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CancellationMode {
DropLocalStream,
ProviderAbort,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResumabilityMode {
RestartRequest,
ResumeProviderState,
}
#[derive(Debug, Clone, Default)]
pub struct Context {
pub messages: Vec<Message>,
}
#[derive(Debug, Clone)]
pub struct RequestOptions {
pub thinking_level: ThinkingLevel,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system_prompt: String,
pub tools: Vec<ToolDefinition>,
pub cache_options: CacheOptions,
pub effort: Option<EffortLevel>,
}
impl Default for RequestOptions {
fn default() -> Self {
Self {
thinking_level: ThinkingLevel::Off,
max_tokens: None,
temperature: None,
system_prompt: String::new(),
tools: Vec::new(),
cache_options: CacheOptions::default(),
effort: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EffortLevel {
Low,
Medium,
High,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ThinkingLevel {
#[default]
Off,
Minimal,
Low,
Medium,
High,
XHigh,
}
#[derive(Debug, Clone, Default)]
pub struct CacheOptions {
pub cache_system_prompt: bool,
pub cache_tools: bool,
pub cache_recent_turns: usize,
pub extended_ttl: bool,
pub global_scope: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub retry_on: Vec<RetryCondition>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
retry_on: vec![
RetryCondition::RateLimit,
RetryCondition::ServerError,
RetryCondition::Timeout,
RetryCondition::ConnectionError,
],
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RetryCondition {
RateLimit,
ServerError,
Timeout,
ConnectionError,
}
#[cfg(test)]
mod transport_capability_tests {
use super::*;
#[test]
fn default_transport_capabilities_are_conservative_streaming_http() {
let capabilities = TransportCapabilities::default();
assert!(capabilities.request_response);
assert!(capabilities.streaming);
assert_eq!(capabilities.continuation, ContinuationMode::None);
assert_eq!(capabilities.persistent_session, PersistentSessionMode::None);
assert_eq!(capabilities.cancellation, CancellationMode::DropLocalStream);
assert_eq!(capabilities.resumability, ResumabilityMode::RestartRequest);
}
}