ricecoder_providers/provider/
manager.rs

1//! Provider manager for orchestrating provider operations
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use super::{ChatStream, Provider, ProviderRegistry};
7use crate::error::ProviderError;
8use crate::health_check::HealthCheckCache;
9use crate::models::{ChatRequest, ChatResponse};
10
11/// Central coordinator for provider operations
12pub struct ProviderManager {
13    registry: ProviderRegistry,
14    default_provider_id: String,
15    retry_count: usize,
16    timeout: Duration,
17    health_check_cache: Arc<HealthCheckCache>,
18}
19
20impl ProviderManager {
21    /// Create a new provider manager
22    pub fn new(registry: ProviderRegistry, default_provider_id: String) -> Self {
23        Self {
24            registry,
25            default_provider_id,
26            retry_count: 3,
27            timeout: Duration::from_secs(30),
28            health_check_cache: Arc::new(HealthCheckCache::default()),
29        }
30    }
31
32    /// Set the number of retries for failed requests
33    pub fn with_retry_count(mut self, count: usize) -> Self {
34        self.retry_count = count;
35        self
36    }
37
38    /// Set the request timeout
39    pub fn with_timeout(mut self, timeout: Duration) -> Self {
40        self.timeout = timeout;
41        self
42    }
43
44    /// Set the health check cache
45    pub fn with_health_check_cache(mut self, cache: Arc<HealthCheckCache>) -> Self {
46        self.health_check_cache = cache;
47        self
48    }
49
50    /// Get the default provider
51    pub fn default_provider(&self) -> Result<Arc<dyn Provider>, ProviderError> {
52        self.registry.get(&self.default_provider_id)
53    }
54
55    /// Get a specific provider
56    pub fn get_provider(&self, provider_id: &str) -> Result<Arc<dyn Provider>, ProviderError> {
57        self.registry.get(provider_id)
58    }
59
60    /// Send a chat request with retry logic
61    pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
62        let provider = self.default_provider()?;
63        self.chat_with_provider(&provider, request).await
64    }
65
66    /// Send a chat request to a specific provider with retry logic
67    pub async fn chat_with_provider(
68        &self,
69        provider: &Arc<dyn Provider>,
70        request: ChatRequest,
71    ) -> Result<ChatResponse, ProviderError> {
72        let mut last_error = None;
73
74        for attempt in 0..=self.retry_count {
75            match tokio::time::timeout(self.timeout, provider.chat(request.clone())).await {
76                Ok(Ok(response)) => return Ok(response),
77                Ok(Err(e)) => {
78                    last_error = Some(e);
79                    if attempt < self.retry_count {
80                        // Exponential backoff
81                        let backoff = Duration::from_millis(100 * 2_u64.pow(attempt as u32));
82                        tokio::time::sleep(backoff).await;
83                    }
84                }
85                Err(_) => {
86                    last_error = Some(ProviderError::ProviderError("Request timeout".to_string()));
87                    if attempt < self.retry_count {
88                        let backoff = Duration::from_millis(100 * 2_u64.pow(attempt as u32));
89                        tokio::time::sleep(backoff).await;
90                    }
91                }
92            }
93        }
94
95        Err(last_error
96            .unwrap_or_else(|| ProviderError::ProviderError("Failed after retries".to_string())))
97    }
98
99    /// Stream a chat response
100    pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream, ProviderError> {
101        let provider = self.default_provider()?;
102        provider.chat_stream(request).await
103    }
104
105    /// Stream a chat response from a specific provider
106    pub async fn chat_stream_with_provider(
107        &self,
108        provider: &Arc<dyn Provider>,
109        request: ChatRequest,
110    ) -> Result<ChatStream, ProviderError> {
111        provider.chat_stream(request).await
112    }
113
114    /// Check provider health with caching
115    pub async fn health_check(&self, provider_id: &str) -> Result<bool, ProviderError> {
116        let provider = self.registry.get(provider_id)?;
117        self.health_check_cache.check_health(&provider).await
118    }
119
120    /// Check health of all providers with caching
121    pub async fn health_check_all(&self) -> Vec<(String, Result<bool, ProviderError>)> {
122        let mut results = Vec::new();
123
124        for provider in self.registry.list_all() {
125            let id = provider.id().to_string();
126            let health = self.health_check_cache.check_health(&provider).await;
127            results.push((id, health));
128        }
129
130        results
131    }
132
133    /// Invalidate health check cache for a provider
134    pub async fn invalidate_health_check(&self, provider_id: &str) {
135        self.health_check_cache.invalidate(provider_id).await;
136    }
137
138    /// Invalidate all health check cache
139    pub async fn invalidate_all_health_checks(&self) {
140        self.health_check_cache.invalidate_all().await;
141    }
142
143    /// Get the health check cache
144    pub fn health_check_cache(&self) -> &Arc<HealthCheckCache> {
145        &self.health_check_cache
146    }
147
148    /// Get the registry
149    pub fn registry(&self) -> &ProviderRegistry {
150        &self.registry
151    }
152
153    /// Get mutable registry
154    pub fn registry_mut(&mut self) -> &mut ProviderRegistry {
155        &mut self.registry
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::models::{ChatResponse, FinishReason, TokenUsage};
163
164    struct MockProvider {
165        id: String,
166    }
167
168    #[async_trait::async_trait]
169    impl Provider for MockProvider {
170        fn id(&self) -> &str {
171            &self.id
172        }
173
174        fn name(&self) -> &str {
175            "Mock"
176        }
177
178        fn models(&self) -> Vec<crate::models::ModelInfo> {
179            vec![]
180        }
181
182        async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
183            Ok(ChatResponse {
184                content: "test response".to_string(),
185                model: "test-model".to_string(),
186                usage: TokenUsage {
187                    prompt_tokens: 10,
188                    completion_tokens: 5,
189                    total_tokens: 15,
190                },
191                finish_reason: FinishReason::Stop,
192            })
193        }
194
195        async fn chat_stream(&self, _request: ChatRequest) -> Result<ChatStream, ProviderError> {
196            Err(ProviderError::NotFound("Not implemented".to_string()))
197        }
198
199        fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
200            Ok(0)
201        }
202
203        async fn health_check(&self) -> Result<bool, ProviderError> {
204            Ok(true)
205        }
206    }
207
208    #[tokio::test]
209    async fn test_manager_creation() {
210        let mut registry = ProviderRegistry::new();
211        let provider = Arc::new(MockProvider {
212            id: "test".to_string(),
213        });
214        registry.register(provider).unwrap();
215
216        let manager = ProviderManager::new(registry, "test".to_string());
217        assert!(manager.default_provider().is_ok());
218    }
219
220    #[tokio::test]
221    async fn test_chat_request() {
222        let mut registry = ProviderRegistry::new();
223        let provider = Arc::new(MockProvider {
224            id: "test".to_string(),
225        });
226        registry.register(provider).unwrap();
227
228        let manager = ProviderManager::new(registry, "test".to_string());
229        let request = ChatRequest {
230            model: "test-model".to_string(),
231            messages: vec![],
232            temperature: None,
233            max_tokens: None,
234            stream: false,
235        };
236
237        let response = manager.chat(request).await;
238        assert!(response.is_ok());
239    }
240
241    #[tokio::test]
242    async fn test_health_check() {
243        let mut registry = ProviderRegistry::new();
244        let provider = Arc::new(MockProvider {
245            id: "test".to_string(),
246        });
247        registry.register(provider).unwrap();
248
249        let manager = ProviderManager::new(registry, "test".to_string());
250        let health = manager.health_check("test").await;
251        assert!(health.is_ok());
252        assert!(health.unwrap());
253    }
254}