use dashmap::DashMap;
use indexmap::IndexMap;
use objectiveai_sdk::Notifier;
use objectiveai_sdk::cli::command::plugins::run::Mcp as PluginMcp;
use objectiveai_sdk::client_objectiveai_mcp::McpKind;
use objectiveai_sdk::client_objectiveai_mcp::client_request::{McpListChanged, McpListChangedKind};
use objectiveai_sdk::client_objectiveai_mcp::server_response::{InitializeReply, JsonRpcResult};
use objectiveai_sdk::client_objectiveai_mcp::{server_request, server_response};
use objectiveai_sdk::http::McpHandler;
use objectiveai_sdk::mcp::resource::{
ListResourcesRequest, ListResourcesResult, ReadResourceRequestParams, ReadResourceResult,
};
use objectiveai_sdk::mcp::tool::{
CallToolRequestParams, CallToolResult, ListToolsRequest, ListToolsResult,
};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
struct ConduitState {
connection: objectiveai_sdk::mcp::Connection,
mcp_kind: McpKind,
agent_instance_hierarchy: String,
}
#[derive(Clone)]
pub struct ConduitMcpHandler {
inner: Arc<Inner>,
}
struct Inner {
mcp_server: crate::websockets::mcp_server::McpServerHandle,
client: objectiveai_sdk::mcp::Client,
connections: DashMap<String, Arc<ConduitState>>,
notifier: OnceLock<Notifier>,
ctx: crate::context::Context,
agent_tag: Option<String>,
}
impl ConduitMcpHandler {
pub fn new(
mcp_server: crate::websockets::mcp_server::McpServerHandle,
ctx: crate::context::Context,
agent_tag: Option<String>,
) -> Self {
let http = reqwest::Client::builder()
.build()
.expect("reqwest::Client::build is infallible without rustls toggles");
let client = objectiveai_sdk::mcp::Client::new(
http,
"objectiveai-cli-stream-conduit".to_string(),
String::new(),
String::new(),
Duration::from_secs(30),
Duration::from_secs(1),
Duration::from_secs(1),
0.5,
2.0,
Duration::from_secs(30),
Duration::from_secs(30),
Duration::from_secs(60),
);
Self {
inner: Arc::new(Inner {
mcp_server,
client,
connections: DashMap::new(),
notifier: OnceLock::new(),
ctx,
agent_tag,
}),
}
}
pub fn install_notifier(&self, notifier: Notifier) {
let _ = self.inner.notifier.set(notifier);
}
}
impl McpHandler for ConduitMcpHandler {
async fn handle(&self, request: server_request::Request) -> server_response::Response {
let id = request.id.clone();
let payload = match request.payload {
server_request::Payload::Initialize { mcp_kind, params } => {
dispatch_initialize(&self.inner, mcp_kind, params, &request.headers).await
}
server_request::Payload::SessionTerminate { mcp_kind } => {
dispatch_session_terminate(&self.inner, mcp_kind, &request.headers).await
}
server_request::Payload::ToolsList { mcp_kind, params } => {
match resolve_connection(self, &mcp_kind, &request.headers).await {
Ok(state) => dispatch_tools_list(&state, &request.headers, params).await,
Err(payload) => payload,
}
}
server_request::Payload::ToolsCall { mcp_kind, params } => {
match resolve_connection(self, &mcp_kind, &request.headers).await {
Ok(state) => dispatch_tools_call(&state, &request.headers, params).await,
Err(payload) => payload,
}
}
server_request::Payload::ResourcesList { mcp_kind, params } => {
match resolve_connection(self, &mcp_kind, &request.headers).await {
Ok(state) => dispatch_resources_list(&state, &request.headers, params).await,
Err(payload) => payload,
}
}
server_request::Payload::ResourcesRead { mcp_kind, params } => {
match resolve_connection(self, &mcp_kind, &request.headers).await {
Ok(state) => dispatch_resources_read(&state, &request.headers, params).await,
Err(payload) => payload,
}
}
server_request::Payload::ReadMessageQueue(req) => {
dispatch_read_message_queue(&self.inner, req).await
}
};
server_response::Response { id, payload }
}
}
async fn resolve_connection(
handler: &ConduitMcpHandler,
mcp_kind: &McpKind,
headers: &IndexMap<String, String>,
) -> Result<Arc<ConduitState>, server_response::Payload> {
let Some(session_id) = mcp_session_id_from_headers(headers) else {
return Err(error_for(
mcp_kind,
-32600,
"missing Mcp-Session-Id header".to_string(),
));
};
if let Some(existing) = handler.inner.connections.get(&session_id) {
return Ok(existing.clone());
}
if !matches!(mcp_kind, McpKind::ObjectiveAi) {
return Err(error_for(
mcp_kind,
-32001,
format!("no cached connection for Mcp-Session-Id {session_id:?}"),
));
}
let mcp_url = match objectiveai_mcp_url(&handler.inner).await {
Ok(u) => u,
Err(message) => return Err(error_for(mcp_kind, -32603, message)),
};
let transient = match require_transient(headers) {
Ok(t) => t,
Err(message) => {
return Err(error_for(
mcp_kind,
-32600,
format!("conduit: {message}"),
));
}
};
let connect_headers = sanitize_connect_headers(headers);
let connection = match handler
.inner
.client
.connect(mcp_url, Some(session_id.clone()), Some(connect_headers))
.await
{
Ok(c) => c,
Err(e) => {
return Err(error_for(
mcp_kind,
-32603,
format!("conduit: connect (resume): {e}"),
));
}
};
install_list_changed_pump(&connection, handler.inner.clone(), mcp_kind.clone());
let state = Arc::new(ConduitState {
connection,
mcp_kind: mcp_kind.clone(),
agent_instance_hierarchy: transient.agent_instance_hierarchy,
});
handler.inner.connections.insert(session_id, state.clone());
Ok(state)
}
async fn objectiveai_mcp_url(inner: &Arc<Inner>) -> Result<String, String> {
let port = inner
.mcp_server
.port
.clone()
.await
.map_err(|_| "in-process objectiveai-mcp failed to bind".to_string())?;
Ok(format!("http://127.0.0.1:{port}"))
}
fn error_for(mcp_kind: &McpKind, code: i64, message: String) -> server_response::Payload {
server_response::Payload::ToolsList {
mcp_kind: mcp_kind.clone(),
result: JsonRpcResult::Err {
code,
message,
data: None,
},
}
}
async fn dispatch_initialize(
inner: &Arc<Inner>,
mcp_kind: McpKind,
init: server_request::InitializeRequest,
headers: &IndexMap<String, String>,
) -> server_response::Payload {
let initialize_err = |code: i64, message: String| server_response::Payload::Initialize {
mcp_kind: mcp_kind.clone(),
result: JsonRpcResult::Err {
code,
message,
data: None,
},
};
let transient = match require_transient(headers) {
Ok(t) => t,
Err(message) => {
return initialize_err(-32600, format!("conduit: {message}"));
}
};
let stored_session_id = mcp_session_id_from_headers(headers);
let dial = match &mcp_kind {
McpKind::ObjectiveAi => {
let mcp_url = match objectiveai_mcp_url(inner).await {
Ok(u) => u,
Err(message) => {
return initialize_err(-32603, message);
}
};
let connect_headers = sanitize_connect_headers(headers);
inner
.client
.connect(mcp_url, stored_session_id, Some(connect_headers))
.await
.map_err(|e| format!("connect: {e}"))
}
McpKind::Other { owner, name, version, mcp } => {
dial_plugin_upstream(
inner,
owner.clone(),
name.clone(),
version.clone(),
mcp.clone(),
init.args,
&transient,
stored_session_id,
)
.await
.map_err(|e| format!("{e}"))
}
};
let connection = match dial {
Ok(c) => c,
Err(message) => {
return initialize_err(-32603, format!("conduit: {message}"));
}
};
install_list_changed_pump(&connection, inner.clone(), mcp_kind.clone());
let mcp_session_id = connection.session_id.clone();
let result = connection.initialize_result.clone();
inner.connections.insert(
mcp_session_id.clone(),
Arc::new(ConduitState {
connection,
mcp_kind: mcp_kind.clone(),
agent_instance_hierarchy: transient.agent_instance_hierarchy,
}),
);
server_response::Payload::Initialize {
mcp_kind,
result: JsonRpcResult::Ok {
result: InitializeReply {
mcp_session_id,
result,
},
},
}
}
async fn dispatch_session_terminate(
inner: &Arc<Inner>,
mcp_kind: McpKind,
headers: &IndexMap<String, String>,
) -> server_response::Payload {
let ok = || server_response::Payload::SessionTerminate {
mcp_kind: mcp_kind.clone(),
result: JsonRpcResult::Ok { result: () },
};
let Some(session_id) = mcp_session_id_from_headers(headers) else {
return ok();
};
let Some(state) = inner
.connections
.get(&session_id)
.map(|e| e.value().clone())
else {
return ok();
};
match state.connection.delete().await {
Ok(()) => {
inner.connections.remove(&session_id);
ok()
}
Err(e) => server_response::Payload::SessionTerminate {
mcp_kind,
result: JsonRpcResult::Err {
code: -32603,
message: format!("conduit: upstream delete: {e}"),
data: None,
},
},
}
}
async fn dispatch_tools_list(
state: &ConduitState,
headers: &IndexMap<String, String>,
params: ListToolsRequest,
) -> server_response::Payload {
let result = upstream_call::<ListToolsRequest, ListToolsResult>(
&state.connection,
headers,
"tools/list",
¶ms,
)
.await;
server_response::Payload::ToolsList {
mcp_kind: state.mcp_kind.clone(),
result: into_rpc_result(result),
}
}
async fn dispatch_tools_call(
state: &ConduitState,
headers: &IndexMap<String, String>,
params: CallToolRequestParams,
) -> server_response::Payload {
let result = upstream_call::<CallToolRequestParams, CallToolResult>(
&state.connection,
headers,
"tools/call",
¶ms,
)
.await;
server_response::Payload::ToolsCall {
mcp_kind: state.mcp_kind.clone(),
result: into_rpc_result(result),
}
}
async fn dispatch_resources_list(
state: &ConduitState,
headers: &IndexMap<String, String>,
params: ListResourcesRequest,
) -> server_response::Payload {
let result = upstream_call::<ListResourcesRequest, ListResourcesResult>(
&state.connection,
headers,
"resources/list",
¶ms,
)
.await;
server_response::Payload::ResourcesList {
mcp_kind: state.mcp_kind.clone(),
result: into_rpc_result(result),
}
}
async fn dispatch_resources_read(
state: &ConduitState,
headers: &IndexMap<String, String>,
params: ReadResourceRequestParams,
) -> server_response::Payload {
let result = upstream_call::<ReadResourceRequestParams, ReadResourceResult>(
&state.connection,
headers,
"resources/read",
¶ms,
)
.await;
server_response::Payload::ResourcesRead {
mcp_kind: state.mcp_kind.clone(),
result: into_rpc_result(result),
}
}
async fn dispatch_read_message_queue(
inner: &Arc<Inner>,
req: server_request::ReadMessageQueueRequest,
) -> server_response::Payload {
match crate::db::message_queue::read_pending_and_upgrade_tag(
&inner.ctx.db,
inner.agent_tag.as_deref(),
&req.agent_instance_hierarchy,
)
.await
{
Ok(result) => server_response::Payload::ReadMessageQueue(JsonRpcResult::Ok { result }),
Err(e) => server_response::Payload::ReadMessageQueue(JsonRpcResult::Err {
code: -32603,
message: format!("conduit: read_message_queue: {e}"),
data: None,
}),
}
}
fn into_rpc_result<R>(
result: Result<JsonRpcResult<R>, ConduitError>,
) -> JsonRpcResult<R> {
match result {
Ok(r) => r,
Err(e) => JsonRpcResult::Err {
code: -32603,
message: format!("conduit: {e}"),
data: None,
},
}
}
async fn upstream_call<P, R>(
conn: &objectiveai_sdk::mcp::Connection,
headers: &IndexMap<String, String>,
method: &str,
params: &P,
) -> Result<JsonRpcResult<R>, ConduitError>
where
P: serde::Serialize,
R: serde::de::DeserializeOwned,
{
let rpc_id = uuid::Uuid::new_v4().to_string();
let envelope = serde_json::json!({
"jsonrpc": "2.0",
"id": rpc_id,
"method": method,
"params": params,
});
let mut req = conn.http_client.post(&conn.url);
for (k, v) in headers {
if k.eq_ignore_ascii_case("host")
|| k.eq_ignore_ascii_case("content-length")
|| k.eq_ignore_ascii_case("connection")
|| k.eq_ignore_ascii_case("accept")
|| k.eq_ignore_ascii_case("content-type")
|| k.eq_ignore_ascii_case("mcp-session-id")
{
continue;
}
req = req.header(k, v);
}
req = req
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("Mcp-Session-Id", &conn.session_id)
.json(&envelope);
let resp = req.send().await.map_err(ConduitError::Request)?;
let resp_text = resp.text().await.map_err(ConduitError::Body)?;
let Some(body) = parse_json_or_sse(&resp_text) else {
return Err(ConduitError::MalformedUpstream(
"empty or unparseable upstream response".into(),
));
};
if let Some(result) = body.get("result") {
let typed: R = serde_json::from_value(result.clone())
.map_err(|e| ConduitError::MalformedUpstream(format!("decode upstream result: {e}")))?;
return Ok(JsonRpcResult::Ok { result: typed });
}
if let Some(err) = body.get("error") {
let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-32603);
let message = err
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("upstream returned an error envelope without a message")
.to_string();
let data = err.get("data").cloned();
return Ok(JsonRpcResult::Err {
code,
message,
data,
});
}
Err(ConduitError::MalformedUpstream(
"upstream response missing both `result` and `error`".into(),
))
}
#[allow(clippy::too_many_arguments)]
async fn dial_plugin_upstream(
inner: &Arc<Inner>,
plugin_owner: String,
plugin_name: String,
plugin_version: String,
mcp_name: String,
args: IndexMap<String, Option<String>>,
transient: &TransientHeaders,
stored_session_id: Option<String>,
) -> Result<objectiveai_sdk::mcp::Connection, ConduitError> {
let fail = |reason: String| ConduitError::PluginDialFailed {
plugin_owner: plugin_owner.clone(),
plugin_name: plugin_name.clone(),
plugin_version: plugin_version.clone(),
mcp_name: mcp_name.clone(),
reason,
};
let mut dial_ctx = inner.ctx.clone();
dial_ctx.config.agent_instance_hierarchy = transient.agent_instance_hierarchy.clone();
dial_ctx.config.agent_id = Some(transient.agent_id.clone());
dial_ctx.config.agent_full_id = Some(transient.agent_full_id.clone());
dial_ctx.config.agent_remote = transient.agent_remote.clone();
dial_ctx.config.response_id = Some(transient.response_id.clone());
dial_ctx.config.response_ids = Some(transient.response_ids.clone());
let mut argv: Vec<String> = vec!["mcp".to_string(), mcp_name.clone(), "begin".to_string()];
for (k, v) in &args {
argv.push(format!("--{k}"));
if let Some(value) = v {
argv.push(value.clone());
}
}
let request = objectiveai_sdk::cli::command::plugins::run::Request {
path_type: objectiveai_sdk::cli::command::plugins::run::Path::PluginsRun,
owner: plugin_owner.clone(),
name: plugin_name.clone(),
version: plugin_version.clone(),
args: argv,
jq: None,
};
let stream = crate::command::plugins::run::execute(&dial_ctx, request)
.await
.map_err(|e| fail(format!("plugin spawn failed: {e}")))?;
let (mcp_tx, mcp_rx) = tokio::sync::oneshot::channel::<PluginMcp>();
tokio::spawn(async move {
use futures::StreamExt;
use objectiveai_sdk::cli::command::plugins::run::ResponseItem;
let mut stream = stream;
let mut mcp_tx = Some(mcp_tx);
while let Some(item) = stream.next().await {
if let Ok(ResponseItem::Mcp(mcp)) = item {
if let Some(tx) = mcp_tx.take() {
let _ = tx.send(mcp);
}
}
}
});
let mcp = mcp_rx
.await
.map_err(|_| fail("plugin exited without emitting mcp{url}".into()))?;
let connection = inner
.client
.connect(mcp.url, stored_session_id, None)
.await
.map_err(|e| fail(format!("connect: {e}")))?;
Ok(connection)
}
fn install_list_changed_pump(
connection: &objectiveai_sdk::mcp::Connection,
inner: Arc<Inner>,
mcp_kind: McpKind,
) {
let inner_tools = inner.clone();
let kind_tools = mcp_kind.clone();
connection.set_on_tools_list_changed(move || {
let Some(notifier) = inner_tools.notifier.get().cloned() else {
return;
};
let mcp_kind = kind_tools.clone();
tokio::spawn(async move {
let _ = notifier
.notify_list_changed(McpListChanged {
mcp_kind,
kind: McpListChangedKind::Tools,
})
.await;
});
});
connection.set_on_resources_list_changed(move || {
let Some(notifier) = inner.notifier.get().cloned() else {
return;
};
let mcp_kind = mcp_kind.clone();
tokio::spawn(async move {
let _ = notifier
.notify_list_changed(McpListChanged {
mcp_kind,
kind: McpListChangedKind::Resources,
})
.await;
});
});
}
fn sanitize_connect_headers(headers: &IndexMap<String, String>) -> IndexMap<String, String> {
let mut out = headers.clone();
for k in [
"Host",
"host",
"Content-Length",
"content-length",
"Mcp-Session-Id",
"mcp-session-id",
] {
out.shift_remove(k);
}
out
}
fn mcp_session_id_from_headers(headers: &IndexMap<String, String>) -> Option<String> {
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("Mcp-Session-Id"))
.map(|(_, v)| v.clone())
}
const REQUIRED_TRANSIENT_HEADERS: [&str; 5] = [
"X-OBJECTIVEAI-AGENT-INSTANCE-HIERARCHY",
"X-OBJECTIVEAI-AGENT-ID",
"X-OBJECTIVEAI-AGENT-FULL-ID",
"X-OBJECTIVEAI-RESPONSE-ID",
"X-OBJECTIVEAI-RESPONSE-IDS",
];
const OPTIONAL_AGENT_REMOTE_HEADER: &str = "X-OBJECTIVEAI-AGENT-REMOTE";
struct TransientHeaders {
agent_instance_hierarchy: String,
agent_id: String,
agent_full_id: String,
agent_remote: Option<String>,
response_id: String,
response_ids: String,
}
fn require_transient(
headers: &IndexMap<String, String>,
) -> Result<TransientHeaders, String> {
let mut values: [Option<String>; 5] = Default::default();
for (idx, key) in REQUIRED_TRANSIENT_HEADERS.iter().enumerate() {
let raw = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(key))
.map(|(_, v)| v.clone());
let v = match raw {
None => return Err(format!("missing required header {key:?}")),
Some(s) if s.is_empty() => {
return Err(format!("empty required header {key:?}"));
}
Some(s) => s,
};
values[idx] = Some(v);
}
let agent_remote = match headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(OPTIONAL_AGENT_REMOTE_HEADER))
.map(|(_, v)| v.clone())
{
None => None,
Some(s) if s.is_empty() => {
return Err(format!(
"empty optional header {OPTIONAL_AGENT_REMOTE_HEADER:?} (absent header is fine; empty value is not)"
));
}
Some(s) => Some(s),
};
let [agent_instance_hierarchy, agent_id, agent_full_id, response_id, response_ids] =
values.map(|o| o.expect("every slot filled before this line"));
Ok(TransientHeaders {
agent_instance_hierarchy,
agent_id,
agent_full_id,
agent_remote,
response_id,
response_ids,
})
}
fn parse_json_or_sse(text: &str) -> Option<serde_json::Value> {
if text.is_empty() {
return None;
}
if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
return Some(v);
}
let collected: String = text
.lines()
.filter_map(|l| l.strip_prefix("data: ").or_else(|| l.strip_prefix("data:")))
.collect();
if collected.is_empty() {
return None;
}
serde_json::from_str::<serde_json::Value>(&collected).ok()
}
#[derive(Debug, thiserror::Error)]
enum ConduitError {
#[error("forwarding HTTP request failed: {0}")]
Request(reqwest::Error),
#[error("reading response body failed: {0}")]
Body(reqwest::Error),
#[error("upstream response was malformed: {0}")]
MalformedUpstream(String),
#[error("plugin upstream {plugin_owner:?}/{plugin_name:?}@{plugin_version:?}/{mcp_name:?} dial failed: {reason}")]
PluginDialFailed {
plugin_owner: String,
plugin_name: String,
plugin_version: String,
mcp_name: String,
reason: String,
},
}