Skip to main content

axon/
resilient_backend.rs

1//! Resilient Backend — production-grade LLM call wrapper with retry, circuit breaker, and fallback.
2//!
3//! Composes multiple resilience patterns into a single call path:
4//!   1. **Circuit Breaker**: Fail fast if provider is known to be down (isolated per tenant)
5//!   2. **Timeout**: Configurable connect + read timeout per provider
6//!   3. **Retry with Backoff**: Exponential backoff with jitter for transient errors
7//!   4. **Fallback Chain**: If primary provider fails, try secondary/tertiary
8//!
9//! Circuit breakers are keyed by `(tenant_id, provider)` so one tenant's failures
10//! cannot open another tenant's circuit — complete blast-radius isolation (M4).
11//!
12//! Designed for production SaaS workloads where LLM availability is critical.
13//! All state transitions and retry attempts are logged via `tracing`.
14//!
15//! # §Fase 33.x.i — `crate::backend` deprecation
16//!
17//! This file wraps the deprecated synchronous `crate::backend`
18//! surface with resilience patterns. The `#![allow(deprecated)]`
19//! below silences the deprecation warnings while the deeper async
20//! migration progresses under followup sub-fase Fase 33.x.i.2.
21//! The async equivalent (`backends::Backend::complete()` with
22//! retry/circuit-breaker layers) is the migration target.
23
24#![allow(deprecated)]
25
26use std::collections::HashMap;
27use std::sync::{Mutex, RwLock};
28use std::time::Duration;
29
30use crate::backend::{self, BackendError, ModelResponse};
31use crate::backend_error::BackendErrorKind;
32use crate::circuit_breaker::CircuitBreaker;
33use crate::retry_policy::RetryPolicy;
34
35/// Per-provider resilience configuration.
36#[derive(Debug, Clone)]
37pub struct ProviderConfig {
38    /// Provider name (e.g., "anthropic", "openai").
39    pub name: String,
40    /// Connection timeout.
41    pub connect_timeout: Duration,
42    /// Read/response timeout (includes LLM thinking time).
43    pub read_timeout: Duration,
44    /// Retry policy for this provider.
45    pub retry_policy: RetryPolicy,
46}
47
48impl ProviderConfig {
49    pub fn new(name: &str) -> Self {
50        ProviderConfig {
51            name: name.to_string(),
52            connect_timeout: Duration::from_secs(10),
53            read_timeout: Duration::from_secs(120),
54            retry_policy: RetryPolicy::default(),
55        }
56    }
57}
58
59/// Resilient backend wrapping the raw LLM call layer with production hardening.
60///
61/// Circuit breakers are keyed by `(tenant_id, provider)` and created lazily on
62/// first access — one tenant's failures cannot trip another tenant's circuit.
63pub struct ResilientBackend {
64    /// Per-(tenant_id, provider) circuit breakers — created lazily on demand.
65    /// RwLock: concurrent reads (per-tenant lookups), exclusive writes only on
66    /// first encounter of a new (tenant, provider) pair.
67    circuit_breakers: RwLock<HashMap<(String, String), Mutex<CircuitBreaker>>>,
68    /// Per-provider configuration (shared across all tenants).
69    provider_configs: HashMap<String, ProviderConfig>,
70    /// Fallback chains: primary_provider → [fallback_1, fallback_2, ...]
71    fallback_chains: HashMap<String, Vec<String>>,
72}
73
74impl ResilientBackend {
75    /// Create a new resilient backend with default configs for all supported providers.
76    /// Circuit breakers are created lazily on first (tenant, provider) access.
77    pub fn new() -> Self {
78        let mut provider_configs = HashMap::new();
79        for &name in backend::SUPPORTED_BACKENDS {
80            provider_configs.insert(name.to_string(), ProviderConfig::new(name));
81        }
82        ResilientBackend {
83            circuit_breakers: RwLock::new(HashMap::new()),
84            provider_configs,
85            fallback_chains: HashMap::new(),
86        }
87    }
88
89    /// Configure a specific provider's resilience settings.
90    pub fn configure_provider(&mut self, config: ProviderConfig) {
91        self.provider_configs.insert(config.name.clone(), config);
92    }
93
94    /// Set a fallback chain for a provider.
95    pub fn set_fallback_chain(&mut self, primary: &str, fallbacks: Vec<String>) {
96        self.fallback_chains.insert(primary.to_string(), fallbacks);
97    }
98
99    /// Make a resilient LLM call with retry, per-tenant circuit breaker, and fallback.
100    ///
101    /// Tenant is derived automatically from `current_tenant_id()` — the active
102    /// Axum request's task-local set by `tenant_extractor_middleware`.
103    pub fn call(
104        &self,
105        provider: &str,
106        api_key: &str,
107        system_prompt: &str,
108        user_prompt: &str,
109        max_tokens: Option<u32>,
110    ) -> Result<ModelResponse, BackendError> {
111        let tenant_id = crate::tenant::current_tenant_id();
112        match self.call_single_provider(&tenant_id, provider, api_key, system_prompt, user_prompt, max_tokens) {
113            Ok(resp) => Ok(resp),
114            Err(primary_err) => {
115                if let Some(fallbacks) = self.fallback_chains.get(provider) {
116                    for fallback in fallbacks {
117                        tracing::warn!(
118                            tenant_id = %tenant_id,
119                            primary = provider,
120                            fallback = %fallback,
121                            primary_error = %primary_err,
122                            "resilient_backend_trying_fallback"
123                        );
124                        let fallback_key = match backend::get_api_key(fallback) {
125                            Ok(k) => k,
126                            Err(_) => continue,
127                        };
128                        match self.call_single_provider(
129                            &tenant_id, fallback, &fallback_key, system_prompt, user_prompt, max_tokens,
130                        ) {
131                            Ok(resp) => {
132                                tracing::info!(
133                                    tenant_id = %tenant_id,
134                                    primary = provider,
135                                    fallback = %fallback,
136                                    "resilient_backend_fallback_succeeded"
137                                );
138                                return Ok(resp);
139                            }
140                            Err(e) => {
141                                tracing::warn!(
142                                    tenant_id = %tenant_id,
143                                    fallback = %fallback,
144                                    error = %e,
145                                    "resilient_backend_fallback_failed"
146                                );
147                            }
148                        }
149                    }
150                }
151                Err(primary_err)
152            }
153        }
154    }
155
156    /// Call a single provider with the tenant-isolated circuit breaker + retry.
157    fn call_single_provider(
158        &self,
159        tenant_id: &str,
160        provider: &str,
161        api_key: &str,
162        system_prompt: &str,
163        user_prompt: &str,
164        max_tokens: Option<u32>,
165    ) -> Result<ModelResponse, BackendError> {
166        let cb_key = (tenant_id.to_string(), provider.to_string());
167
168        // Lazy-init: insert a new CB only on first encounter of this (tenant, provider) pair.
169        // Check with a read lock first; upgrade to write only when needed.
170        {
171            let map = self.circuit_breakers.read().unwrap();
172            if !map.contains_key(&cb_key) {
173                drop(map);
174                let mut map = self.circuit_breakers.write().unwrap();
175                map.entry(cb_key.clone())
176                    .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults(provider)));
177            }
178        }
179
180        // Check circuit breaker state
181        {
182            let map = self.circuit_breakers.read().unwrap();
183            let cb_mutex = map.get(&cb_key).unwrap();
184            let mut cb = cb_mutex.lock().unwrap();
185            if !cb.can_execute() {
186                tracing::warn!(
187                    tenant_id, provider,
188                    state = %cb.state(),
189                    "resilient_backend_circuit_open"
190                );
191                return Err(BackendError {
192                    message: format!(
193                        "Circuit breaker open for provider '{provider}' (tenant '{tenant_id}') — calls rejected"
194                    ),
195                });
196            }
197        }
198
199        let retry_policy = self.provider_configs
200            .get(provider)
201            .map(|c| c.retry_policy.clone())
202            .unwrap_or_default();
203
204        let mut last_error = None;
205        for attempt in 0..=retry_policy.max_retries {
206            if attempt > 0 {
207                let error_kind = BackendErrorKind::Unknown;
208                let delay = retry_policy.effective_delay(attempt - 1, &error_kind);
209                tracing::info!(
210                    tenant_id, provider, attempt,
211                    delay_ms = delay.as_millis() as u64,
212                    "resilient_backend_retrying"
213                );
214                std::thread::sleep(delay);
215            }
216
217            match backend::call(provider, api_key, system_prompt, user_prompt, max_tokens) {
218                Ok(resp) => {
219                    let map = self.circuit_breakers.read().unwrap();
220                    if let Some(cb_mutex) = map.get(&cb_key) {
221                        cb_mutex.lock().unwrap().record_success();
222                    }
223                    return Ok(resp);
224                }
225                Err(e) => {
226                    let error_kind = classify_backend_error(&e);
227                    tracing::warn!(
228                        tenant_id, provider, attempt,
229                        error = %e,
230                        error_kind = error_kind.category(),
231                        retryable = error_kind.is_retryable(),
232                        "resilient_backend_call_failed"
233                    );
234                    {
235                        let map = self.circuit_breakers.read().unwrap();
236                        if let Some(cb_mutex) = map.get(&cb_key) {
237                            cb_mutex.lock().unwrap().record_failure();
238                        }
239                    }
240                    if !retry_policy.should_retry(attempt, &error_kind) {
241                        return Err(e);
242                    }
243                    last_error = Some(e);
244                }
245            }
246        }
247
248        Err(last_error.unwrap_or_else(|| BackendError {
249            message: format!(
250                "All {} retry attempts exhausted for provider '{provider}' (tenant '{tenant_id}')",
251                retry_policy.max_retries
252            ),
253        }))
254    }
255
256    /// Get the circuit breaker state for a (tenant, provider) pair.
257    pub fn circuit_state(
258        &self,
259        tenant_id: &str,
260        provider: &str,
261    ) -> Option<crate::circuit_breaker::CircuitState> {
262        let map = self.circuit_breakers.read().unwrap();
263        map.get(&(tenant_id.to_string(), provider.to_string())).map(|cb| {
264            cb.lock().unwrap().state()
265        })
266    }
267
268    /// Reset the circuit breaker for a (tenant, provider) pair.
269    pub fn reset_circuit(&self, tenant_id: &str, provider: &str) {
270        let map = self.circuit_breakers.read().unwrap();
271        if let Some(cb_mutex) = map.get(&(tenant_id.to_string(), provider.to_string())) {
272            cb_mutex.lock().unwrap().reset();
273            tracing::info!(tenant_id, provider, "circuit_breaker_manually_reset");
274        }
275    }
276
277    /// List all active (tenant, provider) circuit breaker states.
278    pub fn all_circuit_states(&self) -> Vec<(String, String, crate::circuit_breaker::CircuitState)> {
279        let map = self.circuit_breakers.read().unwrap();
280        map.iter().map(|((tid, prov), cb)| {
281            (tid.clone(), prov.clone(), cb.lock().unwrap().state())
282        }).collect()
283    }
284}
285
286/// Classify a BackendError into a BackendErrorKind by inspecting the message.
287fn classify_backend_error(e: &BackendError) -> BackendErrorKind {
288    let msg = e.message.to_lowercase();
289
290    if msg.contains("timeout") || msg.contains("timed out") {
291        BackendErrorKind::Timeout
292    } else if msg.contains("429") || msg.contains("rate limit") || msg.contains("too many requests") {
293        BackendErrorKind::RateLimit { retry_after: None }
294    } else if msg.contains("401") || msg.contains("403") || msg.contains("unauthorized") || msg.contains("forbidden") {
295        BackendErrorKind::AuthError
296    } else if msg.contains("api error (5") {
297        // Match "API error (500)", "API error (502)", etc.
298        let status = msg.split("api error (")
299            .nth(1)
300            .and_then(|s| s.split(')').next())
301            .and_then(|s| s.parse::<u16>().ok())
302            .unwrap_or(500);
303        BackendErrorKind::ServerError { status }
304    } else if msg.contains("connection refused") || msg.contains("dns") || msg.contains("http request failed") {
305        BackendErrorKind::NetworkError
306    } else if msg.contains("stream") && (msg.contains("error") || msg.contains("dropped")) {
307        BackendErrorKind::StreamDropped
308    } else if msg.contains("parse") || msg.contains("json") {
309        BackendErrorKind::InvalidResponse
310    } else if msg.contains("unknown backend") {
311        BackendErrorKind::ProviderUnavailable
312    } else {
313        BackendErrorKind::Unknown
314    }
315}
316
317// ── Tests ──────────────────────────────────────────────────────────────────
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_new_starts_empty_no_circuits() {
325        let rb = ResilientBackend::new();
326        // Lazy init — map is empty until first call
327        assert!(rb.circuit_breakers.read().unwrap().is_empty());
328    }
329
330    #[test]
331    fn test_circuit_state_returns_none_before_first_call() {
332        let rb = ResilientBackend::new();
333        // No circuit breaker exists until a call is made for this (tenant, provider)
334        assert_eq!(rb.circuit_state("acme", "anthropic"), None);
335    }
336
337    #[test]
338    fn test_circuit_state_closed_after_lazy_init() {
339        let rb = ResilientBackend::new();
340        // Force lazy init by manually inserting
341        {
342            let mut map = rb.circuit_breakers.write().unwrap();
343            map.entry(("acme".to_string(), "anthropic".to_string()))
344                .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults("anthropic")));
345        }
346        assert_eq!(
347            rb.circuit_state("acme", "anthropic"),
348            Some(crate::circuit_breaker::CircuitState::Closed)
349        );
350    }
351
352    #[test]
353    fn test_reset_circuit_per_tenant() {
354        let rb = ResilientBackend::new();
355        // Force open for tenant "acme"
356        {
357            let mut map = rb.circuit_breakers.write().unwrap();
358            let cb = map.entry(("acme".to_string(), "openai".to_string()))
359                .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults("openai")));
360            for _ in 0..5 {
361                cb.lock().unwrap().record_failure();
362            }
363            assert_eq!(cb.lock().unwrap().state(), crate::circuit_breaker::CircuitState::Open);
364        }
365        rb.reset_circuit("acme", "openai");
366        assert_eq!(
367            rb.circuit_state("acme", "openai"),
368            Some(crate::circuit_breaker::CircuitState::Closed)
369        );
370    }
371
372    #[test]
373    fn test_tenant_isolation_circuits_independent() {
374        let rb = ResilientBackend::new();
375        // Open circuit for tenant-a / anthropic
376        {
377            let mut map = rb.circuit_breakers.write().unwrap();
378            let cb = map.entry(("tenant-a".to_string(), "anthropic".to_string()))
379                .or_insert_with(|| Mutex::new(CircuitBreaker::with_defaults("anthropic")));
380            for _ in 0..5 {
381                cb.lock().unwrap().record_failure();
382            }
383        }
384        // tenant-b / anthropic circuit should not exist (not open)
385        assert_eq!(rb.circuit_state("tenant-b", "anthropic"), None);
386        // tenant-a / anthropic should be open
387        assert_eq!(
388            rb.circuit_state("tenant-a", "anthropic"),
389            Some(crate::circuit_breaker::CircuitState::Open)
390        );
391    }
392
393    #[test]
394    fn test_all_circuit_states() {
395        let rb = ResilientBackend::new();
396        {
397            let mut map = rb.circuit_breakers.write().unwrap();
398            map.insert(("t1".to_string(), "anthropic".to_string()),
399                Mutex::new(CircuitBreaker::with_defaults("anthropic")));
400            map.insert(("t2".to_string(), "openai".to_string()),
401                Mutex::new(CircuitBreaker::with_defaults("openai")));
402        }
403        let states = rb.all_circuit_states();
404        assert_eq!(states.len(), 2);
405    }
406
407    #[test]
408    fn test_classify_timeout() {
409        let e = BackendError { message: "HTTP request failed: operation timed out".into() };
410        assert!(matches!(classify_backend_error(&e), BackendErrorKind::Timeout));
411    }
412
413    #[test]
414    fn test_classify_rate_limit() {
415        let e = BackendError { message: "API error (429): Too Many Requests".into() };
416        assert!(matches!(classify_backend_error(&e), BackendErrorKind::RateLimit { .. }));
417    }
418
419    #[test]
420    fn test_classify_auth() {
421        let e = BackendError { message: "API error (401): Unauthorized".into() };
422        assert!(matches!(classify_backend_error(&e), BackendErrorKind::AuthError));
423    }
424
425    #[test]
426    fn test_classify_server_error() {
427        let e = BackendError { message: "API error (503): Service Unavailable".into() };
428        assert!(matches!(classify_backend_error(&e), BackendErrorKind::ServerError { status: 503 }));
429    }
430
431    #[test]
432    fn test_classify_network() {
433        let e = BackendError { message: "HTTP request failed: connection refused".into() };
434        assert!(matches!(classify_backend_error(&e), BackendErrorKind::NetworkError));
435    }
436
437    #[test]
438    fn test_classify_unknown_backend() {
439        let e = BackendError { message: "Unknown backend 'foo'".into() };
440        assert!(matches!(classify_backend_error(&e), BackendErrorKind::ProviderUnavailable));
441    }
442
443    #[test]
444    fn test_set_fallback_chain() {
445        let mut rb = ResilientBackend::new();
446        rb.set_fallback_chain("anthropic", vec!["openrouter".into(), "ollama".into()]);
447        assert_eq!(
448            rb.fallback_chains.get("anthropic"),
449            Some(&vec!["openrouter".to_string(), "ollama".to_string()])
450        );
451    }
452}