Skip to main content

llmg_core/
provider.rs

1use crate::streaming::ChatCompletionChunk;
2use crate::types::{
3    ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse,
4};
5use futures::Stream;
6use std::fmt::Debug;
7use std::future::Future;
8use std::pin::Pin;
9
10pub type ChatCompletionStream =
11    Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk, LlmError>> + Send>>;
12
13/// Core trait that all LLM providers must implement
14#[async_trait::async_trait]
15pub trait Provider: Send + Sync + Debug {
16    /// Generate a chat completion
17    async fn chat_completion(
18        &self,
19        request: ChatCompletionRequest,
20    ) -> Result<ChatCompletionResponse, LlmError>;
21
22    /// Stream a chat completion
23    fn chat_completion_stream(
24        &self,
25        _request: ChatCompletionRequest,
26    ) -> Pin<Box<dyn Future<Output = Result<ChatCompletionStream, LlmError>> + Send + '_>> {
27        Box::pin(async { Err(LlmError::UnsupportedFeature) })
28    }
29
30    /// Generate embeddings for text
31    async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError>;
32
33    /// Get the list of models supported by this provider
34    fn supported_models(&self) -> Vec<String> {
35        vec![]
36    }
37
38    /// Dynamically list available models from the provider API
39    async fn list_models(&self) -> Result<Vec<String>, LlmError> {
40        Err(LlmError::UnsupportedFeature)
41    }
42
43    /// Get the provider name
44    fn provider_name(&self) -> &'static str;
45}
46
47/// Error types for LLM operations
48#[derive(Debug, thiserror::Error)]
49pub enum LlmError {
50    #[error("HTTP request failed: {0}")]
51    HttpError(String),
52
53    #[error("API error: {status} - {message}")]
54    ApiError { status: u16, message: String },
55
56    #[error("Authentication failed")]
57    AuthError,
58
59    #[error("Rate limit exceeded")]
60    RateLimitError,
61
62    #[error("Invalid request: {0}")]
63    InvalidRequest(String),
64
65    #[error("Provider error: {0}")]
66    ProviderError(String),
67
68    #[error("Serialization error: {0}")]
69    SerializationError(#[from] serde_json::Error),
70
71    #[error("Unknown error: {0}")]
72    Unknown(String),
73
74    #[error("Feature not supported by this provider")]
75    UnsupportedFeature,
76
77    #[error("Resource not found")]
78    NotFound,
79
80    #[error("Internal provider error: {0}")]
81    InternalError(String),
82
83    #[error("Request timed out")]
84    Timeout,
85}
86
87use std::sync::Arc;
88
89/// Registry for managing multiple providers
90#[derive(Debug)]
91pub struct ProviderRegistry {
92    providers: Vec<Arc<dyn Provider>>,
93}
94
95impl ProviderRegistry {
96    /// Create a new empty registry
97    pub fn new() -> Self {
98        Self {
99            providers: Vec::new(),
100        }
101    }
102
103    /// Register a provider
104    pub fn register(&mut self, provider: Arc<dyn Provider>) {
105        self.providers.push(provider);
106    }
107
108    /// Get a provider by name
109    pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
110        self.providers
111            .iter()
112            .find(|p| p.provider_name() == name)
113            .cloned()
114    }
115
116    /// List all registered providers
117    pub fn list(&self) -> Vec<&'static str> {
118        self.providers.iter().map(|p| p.provider_name()).collect()
119    }
120
121    /// Find provider that supports a specific model.
122    pub fn find_by_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
123        self.providers
124            .iter()
125            .find(|p| p.supported_models().contains(&model.to_string()))
126            .cloned()
127    }
128}
129
130/// Parse a model identifier in the format "provider/model"
131/// Supports nested routing like "openrouter/openai/gpt-4"
132///
133/// Returns: (provider, model_name)
134pub fn parse_model_id(model_id: &str) -> Result<(&str, String), String> {
135    let parts: Vec<&str> = model_id.split('/').collect();
136
137    if parts.len() < 2 {
138        return Err("Model must be in format 'provider/model'".to_string());
139    }
140
141    let provider = parts[0];
142    let model_name = parts[1..].join("/");
143
144    if provider.is_empty() || model_name.is_empty() {
145        return Err("Provider and model name cannot be empty".to_string());
146    }
147
148    Ok((provider, model_name))
149}
150
151/// A provider that acts as a router, delegating to other providers based on the model name.
152/// Model names must be in the format `provider/model` (e.g., `openai/gpt-4`).
153#[derive(Debug, Clone)]
154pub struct RoutingProvider {
155    registry: Arc<ProviderRegistry>,
156}
157
158impl RoutingProvider {
159    /// Create a new routing provider with the given registry
160    pub fn new(registry: ProviderRegistry) -> Self {
161        Self {
162            registry: Arc::new(registry),
163        }
164    }
165}
166
167#[async_trait::async_trait]
168impl Provider for RoutingProvider {
169    async fn chat_completion(
170        &self,
171        mut request: ChatCompletionRequest,
172    ) -> Result<ChatCompletionResponse, LlmError> {
173        // 1. Parse provider from model name
174        let (provider_name, actual_model) =
175            parse_model_id(&request.model).map_err(LlmError::InvalidRequest)?;
176
177        // 2. Find provider
178        let provider = self.registry.get(provider_name).ok_or_else(|| {
179            LlmError::ProviderError(format!("Unknown provider: {}", provider_name))
180        })?;
181
182        // 3. Update request with actual model name (stripped of prefix)
183        request.model = actual_model;
184
185        // 4. Delegate
186        provider.chat_completion(request).await
187    }
188
189    fn chat_completion_stream(
190        &self,
191        mut request: ChatCompletionRequest,
192    ) -> Pin<Box<dyn Future<Output = Result<ChatCompletionStream, LlmError>> + Send + '_>> {
193        // We need to capture the registry for the async block
194        let registry = self.registry.clone();
195
196        Box::pin(async move {
197            let (provider_name, actual_model) =
198                parse_model_id(&request.model).map_err(LlmError::InvalidRequest)?;
199
200            let provider = registry.get(provider_name).ok_or_else(|| {
201                LlmError::ProviderError(format!("Unknown provider: {}", provider_name))
202            })?;
203
204            request.model = actual_model;
205
206            provider.chat_completion_stream(request).await
207        })
208    }
209
210    async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
211        // Embeddings might not always have "model" in the same way, but usually do.
212        // If the request format supports provider/model, we can route it.
213        // However, EmbeddingRequest structure isn't shown fully here in context,
214        // assuming it has a public `model` field like ChatCompletionRequest.
215
216        let (provider_name, actual_model) =
217            parse_model_id(&request.model).map_err(LlmError::InvalidRequest)?;
218
219        let provider = self.registry.get(provider_name).ok_or_else(|| {
220            LlmError::ProviderError(format!("Unknown provider: {}", provider_name))
221        })?;
222
223        // We need to clone the request to modify it, but EmbeddingRequest fields aren't fully visible.
224        // Assuming we can create a new request or modify it.
225        // The trait definition shows `request: EmbeddingRequest` (consumes it).
226        // Let's rely on struct update syntax if fields are public.
227
228        let mut new_request = request;
229        new_request.model = actual_model;
230
231        provider.embeddings(new_request).await
232    }
233
234    fn supported_models(&self) -> Vec<String> {
235        // Return all models from all providers, prefixed with provider name
236        let mut models = Vec::new();
237        for provider in &self.registry.providers {
238            let name = provider.provider_name();
239            for model in provider.supported_models() {
240                models.push(format!("{}/{}", name, model));
241            }
242        }
243        models
244    }
245
246    fn provider_name(&self) -> &'static str {
247        "router"
248    }
249}
250
251/// A provider that tries multiple other providers in sequence (fallbacks)
252#[derive(Debug)]
253pub struct FallbackProvider {
254    providers: Vec<Box<dyn Provider>>,
255}
256
257impl FallbackProvider {
258    pub fn new(providers: Vec<Box<dyn Provider>>) -> Self {
259        Self { providers }
260    }
261}
262
263#[async_trait::async_trait]
264impl Provider for FallbackProvider {
265    async fn chat_completion(
266        &self,
267        request: ChatCompletionRequest,
268    ) -> Result<ChatCompletionResponse, LlmError> {
269        let mut last_error = LlmError::ProviderError("No providers configured".to_string());
270
271        for provider in &self.providers {
272            match provider.chat_completion(request.clone()).await {
273                Ok(response) => return Ok(response),
274                Err(e) => {
275                    tracing::warn!("Provider {} failed: {}", provider.provider_name(), e);
276                    last_error = e;
277                    // Only fallback on certain errors (e.g. RateLimit or ApiErrors)
278                    // If it's a 400 Bad Request (InvalidRequest), we should probably stop.
279                    if matches!(last_error, LlmError::InvalidRequest(_)) {
280                        break;
281                    }
282                }
283            }
284        }
285
286        Err(last_error)
287    }
288
289    async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
290        for provider in &self.providers {
291            if let Ok(res) = provider.embeddings(request.clone()).await {
292                return Ok(res);
293            }
294        }
295        Err(LlmError::ProviderError(
296            "All embedding providers failed".to_string(),
297        ))
298    }
299
300    fn supported_models(&self) -> Vec<String> {
301        self.providers
302            .iter()
303            .flat_map(|p| p.supported_models())
304            .collect()
305    }
306
307    fn provider_name(&self) -> &'static str {
308        "fallback"
309    }
310}
311
312impl Default for ProviderRegistry {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318/// Credentials trait for authentication
319pub trait Credentials: Send + Sync + Debug {
320    /// Apply authentication to a request
321    fn apply(&self, request: &mut reqwest::Request) -> Result<(), LlmError>;
322}
323
324/// Simple API key authentication
325#[derive(Debug, Clone)]
326pub struct ApiKeyCredentials {
327    key: String,
328    header_name: String,
329}
330
331impl ApiKeyCredentials {
332    /// Create new API key credentials
333    pub fn new(key: impl Into<String>) -> Self {
334        Self {
335            key: key.into(),
336            header_name: "Authorization".to_string(),
337        }
338    }
339
340    /// Create with bearer token format
341    pub fn bearer(key: impl Into<String>) -> Self {
342        Self {
343            key: format!("Bearer {}", key.into()),
344            header_name: "Authorization".to_string(),
345        }
346    }
347
348    /// Create with custom header
349    pub fn with_header(key: impl Into<String>, header: impl Into<String>) -> Self {
350        Self {
351            key: key.into(),
352            header_name: header.into(),
353        }
354    }
355}
356
357impl Credentials for ApiKeyCredentials {
358    fn apply(&self, request: &mut reqwest::Request) -> Result<(), LlmError> {
359        request.headers_mut().insert(
360            reqwest::header::HeaderName::from_bytes(self.header_name.as_bytes())
361                .map_err(|e| LlmError::InvalidRequest(format!("Invalid header name: {}", e)))?,
362            reqwest::header::HeaderValue::from_str(&self.key)
363                .map_err(|e| LlmError::InvalidRequest(format!("Invalid header value: {}", e)))?,
364        );
365        Ok(())
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[derive(Debug)]
374    struct MockProvider;
375
376    #[async_trait::async_trait]
377    impl Provider for MockProvider {
378        async fn chat_completion(
379            &self,
380            _request: ChatCompletionRequest,
381        ) -> Result<ChatCompletionResponse, LlmError> {
382            unimplemented!()
383        }
384
385        async fn embeddings(
386            &self,
387            _request: EmbeddingRequest,
388        ) -> Result<EmbeddingResponse, LlmError> {
389            unimplemented!()
390        }
391
392        fn supported_models(&self) -> Vec<String> {
393            vec![]
394        }
395
396        fn provider_name(&self) -> &'static str {
397            "mock"
398        }
399    }
400
401    #[test]
402    fn test_parse_model_id_simple() {
403        let result = parse_model_id("openai/gpt-4").unwrap();
404        assert_eq!(result.0, "openai");
405        assert_eq!(result.1, "gpt-4");
406    }
407
408    #[test]
409    fn test_parse_model_id_nested() {
410        let result = parse_model_id("openrouter/openai/gpt-4").unwrap();
411        assert_eq!(result.0, "openrouter");
412        assert_eq!(result.1, "openai/gpt-4");
413    }
414
415    #[test]
416    fn test_parse_model_id_invalid() {
417        assert!(parse_model_id("invalid").is_err());
418        assert!(parse_model_id("/model").is_err());
419        assert!(parse_model_id("provider/").is_err());
420    }
421
422    #[test]
423    fn test_provider_registry() {
424        let mut registry = ProviderRegistry::new();
425        registry.register(Arc::new(MockProvider));
426
427        assert_eq!(registry.list(), vec!["mock"]);
428        assert!(registry.get("mock").is_some());
429        assert!(registry.get("nonexistent").is_none());
430    }
431}