heliosdb_proxy/circuit_breaker/
agent.rs1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11
12#[derive(Debug, Clone)]
14pub enum RetryDecision {
15 Retry { delay: Duration, attempt: u32 },
17 Fail { reason: String },
19 Fallback { message: String },
21}
22
23impl RetryDecision {
24 pub fn should_retry(&self) -> bool {
26 matches!(self, RetryDecision::Retry { .. })
27 }
28
29 pub fn retry_delay(&self) -> Option<Duration> {
31 match self {
32 RetryDecision::Retry { delay, .. } => Some(*delay),
33 _ => None,
34 }
35 }
36
37 pub fn to_llm_message(&self) -> String {
39 match self {
40 RetryDecision::Retry { delay, attempt } => {
41 format!(
42 r#"{{"action":"retry","delay_ms":{},"attempt":{},"message":"Request failed temporarily. Retry in {} milliseconds."}}"#,
43 delay.as_millis(),
44 attempt,
45 delay.as_millis()
46 )
47 }
48 RetryDecision::Fail { reason } => {
49 format!(
50 r#"{{"action":"fail","reason":"{}","message":"Request cannot be retried: {}"}}"#,
51 reason, reason
52 )
53 }
54 RetryDecision::Fallback { message } => {
55 format!(
56 r#"{{"action":"fallback","message":"{}","note":"Using cached or fallback data due to temporary outage."}}"#,
57 message
58 )
59 }
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct AgentRetryStrategy {
67 base_delay: Duration,
69 max_delay: Duration,
71 max_attempts: u32,
73 jitter_factor: f64,
75 retryable_patterns: Vec<String>,
77 non_retryable_patterns: Vec<String>,
79}
80
81impl Default for AgentRetryStrategy {
82 fn default() -> Self {
83 Self {
84 base_delay: Duration::from_millis(100),
85 max_delay: Duration::from_secs(30),
86 max_attempts: 5,
87 jitter_factor: 0.3,
88 retryable_patterns: vec![
89 "circuit_open".to_string(),
90 "rate_limit".to_string(),
91 "timeout".to_string(),
92 "connection".to_string(),
93 "temporary".to_string(),
94 "unavailable".to_string(),
95 ],
96 non_retryable_patterns: vec![
97 "invalid_query".to_string(),
98 "permission_denied".to_string(),
99 "authentication".to_string(),
100 "not_found".to_string(),
101 "constraint_violation".to_string(),
102 ],
103 }
104 }
105}
106
107impl AgentRetryStrategy {
108 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn with_config(
115 base_delay: Duration,
116 max_delay: Duration,
117 max_attempts: u32,
118 ) -> Self {
119 Self {
120 base_delay,
121 max_delay,
122 max_attempts,
123 ..Default::default()
124 }
125 }
126
127 pub fn with_jitter(mut self, factor: f64) -> Self {
129 self.jitter_factor = factor.clamp(0.0, 1.0);
130 self
131 }
132
133 pub fn with_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
135 self.retryable_patterns.push(pattern.into());
136 self
137 }
138
139 pub fn with_non_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
141 self.non_retryable_patterns.push(pattern.into());
142 self
143 }
144
145 pub fn get_retry_delay(&self, attempt: u32) -> Duration {
147 let exp_delay = self.base_delay * 2u32.pow(attempt.min(10));
149
150 let capped_delay = exp_delay.min(self.max_delay);
152
153 let jitter = rand::random::<f64>() * self.jitter_factor;
155 let jittered = capped_delay.mul_f64(1.0 + jitter);
156
157 jittered.min(self.max_delay)
158 }
159
160 pub fn is_retryable(&self, error: &str) -> bool {
162 let error_lower = error.to_lowercase();
163
164 for pattern in &self.non_retryable_patterns {
166 if error_lower.contains(pattern) {
167 return false;
168 }
169 }
170
171 for pattern in &self.retryable_patterns {
173 if error_lower.contains(pattern) {
174 return true;
175 }
176 }
177
178 true
180 }
181
182 pub fn should_retry(&self, error: &str, attempt: u32) -> RetryDecision {
184 if attempt >= self.max_attempts {
185 return RetryDecision::Fail {
186 reason: format!(
187 "Maximum retry attempts ({}) exceeded",
188 self.max_attempts
189 ),
190 };
191 }
192
193 if !self.is_retryable(error) {
194 return RetryDecision::Fail {
195 reason: format!("Error is not retryable: {}", error),
196 };
197 }
198
199 let delay = self.get_retry_delay(attempt);
200 RetryDecision::Retry {
201 delay,
202 attempt: attempt + 1,
203 }
204 }
205
206 pub fn get_delay_for_error(&self, error: &str, attempt: u32) -> Option<Duration> {
208 let decision = self.should_retry(error, attempt);
209 decision.retry_delay()
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct ConversationContext {
216 pub conversation_id: String,
218 pub last_query: Option<String>,
220 pub last_result: Option<String>,
222 pub metadata: HashMap<String, String>,
224 pub cached_at: Instant,
226 pub ttl: Duration,
228}
229
230impl ConversationContext {
231 pub fn new(conversation_id: impl Into<String>) -> Self {
233 Self {
234 conversation_id: conversation_id.into(),
235 last_query: None,
236 last_result: None,
237 metadata: HashMap::new(),
238 cached_at: Instant::now(),
239 ttl: Duration::from_secs(3600), }
241 }
242
243 pub fn update(&mut self, query: impl Into<String>, result: impl Into<String>) {
245 self.last_query = Some(query.into());
246 self.last_result = Some(result.into());
247 self.cached_at = Instant::now();
248 }
249
250 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
252 self.metadata.insert(key.into(), value.into());
253 self
254 }
255
256 pub fn with_ttl(mut self, ttl: Duration) -> Self {
258 self.ttl = ttl;
259 self
260 }
261
262 pub fn is_valid(&self) -> bool {
264 self.cached_at.elapsed() < self.ttl
265 }
266
267 pub fn age(&self) -> Duration {
269 self.cached_at.elapsed()
270 }
271}
272
273pub struct ConversationFallback {
278 contexts: DashMap<String, ConversationContext>,
280
281 default_ttl: RwLock<Duration>,
283
284 max_contexts: usize,
286}
287
288impl ConversationFallback {
289 pub fn new() -> Self {
291 Self {
292 contexts: DashMap::new(),
293 default_ttl: RwLock::new(Duration::from_secs(3600)),
294 max_contexts: 10000,
295 }
296 }
297
298 pub fn with_config(default_ttl: Duration, max_contexts: usize) -> Self {
300 Self {
301 contexts: DashMap::new(),
302 default_ttl: RwLock::new(default_ttl),
303 max_contexts,
304 }
305 }
306
307 pub fn update_context(
309 &self,
310 conversation_id: &str,
311 query: impl Into<String>,
312 result: impl Into<String>,
313 ) {
314 let ttl = *self.default_ttl.read();
315
316 if let Some(mut ctx) = self.contexts.get_mut(conversation_id) {
317 ctx.update(query, result);
318 } else {
319 if self.contexts.len() >= self.max_contexts {
321 self.cleanup_expired();
322 }
323
324 let mut ctx = ConversationContext::new(conversation_id).with_ttl(ttl);
325 ctx.update(query, result);
326 self.contexts.insert(conversation_id.to_string(), ctx);
327 }
328 }
329
330 pub fn get_fallback(&self, conversation_id: &str) -> Option<ConversationContext> {
332 self.contexts
333 .get(conversation_id)
334 .filter(|ctx| ctx.is_valid())
335 .map(|ctx| ctx.clone())
336 }
337
338 pub fn execute_with_fallback<T, E>(
340 &self,
341 conversation_id: &str,
342 execute: impl FnOnce() -> Result<T, E>,
343 fallback: impl FnOnce(&ConversationContext) -> T,
344 ) -> Result<T, E>
345 where
346 E: std::fmt::Display,
347 {
348 match execute() {
349 Ok(result) => Ok(result),
350 Err(e) => {
351 if let Some(ctx) = self.get_fallback(conversation_id) {
352 Ok(fallback(&ctx))
353 } else {
354 Err(e) }
356 }
357 }
358 }
359
360 pub fn cleanup_expired(&self) {
362 self.contexts.retain(|_, ctx| ctx.is_valid());
363 }
364
365 pub fn remove(&self, conversation_id: &str) -> Option<ConversationContext> {
367 self.contexts.remove(conversation_id).map(|(_, ctx)| ctx)
368 }
369
370 pub fn cached_count(&self) -> usize {
372 self.contexts.len()
373 }
374
375 pub fn set_default_ttl(&self, ttl: Duration) {
377 *self.default_ttl.write() = ttl;
378 }
379
380 pub fn has_context(&self, conversation_id: &str) -> bool {
382 self.contexts
383 .get(conversation_id)
384 .map(|ctx| ctx.is_valid())
385 .unwrap_or(false)
386 }
387}
388
389impl Default for ConversationFallback {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_retry_decision_messages() {
401 let retry = RetryDecision::Retry {
402 delay: Duration::from_millis(100),
403 attempt: 1,
404 };
405 let msg = retry.to_llm_message();
406 assert!(msg.contains("retry"));
407 assert!(msg.contains("100"));
408
409 let fail = RetryDecision::Fail {
410 reason: "test error".to_string(),
411 };
412 let msg = fail.to_llm_message();
413 assert!(msg.contains("fail"));
414 assert!(msg.contains("test error"));
415 }
416
417 #[test]
418 fn test_retry_strategy_delay() {
419 let strategy = AgentRetryStrategy::new();
420
421 let delay0 = strategy.get_retry_delay(0);
422 let delay1 = strategy.get_retry_delay(1);
423 let delay2 = strategy.get_retry_delay(2);
424
425 assert!(delay1 >= delay0);
427 assert!(delay2 >= delay1);
428 assert!(delay2 <= strategy.max_delay);
429 }
430
431 #[test]
432 fn test_retry_strategy_retryable() {
433 let strategy = AgentRetryStrategy::new();
434
435 assert!(strategy.is_retryable("circuit_open for node"));
436 assert!(strategy.is_retryable("rate_limit exceeded"));
437 assert!(strategy.is_retryable("connection timeout"));
438
439 assert!(!strategy.is_retryable("permission_denied"));
440 assert!(!strategy.is_retryable("authentication failed"));
441 }
442
443 #[test]
444 fn test_retry_strategy_should_retry() {
445 let strategy = AgentRetryStrategy::with_config(
446 Duration::from_millis(100),
447 Duration::from_secs(10),
448 3,
449 );
450
451 let decision = strategy.should_retry("circuit_open", 0);
453 assert!(decision.should_retry());
454
455 let decision = strategy.should_retry("circuit_open", 3);
457 assert!(!decision.should_retry());
458
459 let decision = strategy.should_retry("permission_denied", 0);
461 assert!(!decision.should_retry());
462 }
463
464 #[test]
465 fn test_conversation_context() {
466 let mut ctx = ConversationContext::new("conv-123")
467 .with_metadata("user", "alice")
468 .with_ttl(Duration::from_secs(60));
469
470 assert_eq!(ctx.conversation_id, "conv-123");
471 assert!(ctx.is_valid());
472
473 ctx.update("SELECT * FROM users", r#"[{"id": 1}]"#);
474 assert_eq!(ctx.last_query, Some("SELECT * FROM users".to_string()));
475 }
476
477 #[test]
478 fn test_conversation_fallback() {
479 let fallback = ConversationFallback::new();
480
481 fallback.update_context("conv-1", "query1", "result1");
482 assert!(fallback.has_context("conv-1"));
483 assert_eq!(fallback.cached_count(), 1);
484
485 let ctx = fallback.get_fallback("conv-1").unwrap();
486 assert_eq!(ctx.last_query, Some("query1".to_string()));
487 assert_eq!(ctx.last_result, Some("result1".to_string()));
488 }
489
490 #[test]
491 fn test_conversation_fallback_expired() {
492 let fallback = ConversationFallback::with_config(
493 Duration::from_millis(10),
494 100,
495 );
496
497 fallback.update_context("conv-1", "query1", "result1");
498 assert!(fallback.has_context("conv-1"));
499
500 std::thread::sleep(Duration::from_millis(20));
501 assert!(!fallback.has_context("conv-1"));
502 }
503
504 #[test]
505 fn test_execute_with_fallback() {
506 let fallback = ConversationFallback::new();
507 fallback.update_context("conv-1", "query", "cached_result");
508
509 let result: Result<String, &str> =
511 fallback.execute_with_fallback("conv-1", || Ok("new_result".to_string()), |_| {
512 "fallback".to_string()
513 });
514 assert_eq!(result.unwrap(), "new_result");
515 }
516}