Skip to main content

reifydb_core/util/
retry.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	error, fmt,
6	panic::{AssertUnwindSafe, catch_unwind},
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum RetryError<E> {
11	Error(E),
12
13	Panic(String),
14}
15
16impl<E: fmt::Display> fmt::Display for RetryError<E> {
17	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18		match self {
19			RetryError::Error(e) => write!(f, "{}", e),
20			RetryError::Panic(msg) => write!(f, "panic: {}", msg),
21		}
22	}
23}
24
25impl<E: fmt::Display + fmt::Debug> error::Error for RetryError<E> {}
26
27impl<E> From<E> for RetryError<E> {
28	fn from(err: E) -> Self {
29		RetryError::Error(err)
30	}
31}
32
33pub fn retry<R, E>(retries: usize, f: impl Fn() -> Result<R, E>) -> Result<R, RetryError<E>> {
34	let mut retries_left = retries;
35	loop {
36		match catch_unwind(AssertUnwindSafe(&f)) {
37			Ok(Ok(r)) => return Ok(r),
38			Ok(Err(err)) => {
39				if retries_left > 0 {
40					retries_left -= 1;
41				} else {
42					return Err(RetryError::Error(err));
43				}
44			}
45			Err(panic) => {
46				let msg = if let Some(s) = panic.downcast_ref::<String>() {
47					s.clone()
48				} else if let Some(s) = panic.downcast_ref::<&str>() {
49					s.to_string()
50				} else {
51					"Unknown panic".to_string()
52				};
53
54				if retries_left > 0 {
55					retries_left -= 1;
56				} else {
57					return Err(RetryError::Panic(msg));
58				}
59			}
60		}
61	}
62}
63
64#[cfg(test)]
65pub mod tests {
66	use std::cell::Cell;
67
68	use crate::util::retry::{RetryError, retry};
69
70	#[test]
71	fn test_ok() {
72		let result = retry::<i32, ()>(10, || Ok(23));
73		assert_eq!(result, Ok(23));
74	}
75
76	#[test]
77	fn test_success_after_some_retries() {
78		let counter = Cell::new(0);
79		let result = retry::<i32, &'static str>(5, || {
80			if counter.get() < 3 {
81				counter.set(counter.get() + 1);
82				Err("fail")
83			} else {
84				Ok(42)
85			}
86		});
87		assert_eq!(result, Ok(42));
88		assert_eq!(counter.get(), 3);
89	}
90
91	#[test]
92	fn test_failure_after_retries_exhausted() {
93		let counter = Cell::new(0);
94		let result = retry::<i32, &'static str>(3, || {
95			counter.set(counter.get() + 1);
96			Err("still failing")
97		});
98		assert_eq!(result, Err(RetryError::Error("still failing")));
99		assert_eq!(counter.get(), 4); // initial + 3 retries
100	}
101
102	#[test]
103	fn test_zero_retries_allowed() {
104		let counter = Cell::new(0);
105		let result = retry::<i32, &'static str>(0, || {
106			counter.set(counter.get() + 1);
107			Err("fail fast")
108		});
109		assert_eq!(result, Err(RetryError::Error("fail fast")));
110		assert_eq!(counter.get(), 1); // only one try
111	}
112
113	#[test]
114	fn test_retry_catches_panic() {
115		let counter = Cell::new(0);
116		let result = retry::<(), &'static str>(2, || {
117			counter.set(counter.get() + 1);
118			panic!("boom");
119		});
120		assert_eq!(result, Err(RetryError::Panic("boom".to_string())));
121		assert_eq!(counter.get(), 3); // initial + 2 retries
122	}
123
124	#[test]
125	fn test_retry_panic_with_string() {
126		let result = retry::<(), &'static str>(1, || {
127			panic!("{}", String::from("custom panic message"));
128		});
129		assert_eq!(result, Err(RetryError::Panic("custom panic message".to_string())));
130	}
131
132	#[test]
133	fn test_retry_panic_then_success() {
134		let counter = Cell::new(0);
135		let result = retry::<i32, &'static str>(3, || {
136			let count = counter.get();
137			counter.set(count + 1);
138			if count < 2 {
139				panic!("panic #{}", count);
140			} else {
141				Ok(42)
142			}
143		});
144		assert_eq!(result, Ok(42));
145		assert_eq!(counter.get(), 3);
146	}
147
148	#[test]
149	fn test_retry_mixed_errors_and_panics() {
150		let counter = Cell::new(0);
151		let result = retry::<i32, &'static str>(5, || {
152			let count = counter.get();
153			counter.set(count + 1);
154			match count {
155				0 => Err("error 1"),
156				1 => panic!("panic 1"),
157				2 => Err("error 2"),
158				3 => panic!("panic 2"),
159				_ => Ok(100),
160			}
161		});
162		assert_eq!(result, Ok(100));
163		assert_eq!(counter.get(), 5);
164	}
165
166	#[test]
167	fn test_retry_panic_no_retries() {
168		let result = retry::<(), &'static str>(0, || {
169			panic!("immediate panic");
170		});
171		assert_eq!(result, Err(RetryError::Panic("immediate panic".to_string())));
172	}
173
174	#[test]
175	fn test_retry_error_display() {
176		let err: RetryError<&str> = RetryError::Error("test error");
177		assert_eq!(format!("{}", err), "test error");
178
179		let panic: RetryError<&str> = RetryError::Panic("test panic".to_string());
180		assert_eq!(format!("{}", panic), "panic: test panic");
181	}
182}