Skip to main content

heliosdb_proxy/circuit_breaker/
agent.rs

1//! Agent Retry Strategies and Conversation Fallback
2//!
3//! Provides retry guidance for AI agents and maintains conversation context
4//! during circuit breaker outages.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11
12/// Decision from retry strategy
13#[derive(Debug, Clone)]
14pub enum RetryDecision {
15    /// Retry after the specified delay
16    Retry { delay: Duration, attempt: u32 },
17    /// Don't retry, fail immediately
18    Fail { reason: String },
19    /// Use fallback/cached data
20    Fallback { message: String },
21}
22
23impl RetryDecision {
24    /// Check if should retry
25    pub fn should_retry(&self) -> bool {
26        matches!(self, RetryDecision::Retry { .. })
27    }
28
29    /// Get retry delay if applicable
30    pub fn retry_delay(&self) -> Option<Duration> {
31        match self {
32            RetryDecision::Retry { delay, .. } => Some(*delay),
33            _ => None,
34        }
35    }
36
37    /// Generate LLM-friendly message
38    pub fn to_llm_message(&self) -> String {
39        match self {
40            RetryDecision::Retry { delay, attempt } => {
41                format!(
42                    r#"{{"action":"retry","delay_ms":{},"attempt":{},"message":"Request failed temporarily. Retry in {} milliseconds."}}"#,
43                    delay.as_millis(),
44                    attempt,
45                    delay.as_millis()
46                )
47            }
48            RetryDecision::Fail { reason } => {
49                format!(
50                    r#"{{"action":"fail","reason":"{}","message":"Request cannot be retried: {}"}}"#,
51                    reason, reason
52                )
53            }
54            RetryDecision::Fallback { message } => {
55                format!(
56                    r#"{{"action":"fallback","message":"{}","note":"Using cached or fallback data due to temporary outage."}}"#,
57                    message
58                )
59            }
60        }
61    }
62}
63
64/// Retry strategy for AI agents
65#[derive(Debug, Clone)]
66pub struct AgentRetryStrategy {
67    /// Base delay for exponential backoff
68    base_delay: Duration,
69    /// Maximum delay
70    max_delay: Duration,
71    /// Maximum retry attempts
72    max_attempts: u32,
73    /// Jitter factor (0.0 - 1.0)
74    jitter_factor: f64,
75    /// Retryable error patterns
76    retryable_patterns: Vec<String>,
77    /// Non-retryable error patterns
78    non_retryable_patterns: Vec<String>,
79}
80
81impl Default for AgentRetryStrategy {
82    fn default() -> Self {
83        Self {
84            base_delay: Duration::from_millis(100),
85            max_delay: Duration::from_secs(30),
86            max_attempts: 5,
87            jitter_factor: 0.3,
88            retryable_patterns: vec![
89                "circuit_open".to_string(),
90                "rate_limit".to_string(),
91                "timeout".to_string(),
92                "connection".to_string(),
93                "temporary".to_string(),
94                "unavailable".to_string(),
95            ],
96            non_retryable_patterns: vec![
97                "invalid_query".to_string(),
98                "permission_denied".to_string(),
99                "authentication".to_string(),
100                "not_found".to_string(),
101                "constraint_violation".to_string(),
102            ],
103        }
104    }
105}
106
107impl AgentRetryStrategy {
108    /// Create a new retry strategy
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    /// Create with custom configuration
114    pub fn with_config(
115        base_delay: Duration,
116        max_delay: Duration,
117        max_attempts: u32,
118    ) -> Self {
119        Self {
120            base_delay,
121            max_delay,
122            max_attempts,
123            ..Default::default()
124        }
125    }
126
127    /// Set jitter factor
128    pub fn with_jitter(mut self, factor: f64) -> Self {
129        self.jitter_factor = factor.clamp(0.0, 1.0);
130        self
131    }
132
133    /// Add retryable pattern
134    pub fn with_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
135        self.retryable_patterns.push(pattern.into());
136        self
137    }
138
139    /// Add non-retryable pattern
140    pub fn with_non_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
141        self.non_retryable_patterns.push(pattern.into());
142        self
143    }
144
145    /// Calculate retry delay with exponential backoff and jitter
146    pub fn get_retry_delay(&self, attempt: u32) -> Duration {
147        // Exponential backoff: base * 2^attempt
148        let exp_delay = self.base_delay * 2u32.pow(attempt.min(10));
149
150        // Cap at max delay
151        let capped_delay = exp_delay.min(self.max_delay);
152
153        // Add jitter (random variation to prevent thundering herd)
154        let jitter = rand::random::<f64>() * self.jitter_factor;
155        let jittered = capped_delay.mul_f64(1.0 + jitter);
156
157        jittered.min(self.max_delay)
158    }
159
160    /// Determine if an error is retryable
161    pub fn is_retryable(&self, error: &str) -> bool {
162        let error_lower = error.to_lowercase();
163
164        // Check non-retryable patterns first
165        for pattern in &self.non_retryable_patterns {
166            if error_lower.contains(pattern) {
167                return false;
168            }
169        }
170
171        // Check retryable patterns
172        for pattern in &self.retryable_patterns {
173            if error_lower.contains(pattern) {
174                return true;
175            }
176        }
177
178        // Default: retry for unknown errors
179        true
180    }
181
182    /// Get retry decision for an error
183    pub fn should_retry(&self, error: &str, attempt: u32) -> RetryDecision {
184        if attempt >= self.max_attempts {
185            return RetryDecision::Fail {
186                reason: format!(
187                    "Maximum retry attempts ({}) exceeded",
188                    self.max_attempts
189                ),
190            };
191        }
192
193        if !self.is_retryable(error) {
194            return RetryDecision::Fail {
195                reason: format!("Error is not retryable: {}", error),
196            };
197        }
198
199        let delay = self.get_retry_delay(attempt);
200        RetryDecision::Retry {
201            delay,
202            attempt: attempt + 1,
203        }
204    }
205
206    /// Get recommended delay for specific error types
207    pub fn get_delay_for_error(&self, error: &str, attempt: u32) -> Option<Duration> {
208        let decision = self.should_retry(error, attempt);
209        decision.retry_delay()
210    }
211}
212
213/// Cached conversation context for fallback during outages
214#[derive(Debug, Clone)]
215pub struct ConversationContext {
216    /// Conversation ID
217    pub conversation_id: String,
218    /// Last successful query
219    pub last_query: Option<String>,
220    /// Last successful result (serialized)
221    pub last_result: Option<String>,
222    /// Conversation metadata
223    pub metadata: HashMap<String, String>,
224    /// Cache timestamp
225    pub cached_at: Instant,
226    /// Cache TTL
227    pub ttl: Duration,
228}
229
230impl ConversationContext {
231    /// Create a new conversation context
232    pub fn new(conversation_id: impl Into<String>) -> Self {
233        Self {
234            conversation_id: conversation_id.into(),
235            last_query: None,
236            last_result: None,
237            metadata: HashMap::new(),
238            cached_at: Instant::now(),
239            ttl: Duration::from_secs(3600), // 1 hour default
240        }
241    }
242
243    /// Update with latest query/result
244    pub fn update(&mut self, query: impl Into<String>, result: impl Into<String>) {
245        self.last_query = Some(query.into());
246        self.last_result = Some(result.into());
247        self.cached_at = Instant::now();
248    }
249
250    /// Add metadata
251    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
252        self.metadata.insert(key.into(), value.into());
253        self
254    }
255
256    /// Set TTL
257    pub fn with_ttl(mut self, ttl: Duration) -> Self {
258        self.ttl = ttl;
259        self
260    }
261
262    /// Check if context is still valid
263    pub fn is_valid(&self) -> bool {
264        self.cached_at.elapsed() < self.ttl
265    }
266
267    /// Get age of cached data
268    pub fn age(&self) -> Duration {
269        self.cached_at.elapsed()
270    }
271}
272
273/// Conversation fallback manager
274///
275/// Maintains cached conversation contexts to provide fallback responses
276/// during circuit breaker outages.
277pub struct ConversationFallback {
278    /// Cached contexts per conversation
279    contexts: DashMap<String, ConversationContext>,
280
281    /// Default TTL for new contexts
282    default_ttl: RwLock<Duration>,
283
284    /// Maximum cached contexts
285    max_contexts: usize,
286}
287
288impl ConversationFallback {
289    /// Create a new conversation fallback manager
290    pub fn new() -> Self {
291        Self {
292            contexts: DashMap::new(),
293            default_ttl: RwLock::new(Duration::from_secs(3600)),
294            max_contexts: 10000,
295        }
296    }
297
298    /// Create with custom configuration
299    pub fn with_config(default_ttl: Duration, max_contexts: usize) -> Self {
300        Self {
301            contexts: DashMap::new(),
302            default_ttl: RwLock::new(default_ttl),
303            max_contexts,
304        }
305    }
306
307    /// Update context for a conversation
308    pub fn update_context(
309        &self,
310        conversation_id: &str,
311        query: impl Into<String>,
312        result: impl Into<String>,
313    ) {
314        let ttl = *self.default_ttl.read();
315
316        if let Some(mut ctx) = self.contexts.get_mut(conversation_id) {
317            ctx.update(query, result);
318        } else {
319            // Enforce max contexts
320            if self.contexts.len() >= self.max_contexts {
321                self.cleanup_expired();
322            }
323
324            let mut ctx = ConversationContext::new(conversation_id).with_ttl(ttl);
325            ctx.update(query, result);
326            self.contexts.insert(conversation_id.to_string(), ctx);
327        }
328    }
329
330    /// Get fallback for a conversation
331    pub fn get_fallback(&self, conversation_id: &str) -> Option<ConversationContext> {
332        self.contexts
333            .get(conversation_id)
334            .filter(|ctx| ctx.is_valid())
335            .map(|ctx| ctx.clone())
336    }
337
338    /// Execute with fallback on error
339    pub fn execute_with_fallback<T, E>(
340        &self,
341        conversation_id: &str,
342        execute: impl FnOnce() -> Result<T, E>,
343        fallback: impl FnOnce(&ConversationContext) -> T,
344    ) -> Result<T, E>
345    where
346        E: std::fmt::Display,
347    {
348        match execute() {
349            Ok(result) => Ok(result),
350            Err(e) => {
351                if let Some(ctx) = self.get_fallback(conversation_id) {
352                    Ok(fallback(&ctx))
353                } else {
354                    Err(e) // Return original error
355                }
356            }
357        }
358    }
359
360    /// Cleanup expired contexts
361    pub fn cleanup_expired(&self) {
362        self.contexts.retain(|_, ctx| ctx.is_valid());
363    }
364
365    /// Remove specific conversation
366    pub fn remove(&self, conversation_id: &str) -> Option<ConversationContext> {
367        self.contexts.remove(conversation_id).map(|(_, ctx)| ctx)
368    }
369
370    /// Get number of cached contexts
371    pub fn cached_count(&self) -> usize {
372        self.contexts.len()
373    }
374
375    /// Set default TTL
376    pub fn set_default_ttl(&self, ttl: Duration) {
377        *self.default_ttl.write() = ttl;
378    }
379
380    /// Check if conversation has cached context
381    pub fn has_context(&self, conversation_id: &str) -> bool {
382        self.contexts
383            .get(conversation_id)
384            .map(|ctx| ctx.is_valid())
385            .unwrap_or(false)
386    }
387}
388
389impl Default for ConversationFallback {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_retry_decision_messages() {
401        let retry = RetryDecision::Retry {
402            delay: Duration::from_millis(100),
403            attempt: 1,
404        };
405        let msg = retry.to_llm_message();
406        assert!(msg.contains("retry"));
407        assert!(msg.contains("100"));
408
409        let fail = RetryDecision::Fail {
410            reason: "test error".to_string(),
411        };
412        let msg = fail.to_llm_message();
413        assert!(msg.contains("fail"));
414        assert!(msg.contains("test error"));
415    }
416
417    #[test]
418    fn test_retry_strategy_delay() {
419        let strategy = AgentRetryStrategy::new();
420
421        let delay0 = strategy.get_retry_delay(0);
422        let delay1 = strategy.get_retry_delay(1);
423        let delay2 = strategy.get_retry_delay(2);
424
425        // Each delay should be roughly 2x the previous
426        assert!(delay1 >= delay0);
427        assert!(delay2 >= delay1);
428        assert!(delay2 <= strategy.max_delay);
429    }
430
431    #[test]
432    fn test_retry_strategy_retryable() {
433        let strategy = AgentRetryStrategy::new();
434
435        assert!(strategy.is_retryable("circuit_open for node"));
436        assert!(strategy.is_retryable("rate_limit exceeded"));
437        assert!(strategy.is_retryable("connection timeout"));
438
439        assert!(!strategy.is_retryable("permission_denied"));
440        assert!(!strategy.is_retryable("authentication failed"));
441    }
442
443    #[test]
444    fn test_retry_strategy_should_retry() {
445        let strategy = AgentRetryStrategy::with_config(
446            Duration::from_millis(100),
447            Duration::from_secs(10),
448            3,
449        );
450
451        // Should retry circuit_open
452        let decision = strategy.should_retry("circuit_open", 0);
453        assert!(decision.should_retry());
454
455        // Should not retry after max attempts
456        let decision = strategy.should_retry("circuit_open", 3);
457        assert!(!decision.should_retry());
458
459        // Should not retry non-retryable errors
460        let decision = strategy.should_retry("permission_denied", 0);
461        assert!(!decision.should_retry());
462    }
463
464    #[test]
465    fn test_conversation_context() {
466        let mut ctx = ConversationContext::new("conv-123")
467            .with_metadata("user", "alice")
468            .with_ttl(Duration::from_secs(60));
469
470        assert_eq!(ctx.conversation_id, "conv-123");
471        assert!(ctx.is_valid());
472
473        ctx.update("SELECT * FROM users", r#"[{"id": 1}]"#);
474        assert_eq!(ctx.last_query, Some("SELECT * FROM users".to_string()));
475    }
476
477    #[test]
478    fn test_conversation_fallback() {
479        let fallback = ConversationFallback::new();
480
481        fallback.update_context("conv-1", "query1", "result1");
482        assert!(fallback.has_context("conv-1"));
483        assert_eq!(fallback.cached_count(), 1);
484
485        let ctx = fallback.get_fallback("conv-1").unwrap();
486        assert_eq!(ctx.last_query, Some("query1".to_string()));
487        assert_eq!(ctx.last_result, Some("result1".to_string()));
488    }
489
490    #[test]
491    fn test_conversation_fallback_expired() {
492        let fallback = ConversationFallback::with_config(
493            Duration::from_millis(10),
494            100,
495        );
496
497        fallback.update_context("conv-1", "query1", "result1");
498        assert!(fallback.has_context("conv-1"));
499
500        std::thread::sleep(Duration::from_millis(20));
501        assert!(!fallback.has_context("conv-1"));
502    }
503
504    #[test]
505    fn test_execute_with_fallback() {
506        let fallback = ConversationFallback::new();
507        fallback.update_context("conv-1", "query", "cached_result");
508
509        // Successful execution
510        let result: Result<String, &str> =
511            fallback.execute_with_fallback("conv-1", || Ok("new_result".to_string()), |_| {
512                "fallback".to_string()
513            });
514        assert_eq!(result.unwrap(), "new_result");
515    }
516}