use std::time::Duration;
use lellm_core::ToolResult;
use super::event::AgentEvent;
use super::tools::ToolFn;
use tokio::sync::mpsc::Sender;
#[derive(Debug, Clone)]
pub enum BackoffStrategy {
Fixed(Duration),
Exponential { base: Duration, max: Duration },
}
impl BackoffStrategy {
pub fn delay(&self, attempt: u32) -> Duration {
match self {
BackoffStrategy::Fixed(d) => *d,
BackoffStrategy::Exponential { base, max } => {
let d = base.saturating_mul(2_u32.pow(attempt));
d.min(*max)
}
}
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
max_attempts: u32,
backoff: BackoffStrategy,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
backoff: BackoffStrategy::Exponential {
base: Duration::from_millis(500),
max: Duration::from_secs(30),
},
}
}
}
impl RetryPolicy {
pub fn new(max_attempts: u32, backoff: BackoffStrategy) -> Self {
Self {
max_attempts,
backoff,
}
}
pub async fn execute_with_retry(
&self,
tool_fn: &ToolFn,
args: &serde_json::Value,
) -> ToolResult {
let mut last_result = tool_fn(args).await;
if last_result.is_ok() {
return last_result;
}
for attempt in 1..self.max_attempts {
match &last_result {
Err(e) if e.kind.is_retryable() => {}
_ => return last_result,
}
let delay = self.backoff.delay(attempt);
tracing::warn!(
attempt,
max = self.max_attempts,
delay_ms = delay.as_millis(),
"tool execution failed, retrying"
);
tokio::time::sleep(delay).await;
last_result = tool_fn(args).await;
if last_result.is_ok() {
return last_result;
}
}
last_result
}
pub async fn execute_with_retry_and_emission(
&self,
tool_fn: &ToolFn,
args: &serde_json::Value,
tx: &Sender<AgentEvent>,
tool_call_id: &str,
) -> ToolResult {
let mut last_result = tool_fn(args).await;
if last_result.is_ok() {
return last_result;
}
for attempt in 1..self.max_attempts {
let reason = match &last_result {
Err(e) if e.kind.is_retryable() => format!("[{}] {}", e.kind, e.message),
_ => return last_result,
};
let _ = tx
.send(AgentEvent::Retry {
tool_call_id: tool_call_id.to_string(),
attempt: (attempt + 1) as usize,
max_attempts: self.max_attempts as usize,
reason: reason.clone(),
})
.await;
let delay = self.backoff.delay(attempt);
tracing::warn!(
attempt,
max = self.max_attempts,
delay_ms = delay.as_millis(),
"tool execution failed, retrying"
);
tokio::time::sleep(delay).await;
last_result = tool_fn(args).await;
if last_result.is_ok() {
return last_result;
}
}
last_result
}
}