1#![allow(deprecated)]
25
26use std::collections::HashMap;
27use std::sync::{Mutex, RwLock};
28use std::time::Duration;
29
30use crate::backend::{self, BackendError, ModelResponse};
31use crate::backend_error::BackendErrorKind;
32use crate::circuit_breaker::CircuitBreaker;
33use crate::retry_policy::RetryPolicy;
34
35#[derive(Debug, Clone)]
37pub struct ProviderConfig {
38 pub name: String,
40 pub connect_timeout: Duration,
42 pub read_timeout: Duration,
44 pub retry_policy: RetryPolicy,
46}
47
48impl ProviderConfig {
49 pub fn new(name: &str) -> Self {
50 ProviderConfig {
51 name: name.to_string(),
52 connect_timeout: Duration::from_secs(10),
53 read_timeout: Duration::from_secs(120),
54 retry_policy: RetryPolicy::default(),
55 }
56 }
57}
58
59pub struct ResilientBackend {
64 circuit_breakers: RwLock<HashMap<(String, String), Mutex<CircuitBreaker>>>,
68 provider_configs: HashMap<String, ProviderConfig>,
70 fallback_chains: HashMap<String, Vec<String>>,
72}
73
74impl ResilientBackend {
75 pub fn new() -> Self {
78 let mut provider_configs = HashMap::new();
79 for &name in backend::SUPPORTED_BACKENDS {
80 provider_configs.insert(name.to_string(), ProviderConfig::new(name));
81 }
82 ResilientBackend {
83 circuit_breakers: RwLock::new(HashMap::new()),
84 provider_configs,
85 fallback_chains: HashMap::new(),
86 }
87 }
88
89 pub fn configure_provider(&mut self, config: ProviderConfig) {
91 self.provider_configs.insert(config.name.clone(), config);
92 }
93
94 pub fn set_fallback_chain(&mut self, primary: &str, fallbacks: Vec<String>) {
96 self.fallback_chains.insert(primary.to_string(), fallbacks);
97 }
98
99 pub fn call(
104 &self,
105 provider: &str,
106 api_key: &str,
107 system_prompt: &str,
108 user_prompt: &str,
109 max_tokens: Option<u32>,
110 ) -> Result<ModelResponse, BackendError> {
111 let tenant_id = crate::tenant::current_tenant_id();
112 match self.call_single_provider(&tenant_id, provider, api_key, system_prompt, user_prompt, max_tokens) {
113 Ok(resp) => Ok(resp),
114 Err(primary_err) => {
115 if let Some(fallbacks) = self.fallback_chains.get(provider) {
116 for fallback in fallbacks {
117 tracing::warn!(
118 tenant_id = %tenant_id,
119 primary = provider,
120 fallback = %fallback,
121 primary_error = %primary_err,
122 "resilient_backend_trying_fallback"
123 );
124 let fallback_key = match backend::get_api_key(fallback) {
125 Ok(k) => k,
126 Err(_) => continue,
127 };
128 match self.call_single_provider(
129 &tenant_id, fallback, &fallback_key, system_prompt, user_prompt, max_tokens,
130 ) {
131 Ok(resp) => {
132 tracing::info!(
133 tenant_id = %tenant_id,
134 primary = provider,
135 fallback = %fallback,
136 "resilient_backend_fallback_succeeded"
137 );
138 return Ok(resp);
139 }
140 Err(e) => {
141 tracing::warn!(
142 tenant_id = %tenant_id,
143 fallback = %fallback,
144 error = %e,
145 "resilient_backend_fallback_failed"
146 );
147 }
148 }
149 }
150 }
151 Err(primary_err)
152 }
153 }
154 }
155
156 fn call_single_provider(
158 &self,
159 tenant_id: &str,
160 provider: &str,
161 api_key: &str,
162 system_prompt: &str,
163 user_prompt: &str,
164 max_tokens: Option<u32>,
165 ) -> Result<ModelResponse, BackendError> {
166 let cb_key = (tenant_id.to_string(), provider.to_string());
167
168 {
171 let map = self.circuit_breakers.read().unwrap();
172 if !map.contains_key(&cb_key) {
173 drop(map);
174 let mut map = self.circuit_breakers.write().unwrap();
175 map.entry(cb_key.clone())
176 .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults(provider)));
177 }
178 }
179
180 {
182 let map = self.circuit_breakers.read().unwrap();
183 let cb_mutex = map.get(&cb_key).unwrap();
184 let mut cb = cb_mutex.lock().unwrap();
185 if !cb.can_execute() {
186 tracing::warn!(
187 tenant_id, provider,
188 state = %cb.state(),
189 "resilient_backend_circuit_open"
190 );
191 return Err(BackendError {
192 message: format!(
193 "Circuit breaker open for provider '{provider}' (tenant '{tenant_id}') — calls rejected"
194 ),
195 });
196 }
197 }
198
199 let retry_policy = self.provider_configs
200 .get(provider)
201 .map(|c| c.retry_policy.clone())
202 .unwrap_or_default();
203
204 let mut last_error = None;
205 for attempt in 0..=retry_policy.max_retries {
206 if attempt > 0 {
207 let error_kind = BackendErrorKind::Unknown;
208 let delay = retry_policy.effective_delay(attempt - 1, &error_kind);
209 tracing::info!(
210 tenant_id, provider, attempt,
211 delay_ms = delay.as_millis() as u64,
212 "resilient_backend_retrying"
213 );
214 std::thread::sleep(delay);
215 }
216
217 match backend::call(provider, api_key, system_prompt, user_prompt, max_tokens) {
218 Ok(resp) => {
219 let map = self.circuit_breakers.read().unwrap();
220 if let Some(cb_mutex) = map.get(&cb_key) {
221 cb_mutex.lock().unwrap().record_success();
222 }
223 return Ok(resp);
224 }
225 Err(e) => {
226 let error_kind = classify_backend_error(&e);
227 tracing::warn!(
228 tenant_id, provider, attempt,
229 error = %e,
230 error_kind = error_kind.category(),
231 retryable = error_kind.is_retryable(),
232 "resilient_backend_call_failed"
233 );
234 {
235 let map = self.circuit_breakers.read().unwrap();
236 if let Some(cb_mutex) = map.get(&cb_key) {
237 cb_mutex.lock().unwrap().record_failure();
238 }
239 }
240 if !retry_policy.should_retry(attempt, &error_kind) {
241 return Err(e);
242 }
243 last_error = Some(e);
244 }
245 }
246 }
247
248 Err(last_error.unwrap_or_else(|| BackendError {
249 message: format!(
250 "All {} retry attempts exhausted for provider '{provider}' (tenant '{tenant_id}')",
251 retry_policy.max_retries
252 ),
253 }))
254 }
255
256 pub fn circuit_state(
258 &self,
259 tenant_id: &str,
260 provider: &str,
261 ) -> Option<crate::circuit_breaker::CircuitState> {
262 let map = self.circuit_breakers.read().unwrap();
263 map.get(&(tenant_id.to_string(), provider.to_string())).map(|cb| {
264 cb.lock().unwrap().state()
265 })
266 }
267
268 pub fn reset_circuit(&self, tenant_id: &str, provider: &str) {
270 let map = self.circuit_breakers.read().unwrap();
271 if let Some(cb_mutex) = map.get(&(tenant_id.to_string(), provider.to_string())) {
272 cb_mutex.lock().unwrap().reset();
273 tracing::info!(tenant_id, provider, "circuit_breaker_manually_reset");
274 }
275 }
276
277 pub fn all_circuit_states(&self) -> Vec<(String, String, crate::circuit_breaker::CircuitState)> {
279 let map = self.circuit_breakers.read().unwrap();
280 map.iter().map(|((tid, prov), cb)| {
281 (tid.clone(), prov.clone(), cb.lock().unwrap().state())
282 }).collect()
283 }
284}
285
286fn classify_backend_error(e: &BackendError) -> BackendErrorKind {
288 let msg = e.message.to_lowercase();
289
290 if msg.contains("timeout") || msg.contains("timed out") {
291 BackendErrorKind::Timeout
292 } else if msg.contains("429") || msg.contains("rate limit") || msg.contains("too many requests") {
293 BackendErrorKind::RateLimit { retry_after: None }
294 } else if msg.contains("401") || msg.contains("403") || msg.contains("unauthorized") || msg.contains("forbidden") {
295 BackendErrorKind::AuthError
296 } else if msg.contains("api error (5") {
297 let status = msg.split("api error (")
299 .nth(1)
300 .and_then(|s| s.split(')').next())
301 .and_then(|s| s.parse::<u16>().ok())
302 .unwrap_or(500);
303 BackendErrorKind::ServerError { status }
304 } else if msg.contains("connection refused") || msg.contains("dns") || msg.contains("http request failed") {
305 BackendErrorKind::NetworkError
306 } else if msg.contains("stream") && (msg.contains("error") || msg.contains("dropped")) {
307 BackendErrorKind::StreamDropped
308 } else if msg.contains("parse") || msg.contains("json") {
309 BackendErrorKind::InvalidResponse
310 } else if msg.contains("unknown backend") {
311 BackendErrorKind::ProviderUnavailable
312 } else {
313 BackendErrorKind::Unknown
314 }
315}
316
317#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_new_starts_empty_no_circuits() {
325 let rb = ResilientBackend::new();
326 assert!(rb.circuit_breakers.read().unwrap().is_empty());
328 }
329
330 #[test]
331 fn test_circuit_state_returns_none_before_first_call() {
332 let rb = ResilientBackend::new();
333 assert_eq!(rb.circuit_state("acme", "anthropic"), None);
335 }
336
337 #[test]
338 fn test_circuit_state_closed_after_lazy_init() {
339 let rb = ResilientBackend::new();
340 {
342 let mut map = rb.circuit_breakers.write().unwrap();
343 map.entry(("acme".to_string(), "anthropic".to_string()))
344 .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults("anthropic")));
345 }
346 assert_eq!(
347 rb.circuit_state("acme", "anthropic"),
348 Some(crate::circuit_breaker::CircuitState::Closed)
349 );
350 }
351
352 #[test]
353 fn test_reset_circuit_per_tenant() {
354 let rb = ResilientBackend::new();
355 {
357 let mut map = rb.circuit_breakers.write().unwrap();
358 let cb = map.entry(("acme".to_string(), "openai".to_string()))
359 .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults("openai")));
360 for _ in 0..5 {
361 cb.lock().unwrap().record_failure();
362 }
363 assert_eq!(cb.lock().unwrap().state(), crate::circuit_breaker::CircuitState::Open);
364 }
365 rb.reset_circuit("acme", "openai");
366 assert_eq!(
367 rb.circuit_state("acme", "openai"),
368 Some(crate::circuit_breaker::CircuitState::Closed)
369 );
370 }
371
372 #[test]
373 fn test_tenant_isolation_circuits_independent() {
374 let rb = ResilientBackend::new();
375 {
377 let mut map = rb.circuit_breakers.write().unwrap();
378 let cb = map.entry(("tenant-a".to_string(), "anthropic".to_string()))
379 .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults("anthropic")));
380 for _ in 0..5 {
381 cb.lock().unwrap().record_failure();
382 }
383 }
384 assert_eq!(rb.circuit_state("tenant-b", "anthropic"), None);
386 assert_eq!(
388 rb.circuit_state("tenant-a", "anthropic"),
389 Some(crate::circuit_breaker::CircuitState::Open)
390 );
391 }
392
393 #[test]
394 fn test_all_circuit_states() {
395 let rb = ResilientBackend::new();
396 {
397 let mut map = rb.circuit_breakers.write().unwrap();
398 map.insert(("t1".to_string(), "anthropic".to_string()),
399 Mutex::new(CircuitBreaker::with_defaults("anthropic")));
400 map.insert(("t2".to_string(), "openai".to_string()),
401 Mutex::new(CircuitBreaker::with_defaults("openai")));
402 }
403 let states = rb.all_circuit_states();
404 assert_eq!(states.len(), 2);
405 }
406
407 #[test]
408 fn test_classify_timeout() {
409 let e = BackendError { message: "HTTP request failed: operation timed out".into() };
410 assert!(matches!(classify_backend_error(&e), BackendErrorKind::Timeout));
411 }
412
413 #[test]
414 fn test_classify_rate_limit() {
415 let e = BackendError { message: "API error (429): Too Many Requests".into() };
416 assert!(matches!(classify_backend_error(&e), BackendErrorKind::RateLimit { .. }));
417 }
418
419 #[test]
420 fn test_classify_auth() {
421 let e = BackendError { message: "API error (401): Unauthorized".into() };
422 assert!(matches!(classify_backend_error(&e), BackendErrorKind::AuthError));
423 }
424
425 #[test]
426 fn test_classify_server_error() {
427 let e = BackendError { message: "API error (503): Service Unavailable".into() };
428 assert!(matches!(classify_backend_error(&e), BackendErrorKind::ServerError { status: 503 }));
429 }
430
431 #[test]
432 fn test_classify_network() {
433 let e = BackendError { message: "HTTP request failed: connection refused".into() };
434 assert!(matches!(classify_backend_error(&e), BackendErrorKind::NetworkError));
435 }
436
437 #[test]
438 fn test_classify_unknown_backend() {
439 let e = BackendError { message: "Unknown backend 'foo'".into() };
440 assert!(matches!(classify_backend_error(&e), BackendErrorKind::ProviderUnavailable));
441 }
442
443 #[test]
444 fn test_set_fallback_chain() {
445 let mut rb = ResilientBackend::new();
446 rb.set_fallback_chain("anthropic", vec!["openrouter".into(), "ollama".into()]);
447 assert_eq!(
448 rb.fallback_chains.get("anthropic"),
449 Some(&vec!["openrouter".to_string(), "ollama".to_string()])
450 );
451 }
452}