use dashmap::DashMap;
use indexmap::IndexMap;
use objectiveai_sdk::client_objectiveai_mcp::{server_request, server_response};
use objectiveai_sdk::http::McpHandler;
use std::sync::Arc;
use std::time::Duration;
struct ConduitState {
connection: objectiveai_sdk::mcp::Connection,
}
#[derive(Clone)]
pub struct ConduitMcpHandler {
inner: Arc<Inner>,
}
struct Inner {
mcp_url: Option<String>,
client: objectiveai_sdk::mcp::Client,
connections: DashMap<String, Arc<ConduitState>>,
}
impl ConduitMcpHandler {
pub fn new(mcp_url: Option<String>) -> Self {
let http = reqwest::Client::builder()
.build()
.expect("reqwest::Client::build is infallible without rustls toggles");
let client = objectiveai_sdk::mcp::Client::new(
http,
"objectiveai-cli-conduit".to_string(),
String::new(),
String::new(),
Duration::from_secs(30),
Duration::from_secs(1),
Duration::from_secs(1),
0.5,
2.0,
Duration::from_secs(30),
Duration::from_secs(30),
Duration::from_secs(60),
);
Self {
inner: Arc::new(Inner {
mcp_url,
client,
connections: DashMap::new(),
}),
}
}
}
impl McpHandler for ConduitMcpHandler {
async fn handle(&self, request: server_request::Request) -> server_response::Response {
let id_for_err = request.id.clone();
let server_request_id = request.id.clone();
let Some(mcp_url) = self.inner.mcp_url.as_ref() else {
return reject_no_mcp(id_for_err);
};
let incoming_session_id: Option<String> = request
.headers
.iter()
.find_map(|(k, v)| {
k.eq_ignore_ascii_case("mcp-session-id").then(|| v.clone())
});
let state = match &incoming_session_id {
Some(sid) => {
if let Some(existing) = self.inner.connections.get(sid) {
existing.clone()
} else {
let dial_result = self
.dial(mcp_url.clone(), Some(sid.clone()), &request.headers)
.await;
match dial_result {
Ok(st) => {
self.inner.connections.insert(sid.clone(), st.clone());
st
}
Err(e) => {
return conduit_error(id_for_err, format!("connect (resume): {e}"));
}
}
}
}
None => {
let dial_result = self.dial(mcp_url.clone(), None, &request.headers).await;
match dial_result {
Ok(st) => {
self.inner
.connections
.insert(st.connection.session_id.clone(), st.clone());
st
}
Err(e) => {
return conduit_error(id_for_err, format!("connect: {e}"));
}
}
}
};
let forward_result = forward(&state, request).await;
let resp = match forward_result {
Ok(resp) => resp,
Err(e) => conduit_error(id_for_err, e.to_string()),
};
resp
}
}
impl ConduitMcpHandler {
async fn dial(
&self,
url: String,
session_id: Option<String>,
request_headers: &IndexMap<String, String>,
) -> Result<Arc<ConduitState>, objectiveai_sdk::mcp::Error> {
let connect_headers = sanitize_connect_headers(request_headers);
let connection = self
.inner
.client
.connect(url, session_id, Some(connect_headers))
.await?;
Ok(Arc::new(ConduitState { connection }))
}
}
fn sanitize_connect_headers(
headers: &IndexMap<String, String>,
) -> IndexMap<String, String> {
let mut out = headers.clone();
for k in [
"Host",
"host",
"Content-Length",
"content-length",
"Mcp-Session-Id",
"mcp-session-id",
] {
out.shift_remove(k);
}
out
}
async fn forward(
state: &ConduitState,
request: server_request::Request,
) -> Result<server_response::Response, ConduitError> {
let envelope = request.body.clone();
let rpc_id = envelope
.as_ref()
.and_then(|v| v.get("id"))
.cloned();
let rpc_method = envelope
.as_ref()
.and_then(|v| v.get("method"))
.and_then(|m| m.as_str())
.map(|s| s.to_string());
if rpc_id.is_none() {
return Ok(server_response::Response {
id: request.id,
status: 202,
headers: IndexMap::new(),
body: None,
});
}
if rpc_method.as_deref() == Some("initialize") {
let mut init_value = serde_json::to_value(&state.connection.initialize_result)
.map_err(ConduitError::Serialize)?;
if let Some(caps) = init_value.pointer_mut("/capabilities") {
if let Some(obj) = caps.as_object_mut() {
if let Some(tools) = obj.get_mut("tools").and_then(|t| t.as_object_mut()) {
tools.remove("listChanged");
}
if let Some(resources) =
obj.get_mut("resources").and_then(|r| r.as_object_mut())
{
resources.remove("listChanged");
}
}
}
let body = serde_json::json!({
"jsonrpc": "2.0",
"id": rpc_id.unwrap(),
"result": init_value,
});
let mut headers = IndexMap::new();
headers.insert(
"Mcp-Session-Id".to_string(),
state.connection.session_id.clone(),
);
return Ok(server_response::Response {
id: request.id,
status: 200,
headers,
body: Some(body),
});
}
let conn = &state.connection;
let mut req = conn.http_client.post(&conn.url);
for (k, v) in &request.headers {
if k.eq_ignore_ascii_case("host")
|| k.eq_ignore_ascii_case("content-length")
|| k.eq_ignore_ascii_case("connection")
|| k.eq_ignore_ascii_case("accept")
|| k.eq_ignore_ascii_case("content-type")
|| k.eq_ignore_ascii_case("mcp-session-id")
{
continue;
}
req = req.header(k, v);
}
req = req
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Mcp-Session-Id", &conn.session_id);
if let Some(body) = envelope.as_ref() {
req = req.json(body);
}
let rpc_id_str = rpc_id
.as_ref()
.map(|v| format!("{v}"))
.unwrap_or_default();
let method_str = rpc_method.as_deref().unwrap_or("");
let resp = req.send().await.map_err(ConduitError::Request)?;
let status = resp.status().as_u16();
let mut resp_headers = IndexMap::new();
for (k, v) in resp.headers().iter() {
if k.as_str().eq_ignore_ascii_case("mcp-session-id")
|| k.as_str().eq_ignore_ascii_case("content-type")
|| k.as_str().eq_ignore_ascii_case("transfer-encoding")
|| k.as_str().eq_ignore_ascii_case("content-length")
{
continue;
}
if let Ok(value) = v.to_str() {
resp_headers.insert(k.as_str().to_string(), value.to_string());
}
}
let resp_text = resp.text().await.map_err(ConduitError::Body)?;
let resp_body = parse_json_or_sse(&resp_text);
Ok(server_response::Response {
id: request.id,
status,
headers: resp_headers,
body: resp_body,
})
}
fn reject_no_mcp(id: String) -> server_response::Response {
server_response::Response {
id,
status: 501,
headers: IndexMap::new(),
body: Some(serde_json::json!({
"jsonrpc": "2.0",
"id": serde_json::Value::Null,
"error": {
"code": -32601,
"message": "this client has no MCP server configured (set `objectiveai mcp address`)",
},
})),
}
}
fn conduit_error(id: String, message: impl Into<String>) -> server_response::Response {
let message = message.into();
server_response::Response {
id,
status: 502,
headers: IndexMap::new(),
body: Some(serde_json::json!({
"jsonrpc": "2.0",
"id": serde_json::Value::Null,
"error": {
"code": -32603,
"message": format!("conduit: {message}"),
},
})),
}
}
fn parse_json_or_sse(text: &str) -> Option<serde_json::Value> {
if text.is_empty() {
return None;
}
if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
return Some(v);
}
let collected: String = text
.lines()
.filter_map(|l| l.strip_prefix("data: ").or_else(|| l.strip_prefix("data:")))
.collect();
if collected.is_empty() {
return None;
}
serde_json::from_str::<serde_json::Value>(&collected).ok()
}
pub fn build_handler(
config: &mut objectiveai_sdk::filesystem::config::Config,
) -> ConduitMcpHandler {
let mcp_url = std::env::var("OBJECTIVEAI_MCP_ADDRESS").ok().or_else(|| {
let mcp = config.mcp();
let port = std::env::var("OBJECTIVEAI_MCP_PORT")
.ok()
.and_then(|s| s.parse::<u16>().ok())
.or_else(|| mcp.get_port());
crate::api::client::compose_url(mcp.get_address(), port)
});
ConduitMcpHandler::new(mcp_url)
}
#[derive(Debug, thiserror::Error)]
enum ConduitError {
#[error("forwarding HTTP request failed: {0}")]
Request(reqwest::Error),
#[error("reading response body failed: {0}")]
Body(reqwest::Error),
#[error("serializing InitializeResult failed: {0}")]
Serialize(serde_json::Error),
}