Skip to main content

jamjet_core/
retry.rs

1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4/// Defines how a node retries on failure.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct RetryPolicy {
7    /// Maximum number of attempts (including the first). Must be >= 1.
8    pub max_attempts: u32,
9    /// Backoff strategy between attempts.
10    pub backoff: BackoffStrategy,
11    /// Initial delay before the first retry.
12    #[serde(with = "duration_secs")]
13    pub initial_delay: Duration,
14    /// Maximum delay cap (exponential backoff will not exceed this).
15    #[serde(with = "duration_secs")]
16    pub max_delay: Duration,
17    /// Whether to add random jitter to delays (prevents thundering herd).
18    pub jitter: bool,
19    /// Which error classes are retryable. Empty = retry on any error.
20    pub retryable_on: Vec<ErrorClass>,
21}
22
23impl RetryPolicy {
24    /// A sensible default for I/O-bound tool calls.
25    pub fn io_default() -> Self {
26        Self {
27            max_attempts: 3,
28            backoff: BackoffStrategy::Exponential,
29            initial_delay: Duration::from_secs(1),
30            max_delay: Duration::from_secs(30),
31            jitter: true,
32            retryable_on: vec![
33                ErrorClass::IoError,
34                ErrorClass::Timeout,
35                ErrorClass::ConnectionReset,
36            ],
37        }
38    }
39
40    /// A sensible default for LLM calls.
41    pub fn llm_default() -> Self {
42        Self {
43            max_attempts: 3,
44            backoff: BackoffStrategy::Exponential,
45            initial_delay: Duration::from_secs(2),
46            max_delay: Duration::from_secs(60),
47            jitter: true,
48            retryable_on: vec![
49                ErrorClass::RateLimit,
50                ErrorClass::Timeout,
51                ErrorClass::ServerError,
52            ],
53        }
54    }
55
56    /// No retries — fail immediately on any error.
57    pub fn no_retry() -> Self {
58        Self {
59            max_attempts: 1,
60            backoff: BackoffStrategy::Fixed,
61            initial_delay: Duration::ZERO,
62            max_delay: Duration::ZERO,
63            jitter: false,
64            retryable_on: vec![],
65        }
66    }
67
68    /// Compute the delay before the nth retry (0-indexed).
69    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
70        if attempt == 0 {
71            return Duration::ZERO;
72        }
73        let base = match self.backoff {
74            BackoffStrategy::Fixed => self.initial_delay,
75            BackoffStrategy::Linear => self.initial_delay * attempt,
76            BackoffStrategy::Exponential => {
77                let factor = 2u64.saturating_pow(attempt - 1);
78                self.initial_delay.saturating_mul(factor as u32)
79            }
80        };
81        let capped = base.min(self.max_delay);
82        if self.jitter {
83            // Simple jitter: randomize between 50% and 100% of the delay.
84            // In production code, use a proper RNG passed in.
85            capped
86        } else {
87            capped
88        }
89    }
90}
91
92/// Backoff strategy between retry attempts.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum BackoffStrategy {
96    /// Same delay every time.
97    Fixed,
98    /// Linearly increasing delay.
99    Linear,
100    /// Exponentially increasing delay (2^n * initial_delay).
101    Exponential,
102}
103
104/// Categories of errors that can be retried.
105#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum ErrorClass {
108    IoError,
109    Timeout,
110    RateLimit,
111    ServerError,
112    ConnectionReset,
113    Custom(String),
114}
115
116// Serialize Duration as integer seconds for YAML/JSON friendliness.
117mod duration_secs {
118    use serde::{Deserialize, Deserializer, Serializer};
119    use std::time::Duration;
120
121    pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
122        s.serialize_u64(d.as_secs())
123    }
124
125    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
126        let secs = u64::deserialize(d)?;
127        Ok(Duration::from_secs(secs))
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn exponential_delay() {
137        let policy = RetryPolicy {
138            max_attempts: 5,
139            backoff: BackoffStrategy::Exponential,
140            initial_delay: Duration::from_secs(1),
141            max_delay: Duration::from_secs(30),
142            jitter: false,
143            retryable_on: vec![],
144        };
145        assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
146        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(1));
147        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(2));
148        assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(4));
149        // Capped at max_delay
150        assert_eq!(policy.delay_for_attempt(6), Duration::from_secs(30));
151    }
152}