use super::send::send_server_request;
use crate::objectiveai_mcp::context::McpRequestContext;
use axum::http::HeaderMap;
use indexmap::IndexMap;
use objectiveai_sdk::client_objectiveai_mcp::McpKind;
use objectiveai_sdk::client_objectiveai_mcp::server_request::InitializeRequest;
use objectiveai_sdk::client_objectiveai_mcp::{server_request, server_response};
use objectiveai_sdk::mcp::initialize_result::InitializeResult;
use objectiveai_sdk::mcp::resource::{
ListResourcesRequest, ListResourcesResult, ReadResourceRequestParams,
ReadResourceResult,
};
use objectiveai_sdk::mcp::tool::{
CallToolRequestParams, CallToolResult, ListToolsRequest, ListToolsResult,
};
use std::time::Duration;
const FORWARD_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug)]
pub struct McpError {
pub code: i64,
pub message: String,
pub data: Option<serde_json::Value>,
}
impl McpError {
pub fn no_session(id: &str) -> Self {
Self {
code: -32001,
message: format!("no reverse channel for response_id {id:?}"),
data: None,
}
}
pub fn reverse_channel_closed() -> Self {
Self {
code: -32002,
message: "reverse channel closed before request could be sent".into(),
data: None,
}
}
pub fn reverse_channel_dropped() -> Self {
Self {
code: -32002,
message: "reverse channel dropped before response arrived".into(),
data: None,
}
}
pub fn reverse_channel_timeout() -> Self {
Self {
code: -32003,
message: "reverse channel timed out waiting for response".into(),
data: None,
}
}
pub fn variant_mismatch(expected: &str, got: &server_response::Payload) -> Self {
Self {
code: -32603,
message: format!(
"reverse channel returned wrong payload variant: expected {expected}, got {}",
payload_variant_name(got),
),
data: None,
}
}
}
fn payload_variant_name(p: &server_response::Payload) -> &'static str {
use server_response::Payload as P;
match p {
P::Initialize { .. } => "initialize",
P::ToolsList { .. } => "tools_list",
P::ToolsCall { .. } => "tools_call",
P::ResourcesList { .. } => "resources_list",
P::ResourcesRead { .. } => "resources_read",
P::SessionTerminate { .. } => "session_terminate",
P::ReadMessageQueue(_) => "read_message_queue",
}
}
pub async fn handle_initialize(
ctx: McpRequestContext,
mcp_kind: McpKind,
args: IndexMap<String, Option<String>>,
) -> Result<(InitializeResult, String), McpError> {
let response = forward(
&ctx,
server_request::Payload::Initialize {
mcp_kind,
params: InitializeRequest { args },
},
)
.await?;
match response.payload {
server_response::Payload::Initialize { result, .. } => {
let reply = unwrap_rpc(result)?;
Ok((reply.result, reply.mcp_session_id))
}
other => Err(McpError::variant_mismatch("initialize", &other)),
}
}
pub async fn handle_ping(_ctx: McpRequestContext) -> Result<(), McpError> {
Ok(())
}
pub async fn handle_tools_list(
ctx: McpRequestContext,
mcp_kind: McpKind,
params: ListToolsRequest,
) -> Result<ListToolsResult, McpError> {
let response = forward(
&ctx,
server_request::Payload::ToolsList { mcp_kind, params },
)
.await?;
match response.payload {
server_response::Payload::ToolsList { result, .. } => unwrap_rpc(result),
other => Err(McpError::variant_mismatch("tools_list", &other)),
}
}
pub async fn handle_tools_call(
ctx: McpRequestContext,
mcp_kind: McpKind,
params: CallToolRequestParams,
) -> Result<CallToolResult, McpError> {
let response = forward(
&ctx,
server_request::Payload::ToolsCall { mcp_kind, params },
)
.await?;
match response.payload {
server_response::Payload::ToolsCall { result, .. } => unwrap_rpc(result),
other => Err(McpError::variant_mismatch("tools_call", &other)),
}
}
pub async fn handle_resources_list(
ctx: McpRequestContext,
mcp_kind: McpKind,
params: ListResourcesRequest,
) -> Result<ListResourcesResult, McpError> {
let response = forward(
&ctx,
server_request::Payload::ResourcesList { mcp_kind, params },
)
.await?;
match response.payload {
server_response::Payload::ResourcesList { result, .. } => unwrap_rpc(result),
other => Err(McpError::variant_mismatch("resources_list", &other)),
}
}
pub async fn handle_resources_read(
ctx: McpRequestContext,
mcp_kind: McpKind,
params: ReadResourceRequestParams,
) -> Result<ReadResourceResult, McpError> {
let response = forward(
&ctx,
server_request::Payload::ResourcesRead { mcp_kind, params },
)
.await?;
match response.payload {
server_response::Payload::ResourcesRead { result, .. } => unwrap_rpc(result),
other => Err(McpError::variant_mismatch("resources_read", &other)),
}
}
pub async fn handle_session_terminate(
ctx: McpRequestContext,
mcp_kind: McpKind,
) -> Result<(), McpError> {
let response = forward(
&ctx,
server_request::Payload::SessionTerminate { mcp_kind },
)
.await?;
match response.payload {
server_response::Payload::SessionTerminate { result, .. } => unwrap_rpc(result),
other => Err(McpError::variant_mismatch("session_terminate", &other)),
}
}
async fn forward(
ctx: &McpRequestContext,
payload: server_request::Payload,
) -> Result<server_response::Response, McpError> {
let rc = ctx
.registry
.get(&ctx.response_id)
.ok_or_else(|| McpError::no_session(&ctx.response_id))?
.clone();
let request_id = uuid::Uuid::new_v4().to_string();
let request = server_request::Request {
id: request_id,
headers: forward_headers(&ctx.headers),
payload,
};
let rx = send_server_request(&rc.sink, &rc.pending, request)
.await
.map_err(|_| McpError::reverse_channel_closed())?;
match tokio::time::timeout(FORWARD_TIMEOUT, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(McpError::reverse_channel_dropped()),
Err(_) => Err(McpError::reverse_channel_timeout()),
}
}
fn unwrap_rpc<R>(
r: server_response::JsonRpcResult<R>,
) -> Result<R, McpError> {
match r {
server_response::JsonRpcResult::Ok { result } => Ok(result),
server_response::JsonRpcResult::Err {
code,
message,
data,
} => Err(McpError {
code,
message,
data,
}),
}
}
fn forward_headers(headers: &HeaderMap) -> IndexMap<String, String> {
headers
.iter()
.filter_map(|(k, v)| {
let name = k.as_str();
let drop = matches!(
name.to_ascii_lowercase().as_str(),
"host"
| "content-length"
| "connection"
| "accept"
| "content-type"
);
if drop {
return None;
}
Some((name.to_string(), v.to_str().ok()?.to_string()))
})
.collect()
}