use crate::error::AgentError;
use oxi_ai::{Context, Model, ProviderEvent, StreamOptions};
use std::time::Duration;
pub const MAX_RETRIES: usize = 3;
pub const BACKOFF_BASE_SECS: u64 = 2;
pub trait RetryCallback: Send {
fn on_retry(&self, attempt: usize, max_retries: usize, delay_secs: u64, reason: String);
}
#[allow(clippy::too_many_arguments)]
pub async fn stream_with_retry_core(
provider: &dyn oxi_ai::Provider,
model: &Model,
context: &Context,
options: Option<StreamOptions>,
retry_cb: &dyn RetryCallback,
max_delay: Option<u64>,
on_success: impl Fn(),
on_failure: impl Fn(),
) -> Result<futures::stream::BoxStream<'static, ProviderEvent>, AgentError> {
let mut last_err: Option<String> = None;
for attempt in 0..=MAX_RETRIES {
match provider.stream(model, context, options.clone()).await {
Ok(stream) => {
on_success();
return Ok(stream as futures::stream::BoxStream<'static, ProviderEvent>);
}
Err(e) => {
on_failure();
let msg = e.to_string();
let is_rate_limit = matches!(e, oxi_ai::ProviderError::HttpError(429, _));
if !is_rate_limit && attempt == 0 {
return Err(AgentError::Stream(msg));
}
last_err = Some(msg.clone());
if attempt < MAX_RETRIES {
let mut delay = BACKOFF_BASE_SECS.pow(attempt as u32 + 1);
if let Some(cap) = max_delay {
delay = delay.min(cap);
}
retry_cb.on_retry(attempt + 1, MAX_RETRIES, delay, msg);
tokio::time::sleep(Duration::from_secs(delay)).await;
}
}
}
}
Err(AgentError::RetriesExhausted {
attempts: MAX_RETRIES,
last_error: last_err.unwrap_or_default(),
})
}