use anyhow::Result;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
#[derive(Debug, Serialize)]
struct JsonRpcRequest {
jsonrpc: &'static str,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
id: u64,
}
#[derive(Debug, Serialize)]
struct JsonRpcNotification {
jsonrpc: &'static str,
method: String,
}
#[derive(Debug, Deserialize)]
struct JsonRpcResponse {
#[allow(dead_code)]
jsonrpc: String,
#[serde(default)]
result: Option<Value>,
#[serde(default)]
error: Option<JsonRpcError>,
#[allow(dead_code)]
id: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct JsonRpcError {
#[allow(dead_code)]
code: i32,
message: String,
#[allow(dead_code)]
#[serde(default)]
data: Option<Value>,
}
struct SessionInfo {
session_id: Option<String>,
#[allow(dead_code)]
server_info: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default, rename = "inputSchema")]
pub input_schema: Option<Value>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none", default)]
pub meta: Option<Value>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ToolCallResult {
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<ContentItem>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(rename = "structuredContent", skip_serializing_if = "Option::is_none")]
pub structured_content: Option<Value>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentItem {
#[serde(rename = "type")]
pub content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<String>,
#[serde(rename = "structuredContent", skip_serializing_if = "Option::is_none")]
pub structured_content: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceInfo {
pub uri: String,
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub mime_type: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none", default)]
pub meta: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceContentItem {
#[serde(default)]
pub uri: Option<String>,
#[serde(default)]
pub text: Option<String>,
#[serde(default)]
pub mime_type: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none", default)]
pub meta: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceReadResult {
pub contents: Vec<ResourceContentItem>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none", default)]
pub meta: Option<Value>,
}
pub struct RawForwardResult {
pub body: String,
pub session_id: Option<String>,
pub protocol_version: Option<String>,
}
pub(crate) const MCP_SESSION_ID: &str = "mcp-session-id";
pub(crate) const MCP_PROTOCOL_VERSION: &str = "mcp-protocol-version";
fn extract_header(headers: &reqwest::header::HeaderMap, name: &str) -> Option<String> {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(String::from)
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response> {
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
anyhow::bail!("MCP server returned {}: {}", status, text);
}
Ok(response)
}
async fn parse_rpc_response(response: reqwest::Response) -> Result<JsonRpcResponse> {
let is_sse = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("text/event-stream"));
if is_sse {
let body = response.text().await?;
let json_str = body
.lines()
.find_map(|line| line.strip_prefix("data: "))
.unwrap_or(&body);
Ok(serde_json::from_str(json_str)?)
} else {
Ok(response.json().await?)
}
}
pub struct McpProxy {
base_url: String,
client: reqwest::Client,
request_id: AtomicU64,
session: RwLock<Option<SessionInfo>>,
auth_header: Option<String>,
}
impl McpProxy {
#[allow(dead_code)]
pub fn new(base_url: &str) -> Self {
Self::new_with_auth(base_url, None)
}
pub fn new_with_auth(base_url: &str, auth_header: Option<String>) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
client: reqwest::Client::new(),
request_id: AtomicU64::new(1),
session: RwLock::new(None),
auth_header,
}
}
fn next_id(&self) -> u64 {
self.request_id.fetch_add(1, Ordering::Relaxed)
}
fn mcp_post(&self) -> reqwest::RequestBuilder {
let mut builder = self
.client
.post(&self.base_url)
.header("Accept", "application/json, text/event-stream")
.header("Content-Type", "application/json");
if let Some(ref auth) = self.auth_header {
builder = builder.header("Authorization", auth);
}
builder
}
async fn attach_session_id(
&self,
mut builder: reqwest::RequestBuilder,
) -> reqwest::RequestBuilder {
let guard = self.session.read().await;
if let Some(ref session) = *guard {
if let Some(ref sid) = session.session_id {
builder = builder.header(MCP_SESSION_ID, sid);
}
}
builder
}
async fn ensure_initialized(&self) -> Result<()> {
{
let guard = self.session.read().await;
if guard.is_some() {
return Ok(());
}
}
let mut guard = self.session.write().await;
if guard.is_some() {
return Ok(());
}
let params = json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"roots": { "listChanged": false },
"sampling": {}
},
"clientInfo": {
"name": "mcp-preview",
"version": "0.1.0"
}
});
let request_body = JsonRpcRequest {
jsonrpc: "2.0",
method: "initialize".to_string(),
params: Some(params),
id: self.next_id(),
};
let response = check_response(self.mcp_post().json(&request_body).send().await?).await?;
let session_id = extract_header(response.headers(), MCP_SESSION_ID);
let rpc_response: JsonRpcResponse = parse_rpc_response(response).await?;
if let Some(error) = rpc_response.error {
anyhow::bail!("MCP initialize error: {}", error.message);
}
let server_info = rpc_response.result.unwrap_or(Value::Null);
*guard = Some(SessionInfo {
session_id,
server_info,
});
drop(guard);
let _ = self.send_notification("notifications/initialized").await;
Ok(())
}
async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
let request = JsonRpcRequest {
jsonrpc: "2.0",
method: method.to_string(),
params,
id: self.next_id(),
};
let req_builder = self.attach_session_id(self.mcp_post().json(&request)).await;
let response = check_response(req_builder.send().await?).await?;
let rpc_response: JsonRpcResponse = parse_rpc_response(response).await?;
if let Some(error) = rpc_response.error {
anyhow::bail!("MCP error: {}", error.message);
}
Ok(rpc_response.result.unwrap_or(Value::Null))
}
async fn send_notification(&self, method: &str) -> Result<()> {
let notification = JsonRpcNotification {
jsonrpc: "2.0",
method: method.to_string(),
};
let req_builder = self
.attach_session_id(self.mcp_post().json(¬ification))
.await;
let _ = req_builder.send().await;
Ok(())
}
pub async fn reset_session(&self) {
let mut guard = self.session.write().await;
*guard = None;
}
pub async fn is_connected(&self) -> bool {
let guard = self.session.read().await;
guard.is_some()
}
pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
self.ensure_initialized().await?;
let result = self.send_request("tools/list", None).await?;
let tools: Vec<ToolInfo> =
serde_json::from_value(result.get("tools").cloned().unwrap_or(Value::Array(vec![])))?;
Ok(tools)
}
pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<ToolCallResult> {
self.ensure_initialized().await?;
let params = json!({
"name": name,
"arguments": arguments
});
let result = self.send_request("tools/call", Some(params)).await;
match result {
Ok(value) => {
let content: Vec<ContentItem> = serde_json::from_value(
value
.get("content")
.cloned()
.unwrap_or(Value::Array(vec![])),
)
.unwrap_or_default();
let meta = value.get("_meta").cloned();
let structured_content = value.get("structuredContent").cloned();
Ok(ToolCallResult {
success: true,
content: Some(content),
error: None,
structured_content,
meta,
})
},
Err(e) => Ok(ToolCallResult {
success: false,
content: None,
error: Some(e.to_string()),
structured_content: None,
meta: None,
}),
}
}
pub async fn list_resources(&self) -> Result<Vec<ResourceInfo>> {
self.ensure_initialized().await?;
let result = self.send_request("resources/list", None).await?;
let resources: Vec<ResourceInfo> = serde_json::from_value(
result
.get("resources")
.cloned()
.unwrap_or(Value::Array(vec![])),
)?;
Ok(resources)
}
pub async fn read_resource(&self, uri: &str) -> Result<ResourceReadResult> {
self.ensure_initialized().await?;
let params = json!({ "uri": uri });
let result = self.send_request("resources/read", Some(params)).await?;
let contents: Vec<ResourceContentItem> = serde_json::from_value(
result
.get("contents")
.cloned()
.unwrap_or(Value::Array(vec![])),
)?;
let meta = result.get("_meta").cloned();
Ok(ResourceReadResult { contents, meta })
}
pub async fn forward_raw(
&self,
body: String,
session_id: Option<&str>,
protocol_version: Option<&str>,
) -> Result<RawForwardResult> {
let mut req_builder = self.mcp_post().body(body);
if let Some(sid) = session_id {
req_builder = req_builder.header(MCP_SESSION_ID, sid);
}
if let Some(ver) = protocol_version {
req_builder = req_builder.header(MCP_PROTOCOL_VERSION, ver);
}
let response = check_response(req_builder.send().await?).await?;
let session_id = extract_header(response.headers(), MCP_SESSION_ID);
let protocol_version = extract_header(response.headers(), MCP_PROTOCOL_VERSION);
let body = response.text().await?;
Ok(RawForwardResult {
body,
session_id,
protocol_version,
})
}
}