cognis_core/wrappers/
retry.rs1use std::marker::PhantomData;
4use std::time::Duration;
5
6use async_trait::async_trait;
7
8use crate::runnable::{Runnable, RunnableConfig};
9use crate::{CognisError, Result};
10
11#[derive(Debug, Clone)]
13pub struct RetryPolicy {
14 pub max_attempts: u32,
16 pub initial_delay: Duration,
18 pub backoff_multiplier: f64,
20 pub max_delay: Duration,
22}
23
24impl Default for RetryPolicy {
25 fn default() -> Self {
26 Self {
27 max_attempts: 3,
28 initial_delay: Duration::from_millis(100),
29 backoff_multiplier: 2.0,
30 max_delay: Duration::from_secs(30),
31 }
32 }
33}
34
35impl RetryPolicy {
36 pub fn new(max_attempts: u32) -> Self {
38 Self {
39 max_attempts,
40 ..Default::default()
41 }
42 }
43 pub fn with_initial_delay(mut self, d: Duration) -> Self {
45 self.initial_delay = d;
46 self
47 }
48 pub fn with_backoff(mut self, factor: f64) -> Self {
50 self.backoff_multiplier = factor;
51 self
52 }
53 pub fn with_max_delay(mut self, d: Duration) -> Self {
55 self.max_delay = d;
56 self
57 }
58}
59
60pub struct Retry<R, I, O> {
66 inner: R,
67 policy: RetryPolicy,
68 _phantom: PhantomData<fn(I) -> O>,
69}
70
71impl<R, I, O> Retry<R, I, O>
72where
73 R: Runnable<I, O>,
74 I: Clone + Send + 'static,
75 O: Send + 'static,
76{
77 pub fn new(inner: R, policy: RetryPolicy) -> Self {
79 Self {
80 inner,
81 policy,
82 _phantom: PhantomData,
83 }
84 }
85}
86
87#[async_trait]
88impl<R, I, O> Runnable<I, O> for Retry<R, I, O>
89where
90 R: Runnable<I, O>,
91 I: Clone + Send + 'static,
92 O: Send + 'static,
93{
94 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
95 let mut delay = self.policy.initial_delay;
96 let mut last_err: Option<CognisError> = None;
97 for attempt in 0..self.policy.max_attempts {
98 match self.inner.invoke(input.clone(), config.clone()).await {
99 Ok(v) => return Ok(v),
100 Err(e) if !e.is_retryable() => return Err(e),
101 Err(e) => {
102 let suggested = e.retry_delay().unwrap_or(delay);
103 last_err = Some(e);
104 if attempt + 1 >= self.policy.max_attempts {
105 break;
106 }
107 let sleep_for = suggested.min(self.policy.max_delay);
108 tokio::time::sleep(sleep_for).await;
109 delay = Duration::from_secs_f64(
110 (delay.as_secs_f64() * self.policy.backoff_multiplier)
111 .min(self.policy.max_delay.as_secs_f64()),
112 );
113 }
114 }
115 }
116 Err(last_err.unwrap_or_else(|| {
117 CognisError::Internal("retry exhausted with no error captured".into())
118 }))
119 }
120 fn name(&self) -> &str {
121 "Retry"
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use std::sync::atomic::{AtomicU32, Ordering};
129 use std::sync::Arc;
130
131 struct FlakyTwice {
132 attempts: Arc<AtomicU32>,
133 }
134
135 #[async_trait]
136 impl Runnable<u32, u32> for FlakyTwice {
137 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
138 let n = self.attempts.fetch_add(1, Ordering::SeqCst);
139 if n < 2 {
140 Err(CognisError::Network {
141 status_code: Some(503),
142 message: "boom".into(),
143 })
144 } else {
145 Ok(input)
146 }
147 }
148 }
149
150 struct AlwaysAuth;
151
152 #[async_trait]
153 impl Runnable<u32, u32> for AlwaysAuth {
154 async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
155 Err(CognisError::AuthenticationFailed("bad key".into()))
156 }
157 }
158
159 #[tokio::test]
160 async fn retries_until_success() {
161 let attempts = Arc::new(AtomicU32::new(0));
162 let r = Retry::new(
163 FlakyTwice {
164 attempts: attempts.clone(),
165 },
166 RetryPolicy::new(5).with_initial_delay(Duration::from_millis(1)),
167 );
168 let out = r.invoke(7, RunnableConfig::default()).await.unwrap();
169 assert_eq!(out, 7);
170 assert_eq!(attempts.load(Ordering::SeqCst), 3);
171 }
172
173 #[tokio::test]
174 async fn non_retryable_short_circuits() {
175 let r = Retry::new(
176 AlwaysAuth,
177 RetryPolicy::new(5).with_initial_delay(Duration::from_millis(1)),
178 );
179 let err = r.invoke(0, RunnableConfig::default()).await.unwrap_err();
180 assert!(matches!(err, CognisError::AuthenticationFailed(_)));
181 }
182}