recoco_utils/
retryable.rs

1// ReCoco is a Rust-only fork of CocoIndex, by [CocoIndex](https://CocoIndex)
2// Original code from CocoIndex is copyrighted by CocoIndex
3// SPDX-FileCopyrightText: 2025-2026 CocoIndex (upstream)
4// SPDX-FileContributor: CocoIndex Contributors
5//
6// All modifications from the upstream for ReCoco are copyrighted by Knitli Inc.
7// SPDX-FileCopyrightText: 2026 Knitli Inc. (ReCoco)
8// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
9//
10// Both the upstream CocoIndex code and the ReCoco modifications are licensed under the Apache-2.0 License.
11// SPDX-License-Identifier: Apache-2.0
12
13use std::{
14    future::Future,
15    time::{Duration, Instant},
16};
17use tracing::trace;
18
19pub trait IsRetryable {
20    fn is_retryable(&self) -> bool;
21}
22
23pub struct Error {
24    pub error: crate::error::Error,
25    pub is_retryable: bool,
26}
27
28pub const DEFAULT_RETRY_TIMEOUT: Duration = Duration::from_secs(10 * 60);
29
30impl std::fmt::Display for Error {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        std::fmt::Display::fmt(&self.error, f)
33    }
34}
35
36impl std::fmt::Debug for Error {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        std::fmt::Debug::fmt(&self.error, f)
39    }
40}
41
42impl IsRetryable for Error {
43    fn is_retryable(&self) -> bool {
44        self.is_retryable
45    }
46}
47
48#[cfg(feature = "reqwest")]
49impl IsRetryable for reqwest::Error {
50    fn is_retryable(&self) -> bool {
51        self.status() == Some(reqwest::StatusCode::TOO_MANY_REQUESTS)
52    }
53}
54
55// OpenAI errors - retryable if the underlying reqwest error is retryable
56#[cfg(feature = "openai")]
57impl IsRetryable for async_openai::error::OpenAIError {
58    fn is_retryable(&self) -> bool {
59        // The OpenAIError enum has different variants depending on compilation features.
60        // We default to not retryable since we can't reliably check all variants.
61        // Rate limiting is typically indicated by ApiError with specific codes.
62        false
63    }
64}
65
66// Neo4j errors - retryable on connection errors and transient errors
67#[cfg(feature = "neo4rs")]
68impl IsRetryable for neo4rs::Error {
69    fn is_retryable(&self) -> bool {
70        match self {
71            neo4rs::Error::ConnectionError => true,
72            neo4rs::Error::Neo4j(e) => e.kind() == neo4rs::Neo4jErrorKind::Transient,
73            _ => false,
74        }
75    }
76}
77
78impl Error {
79    pub fn retryable<E: Into<crate::error::Error>>(error: E) -> Self {
80        Self {
81            error: error.into(),
82            is_retryable: true,
83        }
84    }
85
86    pub fn not_retryable<E: Into<crate::error::Error>>(error: E) -> Self {
87        Self {
88            error: error.into(),
89            is_retryable: false,
90        }
91    }
92}
93
94impl From<crate::error::Error> for Error {
95    fn from(error: crate::error::Error) -> Self {
96        Self {
97            error,
98            is_retryable: false,
99        }
100    }
101}
102
103impl From<Error> for crate::error::Error {
104    fn from(val: Error) -> Self {
105        val.error
106    }
107}
108
109impl<E: IsRetryable + std::error::Error + Send + Sync + 'static> From<E> for Error {
110    fn from(error: E) -> Self {
111        Self {
112            is_retryable: error.is_retryable(),
113            error: anyhow::Error::from(error).into(),
114        }
115    }
116}
117
118pub type Result<T, E = Error> = std::result::Result<T, E>;
119
120#[allow(non_snake_case)]
121pub fn Ok<T>(value: T) -> Result<T> {
122    Result::Ok(value)
123}
124
125pub struct RetryOptions {
126    pub retry_timeout: Option<Duration>,
127    pub initial_backoff: Duration,
128    pub max_backoff: Duration,
129}
130
131impl Default for RetryOptions {
132    fn default() -> Self {
133        Self {
134            retry_timeout: Some(DEFAULT_RETRY_TIMEOUT),
135            initial_backoff: Duration::from_millis(100),
136            max_backoff: Duration::from_secs(10),
137        }
138    }
139}
140
141pub static HEAVY_LOADED_OPTIONS: RetryOptions = RetryOptions {
142    retry_timeout: Some(DEFAULT_RETRY_TIMEOUT),
143    initial_backoff: Duration::from_secs(1),
144    max_backoff: Duration::from_secs(60),
145};
146
147pub async fn run<
148    Ok,
149    Err: std::fmt::Display + IsRetryable,
150    Fut: Future<Output = Result<Ok, Err>>,
151    F: Fn() -> Fut,
152>(
153    f: F,
154    options: &RetryOptions,
155) -> Result<Ok, Err> {
156    let deadline = options
157        .retry_timeout
158        .map(|timeout| Instant::now() + timeout);
159    let mut backoff = options.initial_backoff;
160
161    loop {
162        match f().await {
163            Result::Ok(result) => return Result::Ok(result),
164            Result::Err(err) => {
165                if !err.is_retryable() {
166                    return Result::Err(err);
167                }
168                let mut sleep_duration = backoff;
169                if let Some(deadline) = deadline {
170                    let now = Instant::now();
171                    if now >= deadline {
172                        return Result::Err(err);
173                    }
174                    let remaining_time = deadline.saturating_duration_since(now);
175                    sleep_duration = std::cmp::min(sleep_duration, remaining_time);
176                }
177                trace!(
178                    "Will retry in {}ms for error: {}",
179                    sleep_duration.as_millis(),
180                    err
181                );
182                tokio::time::sleep(sleep_duration).await;
183                if backoff < options.max_backoff {
184                    backoff = std::cmp::min(
185                        Duration::from_micros(
186                            (backoff.as_micros() * rand::random_range(1618..=2000) / 1000) as u64,
187                        ),
188                        options.max_backoff,
189                    );
190                }
191            }
192        }
193    }
194}