1use crate::streaming::ChatCompletionChunk;
2use crate::types::{
3 ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse,
4};
5use futures::Stream;
6use std::fmt::Debug;
7use std::future::Future;
8use std::pin::Pin;
9
10pub type ChatCompletionStream =
11 Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk, LlmError>> + Send>>;
12
13#[async_trait::async_trait]
15pub trait Provider: Send + Sync + Debug {
16 async fn chat_completion(
18 &self,
19 request: ChatCompletionRequest,
20 ) -> Result<ChatCompletionResponse, LlmError>;
21
22 fn chat_completion_stream(
24 &self,
25 _request: ChatCompletionRequest,
26 ) -> Pin<Box<dyn Future<Output = Result<ChatCompletionStream, LlmError>> + Send + '_>> {
27 Box::pin(async { Err(LlmError::UnsupportedFeature) })
28 }
29
30 async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError>;
32
33 fn supported_models(&self) -> Vec<String> {
35 vec![]
36 }
37
38 async fn list_models(&self) -> Result<Vec<String>, LlmError> {
40 Err(LlmError::UnsupportedFeature)
41 }
42
43 fn provider_name(&self) -> &'static str;
45}
46
47#[derive(Debug, thiserror::Error)]
49pub enum LlmError {
50 #[error("HTTP request failed: {0}")]
51 HttpError(String),
52
53 #[error("API error: {status} - {message}")]
54 ApiError { status: u16, message: String },
55
56 #[error("Authentication failed")]
57 AuthError,
58
59 #[error("Rate limit exceeded")]
60 RateLimitError,
61
62 #[error("Invalid request: {0}")]
63 InvalidRequest(String),
64
65 #[error("Provider error: {0}")]
66 ProviderError(String),
67
68 #[error("Serialization error: {0}")]
69 SerializationError(#[from] serde_json::Error),
70
71 #[error("Unknown error: {0}")]
72 Unknown(String),
73
74 #[error("Feature not supported by this provider")]
75 UnsupportedFeature,
76
77 #[error("Resource not found")]
78 NotFound,
79
80 #[error("Internal provider error: {0}")]
81 InternalError(String),
82
83 #[error("Request timed out")]
84 Timeout,
85}
86
87use std::sync::Arc;
88
89#[derive(Debug)]
91pub struct ProviderRegistry {
92 providers: Vec<Arc<dyn Provider>>,
93}
94
95impl ProviderRegistry {
96 pub fn new() -> Self {
98 Self {
99 providers: Vec::new(),
100 }
101 }
102
103 pub fn register(&mut self, provider: Arc<dyn Provider>) {
105 self.providers.push(provider);
106 }
107
108 pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
110 self.providers
111 .iter()
112 .find(|p| p.provider_name() == name)
113 .cloned()
114 }
115
116 pub fn list(&self) -> Vec<&'static str> {
118 self.providers.iter().map(|p| p.provider_name()).collect()
119 }
120
121 pub fn find_by_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
123 self.providers
124 .iter()
125 .find(|p| p.supported_models().contains(&model.to_string()))
126 .cloned()
127 }
128}
129
130pub fn parse_model_id(model_id: &str) -> Result<(&str, String), String> {
135 let parts: Vec<&str> = model_id.split('/').collect();
136
137 if parts.len() < 2 {
138 return Err("Model must be in format 'provider/model'".to_string());
139 }
140
141 let provider = parts[0];
142 let model_name = parts[1..].join("/");
143
144 if provider.is_empty() || model_name.is_empty() {
145 return Err("Provider and model name cannot be empty".to_string());
146 }
147
148 Ok((provider, model_name))
149}
150
151#[derive(Debug, Clone)]
154pub struct RoutingProvider {
155 registry: Arc<ProviderRegistry>,
156}
157
158impl RoutingProvider {
159 pub fn new(registry: ProviderRegistry) -> Self {
161 Self {
162 registry: Arc::new(registry),
163 }
164 }
165}
166
167#[async_trait::async_trait]
168impl Provider for RoutingProvider {
169 async fn chat_completion(
170 &self,
171 mut request: ChatCompletionRequest,
172 ) -> Result<ChatCompletionResponse, LlmError> {
173 let (provider_name, actual_model) =
175 parse_model_id(&request.model).map_err(LlmError::InvalidRequest)?;
176
177 let provider = self.registry.get(provider_name).ok_or_else(|| {
179 LlmError::ProviderError(format!("Unknown provider: {}", provider_name))
180 })?;
181
182 request.model = actual_model;
184
185 provider.chat_completion(request).await
187 }
188
189 fn chat_completion_stream(
190 &self,
191 mut request: ChatCompletionRequest,
192 ) -> Pin<Box<dyn Future<Output = Result<ChatCompletionStream, LlmError>> + Send + '_>> {
193 let registry = self.registry.clone();
195
196 Box::pin(async move {
197 let (provider_name, actual_model) =
198 parse_model_id(&request.model).map_err(LlmError::InvalidRequest)?;
199
200 let provider = registry.get(provider_name).ok_or_else(|| {
201 LlmError::ProviderError(format!("Unknown provider: {}", provider_name))
202 })?;
203
204 request.model = actual_model;
205
206 provider.chat_completion_stream(request).await
207 })
208 }
209
210 async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
211 let (provider_name, actual_model) =
217 parse_model_id(&request.model).map_err(LlmError::InvalidRequest)?;
218
219 let provider = self.registry.get(provider_name).ok_or_else(|| {
220 LlmError::ProviderError(format!("Unknown provider: {}", provider_name))
221 })?;
222
223 let mut new_request = request;
229 new_request.model = actual_model;
230
231 provider.embeddings(new_request).await
232 }
233
234 fn supported_models(&self) -> Vec<String> {
235 let mut models = Vec::new();
237 for provider in &self.registry.providers {
238 let name = provider.provider_name();
239 for model in provider.supported_models() {
240 models.push(format!("{}/{}", name, model));
241 }
242 }
243 models
244 }
245
246 fn provider_name(&self) -> &'static str {
247 "router"
248 }
249}
250
251#[derive(Debug)]
253pub struct FallbackProvider {
254 providers: Vec<Box<dyn Provider>>,
255}
256
257impl FallbackProvider {
258 pub fn new(providers: Vec<Box<dyn Provider>>) -> Self {
259 Self { providers }
260 }
261}
262
263#[async_trait::async_trait]
264impl Provider for FallbackProvider {
265 async fn chat_completion(
266 &self,
267 request: ChatCompletionRequest,
268 ) -> Result<ChatCompletionResponse, LlmError> {
269 let mut last_error = LlmError::ProviderError("No providers configured".to_string());
270
271 for provider in &self.providers {
272 match provider.chat_completion(request.clone()).await {
273 Ok(response) => return Ok(response),
274 Err(e) => {
275 tracing::warn!("Provider {} failed: {}", provider.provider_name(), e);
276 last_error = e;
277 if matches!(last_error, LlmError::InvalidRequest(_)) {
280 break;
281 }
282 }
283 }
284 }
285
286 Err(last_error)
287 }
288
289 async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
290 for provider in &self.providers {
291 if let Ok(res) = provider.embeddings(request.clone()).await {
292 return Ok(res);
293 }
294 }
295 Err(LlmError::ProviderError(
296 "All embedding providers failed".to_string(),
297 ))
298 }
299
300 fn supported_models(&self) -> Vec<String> {
301 self.providers
302 .iter()
303 .flat_map(|p| p.supported_models())
304 .collect()
305 }
306
307 fn provider_name(&self) -> &'static str {
308 "fallback"
309 }
310}
311
312impl Default for ProviderRegistry {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318pub trait Credentials: Send + Sync + Debug {
320 fn apply(&self, request: &mut reqwest::Request) -> Result<(), LlmError>;
322}
323
324#[derive(Debug, Clone)]
326pub struct ApiKeyCredentials {
327 key: String,
328 header_name: String,
329}
330
331impl ApiKeyCredentials {
332 pub fn new(key: impl Into<String>) -> Self {
334 Self {
335 key: key.into(),
336 header_name: "Authorization".to_string(),
337 }
338 }
339
340 pub fn bearer(key: impl Into<String>) -> Self {
342 Self {
343 key: format!("Bearer {}", key.into()),
344 header_name: "Authorization".to_string(),
345 }
346 }
347
348 pub fn with_header(key: impl Into<String>, header: impl Into<String>) -> Self {
350 Self {
351 key: key.into(),
352 header_name: header.into(),
353 }
354 }
355}
356
357impl Credentials for ApiKeyCredentials {
358 fn apply(&self, request: &mut reqwest::Request) -> Result<(), LlmError> {
359 request.headers_mut().insert(
360 reqwest::header::HeaderName::from_bytes(self.header_name.as_bytes())
361 .map_err(|e| LlmError::InvalidRequest(format!("Invalid header name: {}", e)))?,
362 reqwest::header::HeaderValue::from_str(&self.key)
363 .map_err(|e| LlmError::InvalidRequest(format!("Invalid header value: {}", e)))?,
364 );
365 Ok(())
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[derive(Debug)]
374 struct MockProvider;
375
376 #[async_trait::async_trait]
377 impl Provider for MockProvider {
378 async fn chat_completion(
379 &self,
380 _request: ChatCompletionRequest,
381 ) -> Result<ChatCompletionResponse, LlmError> {
382 unimplemented!()
383 }
384
385 async fn embeddings(
386 &self,
387 _request: EmbeddingRequest,
388 ) -> Result<EmbeddingResponse, LlmError> {
389 unimplemented!()
390 }
391
392 fn supported_models(&self) -> Vec<String> {
393 vec![]
394 }
395
396 fn provider_name(&self) -> &'static str {
397 "mock"
398 }
399 }
400
401 #[test]
402 fn test_parse_model_id_simple() {
403 let result = parse_model_id("openai/gpt-4").unwrap();
404 assert_eq!(result.0, "openai");
405 assert_eq!(result.1, "gpt-4");
406 }
407
408 #[test]
409 fn test_parse_model_id_nested() {
410 let result = parse_model_id("openrouter/openai/gpt-4").unwrap();
411 assert_eq!(result.0, "openrouter");
412 assert_eq!(result.1, "openai/gpt-4");
413 }
414
415 #[test]
416 fn test_parse_model_id_invalid() {
417 assert!(parse_model_id("invalid").is_err());
418 assert!(parse_model_id("/model").is_err());
419 assert!(parse_model_id("provider/").is_err());
420 }
421
422 #[test]
423 fn test_provider_registry() {
424 let mut registry = ProviderRegistry::new();
425 registry.register(Arc::new(MockProvider));
426
427 assert_eq!(registry.list(), vec!["mock"]);
428 assert!(registry.get("mock").is_some());
429 assert!(registry.get("nonexistent").is_none());
430 }
431}