Skip to main content

agentic_workflow/resilience/
retry.rs

1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use crate::types::{
6    FailureClass, RetryBudget, RetryPattern, RetryPolicy, RetryProfile,
7    RetryStats, RetryStrategy, WorkflowError, WorkflowResult,
8};
9
10/// Failure-classified retry engine.
11pub struct RetryEngine {
12    policies: HashMap<String, RetryPolicy>,
13    stats: HashMap<String, RetryStats>,
14}
15
16impl RetryEngine {
17    pub fn new() -> Self {
18        Self {
19            policies: HashMap::new(),
20            stats: HashMap::new(),
21        }
22    }
23
24    /// Configure a retry policy.
25    pub fn configure_policy(
26        &mut self,
27        name: &str,
28        profiles: Vec<RetryProfile>,
29        budget: Option<RetryBudget>,
30    ) -> WorkflowResult<String> {
31        let id = Uuid::new_v4().to_string();
32        let policy = RetryPolicy {
33            id: id.clone(),
34            name: name.to_string(),
35            profiles,
36            budget,
37            escalation: None,
38        };
39
40        self.policies.insert(id.clone(), policy);
41        Ok(id)
42    }
43
44    /// Get a retry policy.
45    pub fn get_policy(&self, policy_id: &str) -> WorkflowResult<&RetryPolicy> {
46        self.policies
47            .get(policy_id)
48            .ok_or_else(|| WorkflowError::Internal(format!("Policy not found: {}", policy_id)))
49    }
50
51    /// Get retry profile for a specific failure class.
52    pub fn get_profile_for_failure(
53        &self,
54        policy_id: &str,
55        failure_class: &FailureClass,
56    ) -> WorkflowResult<Option<&RetryProfile>> {
57        let policy = self.get_policy(policy_id)?;
58        Ok(policy
59            .profiles
60            .iter()
61            .find(|p| p.failure_class == *failure_class))
62    }
63
64    /// Calculate delay for next retry attempt.
65    pub fn calculate_delay(
66        &self,
67        strategy: &RetryStrategy,
68        attempt: u32,
69    ) -> u64 {
70        match strategy {
71            RetryStrategy::Immediate => 0,
72            RetryStrategy::FixedDelay { delay_ms } => *delay_ms,
73            RetryStrategy::ExponentialBackoff {
74                initial_ms,
75                max_ms,
76                multiplier,
77            } => {
78                let delay = (*initial_ms as f64) * multiplier.powi(attempt as i32);
79                (delay as u64).min(*max_ms)
80            }
81            RetryStrategy::Linear {
82                delay_ms,
83                increment_ms,
84            } => delay_ms + (increment_ms * attempt as u64),
85        }
86    }
87
88    /// Check if retry is within budget.
89    pub fn within_budget(
90        &self,
91        policy_id: &str,
92        step_id: &str,
93    ) -> WorkflowResult<bool> {
94        let policy = self.get_policy(policy_id)?;
95
96        if let Some(budget) = &policy.budget {
97            if let Some(stats) = self.stats.get(step_id) {
98                if let Some(max) = budget.max_total_attempts {
99                    if stats.total_attempts >= max {
100                        return Ok(false);
101                    }
102                }
103            }
104        }
105
106        Ok(true)
107    }
108
109    /// Record a retry attempt.
110    pub fn record_attempt(
111        &mut self,
112        step_id: &str,
113        failure_class: FailureClass,
114    ) {
115        let stats = self.stats.entry(step_id.to_string()).or_insert_with(|| {
116            RetryStats {
117                step_id: step_id.to_string(),
118                total_attempts: 0,
119                successes_by_attempt: Vec::new(),
120                avg_delay_ms: 0.0,
121                last_failure_class: None,
122                last_retry_at: None,
123            }
124        });
125
126        stats.total_attempts += 1;
127        stats.last_failure_class = Some(failure_class);
128        stats.last_retry_at = Some(chrono::Utc::now());
129    }
130
131    /// Get retry stats for a step.
132    pub fn get_stats(&self, step_id: &str) -> Option<&RetryStats> {
133        self.stats.get(step_id)
134    }
135
136    /// Get learned retry patterns.
137    pub fn get_patterns(&self) -> Vec<RetryPattern> {
138        self.stats
139            .values()
140            .map(|s| RetryPattern {
141                step_id: s.step_id.clone(),
142                optimal_delay_ms: s.avg_delay_ms as u64,
143                success_rate_by_attempt: s
144                    .successes_by_attempt
145                    .iter()
146                    .map(|&v| v as f64)
147                    .collect(),
148                recommendation: if s.total_attempts > 10 {
149                    "Consider optimizing retry strategy based on patterns".to_string()
150                } else {
151                    "Insufficient data for recommendation".to_string()
152                },
153            })
154            .collect()
155    }
156
157    /// List all policies.
158    pub fn list_policies(&self) -> Vec<&RetryPolicy> {
159        self.policies.values().collect()
160    }
161}
162
163impl Default for RetryEngine {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_exponential_backoff() {
175        let engine = RetryEngine::new();
176        let strategy = RetryStrategy::ExponentialBackoff {
177            initial_ms: 100,
178            max_ms: 10000,
179            multiplier: 2.0,
180        };
181
182        assert_eq!(engine.calculate_delay(&strategy, 0), 100);
183        assert_eq!(engine.calculate_delay(&strategy, 1), 200);
184        assert_eq!(engine.calculate_delay(&strategy, 2), 400);
185        assert_eq!(engine.calculate_delay(&strategy, 10), 10000); // capped
186    }
187}