use std::time::Duration;
use backon::{ExponentialBuilder, Retryable};
use crate::error::NikaError;
pub const DEFAULT_MAX_RETRIES: usize = 3;
pub const DEFAULT_INITIAL_DELAY: Duration = Duration::from_millis(100);
pub const DEFAULT_MAX_DELAY: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
pub struct McpRetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub jitter: bool,
}
impl Default for McpRetryConfig {
fn default() -> Self {
Self {
max_retries: DEFAULT_MAX_RETRIES,
initial_delay: DEFAULT_INITIAL_DELAY,
max_delay: DEFAULT_MAX_DELAY,
jitter: true,
}
}
}
impl McpRetryConfig {
pub fn new(max_retries: usize, initial_delay: Duration, max_delay: Duration) -> Self {
Self {
max_retries,
initial_delay,
max_delay,
jitter: true,
}
}
pub fn without_jitter(mut self) -> Self {
self.jitter = false;
self
}
}
pub fn is_retryable_mcp_error(error: &NikaError) -> bool {
match error {
NikaError::McpTimeout { .. }
| NikaError::McpNotConnected { .. }
| NikaError::McpToolCallFailed { .. }
| NikaError::McpProtocolError { .. } => true,
NikaError::McpToolError { error_code, .. } => {
error_code.as_ref().is_none_or(|code| code.is_retryable())
}
NikaError::McpStartError { .. } => false,
_ => false,
}
}
pub async fn retry_mcp_call<F, Fut, T>(config: McpRetryConfig, operation: F) -> Result<T, NikaError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, NikaError>>,
{
let mut builder = ExponentialBuilder::default()
.with_min_delay(config.initial_delay)
.with_max_delay(config.max_delay)
.with_max_times(config.max_retries)
.with_factor(2.0);
if config.jitter {
builder = builder.with_jitter();
}
operation.retry(builder).when(is_retryable_mcp_error).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_is_retryable_mcp_error_timeout() {
let err = NikaError::McpTimeout {
name: "test".to_string(),
operation: "call_tool".to_string(),
timeout_secs: 30,
};
assert!(is_retryable_mcp_error(&err));
}
#[test]
fn test_is_retryable_mcp_error_not_connected() {
let err = NikaError::McpNotConnected {
name: "test".to_string(),
};
assert!(is_retryable_mcp_error(&err));
}
#[test]
fn test_is_retryable_mcp_error_tool_error() {
let err = NikaError::McpToolError {
tool: "test".to_string(),
reason: "transient failure".to_string(),
error_code: None,
};
assert!(is_retryable_mcp_error(&err));
}
#[test]
fn test_is_retryable_mcp_error_tool_call_failed() {
let err = NikaError::McpToolCallFailed {
tool: "test".to_string(),
reason: "timeout".to_string(),
};
assert!(is_retryable_mcp_error(&err));
}
#[test]
fn test_is_not_retryable_resource_not_found() {
let err = NikaError::McpResourceNotFound {
uri: "test://resource".to_string(),
};
assert!(!is_retryable_mcp_error(&err));
}
#[test]
fn test_is_not_retryable_validation_failed() {
let err = NikaError::McpValidationFailed {
tool: "test".to_string(),
details: "invalid params".to_string(),
missing: vec![],
suggestions: vec![],
};
assert!(!is_retryable_mcp_error(&err));
}
#[test]
fn test_is_not_retryable_schema_error() {
let err = NikaError::McpSchemaError {
tool: "test".to_string(),
reason: "invalid schema".to_string(),
};
assert!(!is_retryable_mcp_error(&err));
}
#[test]
fn test_is_not_retryable_not_configured() {
let err = NikaError::McpNotConfigured {
name: "test".to_string(),
};
assert!(!is_retryable_mcp_error(&err));
}
#[test]
fn test_default_config() {
let config = McpRetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_delay, Duration::from_millis(100));
assert_eq!(config.max_delay, Duration::from_secs(5));
assert!(config.jitter);
}
#[test]
fn test_config_without_jitter() {
let config = McpRetryConfig::default().without_jitter();
assert!(!config.jitter);
}
#[tokio::test]
async fn test_retry_mcp_call_success_first_try() {
let config = McpRetryConfig::default().without_jitter();
let result: Result<i32, NikaError> = retry_mcp_call(config, || async { Ok(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_retry_mcp_call_success_after_retries() {
let config = McpRetryConfig::new(3, Duration::from_millis(1), Duration::from_millis(10))
.without_jitter();
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result: Result<i32, NikaError> = retry_mcp_call(config, || {
let attempts = Arc::clone(&attempts_clone);
async move {
let current = attempts.fetch_add(1, Ordering::SeqCst);
if current < 2 {
Err(NikaError::McpTimeout {
name: "test".to_string(),
operation: "call".to_string(),
timeout_secs: 30,
})
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_mcp_call_exhausted() {
let config = McpRetryConfig::new(2, Duration::from_millis(1), Duration::from_millis(10))
.without_jitter();
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result: Result<i32, NikaError> = retry_mcp_call(config, || {
let attempts = Arc::clone(&attempts_clone);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err(NikaError::McpTimeout {
name: "test".to_string(),
operation: "call".to_string(),
timeout_secs: 30,
})
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_mcp_call_non_retryable_error_no_retry() {
let config = McpRetryConfig::new(3, Duration::from_millis(1), Duration::from_millis(10))
.without_jitter();
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result: Result<i32, NikaError> = retry_mcp_call(config, || {
let attempts = Arc::clone(&attempts_clone);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err(NikaError::McpResourceNotFound {
uri: "test://resource".to_string(),
})
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
}