1mod backoff;
47mod config;
48mod events;
49mod layer;
50mod policy;
51
52pub use backoff::{
53 ExponentialBackoff, ExponentialRandomBackoff, FixedInterval, FnInterval, IntervalFunction,
54};
55pub use config::{RetryConfig, RetryConfigBuilder};
56pub use events::RetryEvent;
57pub use layer::RetryLayer;
58pub use policy::{RetryPolicy, RetryPredicate};
59
60use futures::future::BoxFuture;
61use std::sync::Arc;
62use std::task::{Context, Poll};
63use std::time::Instant;
64use tower::Service;
65
66pub struct Retry<S, E> {
71 inner: S,
72 config: Arc<RetryConfig<E>>,
73}
74
75impl<S, E> Retry<S, E> {
76 pub fn new(inner: S, config: Arc<RetryConfig<E>>) -> Self {
78 Self { inner, config }
79 }
80}
81
82impl<S, E> Clone for Retry<S, E>
83where
84 S: Clone,
85{
86 fn clone(&self) -> Self {
87 Self {
88 inner: self.inner.clone(),
89 config: Arc::clone(&self.config),
90 }
91 }
92}
93
94impl<S, Req, E> Service<Req> for Retry<S, E>
95where
96 S: Service<Req, Error = E> + Clone + Send + 'static,
97 S::Future: Send + 'static,
98 Req: Clone + Send + 'static,
99 E: Clone + Send + 'static,
100 S::Response: Send + 'static,
101{
102 type Response = S::Response;
103 type Error = E;
104 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
105
106 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107 self.inner.poll_ready(cx)
108 }
109
110 fn call(&mut self, req: Req) -> Self::Future {
111 let mut service = self.inner.clone();
112 let config = Arc::clone(&self.config);
113
114 Box::pin(async move {
115 let mut attempt = 0;
116
117 loop {
118 let result = service.call(req.clone()).await;
119
120 match result {
121 Ok(response) => {
122 let event = RetryEvent::Success {
124 pattern_name: config.name.clone(),
125 timestamp: Instant::now(),
126 attempts: attempt + 1,
127 };
128 config.event_listeners.emit(&event);
129 return Ok(response);
130 }
131 Err(error) => {
132 if !config.policy.should_retry(&error) {
134 let event = RetryEvent::IgnoredError {
135 pattern_name: config.name.clone(),
136 timestamp: Instant::now(),
137 };
138 config.event_listeners.emit(&event);
139 return Err(error);
140 }
141
142 if attempt + 1 >= config.policy.max_attempts {
144 let event = RetryEvent::Error {
145 pattern_name: config.name.clone(),
146 timestamp: Instant::now(),
147 attempts: attempt + 1,
148 };
149 config.event_listeners.emit(&event);
150 return Err(error);
151 }
152
153 let delay = config.policy.next_backoff(attempt);
155 let event = RetryEvent::Retry {
156 pattern_name: config.name.clone(),
157 timestamp: Instant::now(),
158 attempt,
159 delay,
160 };
161 config.event_listeners.emit(&event);
162
163 tokio::time::sleep(delay).await;
164 attempt += 1;
165 }
166 }
167 }
168 })
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use std::sync::atomic::{AtomicUsize, Ordering};
176 use std::time::Duration;
177 use tower::service_fn;
178 use tower::{Layer, ServiceExt};
179
180 #[derive(Debug, Clone)]
181 struct TestError {
182 #[allow(dead_code)]
183 message: String,
184 }
185
186 impl TestError {
187 fn new(message: &str) -> Self {
188 Self {
189 message: message.to_string(),
190 }
191 }
192 }
193
194 #[tokio::test]
195 async fn successful_request_no_retry() {
196 let call_count = Arc::new(AtomicUsize::new(0));
197 let cc = Arc::clone(&call_count);
198
199 let service = service_fn(move |req: String| {
200 let cc = Arc::clone(&cc);
201 async move {
202 cc.fetch_add(1, Ordering::SeqCst);
203 Ok::<_, TestError>(format!("Response: {}", req))
204 }
205 });
206
207 let config: RetryConfig<TestError> = RetryConfig::builder()
208 .max_attempts(3)
209 .fixed_backoff(Duration::from_millis(10))
210 .build();
211
212 let layer = config.layer();
213 let mut service = layer.layer(service);
214
215 let response = service
216 .ready()
217 .await
218 .unwrap()
219 .call("test".to_string())
220 .await
221 .unwrap();
222
223 assert_eq!(response, "Response: test");
224 assert_eq!(call_count.load(Ordering::SeqCst), 1);
225 }
226
227 #[tokio::test]
228 async fn retries_on_failure() {
229 let call_count = Arc::new(AtomicUsize::new(0));
230 let cc = Arc::clone(&call_count);
231
232 let service = service_fn(move |_req: String| {
233 let cc = Arc::clone(&cc);
234 async move {
235 let count = cc.fetch_add(1, Ordering::SeqCst);
236 if count < 2 {
237 Err(TestError::new("temporary failure"))
238 } else {
239 Ok::<_, TestError>("success".to_string())
240 }
241 }
242 });
243
244 let config: RetryConfig<TestError> = RetryConfig::builder()
245 .max_attempts(3)
246 .fixed_backoff(Duration::from_millis(10))
247 .build();
248
249 let layer = config.layer();
250 let mut service = layer.layer(service);
251
252 let response = service
253 .ready()
254 .await
255 .unwrap()
256 .call("test".to_string())
257 .await
258 .unwrap();
259
260 assert_eq!(response, "success");
261 assert_eq!(call_count.load(Ordering::SeqCst), 3);
262 }
263
264 #[tokio::test]
265 async fn exhausts_retries() {
266 let call_count = Arc::new(AtomicUsize::new(0));
267 let cc = Arc::clone(&call_count);
268
269 let service = service_fn(move |_req: String| {
270 let cc = Arc::clone(&cc);
271 async move {
272 cc.fetch_add(1, Ordering::SeqCst);
273 Err::<String, _>(TestError::new("permanent failure"))
274 }
275 });
276
277 let config: RetryConfig<TestError> = RetryConfig::builder()
278 .max_attempts(3)
279 .fixed_backoff(Duration::from_millis(10))
280 .build();
281
282 let layer = config.layer();
283 let mut service = layer.layer(service);
284
285 let result = service
286 .ready()
287 .await
288 .unwrap()
289 .call("test".to_string())
290 .await;
291
292 assert!(result.is_err());
293 assert_eq!(call_count.load(Ordering::SeqCst), 3);
294 }
295
296 #[tokio::test]
297 async fn retry_predicate_filters_errors() {
298 let call_count = Arc::new(AtomicUsize::new(0));
299 let cc = Arc::clone(&call_count);
300
301 let service = service_fn(move |_req: String| {
302 let cc = Arc::clone(&cc);
303 async move {
304 cc.fetch_add(1, Ordering::SeqCst);
305 Err::<String, _>(TestError::new("non-retryable"))
306 }
307 });
308
309 let config: RetryConfig<TestError> = RetryConfig::builder()
310 .max_attempts(3)
311 .fixed_backoff(Duration::from_millis(10))
312 .retry_on(|_: &TestError| false) .build();
314
315 let layer = config.layer();
316 let mut service = layer.layer(service);
317
318 let result = service
319 .ready()
320 .await
321 .unwrap()
322 .call("test".to_string())
323 .await;
324
325 assert!(result.is_err());
326 assert_eq!(call_count.load(Ordering::SeqCst), 1); }
328
329 #[tokio::test]
330 async fn event_listeners_called() {
331 let retry_count = Arc::new(AtomicUsize::new(0));
332 let success_count = Arc::new(AtomicUsize::new(0));
333
334 let rc = Arc::clone(&retry_count);
335 let sc = Arc::clone(&success_count);
336
337 let call_count = Arc::new(AtomicUsize::new(0));
338 let cc = Arc::clone(&call_count);
339
340 let service = service_fn(move |_req: String| {
341 let cc = Arc::clone(&cc);
342 async move {
343 let count = cc.fetch_add(1, Ordering::SeqCst);
344 if count < 2 {
345 Err(TestError::new("temporary"))
346 } else {
347 Ok::<_, TestError>("success".to_string())
348 }
349 }
350 });
351
352 let config: RetryConfig<TestError> = RetryConfig::builder()
353 .max_attempts(3)
354 .fixed_backoff(Duration::from_millis(10))
355 .on_retry(move |_, _| {
356 rc.fetch_add(1, Ordering::SeqCst);
357 })
358 .on_success(move |_| {
359 sc.fetch_add(1, Ordering::SeqCst);
360 })
361 .build();
362
363 let layer = config.layer();
364 let mut service = layer.layer(service);
365
366 let _ = service
367 .ready()
368 .await
369 .unwrap()
370 .call("test".to_string())
371 .await;
372
373 assert_eq!(retry_count.load(Ordering::SeqCst), 2); assert_eq!(success_count.load(Ordering::SeqCst), 1); }
376
377 #[tokio::test]
378 async fn exponential_backoff_increases_delay() {
379 let config: RetryConfig<TestError> = RetryConfig::builder()
380 .max_attempts(5)
381 .backoff(ExponentialBackoff::new(Duration::from_millis(100)))
382 .build();
383
384 assert_eq!(config.policy.next_backoff(0), Duration::from_millis(100));
385 assert_eq!(config.policy.next_backoff(1), Duration::from_millis(200));
386 assert_eq!(config.policy.next_backoff(2), Duration::from_millis(400));
387 }
388
389 #[tokio::test]
390 async fn custom_interval_function() {
391 let config: RetryConfig<TestError> = RetryConfig::builder()
392 .max_attempts(3)
393 .backoff(FnInterval::new(|attempt| {
394 Duration::from_secs((attempt + 1) as u64)
395 }))
396 .build();
397
398 assert_eq!(config.policy.next_backoff(0), Duration::from_secs(1));
399 assert_eq!(config.policy.next_backoff(1), Duration::from_secs(2));
400 assert_eq!(config.policy.next_backoff(2), Duration::from_secs(3));
401 }
402}