ricecoder_providers/provider/
manager.rs1use std::sync::Arc;
4use std::time::Duration;
5
6use super::{ChatStream, Provider, ProviderRegistry};
7use crate::error::ProviderError;
8use crate::health_check::HealthCheckCache;
9use crate::models::{ChatRequest, ChatResponse};
10
11pub struct ProviderManager {
13 registry: ProviderRegistry,
14 default_provider_id: String,
15 retry_count: usize,
16 timeout: Duration,
17 health_check_cache: Arc<HealthCheckCache>,
18}
19
20impl ProviderManager {
21 pub fn new(registry: ProviderRegistry, default_provider_id: String) -> Self {
23 Self {
24 registry,
25 default_provider_id,
26 retry_count: 3,
27 timeout: Duration::from_secs(30),
28 health_check_cache: Arc::new(HealthCheckCache::default()),
29 }
30 }
31
32 pub fn with_retry_count(mut self, count: usize) -> Self {
34 self.retry_count = count;
35 self
36 }
37
38 pub fn with_timeout(mut self, timeout: Duration) -> Self {
40 self.timeout = timeout;
41 self
42 }
43
44 pub fn with_health_check_cache(mut self, cache: Arc<HealthCheckCache>) -> Self {
46 self.health_check_cache = cache;
47 self
48 }
49
50 pub fn default_provider(&self) -> Result<Arc<dyn Provider>, ProviderError> {
52 self.registry.get(&self.default_provider_id)
53 }
54
55 pub fn get_provider(&self, provider_id: &str) -> Result<Arc<dyn Provider>, ProviderError> {
57 self.registry.get(provider_id)
58 }
59
60 pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
62 let provider = self.default_provider()?;
63 self.chat_with_provider(&provider, request).await
64 }
65
66 pub async fn chat_with_provider(
68 &self,
69 provider: &Arc<dyn Provider>,
70 request: ChatRequest,
71 ) -> Result<ChatResponse, ProviderError> {
72 let mut last_error = None;
73
74 for attempt in 0..=self.retry_count {
75 match tokio::time::timeout(self.timeout, provider.chat(request.clone())).await {
76 Ok(Ok(response)) => return Ok(response),
77 Ok(Err(e)) => {
78 last_error = Some(e);
79 if attempt < self.retry_count {
80 let backoff = Duration::from_millis(100 * 2_u64.pow(attempt as u32));
82 tokio::time::sleep(backoff).await;
83 }
84 }
85 Err(_) => {
86 last_error = Some(ProviderError::ProviderError("Request timeout".to_string()));
87 if attempt < self.retry_count {
88 let backoff = Duration::from_millis(100 * 2_u64.pow(attempt as u32));
89 tokio::time::sleep(backoff).await;
90 }
91 }
92 }
93 }
94
95 Err(last_error
96 .unwrap_or_else(|| ProviderError::ProviderError("Failed after retries".to_string())))
97 }
98
99 pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream, ProviderError> {
101 let provider = self.default_provider()?;
102 provider.chat_stream(request).await
103 }
104
105 pub async fn chat_stream_with_provider(
107 &self,
108 provider: &Arc<dyn Provider>,
109 request: ChatRequest,
110 ) -> Result<ChatStream, ProviderError> {
111 provider.chat_stream(request).await
112 }
113
114 pub async fn health_check(&self, provider_id: &str) -> Result<bool, ProviderError> {
116 let provider = self.registry.get(provider_id)?;
117 self.health_check_cache.check_health(&provider).await
118 }
119
120 pub async fn health_check_all(&self) -> Vec<(String, Result<bool, ProviderError>)> {
122 let mut results = Vec::new();
123
124 for provider in self.registry.list_all() {
125 let id = provider.id().to_string();
126 let health = self.health_check_cache.check_health(&provider).await;
127 results.push((id, health));
128 }
129
130 results
131 }
132
133 pub async fn invalidate_health_check(&self, provider_id: &str) {
135 self.health_check_cache.invalidate(provider_id).await;
136 }
137
138 pub async fn invalidate_all_health_checks(&self) {
140 self.health_check_cache.invalidate_all().await;
141 }
142
143 pub fn health_check_cache(&self) -> &Arc<HealthCheckCache> {
145 &self.health_check_cache
146 }
147
148 pub fn registry(&self) -> &ProviderRegistry {
150 &self.registry
151 }
152
153 pub fn registry_mut(&mut self) -> &mut ProviderRegistry {
155 &mut self.registry
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::models::{ChatResponse, FinishReason, TokenUsage};
163
164 struct MockProvider {
165 id: String,
166 }
167
168 #[async_trait::async_trait]
169 impl Provider for MockProvider {
170 fn id(&self) -> &str {
171 &self.id
172 }
173
174 fn name(&self) -> &str {
175 "Mock"
176 }
177
178 fn models(&self) -> Vec<crate::models::ModelInfo> {
179 vec![]
180 }
181
182 async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ProviderError> {
183 Ok(ChatResponse {
184 content: "test response".to_string(),
185 model: "test-model".to_string(),
186 usage: TokenUsage {
187 prompt_tokens: 10,
188 completion_tokens: 5,
189 total_tokens: 15,
190 },
191 finish_reason: FinishReason::Stop,
192 })
193 }
194
195 async fn chat_stream(&self, _request: ChatRequest) -> Result<ChatStream, ProviderError> {
196 Err(ProviderError::NotFound("Not implemented".to_string()))
197 }
198
199 fn count_tokens(&self, _content: &str, _model: &str) -> Result<usize, ProviderError> {
200 Ok(0)
201 }
202
203 async fn health_check(&self) -> Result<bool, ProviderError> {
204 Ok(true)
205 }
206 }
207
208 #[tokio::test]
209 async fn test_manager_creation() {
210 let mut registry = ProviderRegistry::new();
211 let provider = Arc::new(MockProvider {
212 id: "test".to_string(),
213 });
214 registry.register(provider).unwrap();
215
216 let manager = ProviderManager::new(registry, "test".to_string());
217 assert!(manager.default_provider().is_ok());
218 }
219
220 #[tokio::test]
221 async fn test_chat_request() {
222 let mut registry = ProviderRegistry::new();
223 let provider = Arc::new(MockProvider {
224 id: "test".to_string(),
225 });
226 registry.register(provider).unwrap();
227
228 let manager = ProviderManager::new(registry, "test".to_string());
229 let request = ChatRequest {
230 model: "test-model".to_string(),
231 messages: vec![],
232 temperature: None,
233 max_tokens: None,
234 stream: false,
235 };
236
237 let response = manager.chat(request).await;
238 assert!(response.is_ok());
239 }
240
241 #[tokio::test]
242 async fn test_health_check() {
243 let mut registry = ProviderRegistry::new();
244 let provider = Arc::new(MockProvider {
245 id: "test".to_string(),
246 });
247 registry.register(provider).unwrap();
248
249 let manager = ProviderManager::new(registry, "test".to_string());
250 let health = manager.health_check("test").await;
251 assert!(health.is_ok());
252 assert!(health.unwrap());
253 }
254}