reifydb_core/util/
retry.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4use std::{
5	fmt,
6	panic::{AssertUnwindSafe, catch_unwind},
7};
8
9/// Error type that can represent both regular errors and panics
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum RetryError<E> {
12	/// The original error from the function
13	Error(E),
14	/// A panic occurred during execution
15	Panic(String),
16}
17
18impl<E: fmt::Display> fmt::Display for RetryError<E> {
19	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20		match self {
21			RetryError::Error(e) => write!(f, "{}", e),
22			RetryError::Panic(msg) => write!(f, "panic: {}", msg),
23		}
24	}
25}
26
27impl<E: fmt::Display + fmt::Debug> std::error::Error for RetryError<E> {}
28
29impl<E> From<E> for RetryError<E> {
30	fn from(err: E) -> Self {
31		RetryError::Error(err)
32	}
33}
34
35pub fn retry<R, E>(retries: usize, f: impl Fn() -> Result<R, E>) -> Result<R, RetryError<E>> {
36	let mut retries_left = retries;
37	loop {
38		match catch_unwind(AssertUnwindSafe(&f)) {
39			Ok(Ok(r)) => return Ok(r),
40			Ok(Err(err)) => {
41				if retries_left > 0 {
42					retries_left -= 1;
43				} else {
44					return Err(RetryError::Error(err));
45				}
46			}
47			Err(panic) => {
48				let msg = if let Some(s) = panic.downcast_ref::<String>() {
49					s.clone()
50				} else if let Some(s) = panic.downcast_ref::<&str>() {
51					s.to_string()
52				} else {
53					"Unknown panic".to_string()
54				};
55
56				if retries_left > 0 {
57					retries_left -= 1;
58				} else {
59					return Err(RetryError::Panic(msg));
60				}
61			}
62		}
63	}
64}
65
66#[cfg(test)]
67mod tests {
68	use std::cell::Cell;
69
70	use crate::util::{RetryError, retry};
71
72	#[test]
73	fn test_ok() {
74		let result = retry::<i32, ()>(10, || Ok(23));
75		assert_eq!(result, Ok(23));
76	}
77
78	#[test]
79	fn test_success_after_some_retries() {
80		let counter = Cell::new(0);
81		let result = retry::<i32, &'static str>(5, || {
82			if counter.get() < 3 {
83				counter.set(counter.get() + 1);
84				Err("fail")
85			} else {
86				Ok(42)
87			}
88		});
89		assert_eq!(result, Ok(42));
90		assert_eq!(counter.get(), 3);
91	}
92
93	#[test]
94	fn test_failure_after_retries_exhausted() {
95		let counter = Cell::new(0);
96		let result = retry::<i32, &'static str>(3, || {
97			counter.set(counter.get() + 1);
98			Err("still failing")
99		});
100		assert_eq!(result, Err(RetryError::Error("still failing")));
101		assert_eq!(counter.get(), 4); // initial + 3 retries
102	}
103
104	#[test]
105	fn test_zero_retries_allowed() {
106		let counter = Cell::new(0);
107		let result = retry::<i32, &'static str>(0, || {
108			counter.set(counter.get() + 1);
109			Err("fail fast")
110		});
111		assert_eq!(result, Err(RetryError::Error("fail fast")));
112		assert_eq!(counter.get(), 1); // only one try
113	}
114
115	#[test]
116	fn test_retry_catches_panic() {
117		let counter = Cell::new(0);
118		let result = retry::<(), &'static str>(2, || {
119			counter.set(counter.get() + 1);
120			panic!("boom");
121		});
122		assert_eq!(result, Err(RetryError::Panic("boom".to_string())));
123		assert_eq!(counter.get(), 3); // initial + 2 retries
124	}
125
126	#[test]
127	fn test_retry_panic_with_string() {
128		let result = retry::<(), &'static str>(1, || {
129			panic!("{}", String::from("custom panic message"));
130		});
131		assert_eq!(result, Err(RetryError::Panic("custom panic message".to_string())));
132	}
133
134	#[test]
135	fn test_retry_panic_then_success() {
136		let counter = Cell::new(0);
137		let result = retry::<i32, &'static str>(3, || {
138			let count = counter.get();
139			counter.set(count + 1);
140			if count < 2 {
141				panic!("panic #{}", count);
142			} else {
143				Ok(42)
144			}
145		});
146		assert_eq!(result, Ok(42));
147		assert_eq!(counter.get(), 3);
148	}
149
150	#[test]
151	fn test_retry_mixed_errors_and_panics() {
152		let counter = Cell::new(0);
153		let result = retry::<i32, &'static str>(5, || {
154			let count = counter.get();
155			counter.set(count + 1);
156			match count {
157				0 => Err("error 1"),
158				1 => panic!("panic 1"),
159				2 => Err("error 2"),
160				3 => panic!("panic 2"),
161				_ => Ok(100),
162			}
163		});
164		assert_eq!(result, Ok(100));
165		assert_eq!(counter.get(), 5);
166	}
167
168	#[test]
169	fn test_retry_panic_no_retries() {
170		let result = retry::<(), &'static str>(0, || {
171			panic!("immediate panic");
172		});
173		assert_eq!(result, Err(RetryError::Panic("immediate panic".to_string())));
174	}
175
176	#[test]
177	fn test_retry_error_display() {
178		let err: RetryError<&str> = RetryError::Error("test error");
179		assert_eq!(format!("{}", err), "test error");
180
181		let panic: RetryError<&str> = RetryError::Panic("test panic".to_string());
182		assert_eq!(format!("{}", panic), "panic: test panic");
183	}
184}