use crate::extensions::mcp::types::ServerEntry;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use std::time::Instant;
use tokio::sync::Mutex;
use yoagent::mcp::McpClient;
use yoagent::mcp::McpTransport;
use yoagent::mcp::types::*;
struct SseHttpTransport {
client: reqwest::Client,
base_url: String,
headers: Vec<(String, String)>,
session_id: StdMutex<Option<String>>,
}
impl SseHttpTransport {
fn new(url: &str) -> Self {
Self {
client: reqwest::Client::new(),
base_url: url.trim_end_matches('/').to_string(),
headers: Vec::new(),
session_id: StdMutex::new(None),
}
}
fn with_headers(mut self, headers: Option<&std::collections::HashMap<String, String>>) -> Self {
if let Some(h) = headers {
for (k, v) in h {
self.headers.push((k.clone(), v.clone()));
}
}
self
}
fn parse_sse_response(body: &str) -> Result<JsonRpcResponse, McpError> {
if let Ok(r) = serde_json::from_str::<JsonRpcResponse>(body) {
return Ok(r);
}
for event in body.split("\n\n") {
let event = event.trim();
if event.is_empty() {
continue;
}
for line in event.lines() {
if let Some(data) = line
.strip_prefix("data: ")
.or_else(|| line.strip_prefix("data:"))
{
let data = data.trim();
if data.starts_with('{')
&& let Ok(r) = serde_json::from_str::<JsonRpcResponse>(data)
{
return Ok(r);
}
}
}
}
Err(McpError::Transport(format!(
"Cannot parse SSE response: {}",
body.chars().take(200).collect::<String>()
)))
}
}
#[async_trait]
impl McpTransport for SseHttpTransport {
async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, McpError> {
let mut req = self
.client
.post(&self.base_url)
.header("Accept", "application/json, text/event-stream")
.json(&request);
for (k, v) in &self.headers {
req = req.header(k.as_str(), v.as_str());
}
if let Ok(guard) = self.session_id.lock()
&& let Some(ref sid) = *guard
{
req = req.header("Mcp-Session-Id", sid.as_str());
}
let resp = req
.send()
.await
.map_err(|e| McpError::Transport(format!("HTTP error: {}", e)))?;
let status = resp.status();
if let Some(sid) = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.filter(|s| !s.is_empty())
&& let Ok(mut guard) = self.session_id.lock()
&& guard.is_none()
{
*guard = Some(sid.to_string());
}
let body = resp
.text()
.await
.map_err(|e| McpError::Transport(format!("Failed to read response: {}", e)))?;
if status.is_success() || status == 202 {
Self::parse_sse_response(&body)
} else {
Err(McpError::Transport(format!(
"HTTP {} from server: {}",
status,
body.chars().take(200).collect::<String>()
)))
}
}
async fn close(&self) -> Result<(), McpError> {
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectionStatus {
Connected,
Idle,
Failed,
}
struct ServerConnection {
entry: ServerEntry,
client: Option<Arc<Mutex<McpClient>>>,
status: ConnectionStatus,
last_used: Instant,
last_failure: Option<Instant>,
config_hash: u64,
}
pub struct ServerManager {
servers: HashMap<String, ServerConnection>,
global_idle_timeout: std::time::Duration,
}
impl ServerManager {
pub fn new(global_idle_timeout_minutes: u64) -> Self {
Self {
servers: HashMap::new(),
global_idle_timeout: std::time::Duration::from_secs(global_idle_timeout_minutes * 60),
}
}
pub fn register(&mut self, name: &str, entry: ServerEntry, config_hash: u64) {
self.servers
.entry(name.to_string())
.or_insert_with(|| ServerConnection {
entry,
client: None,
status: ConnectionStatus::Idle,
last_used: Instant::now(),
last_failure: None,
config_hash,
});
}
pub async fn ensure_connected(&mut self, name: &str) -> bool {
if let Some(conn) = self.servers.get(name)
&& conn.status == ConnectionStatus::Connected
&& conn.client.is_some()
{
if let Some(c) = self.servers.get_mut(name) {
c.last_used = Instant::now();
}
return true;
}
let entry = match self.servers.get(name) {
Some(e) => e.entry.clone(),
None => return false,
};
let client = match &entry.url {
Some(url) => {
let transport =
Box::new(SseHttpTransport::new(url).with_headers(entry.headers.as_ref()));
let mut c = McpClient::from_transport(transport);
c.initialize().await.map(|_| c)
}
None => {
let env = entry.env.as_ref().cloned();
let cmd = entry.command.as_deref().unwrap_or("npx");
McpClient::connect_stdio(cmd, &to_str_slice(&entry.args), env).await
}
};
match client {
Ok(c) => {
let c = Arc::new(Mutex::new(c));
if let Some(conn) = self.servers.get_mut(name) {
conn.client = Some(c);
conn.status = ConnectionStatus::Connected;
conn.last_used = Instant::now();
conn.last_failure = None;
}
true
}
Err(e) => {
eprintln!("MCP: failed to connect to '{}': {}", name, e);
if let Some(conn) = self.servers.get_mut(name) {
conn.status = ConnectionStatus::Failed;
conn.last_failure = Some(Instant::now());
conn.client = None;
}
false
}
}
}
pub fn get_client(&self, name: &str) -> Option<Arc<Mutex<McpClient>>> {
self.servers.get(name).and_then(|c| c.client.clone())
}
pub fn status(&self, name: &str) -> Option<ConnectionStatus> {
self.servers.get(name).map(|c| c.status.clone())
}
pub fn mark_failed(&mut self, name: &str) {
if let Some(conn) = self.servers.get_mut(name) {
conn.status = ConnectionStatus::Failed;
conn.last_failure = Some(Instant::now());
conn.client = None;
}
}
pub fn touch(&mut self, name: &str) {
if let Some(conn) = self.servers.get_mut(name) {
conn.last_used = Instant::now();
if conn.status == ConnectionStatus::Failed && conn.last_failure.is_some() {
let backoff = std::time::Duration::from_secs(60);
if conn.last_failure.unwrap().elapsed() > backoff {
conn.status = ConnectionStatus::Idle;
conn.last_failure = None;
}
}
}
}
pub async fn disconnect(&mut self, name: &str) {
if let Some(conn) = self.servers.get_mut(name) {
if let Some(ref client) = conn.client {
let _ = client.lock().await.close().await;
}
conn.client = None;
conn.status = ConnectionStatus::Idle;
}
}
pub async fn close_all(&mut self) {
let names: Vec<String> = self.servers.keys().cloned().collect();
for name in &names {
self.disconnect(name).await;
}
}
pub fn idle_timeout(&self, name: &str) -> std::time::Duration {
if let Some(conn) = self.servers.get(name) {
idle_timeout_for(conn, self.global_idle_timeout)
} else {
self.global_idle_timeout
}
}
pub async fn sweep_idle(&mut self) {
let now = Instant::now();
let idle_names: Vec<String> = self
.servers
.iter()
.filter(|(_name, conn)| {
if conn.status != ConnectionStatus::Connected {
return false;
}
let timeout = idle_timeout_for(conn, self.global_idle_timeout);
now.duration_since(conn.last_used) > timeout
})
.map(|(name, _)| name.clone())
.collect();
for name in &idle_names {
self.disconnect(name).await;
}
}
pub fn server_names(&self) -> Vec<String> {
self.servers.keys().cloned().collect()
}
pub fn should_connect_eagerly(&self, name: &str) -> bool {
self.servers
.get(name)
.is_some_and(|c| matches!(c.entry.lifecycle.as_deref(), Some("eager" | "keep-alive")))
}
pub fn config_hash(&self, name: &str) -> Option<u64> {
self.servers.get(name).map(|c| c.config_hash)
}
}
fn to_str_slice(args: &[String]) -> Vec<&str> {
args.iter().map(|s| s.as_str()).collect()
}
fn idle_timeout_for(conn: &ServerConnection, global: std::time::Duration) -> std::time::Duration {
if let Some(t) = conn.entry.idle_timeout {
return std::time::Duration::from_secs(t * 60);
}
if conn.entry.lifecycle.as_deref() == Some("keep-alive") {
return std::time::Duration::MAX;
}
global
}