heliosdb_proxy/circuit_breaker/
agent.rs1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11
12#[derive(Debug, Clone)]
14pub enum RetryDecision {
15 Retry { delay: Duration, attempt: u32 },
17 Fail { reason: String },
19 Fallback { message: String },
21}
22
23impl RetryDecision {
24 pub fn should_retry(&self) -> bool {
26 matches!(self, RetryDecision::Retry { .. })
27 }
28
29 pub fn retry_delay(&self) -> Option<Duration> {
31 match self {
32 RetryDecision::Retry { delay, .. } => Some(*delay),
33 _ => None,
34 }
35 }
36
37 pub fn to_llm_message(&self) -> String {
39 match self {
40 RetryDecision::Retry { delay, attempt } => {
41 format!(
42 r#"{{"action":"retry","delay_ms":{},"attempt":{},"message":"Request failed temporarily. Retry in {} milliseconds."}}"#,
43 delay.as_millis(),
44 attempt,
45 delay.as_millis()
46 )
47 }
48 RetryDecision::Fail { reason } => {
49 format!(
50 r#"{{"action":"fail","reason":"{}","message":"Request cannot be retried: {}"}}"#,
51 reason, reason
52 )
53 }
54 RetryDecision::Fallback { message } => {
55 format!(
56 r#"{{"action":"fallback","message":"{}","note":"Using cached or fallback data due to temporary outage."}}"#,
57 message
58 )
59 }
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct AgentRetryStrategy {
67 base_delay: Duration,
69 max_delay: Duration,
71 max_attempts: u32,
73 jitter_factor: f64,
75 retryable_patterns: Vec<String>,
77 non_retryable_patterns: Vec<String>,
79}
80
81impl Default for AgentRetryStrategy {
82 fn default() -> Self {
83 Self {
84 base_delay: Duration::from_millis(100),
85 max_delay: Duration::from_secs(30),
86 max_attempts: 5,
87 jitter_factor: 0.3,
88 retryable_patterns: vec![
89 "circuit_open".to_string(),
90 "rate_limit".to_string(),
91 "timeout".to_string(),
92 "connection".to_string(),
93 "temporary".to_string(),
94 "unavailable".to_string(),
95 ],
96 non_retryable_patterns: vec![
97 "invalid_query".to_string(),
98 "permission_denied".to_string(),
99 "authentication".to_string(),
100 "not_found".to_string(),
101 "constraint_violation".to_string(),
102 ],
103 }
104 }
105}
106
107impl AgentRetryStrategy {
108 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn with_config(base_delay: Duration, max_delay: Duration, max_attempts: u32) -> Self {
115 Self {
116 base_delay,
117 max_delay,
118 max_attempts,
119 ..Default::default()
120 }
121 }
122
123 pub fn with_jitter(mut self, factor: f64) -> Self {
125 self.jitter_factor = factor.clamp(0.0, 1.0);
126 self
127 }
128
129 pub fn with_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
131 self.retryable_patterns.push(pattern.into());
132 self
133 }
134
135 pub fn with_non_retryable_pattern(mut self, pattern: impl Into<String>) -> Self {
137 self.non_retryable_patterns.push(pattern.into());
138 self
139 }
140
141 pub fn get_retry_delay(&self, attempt: u32) -> Duration {
143 let exp_delay = self.base_delay * 2u32.pow(attempt.min(10));
145
146 let capped_delay = exp_delay.min(self.max_delay);
148
149 let jitter = rand::random::<f64>() * self.jitter_factor;
151 let jittered = capped_delay.mul_f64(1.0 + jitter);
152
153 jittered.min(self.max_delay)
154 }
155
156 pub fn is_retryable(&self, error: &str) -> bool {
158 let error_lower = error.to_lowercase();
159
160 for pattern in &self.non_retryable_patterns {
162 if error_lower.contains(pattern) {
163 return false;
164 }
165 }
166
167 for pattern in &self.retryable_patterns {
169 if error_lower.contains(pattern) {
170 return true;
171 }
172 }
173
174 true
176 }
177
178 pub fn should_retry(&self, error: &str, attempt: u32) -> RetryDecision {
180 if attempt >= self.max_attempts {
181 return RetryDecision::Fail {
182 reason: format!("Maximum retry attempts ({}) exceeded", self.max_attempts),
183 };
184 }
185
186 if !self.is_retryable(error) {
187 return RetryDecision::Fail {
188 reason: format!("Error is not retryable: {}", error),
189 };
190 }
191
192 let delay = self.get_retry_delay(attempt);
193 RetryDecision::Retry {
194 delay,
195 attempt: attempt + 1,
196 }
197 }
198
199 pub fn get_delay_for_error(&self, error: &str, attempt: u32) -> Option<Duration> {
201 let decision = self.should_retry(error, attempt);
202 decision.retry_delay()
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct ConversationContext {
209 pub conversation_id: String,
211 pub last_query: Option<String>,
213 pub last_result: Option<String>,
215 pub metadata: HashMap<String, String>,
217 pub cached_at: Instant,
219 pub ttl: Duration,
221}
222
223impl ConversationContext {
224 pub fn new(conversation_id: impl Into<String>) -> Self {
226 Self {
227 conversation_id: conversation_id.into(),
228 last_query: None,
229 last_result: None,
230 metadata: HashMap::new(),
231 cached_at: Instant::now(),
232 ttl: Duration::from_secs(3600), }
234 }
235
236 pub fn update(&mut self, query: impl Into<String>, result: impl Into<String>) {
238 self.last_query = Some(query.into());
239 self.last_result = Some(result.into());
240 self.cached_at = Instant::now();
241 }
242
243 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
245 self.metadata.insert(key.into(), value.into());
246 self
247 }
248
249 pub fn with_ttl(mut self, ttl: Duration) -> Self {
251 self.ttl = ttl;
252 self
253 }
254
255 pub fn is_valid(&self) -> bool {
257 self.cached_at.elapsed() < self.ttl
258 }
259
260 pub fn age(&self) -> Duration {
262 self.cached_at.elapsed()
263 }
264}
265
266pub struct ConversationFallback {
271 contexts: DashMap<String, ConversationContext>,
273
274 default_ttl: RwLock<Duration>,
276
277 max_contexts: usize,
279}
280
281impl ConversationFallback {
282 pub fn new() -> Self {
284 Self {
285 contexts: DashMap::new(),
286 default_ttl: RwLock::new(Duration::from_secs(3600)),
287 max_contexts: 10000,
288 }
289 }
290
291 pub fn with_config(default_ttl: Duration, max_contexts: usize) -> Self {
293 Self {
294 contexts: DashMap::new(),
295 default_ttl: RwLock::new(default_ttl),
296 max_contexts,
297 }
298 }
299
300 pub fn update_context(
302 &self,
303 conversation_id: &str,
304 query: impl Into<String>,
305 result: impl Into<String>,
306 ) {
307 let ttl = *self.default_ttl.read();
308
309 if let Some(mut ctx) = self.contexts.get_mut(conversation_id) {
310 ctx.update(query, result);
311 } else {
312 if self.contexts.len() >= self.max_contexts {
314 self.cleanup_expired();
315 }
316
317 let mut ctx = ConversationContext::new(conversation_id).with_ttl(ttl);
318 ctx.update(query, result);
319 self.contexts.insert(conversation_id.to_string(), ctx);
320 }
321 }
322
323 pub fn get_fallback(&self, conversation_id: &str) -> Option<ConversationContext> {
325 self.contexts
326 .get(conversation_id)
327 .filter(|ctx| ctx.is_valid())
328 .map(|ctx| ctx.clone())
329 }
330
331 pub fn execute_with_fallback<T, E>(
333 &self,
334 conversation_id: &str,
335 execute: impl FnOnce() -> Result<T, E>,
336 fallback: impl FnOnce(&ConversationContext) -> T,
337 ) -> Result<T, E>
338 where
339 E: std::fmt::Display,
340 {
341 match execute() {
342 Ok(result) => Ok(result),
343 Err(e) => {
344 if let Some(ctx) = self.get_fallback(conversation_id) {
345 Ok(fallback(&ctx))
346 } else {
347 Err(e) }
349 }
350 }
351 }
352
353 pub fn cleanup_expired(&self) {
355 self.contexts.retain(|_, ctx| ctx.is_valid());
356 }
357
358 pub fn remove(&self, conversation_id: &str) -> Option<ConversationContext> {
360 self.contexts.remove(conversation_id).map(|(_, ctx)| ctx)
361 }
362
363 pub fn cached_count(&self) -> usize {
365 self.contexts.len()
366 }
367
368 pub fn set_default_ttl(&self, ttl: Duration) {
370 *self.default_ttl.write() = ttl;
371 }
372
373 pub fn has_context(&self, conversation_id: &str) -> bool {
375 self.contexts
376 .get(conversation_id)
377 .map(|ctx| ctx.is_valid())
378 .unwrap_or(false)
379 }
380}
381
382impl Default for ConversationFallback {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_retry_decision_messages() {
394 let retry = RetryDecision::Retry {
395 delay: Duration::from_millis(100),
396 attempt: 1,
397 };
398 let msg = retry.to_llm_message();
399 assert!(msg.contains("retry"));
400 assert!(msg.contains("100"));
401
402 let fail = RetryDecision::Fail {
403 reason: "test error".to_string(),
404 };
405 let msg = fail.to_llm_message();
406 assert!(msg.contains("fail"));
407 assert!(msg.contains("test error"));
408 }
409
410 #[test]
411 fn test_retry_strategy_delay() {
412 let strategy = AgentRetryStrategy::new();
413
414 let delay0 = strategy.get_retry_delay(0);
415 let delay1 = strategy.get_retry_delay(1);
416 let delay2 = strategy.get_retry_delay(2);
417
418 assert!(delay1 >= delay0);
420 assert!(delay2 >= delay1);
421 assert!(delay2 <= strategy.max_delay);
422 }
423
424 #[test]
425 fn test_retry_strategy_retryable() {
426 let strategy = AgentRetryStrategy::new();
427
428 assert!(strategy.is_retryable("circuit_open for node"));
429 assert!(strategy.is_retryable("rate_limit exceeded"));
430 assert!(strategy.is_retryable("connection timeout"));
431
432 assert!(!strategy.is_retryable("permission_denied"));
433 assert!(!strategy.is_retryable("authentication failed"));
434 }
435
436 #[test]
437 fn test_retry_strategy_should_retry() {
438 let strategy =
439 AgentRetryStrategy::with_config(Duration::from_millis(100), Duration::from_secs(10), 3);
440
441 let decision = strategy.should_retry("circuit_open", 0);
443 assert!(decision.should_retry());
444
445 let decision = strategy.should_retry("circuit_open", 3);
447 assert!(!decision.should_retry());
448
449 let decision = strategy.should_retry("permission_denied", 0);
451 assert!(!decision.should_retry());
452 }
453
454 #[test]
455 fn test_conversation_context() {
456 let mut ctx = ConversationContext::new("conv-123")
457 .with_metadata("user", "alice")
458 .with_ttl(Duration::from_secs(60));
459
460 assert_eq!(ctx.conversation_id, "conv-123");
461 assert!(ctx.is_valid());
462
463 ctx.update("SELECT * FROM users", r#"[{"id": 1}]"#);
464 assert_eq!(ctx.last_query, Some("SELECT * FROM users".to_string()));
465 }
466
467 #[test]
468 fn test_conversation_fallback() {
469 let fallback = ConversationFallback::new();
470
471 fallback.update_context("conv-1", "query1", "result1");
472 assert!(fallback.has_context("conv-1"));
473 assert_eq!(fallback.cached_count(), 1);
474
475 let ctx = fallback.get_fallback("conv-1").unwrap();
476 assert_eq!(ctx.last_query, Some("query1".to_string()));
477 assert_eq!(ctx.last_result, Some("result1".to_string()));
478 }
479
480 #[test]
481 fn test_conversation_fallback_expired() {
482 let fallback = ConversationFallback::with_config(Duration::from_millis(10), 100);
483
484 fallback.update_context("conv-1", "query1", "result1");
485 assert!(fallback.has_context("conv-1"));
486
487 std::thread::sleep(Duration::from_millis(20));
488 assert!(!fallback.has_context("conv-1"));
489 }
490
491 #[test]
492 fn test_execute_with_fallback() {
493 let fallback = ConversationFallback::new();
494 fallback.update_context("conv-1", "query", "cached_result");
495
496 let result: Result<String, &str> = fallback.execute_with_fallback(
498 "conv-1",
499 || Ok("new_result".to_string()),
500 |_| "fallback".to_string(),
501 );
502 assert_eq!(result.unwrap(), "new_result");
503 }
504}