use std::time::Duration;
use tokio_retry::{
Retry,
strategy::{ExponentialBackoff, jitter},
};
use tracing::warn;
use crate::{error::ProviderError, types::GenerateRequest};
fn retry_strategy() -> impl Iterator<Item = Duration> {
ExponentialBackoff::from_millis(250)
.factor(4)
.max_delay(Duration::from_secs(4))
.map(jitter)
.take(3)
}
pub async fn with_retry<F, Fut>(
provider_id: &'static str,
req: &GenerateRequest,
f: F,
) -> Result<crate::types::GenerateResponse, ProviderError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<crate::types::GenerateResponse, ProviderError>>,
{
let strategy = retry_strategy();
Retry::spawn(strategy, || {
let fut = f();
async move {
match fut.await {
Ok(resp) => Ok(resp),
Err(e) if e.is_retryable() => {
warn!(
provider = provider_id,
model = req.model,
error = %e,
"transient provider error - will retry"
);
Err(e)
}
Err(e) => {
Err(e)
}
}
}
})
.await
}