Skip to main content

heliosdb_proxy/circuit_breaker/
agent.rs

1//! Agent Retry Strategies and Conversation Fallback
2//!
3//! Provides retry guidance for AI agents and maintains conversation context
4//! during circuit breaker outages.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11
12/// Decision from retry strategy
13#[derive(Debug, Clone)]
14pub enum RetryDecision {
15    /// Retry after the specified delay
16    Retry { delay: Duration, attempt: u32 },
17    /// Don't retry, fail immediately
18    Fail { reason: String },
19    /// Use fallback/cached data
20    Fallback { message: String },
21}
22
23impl RetryDecision {
24    /// Check if should retry
25    pub fn should_retry(&self) -> bool {
26        matches!(self, RetryDecision::Retry { .. })
27    }
28
29    /// Get retry delay if applicable
30    pub fn retry_delay(&self) -> Option<Duration> {
31        match self {
32            RetryDecision::Retry { delay, .. } => Some(*delay),
33            _ => None,
34        }
35    }
36
37    /// Generate LLM-friendly message
38    pub fn to_llm_message(&self) -> String {
39        match self {
40            RetryDecision::Retry { delay, attempt } => {
41                format!(
42                    r#"{{"action":"retry","delay_ms":{},"attempt":{},"message":"Request failed temporarily. Retry in {} milliseconds."}}"#,
43                    delay.as_millis(),
44                    attempt,
45                    delay.as_millis()
46                )
47            }
48            RetryDecision::Fail { reason } => {
49                format!(
50                    r#"{{"action":"fail","reason":"{}","message":"Request cannot be retried: {}"}}"#,
51                    reason, reason
52                )
53            }
54            RetryDecision::Fallback { message } => {
55                format!(
56                    r#"{{"action":"fallback","message":"{}","note":"Using cached or fallback data due to temporary outage."}}"#,
57                    message
58                )
59            }
60        }
61    }
62}
63
64/// Retry strategy for AI agents
65#[derive(Debug, Clone)]
66pub struct AgentRetryStrategy {
67    /// Base delay for exponential backoff
68    base_delay: Duration,
69    /// Maximum delay
70    max_delay: Duration,
71    /// Maximum retry attempts
72    max_attempts: u32,
73    /// Jitter factor (0.0 - 1.0)
74    jitter_factor: f64,
75    /// Retryable error patterns
76    retryable_patterns: Vec<String>,
77    /// Non-retryable error patterns
78    non_retryable_patterns: Vec<String>,
79}
80
81impl Default for AgentRetryStrategy {
82    fn default() -> Self {
83        Self {
84            base_delay: Duration::from_millis(100),
85            max_delay: Duration::from_secs(30),
86            max_attempts: 5,
87            jitter_factor: 0.3,
88            retryable_patterns: vec![
89                "circuit_open".to_string(),
90                "rate_limit".to_string(),
91                "timeout".to_string(),
92                "connection".to_string(),
93                "temporary".to_string(),
94                "unavailable".to_string(),
95            ],
96            non_retryable_patterns: vec![
97                "invalid_query".to_string(),
98                "permission_denied".to_string(),
99                "authentication".to_string(),
100                "not_found".to_string(),
101                "constraint_violation".to_string(),
102            ],
103        }
104    }
105}
106
107impl AgentRetryStrategy {
108    /// Create a new retry strategy
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    /// Create with custom configuration
114    pub fn with_config(base_delay: Duration, max_delay: Duration, max_attempts: u32) -> Self {
115        Self {
116            base_delay,
117            max_delay,
118            max_attempts,
119            ..Default::default()
120        }
121    }
122
123    /// Set jitter factor
124    pub fn with_jitter(mut self, factor: f64) -> Self {
125        self.jitter_factor = factor.clamp(0.0, 1.0);
126        self
127    }
128
129    /// Add retryable pattern
130    pub fn with_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
131        self.retryable_patterns.push(pattern.into());
132        self
133    }
134
135    /// Add non-retryable pattern
136    pub fn with_non_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
137        self.non_retryable_patterns.push(pattern.into());
138        self
139    }
140
141    /// Calculate retry delay with exponential backoff and jitter
142    pub fn get_retry_delay(&self, attempt: u32) -> Duration {
143        // Exponential backoff: base * 2^attempt
144        let exp_delay = self.base_delay * 2u32.pow(attempt.min(10));
145
146        // Cap at max delay
147        let capped_delay = exp_delay.min(self.max_delay);
148
149        // Add jitter (random variation to prevent thundering herd)
150        let jitter = rand::random::<f64>() * self.jitter_factor;
151        let jittered = capped_delay.mul_f64(1.0 + jitter);
152
153        jittered.min(self.max_delay)
154    }
155
156    /// Determine if an error is retryable
157    pub fn is_retryable(&self, error: &str) -> bool {
158        let error_lower = error.to_lowercase();
159
160        // Check non-retryable patterns first
161        for pattern in &self.non_retryable_patterns {
162            if error_lower.contains(pattern) {
163                return false;
164            }
165        }
166
167        // Check retryable patterns
168        for pattern in &self.retryable_patterns {
169            if error_lower.contains(pattern) {
170                return true;
171            }
172        }
173
174        // Default: retry for unknown errors
175        true
176    }
177
178    /// Get retry decision for an error
179    pub fn should_retry(&self, error: &str, attempt: u32) -> RetryDecision {
180        if attempt >= self.max_attempts {
181            return RetryDecision::Fail {
182                reason: format!("Maximum retry attempts ({}) exceeded", self.max_attempts),
183            };
184        }
185
186        if !self.is_retryable(error) {
187            return RetryDecision::Fail {
188                reason: format!("Error is not retryable: {}", error),
189            };
190        }
191
192        let delay = self.get_retry_delay(attempt);
193        RetryDecision::Retry {
194            delay,
195            attempt: attempt + 1,
196        }
197    }
198
199    /// Get recommended delay for specific error types
200    pub fn get_delay_for_error(&self, error: &str, attempt: u32) -> Option<Duration> {
201        let decision = self.should_retry(error, attempt);
202        decision.retry_delay()
203    }
204}
205
206/// Cached conversation context for fallback during outages
207#[derive(Debug, Clone)]
208pub struct ConversationContext {
209    /// Conversation ID
210    pub conversation_id: String,
211    /// Last successful query
212    pub last_query: Option<String>,
213    /// Last successful result (serialized)
214    pub last_result: Option<String>,
215    /// Conversation metadata
216    pub metadata: HashMap<String, String>,
217    /// Cache timestamp
218    pub cached_at: Instant,
219    /// Cache TTL
220    pub ttl: Duration,
221}
222
223impl ConversationContext {
224    /// Create a new conversation context
225    pub fn new(conversation_id: impl Into<String>) -> Self {
226        Self {
227            conversation_id: conversation_id.into(),
228            last_query: None,
229            last_result: None,
230            metadata: HashMap::new(),
231            cached_at: Instant::now(),
232            ttl: Duration::from_secs(3600), // 1 hour default
233        }
234    }
235
236    /// Update with latest query/result
237    pub fn update(&mut self, query: impl Into<String>, result: impl Into<String>) {
238        self.last_query = Some(query.into());
239        self.last_result = Some(result.into());
240        self.cached_at = Instant::now();
241    }
242
243    /// Add metadata
244    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
245        self.metadata.insert(key.into(), value.into());
246        self
247    }
248
249    /// Set TTL
250    pub fn with_ttl(mut self, ttl: Duration) -> Self {
251        self.ttl = ttl;
252        self
253    }
254
255    /// Check if context is still valid
256    pub fn is_valid(&self) -> bool {
257        self.cached_at.elapsed() < self.ttl
258    }
259
260    /// Get age of cached data
261    pub fn age(&self) -> Duration {
262        self.cached_at.elapsed()
263    }
264}
265
266/// Conversation fallback manager
267///
268/// Maintains cached conversation contexts to provide fallback responses
269/// during circuit breaker outages.
270pub struct ConversationFallback {
271    /// Cached contexts per conversation
272    contexts: DashMap<String, ConversationContext>,
273
274    /// Default TTL for new contexts
275    default_ttl: RwLock<Duration>,
276
277    /// Maximum cached contexts
278    max_contexts: usize,
279}
280
281impl ConversationFallback {
282    /// Create a new conversation fallback manager
283    pub fn new() -> Self {
284        Self {
285            contexts: DashMap::new(),
286            default_ttl: RwLock::new(Duration::from_secs(3600)),
287            max_contexts: 10000,
288        }
289    }
290
291    /// Create with custom configuration
292    pub fn with_config(default_ttl: Duration, max_contexts: usize) -> Self {
293        Self {
294            contexts: DashMap::new(),
295            default_ttl: RwLock::new(default_ttl),
296            max_contexts,
297        }
298    }
299
300    /// Update context for a conversation
301    pub fn update_context(
302        &self,
303        conversation_id: &str,
304        query: impl Into<String>,
305        result: impl Into<String>,
306    ) {
307        let ttl = *self.default_ttl.read();
308
309        if let Some(mut ctx) = self.contexts.get_mut(conversation_id) {
310            ctx.update(query, result);
311        } else {
312            // Enforce max contexts
313            if self.contexts.len() >= self.max_contexts {
314                self.cleanup_expired();
315            }
316
317            let mut ctx = ConversationContext::new(conversation_id).with_ttl(ttl);
318            ctx.update(query, result);
319            self.contexts.insert(conversation_id.to_string(), ctx);
320        }
321    }
322
323    /// Get fallback for a conversation
324    pub fn get_fallback(&self, conversation_id: &str) -> Option<ConversationContext> {
325        self.contexts
326            .get(conversation_id)
327            .filter(|ctx| ctx.is_valid())
328            .map(|ctx| ctx.clone())
329    }
330
331    /// Execute with fallback on error
332    pub fn execute_with_fallback<T, E>(
333        &self,
334        conversation_id: &str,
335        execute: impl FnOnce() -> Result<T, E>,
336        fallback: impl FnOnce(&ConversationContext) -> T,
337    ) -> Result<T, E>
338    where
339        E: std::fmt::Display,
340    {
341        match execute() {
342            Ok(result) => Ok(result),
343            Err(e) => {
344                if let Some(ctx) = self.get_fallback(conversation_id) {
345                    Ok(fallback(&ctx))
346                } else {
347                    Err(e) // Return original error
348                }
349            }
350        }
351    }
352
353    /// Cleanup expired contexts
354    pub fn cleanup_expired(&self) {
355        self.contexts.retain(|_, ctx| ctx.is_valid());
356    }
357
358    /// Remove specific conversation
359    pub fn remove(&self, conversation_id: &str) -> Option<ConversationContext> {
360        self.contexts.remove(conversation_id).map(|(_, ctx)| ctx)
361    }
362
363    /// Get number of cached contexts
364    pub fn cached_count(&self) -> usize {
365        self.contexts.len()
366    }
367
368    /// Set default TTL
369    pub fn set_default_ttl(&self, ttl: Duration) {
370        *self.default_ttl.write() = ttl;
371    }
372
373    /// Check if conversation has cached context
374    pub fn has_context(&self, conversation_id: &str) -> bool {
375        self.contexts
376            .get(conversation_id)
377            .map(|ctx| ctx.is_valid())
378            .unwrap_or(false)
379    }
380}
381
382impl Default for ConversationFallback {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_retry_decision_messages() {
394        let retry = RetryDecision::Retry {
395            delay: Duration::from_millis(100),
396            attempt: 1,
397        };
398        let msg = retry.to_llm_message();
399        assert!(msg.contains("retry"));
400        assert!(msg.contains("100"));
401
402        let fail = RetryDecision::Fail {
403            reason: "test error".to_string(),
404        };
405        let msg = fail.to_llm_message();
406        assert!(msg.contains("fail"));
407        assert!(msg.contains("test error"));
408    }
409
410    #[test]
411    fn test_retry_strategy_delay() {
412        let strategy = AgentRetryStrategy::new();
413
414        let delay0 = strategy.get_retry_delay(0);
415        let delay1 = strategy.get_retry_delay(1);
416        let delay2 = strategy.get_retry_delay(2);
417
418        // Each delay should be roughly 2x the previous
419        assert!(delay1 >= delay0);
420        assert!(delay2 >= delay1);
421        assert!(delay2 <= strategy.max_delay);
422    }
423
424    #[test]
425    fn test_retry_strategy_retryable() {
426        let strategy = AgentRetryStrategy::new();
427
428        assert!(strategy.is_retryable("circuit_open for node"));
429        assert!(strategy.is_retryable("rate_limit exceeded"));
430        assert!(strategy.is_retryable("connection timeout"));
431
432        assert!(!strategy.is_retryable("permission_denied"));
433        assert!(!strategy.is_retryable("authentication failed"));
434    }
435
436    #[test]
437    fn test_retry_strategy_should_retry() {
438        let strategy =
439            AgentRetryStrategy::with_config(Duration::from_millis(100), Duration::from_secs(10), 3);
440
441        // Should retry circuit_open
442        let decision = strategy.should_retry("circuit_open", 0);
443        assert!(decision.should_retry());
444
445        // Should not retry after max attempts
446        let decision = strategy.should_retry("circuit_open", 3);
447        assert!(!decision.should_retry());
448
449        // Should not retry non-retryable errors
450        let decision = strategy.should_retry("permission_denied", 0);
451        assert!(!decision.should_retry());
452    }
453
454    #[test]
455    fn test_conversation_context() {
456        let mut ctx = ConversationContext::new("conv-123")
457            .with_metadata("user", "alice")
458            .with_ttl(Duration::from_secs(60));
459
460        assert_eq!(ctx.conversation_id, "conv-123");
461        assert!(ctx.is_valid());
462
463        ctx.update("SELECT * FROM users", r#"[{"id": 1}]"#);
464        assert_eq!(ctx.last_query, Some("SELECT * FROM users".to_string()));
465    }
466
467    #[test]
468    fn test_conversation_fallback() {
469        let fallback = ConversationFallback::new();
470
471        fallback.update_context("conv-1", "query1", "result1");
472        assert!(fallback.has_context("conv-1"));
473        assert_eq!(fallback.cached_count(), 1);
474
475        let ctx = fallback.get_fallback("conv-1").unwrap();
476        assert_eq!(ctx.last_query, Some("query1".to_string()));
477        assert_eq!(ctx.last_result, Some("result1".to_string()));
478    }
479
480    #[test]
481    fn test_conversation_fallback_expired() {
482        let fallback = ConversationFallback::with_config(Duration::from_millis(10), 100);
483
484        fallback.update_context("conv-1", "query1", "result1");
485        assert!(fallback.has_context("conv-1"));
486
487        std::thread::sleep(Duration::from_millis(20));
488        assert!(!fallback.has_context("conv-1"));
489    }
490
491    #[test]
492    fn test_execute_with_fallback() {
493        let fallback = ConversationFallback::new();
494        fallback.update_context("conv-1", "query", "cached_result");
495
496        // Successful execution
497        let result: Result<String, &str> = fallback.execute_with_fallback(
498            "conv-1",
499            || Ok("new_result".to_string()),
500            |_| "fallback".to_string(),
501        );
502        assert_eq!(result.unwrap(), "new_result");
503    }
504}