reifydb_core/util/
retry.rs1use std::{
5 fmt,
6 panic::{AssertUnwindSafe, catch_unwind},
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum RetryError<E> {
12 Error(E),
14 Panic(String),
16}
17
18impl<E: fmt::Display> fmt::Display for RetryError<E> {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 RetryError::Error(e) => write!(f, "{}", e),
22 RetryError::Panic(msg) => write!(f, "panic: {}", msg),
23 }
24 }
25}
26
27impl<E: fmt::Display + fmt::Debug> std::error::Error for RetryError<E> {}
28
29impl<E> From<E> for RetryError<E> {
30 fn from(err: E) -> Self {
31 RetryError::Error(err)
32 }
33}
34
35pub fn retry<R, E>(retries: usize, f: impl Fn() -> Result<R, E>) -> Result<R, RetryError<E>> {
36 let mut retries_left = retries;
37 loop {
38 match catch_unwind(AssertUnwindSafe(&f)) {
39 Ok(Ok(r)) => return Ok(r),
40 Ok(Err(err)) => {
41 if retries_left > 0 {
42 retries_left -= 1;
43 } else {
44 return Err(RetryError::Error(err));
45 }
46 }
47 Err(panic) => {
48 let msg = if let Some(s) = panic.downcast_ref::<String>() {
49 s.clone()
50 } else if let Some(s) = panic.downcast_ref::<&str>() {
51 s.to_string()
52 } else {
53 "Unknown panic".to_string()
54 };
55
56 if retries_left > 0 {
57 retries_left -= 1;
58 } else {
59 return Err(RetryError::Panic(msg));
60 }
61 }
62 }
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use std::cell::Cell;
69
70 use crate::util::{RetryError, retry};
71
72 #[test]
73 fn test_ok() {
74 let result = retry::<i32, ()>(10, || Ok(23));
75 assert_eq!(result, Ok(23));
76 }
77
78 #[test]
79 fn test_success_after_some_retries() {
80 let counter = Cell::new(0);
81 let result = retry::<i32, &'static str>(5, || {
82 if counter.get() < 3 {
83 counter.set(counter.get() + 1);
84 Err("fail")
85 } else {
86 Ok(42)
87 }
88 });
89 assert_eq!(result, Ok(42));
90 assert_eq!(counter.get(), 3);
91 }
92
93 #[test]
94 fn test_failure_after_retries_exhausted() {
95 let counter = Cell::new(0);
96 let result = retry::<i32, &'static str>(3, || {
97 counter.set(counter.get() + 1);
98 Err("still failing")
99 });
100 assert_eq!(result, Err(RetryError::Error("still failing")));
101 assert_eq!(counter.get(), 4); }
103
104 #[test]
105 fn test_zero_retries_allowed() {
106 let counter = Cell::new(0);
107 let result = retry::<i32, &'static str>(0, || {
108 counter.set(counter.get() + 1);
109 Err("fail fast")
110 });
111 assert_eq!(result, Err(RetryError::Error("fail fast")));
112 assert_eq!(counter.get(), 1); }
114
115 #[test]
116 fn test_retry_catches_panic() {
117 let counter = Cell::new(0);
118 let result = retry::<(), &'static str>(2, || {
119 counter.set(counter.get() + 1);
120 panic!("boom");
121 });
122 assert_eq!(result, Err(RetryError::Panic("boom".to_string())));
123 assert_eq!(counter.get(), 3); }
125
126 #[test]
127 fn test_retry_panic_with_string() {
128 let result = retry::<(), &'static str>(1, || {
129 panic!("{}", String::from("custom panic message"));
130 });
131 assert_eq!(result, Err(RetryError::Panic("custom panic message".to_string())));
132 }
133
134 #[test]
135 fn test_retry_panic_then_success() {
136 let counter = Cell::new(0);
137 let result = retry::<i32, &'static str>(3, || {
138 let count = counter.get();
139 counter.set(count + 1);
140 if count < 2 {
141 panic!("panic #{}", count);
142 } else {
143 Ok(42)
144 }
145 });
146 assert_eq!(result, Ok(42));
147 assert_eq!(counter.get(), 3);
148 }
149
150 #[test]
151 fn test_retry_mixed_errors_and_panics() {
152 let counter = Cell::new(0);
153 let result = retry::<i32, &'static str>(5, || {
154 let count = counter.get();
155 counter.set(count + 1);
156 match count {
157 0 => Err("error 1"),
158 1 => panic!("panic 1"),
159 2 => Err("error 2"),
160 3 => panic!("panic 2"),
161 _ => Ok(100),
162 }
163 });
164 assert_eq!(result, Ok(100));
165 assert_eq!(counter.get(), 5);
166 }
167
168 #[test]
169 fn test_retry_panic_no_retries() {
170 let result = retry::<(), &'static str>(0, || {
171 panic!("immediate panic");
172 });
173 assert_eq!(result, Err(RetryError::Panic("immediate panic".to_string())));
174 }
175
176 #[test]
177 fn test_retry_error_display() {
178 let err: RetryError<&str> = RetryError::Error("test error");
179 assert_eq!(format!("{}", err), "test error");
180
181 let panic: RetryError<&str> = RetryError::Panic("test panic".to_string());
182 assert_eq!(format!("{}", panic), "panic: test panic");
183 }
184}