use super::McpUpstream;
use async_trait::async_trait;
use reqwest::{Client, ClientBuilder};
use serde_json::{Value, json};
use std::{
sync::Arc,
sync::atomic::{AtomicUsize, Ordering},
time::{Duration, Instant},
};
use tokio::sync::Mutex;
enum CbState {
Closed,
Open { until: Instant },
HalfOpen,
}
struct CircuitBreaker {
state: Mutex<CbState>,
failure_count: AtomicUsize,
threshold: usize,
recovery_secs: u64,
}
impl CircuitBreaker {
fn new(threshold: usize, recovery_secs: u64) -> Self {
Self {
state: Mutex::new(CbState::Closed),
failure_count: AtomicUsize::new(0),
threshold,
recovery_secs,
}
}
async fn is_open(&self) -> bool {
let mut state = self.state.lock().await;
match &*state {
CbState::Closed | CbState::HalfOpen => false,
CbState::Open { until } => {
if Instant::now() >= *until {
*state = CbState::HalfOpen;
tracing::info!("circuit entering half-open, probing upstream");
false
} else {
true
}
}
}
}
async fn on_success(&self) {
let prev = self.failure_count.swap(0, Ordering::Relaxed);
let mut state = self.state.lock().await;
if !matches!(*state, CbState::Closed) {
tracing::info!(
previous_failures = prev,
"upstream recovered, circuit closed"
);
*state = CbState::Closed;
}
}
async fn on_failure(&self) {
let count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= self.threshold {
let mut state = self.state.lock().await;
let until = Instant::now() + Duration::from_secs(self.recovery_secs);
*state = CbState::Open { until };
self.failure_count.store(0, Ordering::Relaxed);
tracing::warn!(
failures = count,
recovery_secs = self.recovery_secs,
"circuit opened"
);
}
}
}
pub struct HttpUpstream {
url: String,
client: Client,
cb: Arc<CircuitBreaker>,
}
impl HttpUpstream {
pub fn new(url: impl Into<String>) -> Self {
Self::with_circuit_breaker(url, 5, 30)
}
pub fn with_circuit_breaker(
url: impl Into<String>,
threshold: usize,
recovery_secs: u64,
) -> Self {
let client = ClientBuilder::new()
.timeout(Duration::from_secs(30))
.pool_max_idle_per_host(10)
.build()
.expect("failed to build HTTP client");
Self {
url: url.into(),
client,
cb: Arc::new(CircuitBreaker::new(threshold, recovery_secs)),
}
}
}
#[async_trait]
impl McpUpstream for HttpUpstream {
async fn forward(&self, msg: &Value) -> Option<Value> {
if self.cb.is_open().await {
tracing::warn!("circuit open, rejecting request");
return Some(json!({
"jsonrpc": "2.0",
"error": { "code": -32603, "message": "service unavailable (circuit open)" }
}));
}
match self.client.post(&self.url).json(msg).send().await {
Ok(resp) => {
self.cb.on_success().await;
if resp.status() == reqwest::StatusCode::ACCEPTED {
return None; }
match resp.json::<Value>().await {
Ok(body) => Some(body),
Err(e) => {
tracing::warn!(error = %e, "failed to parse upstream response");
Some(json!({
"jsonrpc": "2.0",
"error": { "code": -32603, "message": "internal error" }
}))
}
}
}
Err(e) => {
tracing::error!(error = %e, "upstream request failed");
self.cb.on_failure().await;
Some(json!({
"jsonrpc": "2.0",
"error": { "code": -32603, "message": "service unavailable" }
}))
}
}
}
fn base_url(&self) -> &str {
&self.url
}
async fn is_healthy(&self) -> bool {
!self.cb.is_open().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn circuit_starts_closed() {
let cb = CircuitBreaker::new(3, 30);
assert!(!cb.is_open().await);
}
#[tokio::test]
async fn failures_below_threshold_keep_circuit_closed() {
let cb = CircuitBreaker::new(3, 30);
cb.on_failure().await;
cb.on_failure().await;
assert!(!cb.is_open().await);
}
#[tokio::test]
async fn threshold_failures_open_circuit() {
let cb = CircuitBreaker::new(3, 60);
cb.on_failure().await;
cb.on_failure().await;
cb.on_failure().await; assert!(cb.is_open().await);
}
#[tokio::test]
async fn success_resets_failure_count() {
let cb = CircuitBreaker::new(3, 60);
cb.on_failure().await;
cb.on_failure().await; cb.on_success().await; cb.on_failure().await;
cb.on_failure().await; assert!(!cb.is_open().await);
}
#[tokio::test]
async fn open_circuit_transitions_to_halfopen_after_recovery() {
let cb = CircuitBreaker::new(1, 0);
cb.on_failure().await; tokio::time::sleep(Duration::from_millis(1)).await; assert!(
!cb.is_open().await,
"circuit should be HalfOpen (not Open) after recovery window elapsed"
);
}
#[tokio::test]
async fn halfopen_success_closes_circuit() {
let cb = CircuitBreaker::new(1, 0);
cb.on_failure().await; tokio::time::sleep(Duration::from_millis(1)).await;
assert!(!cb.is_open().await); cb.on_success().await; assert!(
!cb.is_open().await,
"circuit should be Closed after success in HalfOpen"
);
}
#[tokio::test]
async fn halfopen_failure_resets_failure_count_for_next_probe() {
let cb = CircuitBreaker::new(2, 0);
cb.on_failure().await;
cb.on_failure().await; tokio::time::sleep(Duration::from_millis(1)).await;
assert!(!cb.is_open().await); cb.on_failure().await;
cb.on_success().await; assert!(!cb.is_open().await, "should be Closed after success");
cb.on_failure().await; assert!(
!cb.is_open().await,
"one failure below threshold keeps circuit Closed"
);
}
#[tokio::test]
async fn success_on_closed_circuit_is_noop() {
let cb = CircuitBreaker::new(3, 60);
cb.on_success().await; assert!(!cb.is_open().await);
}
fn failing_upstream(threshold: usize) -> HttpUpstream {
HttpUpstream::with_circuit_breaker("http://127.0.0.1:1", threshold, 60)
}
#[tokio::test]
async fn forward_to_unreachable_upstream_returns_error_response() {
let up = failing_upstream(5);
let resp = up.forward(&serde_json::json!({"method": "ping"})).await;
assert!(resp.is_some());
let resp = resp.unwrap();
assert!(
resp["error"].is_object(),
"expected error JSON-RPC response, got: {resp}"
);
assert_eq!(resp["error"]["code"], -32603);
}
#[tokio::test]
async fn forward_opens_circuit_after_threshold_failures() {
let up = failing_upstream(2);
let msg = serde_json::json!({"method": "ping"});
up.forward(&msg).await; up.forward(&msg).await; let resp = up.forward(&msg).await.unwrap();
assert!(
resp["error"]["message"]
.as_str()
.unwrap_or("")
.contains("circuit open"),
"expected circuit open error, got: {resp}"
);
}
#[tokio::test]
async fn is_healthy_false_when_circuit_open() {
let up = failing_upstream(1);
assert!(up.is_healthy().await);
up.forward(&serde_json::json!({})).await; assert!(!up.is_healthy().await);
}
#[tokio::test]
async fn is_healthy_true_when_circuit_closed() {
let up = failing_upstream(5);
assert!(up.is_healthy().await);
}
#[tokio::test]
async fn notification_202_returns_none() {
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = [0u8; 512];
stream.read(&mut buf).await.ok();
stream
.write_all(
b"HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\nConnection: close\r\n\r\n",
)
.await
.ok();
});
let up = HttpUpstream::new(format!("http://127.0.0.1:{port}"));
let resp = up.forward(&serde_json::json!({"method": "ping"})).await;
assert!(
resp.is_none(),
"202 response should return None (notification)"
);
}
}