reifydb_core/util/
retry.rs1use std::{
5 error, fmt,
6 panic::{AssertUnwindSafe, catch_unwind},
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum RetryError<E> {
11 Error(E),
12
13 Panic(String),
14}
15
16impl<E: fmt::Display> fmt::Display for RetryError<E> {
17 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18 match self {
19 RetryError::Error(e) => write!(f, "{}", e),
20 RetryError::Panic(msg) => write!(f, "panic: {}", msg),
21 }
22 }
23}
24
25impl<E: fmt::Display + fmt::Debug> error::Error for RetryError<E> {}
26
27impl<E> From<E> for RetryError<E> {
28 fn from(err: E) -> Self {
29 RetryError::Error(err)
30 }
31}
32
33pub fn retry<R, E>(retries: usize, f: impl Fn() -> Result<R, E>) -> Result<R, RetryError<E>> {
34 let mut retries_left = retries;
35 loop {
36 match catch_unwind(AssertUnwindSafe(&f)) {
37 Ok(Ok(r)) => return Ok(r),
38 Ok(Err(err)) => {
39 if retries_left > 0 {
40 retries_left -= 1;
41 } else {
42 return Err(RetryError::Error(err));
43 }
44 }
45 Err(panic) => {
46 let msg = if let Some(s) = panic.downcast_ref::<String>() {
47 s.clone()
48 } else if let Some(s) = panic.downcast_ref::<&str>() {
49 s.to_string()
50 } else {
51 "Unknown panic".to_string()
52 };
53
54 if retries_left > 0 {
55 retries_left -= 1;
56 } else {
57 return Err(RetryError::Panic(msg));
58 }
59 }
60 }
61 }
62}
63
64#[cfg(test)]
65pub mod tests {
66 use std::cell::Cell;
67
68 use crate::util::retry::{RetryError, retry};
69
70 #[test]
71 fn test_ok() {
72 let result = retry::<i32, ()>(10, || Ok(23));
73 assert_eq!(result, Ok(23));
74 }
75
76 #[test]
77 fn test_success_after_some_retries() {
78 let counter = Cell::new(0);
79 let result = retry::<i32, &'static str>(5, || {
80 if counter.get() < 3 {
81 counter.set(counter.get() + 1);
82 Err("fail")
83 } else {
84 Ok(42)
85 }
86 });
87 assert_eq!(result, Ok(42));
88 assert_eq!(counter.get(), 3);
89 }
90
91 #[test]
92 fn test_failure_after_retries_exhausted() {
93 let counter = Cell::new(0);
94 let result = retry::<i32, &'static str>(3, || {
95 counter.set(counter.get() + 1);
96 Err("still failing")
97 });
98 assert_eq!(result, Err(RetryError::Error("still failing")));
99 assert_eq!(counter.get(), 4); }
101
102 #[test]
103 fn test_zero_retries_allowed() {
104 let counter = Cell::new(0);
105 let result = retry::<i32, &'static str>(0, || {
106 counter.set(counter.get() + 1);
107 Err("fail fast")
108 });
109 assert_eq!(result, Err(RetryError::Error("fail fast")));
110 assert_eq!(counter.get(), 1); }
112
113 #[test]
114 fn test_retry_catches_panic() {
115 let counter = Cell::new(0);
116 let result = retry::<(), &'static str>(2, || {
117 counter.set(counter.get() + 1);
118 panic!("boom");
119 });
120 assert_eq!(result, Err(RetryError::Panic("boom".to_string())));
121 assert_eq!(counter.get(), 3); }
123
124 #[test]
125 fn test_retry_panic_with_string() {
126 let result = retry::<(), &'static str>(1, || {
127 panic!("{}", String::from("custom panic message"));
128 });
129 assert_eq!(result, Err(RetryError::Panic("custom panic message".to_string())));
130 }
131
132 #[test]
133 fn test_retry_panic_then_success() {
134 let counter = Cell::new(0);
135 let result = retry::<i32, &'static str>(3, || {
136 let count = counter.get();
137 counter.set(count + 1);
138 if count < 2 {
139 panic!("panic #{}", count);
140 } else {
141 Ok(42)
142 }
143 });
144 assert_eq!(result, Ok(42));
145 assert_eq!(counter.get(), 3);
146 }
147
148 #[test]
149 fn test_retry_mixed_errors_and_panics() {
150 let counter = Cell::new(0);
151 let result = retry::<i32, &'static str>(5, || {
152 let count = counter.get();
153 counter.set(count + 1);
154 match count {
155 0 => Err("error 1"),
156 1 => panic!("panic 1"),
157 2 => Err("error 2"),
158 3 => panic!("panic 2"),
159 _ => Ok(100),
160 }
161 });
162 assert_eq!(result, Ok(100));
163 assert_eq!(counter.get(), 5);
164 }
165
166 #[test]
167 fn test_retry_panic_no_retries() {
168 let result = retry::<(), &'static str>(0, || {
169 panic!("immediate panic");
170 });
171 assert_eq!(result, Err(RetryError::Panic("immediate panic".to_string())));
172 }
173
174 #[test]
175 fn test_retry_error_display() {
176 let err: RetryError<&str> = RetryError::Error("test error");
177 assert_eq!(format!("{}", err), "test error");
178
179 let panic: RetryError<&str> = RetryError::Panic("test panic".to_string());
180 assert_eq!(format!("{}", panic), "panic: test panic");
181 }
182}