1use super::types::*;
6use async_trait::async_trait;
7use futures::Stream;
8use std::pin::Pin;
9
10pub type ChatStream = Pin<Box<dyn Stream<Item = LLMResult<ChatCompletionChunk>> + Send>>;
12
13#[async_trait]
40pub trait LLMProvider: Send + Sync {
41 fn name(&self) -> &str;
43
44 fn default_model(&self) -> &str {
46 ""
47 }
48
49 fn supported_models(&self) -> Vec<&str> {
51 vec![]
52 }
53
54 fn supports_model(&self, model: &str) -> bool {
56 self.supported_models().contains(&model)
57 }
58
59 fn supports_streaming(&self) -> bool {
61 true
62 }
63
64 fn supports_tools(&self) -> bool {
66 true
67 }
68
69 fn supports_vision(&self) -> bool {
71 false
72 }
73
74 fn supports_embedding(&self) -> bool {
76 false
77 }
78
79 async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse>;
81
82 async fn chat_stream(&self, _request: ChatCompletionRequest) -> LLMResult<ChatStream> {
84 Err(LLMError::ProviderNotSupported(format!(
86 "Provider {} does not support streaming",
87 self.name()
88 )))
89 }
90
91 async fn embedding(&self, _request: EmbeddingRequest) -> LLMResult<EmbeddingResponse> {
93 Err(LLMError::ProviderNotSupported(format!(
94 "Provider {} does not support embedding",
95 self.name()
96 )))
97 }
98
99 async fn health_check(&self) -> LLMResult<bool> {
101 Ok(true)
102 }
103
104 async fn get_model_info(&self, _model: &str) -> LLMResult<ModelInfo> {
106 Err(LLMError::ProviderNotSupported(format!(
107 "Provider {} does not support model info",
108 self.name()
109 )))
110 }
111}
112
113#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
115pub struct ModelInfo {
116 pub id: String,
118 pub name: String,
120 pub description: Option<String>,
122 pub context_window: Option<u32>,
124 pub max_output_tokens: Option<u32>,
126 pub training_cutoff: Option<String>,
128 pub capabilities: ModelCapabilities,
130}
131
132#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
134pub struct ModelCapabilities {
135 pub streaming: bool,
137 pub tools: bool,
139 pub vision: bool,
141 pub json_mode: bool,
143 pub json_schema: bool,
145}
146
147#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
149pub struct LLMConfig {
150 pub provider: String,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub api_key: Option<String>,
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub base_url: Option<String>,
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub default_model: Option<String>,
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub default_temperature: Option<f32>,
164 #[serde(skip_serializing_if = "Option::is_none")]
166 pub default_max_tokens: Option<u32>,
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub timeout_secs: Option<u64>,
170 #[serde(skip_serializing_if = "Option::is_none")]
172 pub max_retries: Option<u32>,
173 #[serde(flatten)]
175 pub extra: HashMap<String, serde_json::Value>,
176}
177
178impl Default for LLMConfig {
179 fn default() -> Self {
180 Self {
181 provider: "openai".to_string(),
182 api_key: None,
183 base_url: None,
184 default_model: None,
185 default_temperature: Some(0.7),
186 default_max_tokens: Some(4096),
187 timeout_secs: Some(60),
188 max_retries: Some(3),
189 extra: std::collections::HashMap::new(),
190 }
191 }
192}
193
194impl LLMConfig {
195 pub fn openai(api_key: impl Into<String>) -> Self {
197 Self {
198 provider: "openai".to_string(),
199 api_key: Some(api_key.into()),
200 base_url: Some("https://api.openai.com/v1".to_string()),
201 default_model: Some("gpt-4".to_string()),
202 ..Default::default()
203 }
204 }
205
206 pub fn anthropic(api_key: impl Into<String>) -> Self {
208 Self {
209 provider: "anthropic".to_string(),
210 api_key: Some(api_key.into()),
211 base_url: Some("https://api.anthropic.com".to_string()),
212 default_model: Some("claude-3-sonnet-20240229".to_string()),
213 ..Default::default()
214 }
215 }
216
217 pub fn ollama(model: impl Into<String>) -> Self {
219 Self {
220 provider: "ollama".to_string(),
221 api_key: None,
222 base_url: Some("http://localhost:11434".to_string()),
223 default_model: Some(model.into()),
224 ..Default::default()
225 }
226 }
227
228 pub fn openai_compatible(
230 base_url: impl Into<String>,
231 api_key: impl Into<String>,
232 model: impl Into<String>,
233 ) -> Self {
234 Self {
235 provider: "openai-compatible".to_string(),
236 api_key: Some(api_key.into()),
237 base_url: Some(base_url.into()),
238 default_model: Some(model.into()),
239 ..Default::default()
240 }
241 }
242
243 pub fn model(mut self, model: impl Into<String>) -> Self {
245 self.default_model = Some(model.into());
246 self
247 }
248
249 pub fn temperature(mut self, temp: f32) -> Self {
251 self.default_temperature = Some(temp);
252 self
253 }
254
255 pub fn max_tokens(mut self, tokens: u32) -> Self {
257 self.default_max_tokens = Some(tokens);
258 self
259 }
260}
261
262use std::collections::HashMap;
267use std::sync::Arc;
268use tokio::sync::RwLock;
269
270pub type ProviderFactory = Box<dyn Fn(LLMConfig) -> LLMResult<Box<dyn LLMProvider>> + Send + Sync>;
272
273pub struct LLMRegistry {
277 factories: RwLock<HashMap<String, ProviderFactory>>,
278 providers: RwLock<HashMap<String, Arc<dyn LLMProvider>>>,
279}
280
281impl LLMRegistry {
282 pub fn new() -> Self {
284 Self {
285 factories: RwLock::new(HashMap::new()),
286 providers: RwLock::new(HashMap::new()),
287 }
288 }
289
290 pub async fn register_factory<F>(&self, name: &str, factory: F)
292 where
293 F: Fn(LLMConfig) -> LLMResult<Box<dyn LLMProvider>> + Send + Sync + 'static,
294 {
295 let mut factories = self.factories.write().await;
296 factories.insert(name.to_string(), Box::new(factory));
297 }
298
299 pub async fn create(&self, config: LLMConfig) -> LLMResult<Arc<dyn LLMProvider>> {
301 let factories = self.factories.read().await;
302 let factory = factories
303 .get(&config.provider)
304 .ok_or_else(|| LLMError::ProviderNotSupported(config.provider.clone()))?;
305
306 let provider = factory(config)?;
307 Ok(Arc::from(provider))
308 }
309
310 pub async fn register(&self, name: &str, provider: Arc<dyn LLMProvider>) {
312 let mut providers = self.providers.write().await;
313 providers.insert(name.to_string(), provider);
314 }
315
316 pub async fn get(&self, name: &str) -> Option<Arc<dyn LLMProvider>> {
318 let providers = self.providers.read().await;
319 providers.get(name).cloned()
320 }
321
322 pub async fn list_providers(&self) -> Vec<String> {
324 let providers = self.providers.read().await;
325 providers.keys().cloned().collect()
326 }
327
328 pub async fn list_factories(&self) -> Vec<String> {
330 let factories = self.factories.read().await;
331 factories.keys().cloned().collect()
332 }
333}
334
335impl Default for LLMRegistry {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341use std::sync::OnceLock;
346
347static GLOBAL_REGISTRY: OnceLock<LLMRegistry> = OnceLock::new();
348
349pub fn global_registry() -> &'static LLMRegistry {
351 GLOBAL_REGISTRY.get_or_init(LLMRegistry::new)
352}