agentic_workflow/resilience/
retry.rs1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use crate::types::{
6 FailureClass, RetryBudget, RetryPattern, RetryPolicy, RetryProfile,
7 RetryStats, RetryStrategy, WorkflowError, WorkflowResult,
8};
9
10pub struct RetryEngine {
12 policies: HashMap<String, RetryPolicy>,
13 stats: HashMap<String, RetryStats>,
14}
15
16impl RetryEngine {
17 pub fn new() -> Self {
18 Self {
19 policies: HashMap::new(),
20 stats: HashMap::new(),
21 }
22 }
23
24 pub fn configure_policy(
26 &mut self,
27 name: &str,
28 profiles: Vec<RetryProfile>,
29 budget: Option<RetryBudget>,
30 ) -> WorkflowResult<String> {
31 let id = Uuid::new_v4().to_string();
32 let policy = RetryPolicy {
33 id: id.clone(),
34 name: name.to_string(),
35 profiles,
36 budget,
37 escalation: None,
38 };
39
40 self.policies.insert(id.clone(), policy);
41 Ok(id)
42 }
43
44 pub fn get_policy(&self, policy_id: &str) -> WorkflowResult<&RetryPolicy> {
46 self.policies
47 .get(policy_id)
48 .ok_or_else(|| WorkflowError::Internal(format!("Policy not found: {}", policy_id)))
49 }
50
51 pub fn get_profile_for_failure(
53 &self,
54 policy_id: &str,
55 failure_class: &FailureClass,
56 ) -> WorkflowResult<Option<&RetryProfile>> {
57 let policy = self.get_policy(policy_id)?;
58 Ok(policy
59 .profiles
60 .iter()
61 .find(|p| p.failure_class == *failure_class))
62 }
63
64 pub fn calculate_delay(
66 &self,
67 strategy: &RetryStrategy,
68 attempt: u32,
69 ) -> u64 {
70 match strategy {
71 RetryStrategy::Immediate => 0,
72 RetryStrategy::FixedDelay { delay_ms } => *delay_ms,
73 RetryStrategy::ExponentialBackoff {
74 initial_ms,
75 max_ms,
76 multiplier,
77 } => {
78 let delay = (*initial_ms as f64) * multiplier.powi(attempt as i32);
79 (delay as u64).min(*max_ms)
80 }
81 RetryStrategy::Linear {
82 delay_ms,
83 increment_ms,
84 } => delay_ms + (increment_ms * attempt as u64),
85 }
86 }
87
88 pub fn within_budget(
90 &self,
91 policy_id: &str,
92 step_id: &str,
93 ) -> WorkflowResult<bool> {
94 let policy = self.get_policy(policy_id)?;
95
96 if let Some(budget) = &policy.budget {
97 if let Some(stats) = self.stats.get(step_id) {
98 if let Some(max) = budget.max_total_attempts {
99 if stats.total_attempts >= max {
100 return Ok(false);
101 }
102 }
103 }
104 }
105
106 Ok(true)
107 }
108
109 pub fn record_attempt(
111 &mut self,
112 step_id: &str,
113 failure_class: FailureClass,
114 ) {
115 let stats = self.stats.entry(step_id.to_string()).or_insert_with(|| {
116 RetryStats {
117 step_id: step_id.to_string(),
118 total_attempts: 0,
119 successes_by_attempt: Vec::new(),
120 avg_delay_ms: 0.0,
121 last_failure_class: None,
122 last_retry_at: None,
123 }
124 });
125
126 stats.total_attempts += 1;
127 stats.last_failure_class = Some(failure_class);
128 stats.last_retry_at = Some(chrono::Utc::now());
129 }
130
131 pub fn get_stats(&self, step_id: &str) -> Option<&RetryStats> {
133 self.stats.get(step_id)
134 }
135
136 pub fn get_patterns(&self) -> Vec<RetryPattern> {
138 self.stats
139 .values()
140 .map(|s| RetryPattern {
141 step_id: s.step_id.clone(),
142 optimal_delay_ms: s.avg_delay_ms as u64,
143 success_rate_by_attempt: s
144 .successes_by_attempt
145 .iter()
146 .map(|&v| v as f64)
147 .collect(),
148 recommendation: if s.total_attempts > 10 {
149 "Consider optimizing retry strategy based on patterns".to_string()
150 } else {
151 "Insufficient data for recommendation".to_string()
152 },
153 })
154 .collect()
155 }
156
157 pub fn list_policies(&self) -> Vec<&RetryPolicy> {
159 self.policies.values().collect()
160 }
161}
162
163impl Default for RetryEngine {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_exponential_backoff() {
175 let engine = RetryEngine::new();
176 let strategy = RetryStrategy::ExponentialBackoff {
177 initial_ms: 100,
178 max_ms: 10000,
179 multiplier: 2.0,
180 };
181
182 assert_eq!(engine.calculate_delay(&strategy, 0), 100);
183 assert_eq!(engine.calculate_delay(&strategy, 1), 200);
184 assert_eq!(engine.calculate_delay(&strategy, 2), 400);
185 assert_eq!(engine.calculate_delay(&strategy, 10), 10000); }
187}