1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct RetryPolicy {
7 pub max_attempts: u32,
9 pub backoff: BackoffStrategy,
11 #[serde(with = "duration_secs")]
13 pub initial_delay: Duration,
14 #[serde(with = "duration_secs")]
16 pub max_delay: Duration,
17 pub jitter: bool,
19 pub retryable_on: Vec<ErrorClass>,
21}
22
23impl RetryPolicy {
24 pub fn io_default() -> Self {
26 Self {
27 max_attempts: 3,
28 backoff: BackoffStrategy::Exponential,
29 initial_delay: Duration::from_secs(1),
30 max_delay: Duration::from_secs(30),
31 jitter: true,
32 retryable_on: vec![
33 ErrorClass::IoError,
34 ErrorClass::Timeout,
35 ErrorClass::ConnectionReset,
36 ],
37 }
38 }
39
40 pub fn llm_default() -> Self {
42 Self {
43 max_attempts: 3,
44 backoff: BackoffStrategy::Exponential,
45 initial_delay: Duration::from_secs(2),
46 max_delay: Duration::from_secs(60),
47 jitter: true,
48 retryable_on: vec![
49 ErrorClass::RateLimit,
50 ErrorClass::Timeout,
51 ErrorClass::ServerError,
52 ],
53 }
54 }
55
56 pub fn no_retry() -> Self {
58 Self {
59 max_attempts: 1,
60 backoff: BackoffStrategy::Fixed,
61 initial_delay: Duration::ZERO,
62 max_delay: Duration::ZERO,
63 jitter: false,
64 retryable_on: vec![],
65 }
66 }
67
68 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
70 if attempt == 0 {
71 return Duration::ZERO;
72 }
73 let base = match self.backoff {
74 BackoffStrategy::Fixed => self.initial_delay,
75 BackoffStrategy::Linear => self.initial_delay * attempt,
76 BackoffStrategy::Exponential => {
77 let factor = 2u64.saturating_pow(attempt - 1);
78 self.initial_delay.saturating_mul(factor as u32)
79 }
80 };
81 let capped = base.min(self.max_delay);
82 if self.jitter {
83 capped
86 } else {
87 capped
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum BackoffStrategy {
96 Fixed,
98 Linear,
100 Exponential,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum ErrorClass {
108 IoError,
109 Timeout,
110 RateLimit,
111 ServerError,
112 ConnectionReset,
113 Custom(String),
114}
115
116mod duration_secs {
118 use serde::{Deserialize, Deserializer, Serializer};
119 use std::time::Duration;
120
121 pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
122 s.serialize_u64(d.as_secs())
123 }
124
125 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
126 let secs = u64::deserialize(d)?;
127 Ok(Duration::from_secs(secs))
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn exponential_delay() {
137 let policy = RetryPolicy {
138 max_attempts: 5,
139 backoff: BackoffStrategy::Exponential,
140 initial_delay: Duration::from_secs(1),
141 max_delay: Duration::from_secs(30),
142 jitter: false,
143 retryable_on: vec![],
144 };
145 assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
146 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(1));
147 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(2));
148 assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(4));
149 assert_eq!(policy.delay_for_attempt(6), Duration::from_secs(30));
151 }
152}