recoco_utils/
retryable.rs1use std::{
14 future::Future,
15 time::{Duration, Instant},
16};
17use tracing::trace;
18
19pub trait IsRetryable {
20 fn is_retryable(&self) -> bool;
21}
22
23pub struct Error {
24 pub error: crate::error::Error,
25 pub is_retryable: bool,
26}
27
28pub const DEFAULT_RETRY_TIMEOUT: Duration = Duration::from_secs(10 * 60);
29
30impl std::fmt::Display for Error {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 std::fmt::Display::fmt(&self.error, f)
33 }
34}
35
36impl std::fmt::Debug for Error {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 std::fmt::Debug::fmt(&self.error, f)
39 }
40}
41
42impl IsRetryable for Error {
43 fn is_retryable(&self) -> bool {
44 self.is_retryable
45 }
46}
47
48#[cfg(feature = "reqwest")]
49impl IsRetryable for reqwest::Error {
50 fn is_retryable(&self) -> bool {
51 self.status() == Some(reqwest::StatusCode::TOO_MANY_REQUESTS)
52 }
53}
54
55#[cfg(feature = "openai")]
57impl IsRetryable for async_openai::error::OpenAIError {
58 fn is_retryable(&self) -> bool {
59 false
63 }
64}
65
66#[cfg(feature = "neo4rs")]
68impl IsRetryable for neo4rs::Error {
69 fn is_retryable(&self) -> bool {
70 match self {
71 neo4rs::Error::ConnectionError => true,
72 neo4rs::Error::Neo4j(e) => e.kind() == neo4rs::Neo4jErrorKind::Transient,
73 _ => false,
74 }
75 }
76}
77
78impl Error {
79 pub fn retryable<E: Into<crate::error::Error>>(error: E) -> Self {
80 Self {
81 error: error.into(),
82 is_retryable: true,
83 }
84 }
85
86 pub fn not_retryable<E: Into<crate::error::Error>>(error: E) -> Self {
87 Self {
88 error: error.into(),
89 is_retryable: false,
90 }
91 }
92}
93
94impl From<crate::error::Error> for Error {
95 fn from(error: crate::error::Error) -> Self {
96 Self {
97 error,
98 is_retryable: false,
99 }
100 }
101}
102
103impl From<Error> for crate::error::Error {
104 fn from(val: Error) -> Self {
105 val.error
106 }
107}
108
109impl<E: IsRetryable + std::error::Error + Send + Sync + 'static> From<E> for Error {
110 fn from(error: E) -> Self {
111 Self {
112 is_retryable: error.is_retryable(),
113 error: anyhow::Error::from(error).into(),
114 }
115 }
116}
117
118pub type Result<T, E = Error> = std::result::Result<T, E>;
119
120#[allow(non_snake_case)]
121pub fn Ok<T>(value: T) -> Result<T> {
122 Result::Ok(value)
123}
124
125pub struct RetryOptions {
126 pub retry_timeout: Option<Duration>,
127 pub initial_backoff: Duration,
128 pub max_backoff: Duration,
129}
130
131impl Default for RetryOptions {
132 fn default() -> Self {
133 Self {
134 retry_timeout: Some(DEFAULT_RETRY_TIMEOUT),
135 initial_backoff: Duration::from_millis(100),
136 max_backoff: Duration::from_secs(10),
137 }
138 }
139}
140
141pub static HEAVY_LOADED_OPTIONS: RetryOptions = RetryOptions {
142 retry_timeout: Some(DEFAULT_RETRY_TIMEOUT),
143 initial_backoff: Duration::from_secs(1),
144 max_backoff: Duration::from_secs(60),
145};
146
147pub async fn run<
148 Ok,
149 Err: std::fmt::Display + IsRetryable,
150 Fut: Future<Output = Result<Ok, Err>>,
151 F: Fn() -> Fut,
152>(
153 f: F,
154 options: &RetryOptions,
155) -> Result<Ok, Err> {
156 let deadline = options
157 .retry_timeout
158 .map(|timeout| Instant::now() + timeout);
159 let mut backoff = options.initial_backoff;
160
161 loop {
162 match f().await {
163 Result::Ok(result) => return Result::Ok(result),
164 Result::Err(err) => {
165 if !err.is_retryable() {
166 return Result::Err(err);
167 }
168 let mut sleep_duration = backoff;
169 if let Some(deadline) = deadline {
170 let now = Instant::now();
171 if now >= deadline {
172 return Result::Err(err);
173 }
174 let remaining_time = deadline.saturating_duration_since(now);
175 sleep_duration = std::cmp::min(sleep_duration, remaining_time);
176 }
177 trace!(
178 "Will retry in {}ms for error: {}",
179 sleep_duration.as_millis(),
180 err
181 );
182 tokio::time::sleep(sleep_duration).await;
183 if backoff < options.max_backoff {
184 backoff = std::cmp::min(
185 Duration::from_micros(
186 (backoff.as_micros() * rand::random_range(1618..=2000) / 1000) as u64,
187 ),
188 options.max_backoff,
189 );
190 }
191 }
192 }
193 }
194}