use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use chrono::Utc;
use intent::BreakerCheck;
use resilience::{
retry, timeout, BreakerRegistry, DeadLetterQueue, DlqEntry, LoopDetector, LoopDetectorError,
RateLimitRegistry, RetryConfig, RetryOutcome, TimeoutError,
};
use tracing::{debug, warn};
use uuid::Uuid;
use crate::error::McpHostError;
use crate::types::{CallOutcome, ServerConfig, ServerStatus, ToolDescriptor};
use crate::MCPHost;
#[derive(Debug, Clone, Default)]
pub struct ResilienceConfig {
pub timeout: Option<Duration>,
pub retry: Option<RetryConfig>,
}
pub struct ResilientMcpHost {
inner: Arc<dyn MCPHost>,
principal: String,
config: ResilienceConfig,
breakers: Option<Arc<BreakerRegistry>>,
rate_limits: Option<Arc<RateLimitRegistry>>,
loop_detector: Option<Arc<LoopDetector>>,
dlq: Option<Arc<dyn DeadLetterQueue>>,
}
impl ResilientMcpHost {
pub fn new(inner: Arc<dyn MCPHost>) -> Self {
Self {
inner,
principal: String::new(),
config: ResilienceConfig::default(),
breakers: None,
rate_limits: None,
loop_detector: None,
dlq: None,
}
}
pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
self.principal = principal.into();
self
}
pub fn with_config(mut self, config: ResilienceConfig) -> Self {
self.config = config;
self
}
pub fn with_breakers(mut self, breakers: Arc<BreakerRegistry>) -> Self {
self.breakers = Some(breakers);
self
}
pub fn with_rate_limits(mut self, rl: Arc<RateLimitRegistry>) -> Self {
self.rate_limits = Some(rl);
self
}
pub fn with_loop_detector(mut self, ld: Arc<LoopDetector>) -> Self {
self.loop_detector = Some(ld);
self
}
pub fn with_dlq(mut self, dlq: Arc<dyn DeadLetterQueue>) -> Self {
self.dlq = Some(dlq);
self
}
async fn one_attempt(
&self,
server: &str,
tool: &str,
tool_id: &str,
args: serde_json::Value,
) -> Result<CallOutcome, McpHostError> {
let result = self.inner.call(server, tool, args).await;
if let Some(breakers) = &self.breakers {
match &result {
Ok(outcome) if !outcome.is_error => breakers.record_success(tool_id).await,
_ => breakers.record_failure(tool_id).await,
}
}
result
}
async fn enqueue_dlq(
&self,
tool_id: &str,
request_json: &str,
error_message: String,
attempts: u32,
) {
let Some(dlq) = &self.dlq else { return };
let entry = DlqEntry {
id: Uuid::new_v4().to_string(),
tool_id: tool_id.to_string(),
request_json: request_json.to_string(),
error_message,
attempts,
dlq_at: Utc::now(),
};
if let Err(e) = dlq.enqueue(entry).await {
warn!(tool_id, error = %e, "failed to enqueue DLQ entry");
}
}
}
#[async_trait]
impl MCPHost for ResilientMcpHost {
async fn mount(&self, name: String, cfg: ServerConfig) -> Result<(), McpHostError> {
self.inner.mount(name, cfg).await
}
async fn unmount(&self, name: &str) -> Result<(), McpHostError> {
self.inner.unmount(name).await
}
async fn list_servers(&self) -> Vec<ServerStatus> {
self.inner.list_servers().await
}
async fn list_all_tools(&self) -> Vec<ToolDescriptor> {
self.inner.list_all_tools().await
}
async fn call(
&self,
server: &str,
tool: &str,
args: serde_json::Value,
) -> Result<CallOutcome, McpHostError> {
let tool_id = format!("mcp:{server}:{tool}");
let request_json = serde_json::to_string(&args)
.unwrap_or_else(|_| String::from(r#"{"error":"unserializable"}"#));
if let Some(ld) = &self.loop_detector {
if let Err(LoopDetectorError::LoopDetected { count, window, .. }) =
ld.check(&self.principal, &tool_id, &args).await
{
return Err(McpHostError::Transport(format!(
"loop detected: {tool_id} repeated {count} times in window {window}"
)));
}
}
if let Some(breakers) = &self.breakers {
if breakers.is_open(&tool_id).await {
debug!(tool_id, "breaker open — short-circuiting");
return Err(McpHostError::Transport(format!(
"breaker open for {tool_id}"
)));
}
}
if let Some(rl) = &self.rate_limits {
rl.acquire(&tool_id).await;
}
let attempts_taken = std::sync::atomic::AtomicU32::new(0);
let call_fut = async {
if let Some(retry_cfg) = &self.config.retry {
let breaker_check = self
.breakers
.as_ref()
.map(|b| (b.clone() as Arc<dyn intent::BreakerCheck>, tool_id.as_str()));
retry(retry_cfg, breaker_check, || {
attempts_taken.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.one_attempt(server, tool, &tool_id, args.clone())
})
.await
} else {
attempts_taken.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.one_attempt(server, tool, &tool_id, args.clone())
.await
.map_err(RetryOutcome::Exhausted)
}
};
let final_result: Result<CallOutcome, RetryOutcome<McpHostError>> =
match self.config.timeout {
Some(d) => match timeout(d, call_fut).await {
Ok(outcome) => Ok(outcome),
Err(TimeoutError::Elapsed) => {
if let Some(breakers) = &self.breakers {
breakers.record_failure(&tool_id).await;
}
let attempts = attempts_taken.load(std::sync::atomic::Ordering::SeqCst);
self.enqueue_dlq(
&tool_id,
&request_json,
format!("timeout after {}ms", d.as_millis()),
attempts,
)
.await;
return Err(McpHostError::Transport(format!(
"timeout after {}ms",
d.as_millis()
)));
}
Err(TimeoutError::Inner(retry_outcome)) => Err(retry_outcome),
},
None => call_fut.await,
};
let attempts = attempts_taken.load(std::sync::atomic::Ordering::SeqCst);
match final_result {
Ok(outcome) => {
if outcome.is_error {
self.enqueue_dlq(
&tool_id,
&request_json,
format!(
"tool returned is_error after {attempts} attempt(s): {}",
outcome.content
),
attempts,
)
.await;
}
Ok(outcome)
}
Err(RetryOutcome::Exhausted(e)) => {
self.enqueue_dlq(
&tool_id,
&request_json,
format!("exhausted after {attempts} attempt(s): {e}"),
attempts,
)
.await;
Err(e)
}
Err(RetryOutcome::BreakerOpenAbort(abort)) => {
self.enqueue_dlq(
&tool_id,
&request_json,
format!(
"breaker opened mid-retry after {attempts} attempt(s): {}",
abort.last_error
),
attempts,
)
.await;
Err(abort.last_error)
}
}
}
}