1use std::pin::Pin;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use futures_core::Stream;
6use serde::{Deserialize, Serialize};
7
8use crate::auth::{ApiKey, AuthStore};
9use crate::error::Result;
10use crate::message::Message;
11use crate::model::{Model, ModelMeta};
12use crate::stream::StreamEvent;
13
14#[async_trait]
19pub trait Provider: Send + Sync {
20 fn stream(
22 &self,
23 model: &Model,
24 context: Context,
25 options: RequestOptions,
26 api_key: &str,
27 ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>;
28
29 async fn resolve_auth(&self, auth: &AuthStore) -> Result<ApiKey>;
31
32 fn id(&self) -> &str;
34
35 fn models(&self) -> &[ModelMeta];
37
38 fn transport_capabilities(&self) -> TransportCapabilities {
42 TransportCapabilities::default()
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52pub struct TransportCapabilities {
53 pub request_response: bool,
54 pub streaming: bool,
55 pub continuation: ContinuationMode,
56 pub persistent_session: PersistentSessionMode,
57 pub cancellation: CancellationMode,
58 pub resumability: ResumabilityMode,
59}
60
61impl TransportCapabilities {
62 pub const fn stateless_streaming_http() -> Self {
63 Self {
64 request_response: true,
65 streaming: true,
66 continuation: ContinuationMode::None,
67 persistent_session: PersistentSessionMode::None,
68 cancellation: CancellationMode::DropLocalStream,
69 resumability: ResumabilityMode::RestartRequest,
70 }
71 }
72}
73
74impl Default for TransportCapabilities {
75 fn default() -> Self {
76 Self::stateless_streaming_http()
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum ContinuationMode {
83 None,
84 ProviderManagedId,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
88#[serde(rename_all = "snake_case")]
89pub enum PersistentSessionMode {
90 None,
91 WebSocket,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95#[serde(rename_all = "snake_case")]
96pub enum CancellationMode {
97 DropLocalStream,
98 ProviderAbort,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103pub enum ResumabilityMode {
104 RestartRequest,
105 ResumeProviderState,
106}
107
108#[derive(Debug, Clone, Default)]
110pub struct Context {
111 pub messages: Vec<Message>,
112}
113
114#[derive(Debug, Clone)]
116pub struct RequestOptions {
117 pub thinking_level: ThinkingLevel,
118 pub max_tokens: Option<u32>,
119 pub temperature: Option<f32>,
120 pub system_prompt: String,
121 pub tools: Vec<ToolDefinition>,
122 pub cache_options: CacheOptions,
123 pub effort: Option<EffortLevel>,
125}
126
127impl Default for RequestOptions {
128 fn default() -> Self {
129 Self {
130 thinking_level: ThinkingLevel::Off,
131 max_tokens: None,
132 temperature: None,
133 system_prompt: String::new(),
134 tools: Vec::new(),
135 cache_options: CacheOptions::default(),
136 effort: None,
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
145#[serde(rename_all = "lowercase")]
146pub enum EffortLevel {
147 Low,
148 Medium,
149 High,
150}
151
152#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
154#[serde(rename_all = "lowercase")]
155pub enum ThinkingLevel {
156 #[default]
158 Off,
159 Minimal,
161 Low,
163 Medium,
165 High,
167 XHigh,
169}
170
171#[derive(Debug, Clone, Default)]
173pub struct CacheOptions {
174 pub cache_system_prompt: bool,
176 pub cache_tools: bool,
178 pub cache_recent_turns: usize,
180 pub extended_ttl: bool,
182 pub global_scope: bool,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct ToolDefinition {
189 pub name: String,
190 pub description: String,
191 pub parameters: serde_json::Value,
192}
193
194#[derive(Debug, Clone)]
196pub struct RetryPolicy {
197 pub max_retries: u32,
198 pub base_delay: Duration,
199 pub max_delay: Duration,
200 pub retry_on: Vec<RetryCondition>,
201}
202
203impl Default for RetryPolicy {
204 fn default() -> Self {
205 Self {
206 max_retries: 3,
207 base_delay: Duration::from_secs(1),
208 max_delay: Duration::from_secs(30),
209 retry_on: vec![
210 RetryCondition::RateLimit,
211 RetryCondition::ServerError,
212 RetryCondition::Timeout,
213 RetryCondition::ConnectionError,
214 ],
215 }
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq)]
221pub enum RetryCondition {
222 RateLimit,
223 ServerError,
224 Timeout,
225 ConnectionError,
226}
227
228#[cfg(test)]
229mod transport_capability_tests {
230 use super::*;
231
232 #[test]
233 fn default_transport_capabilities_are_conservative_streaming_http() {
234 let capabilities = TransportCapabilities::default();
235
236 assert!(capabilities.request_response);
237 assert!(capabilities.streaming);
238 assert_eq!(capabilities.continuation, ContinuationMode::None);
239 assert_eq!(capabilities.persistent_session, PersistentSessionMode::None);
240 assert_eq!(capabilities.cancellation, CancellationMode::DropLocalStream);
241 assert_eq!(capabilities.resumability, ResumabilityMode::RestartRequest);
242 }
243}