use std::{
collections::{BTreeMap, HashMap},
path::PathBuf,
sync::Arc,
};
use async_trait::async_trait;
use chrono::Utc;
use tokio::sync::RwLock;
pub mod aud_check;
pub mod capability_index;
pub mod error;
pub mod oauth;
pub mod resilient;
pub mod rmcp_host;
pub mod types;
pub use aud_check::{validate_token_aud, AudCheckOutcome};
pub use capability_index::{InMemoryToolCapabilityIndex, ToolCapabilityIndex};
pub use error::McpHostError;
pub use oauth::{manager_from_vault, VaultCredentialStore};
pub use resilient::{ResilienceConfig, ResilientMcpHost};
pub use rmcp_host::RmcpHost;
pub use types::{
CallOutcome, MountedServer, OAuthConfig, ServerConfig, ServerInfo, ServerStatus, ToolDescriptor,
};
pub const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
#[async_trait]
pub trait MCPHost: Send + Sync {
async fn mount(&self, name: String, cfg: ServerConfig) -> Result<(), McpHostError>;
async fn unmount(&self, name: &str) -> Result<(), McpHostError>;
async fn list_servers(&self) -> Vec<ServerStatus>;
async fn list_all_tools(&self) -> Vec<ToolDescriptor>;
async fn call(
&self,
server: &str,
tool: &str,
args: serde_json::Value,
) -> Result<CallOutcome, McpHostError>;
}
#[async_trait]
pub trait MCPClient: Send + Sync {
async fn initialize(&self) -> Result<ServerInfo, McpHostError>;
async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, McpHostError>;
async fn call_tool(
&self,
name: &str,
args: serde_json::Value,
) -> Result<CallOutcome, McpHostError>;
async fn shutdown(&self) -> Result<(), McpHostError>;
fn server_info(&self) -> Option<ServerInfo>;
}
#[derive(Default)]
pub struct InMemoryMcpHost {
mounted: RwLock<HashMap<String, MountedServer>>,
}
impl InMemoryMcpHost {
pub fn new() -> Self {
Self {
mounted: RwLock::new(HashMap::new()),
}
}
pub fn shared() -> Arc<dyn MCPHost> {
Arc::new(Self::new())
}
}
#[async_trait]
impl MCPHost for InMemoryMcpHost {
async fn mount(&self, name: String, cfg: ServerConfig) -> Result<(), McpHostError> {
let mut guard = self.mounted.write().await;
if guard.contains_key(&name) {
return Err(McpHostError::AlreadyMounted(name));
}
guard.insert(
name.clone(),
MountedServer {
name,
config: cfg,
mounted_at: Utc::now(),
info: None,
tools: Vec::new(),
},
);
Ok(())
}
async fn unmount(&self, name: &str) -> Result<(), McpHostError> {
self.mounted
.write()
.await
.remove(name)
.map(|_| ())
.ok_or_else(|| McpHostError::NotMounted(name.to_string()))
}
async fn list_servers(&self) -> Vec<ServerStatus> {
self.mounted
.read()
.await
.values()
.map(|m| ServerStatus {
name: m.name.clone(),
mounted_at: m.mounted_at,
tool_count: m.tools.len(),
info: m.info.clone(),
})
.collect()
}
async fn list_all_tools(&self) -> Vec<ToolDescriptor> {
self.mounted
.read()
.await
.values()
.flat_map(|m| m.tools.clone())
.collect()
}
async fn call(
&self,
server: &str,
_tool: &str,
_args: serde_json::Value,
) -> Result<CallOutcome, McpHostError> {
let guard = self.mounted.read().await;
if !guard.contains_key(server) {
return Err(McpHostError::NotMounted(server.to_string()));
}
Err(McpHostError::Transport(
"no transport configured for in-memory host".to_string(),
))
}
}
pub fn empty_env() -> BTreeMap<String, String> {
BTreeMap::new()
}
pub fn stdio_cfg(command: impl Into<String>, args: Vec<String>) -> ServerConfig {
ServerConfig::Stdio {
command: command.into(),
args,
env: empty_env(),
cwd: None::<PathBuf>,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn mount_and_list() {
let host = InMemoryMcpHost::new();
host.mount("fs".into(), stdio_cfg("mcp-fs", vec![]))
.await
.unwrap();
let servers = host.list_servers().await;
assert_eq!(servers.len(), 1);
assert_eq!(servers[0].name, "fs");
assert_eq!(servers[0].tool_count, 0);
}
#[tokio::test]
async fn double_mount_rejected() {
let host = InMemoryMcpHost::new();
host.mount("fs".into(), stdio_cfg("mcp-fs", vec![]))
.await
.unwrap();
let err = host
.mount("fs".into(), stdio_cfg("mcp-fs", vec![]))
.await
.unwrap_err();
assert!(matches!(err, McpHostError::AlreadyMounted(_)));
}
#[tokio::test]
async fn unmount_missing_errors() {
let host = InMemoryMcpHost::new();
let err = host.unmount("nope").await.unwrap_err();
assert!(matches!(err, McpHostError::NotMounted(_)));
}
#[tokio::test]
async fn call_without_transport_errors() {
let host = InMemoryMcpHost::new();
host.mount("fs".into(), stdio_cfg("mcp-fs", vec![]))
.await
.unwrap();
let err = host
.call("fs", "read_text_file", serde_json::json!({}))
.await
.unwrap_err();
assert!(matches!(err, McpHostError::Transport(_)));
}
#[test]
fn protocol_version_matches_spec() {
assert_eq!(MCP_PROTOCOL_VERSION, "2025-11-25");
}
}