use axum::http::HeaderMap;
use futures::TryFutureExt;
use futures::future::try_join_all;
use indexmap::IndexMap;
use objectiveai_sdk::mcp::Client;
use crate::reverse_channel::{ReverseChannel, Upstream};
const SERVERS_HEADER: &str = "X-MCP-Servers";
const HEADERS_HEADER: &str = "X-MCP-Headers";
pub const LIST_FILTER_HEADER: &str = "X-List-Filter";
#[derive(Debug)]
struct UpstreamSpec {
url: String,
headers: IndexMap<String, String>,
}
#[derive(Debug, thiserror::Error)]
pub enum BadInit {
#[error("{header} is not valid UTF-8")]
NotUtf8 { header: &'static str },
#[error("{header} is not valid JSON: {source}")]
NotJson {
header: &'static str,
#[source]
source: serde_json::Error,
},
#[error("upstream connect failed for {url}: {source}")]
UpstreamConnectFailed {
url: String,
#[source]
source: objectiveai_sdk::mcp::Error,
},
#[error("upstream {kind} list failed for {url}: {source}")]
UpstreamListFailed {
url: String,
kind: &'static str,
source: std::sync::Arc<objectiveai_sdk::mcp::Error>,
},
}
pub const MCP_SESSION_ID_KEY: &str = "Mcp-Session-Id";
pub async fn connect_all_fresh(
client: &Client,
reverse_channel: Option<&ReverseChannel>,
http_headers: &HeaderMap,
) -> Result<Vec<(Upstream, IndexMap<String, String>)>, BadInit> {
let specs = parse_init_headers(http_headers)?;
let transient: IndexMap<String, String> = crate::session::Session::TRANSIENT_HEADER_KEYS
.iter()
.filter_map(|key| {
let v = http_headers.get(*key)?.to_str().ok()?;
Some((key.to_string(), v.to_string()))
})
.collect();
let attempts = specs.into_iter().map(|spec| {
let url = spec.url.clone();
let headers_for_payload = spec.headers.clone();
let transient = transient.clone();
async move {
let mut headers = spec.headers;
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
for (k, v) in transient {
headers.entry(k).or_insert(v);
}
let upstream =
connect_upstream(client, reverse_channel, &url, session_id, headers).await?;
tokio::try_join!(
upstream.list_tools().map_err(|source| BadInit::UpstreamListFailed {
url: url.clone(),
kind: "tools",
source,
}),
upstream.list_resources().map_err(|source| BadInit::UpstreamListFailed {
url: url.clone(),
kind: "resources",
source,
}),
)?;
let payload_headers =
build_canonical_headers(&headers_for_payload, upstream.session_id());
Ok::<_, BadInit>((upstream, payload_headers))
}
});
let connections = try_join_all(attempts).await?;
Ok(connections)
}
pub async fn reconnect_from_payload(
client: &Client,
reverse_channel: Option<&ReverseChannel>,
payload: &crate::session_manager::SessionPayload,
http_headers: &HeaderMap,
) -> Result<Vec<(Upstream, IndexMap<String, String>)>, BadInit> {
let transient: IndexMap<String, String> = crate::session::Session::TRANSIENT_HEADER_KEYS
.iter()
.filter_map(|key| {
let v = http_headers.get(*key)?.to_str().ok()?;
Some((key.to_string(), v.to_string()))
})
.collect();
let attempts = payload.connections.iter().map(|(url, headers)| {
let url = url.clone();
let mut headers = headers.clone();
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
let payload_headers = headers.clone();
let transient = transient.clone();
async move {
for (k, v) in transient {
headers.entry(k).or_insert(v);
}
let upstream =
connect_upstream(client, reverse_channel, &url, session_id, headers).await?;
tokio::try_join!(
upstream.list_tools().map_err(|source| BadInit::UpstreamListFailed {
url: url.clone(),
kind: "tools",
source,
}),
upstream.list_resources().map_err(|source| BadInit::UpstreamListFailed {
url: url.clone(),
kind: "resources",
source,
}),
)?;
let canonical = build_canonical_headers(&payload_headers, upstream.session_id());
Ok::<_, BadInit>((upstream, canonical))
}
});
try_join_all(attempts).await
}
async fn connect_upstream(
client: &Client,
reverse_channel: Option<&ReverseChannel>,
url: &str,
session_id: Option<String>,
mut headers: IndexMap<String, String>,
) -> Result<Upstream, BadInit> {
if let Some(mcp_kind) = crate::reverse_channel::parse_ws_mcp_kind(url) {
let channel = reverse_channel.cloned().ok_or_else(|| {
BadInit::UpstreamConnectFailed {
url: url.to_string(),
source: objectiveai_sdk::mcp::Error::MalformedResponse {
url: url.to_string(),
message: "ws:// upstream requires a reverse channel".into(),
},
}
})?;
if let Some(sid) = session_id {
headers.insert(MCP_SESSION_ID_KEY.to_string(), sid);
}
let args: IndexMap<String, Option<String>> = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("X-OBJECTIVEAI-ARGUMENTS"))
.and_then(|(_, v)| serde_json::from_str(v).ok())
.unwrap_or_default();
let upstream = crate::reverse_channel::connect_ws(
channel,
url.to_string(),
mcp_kind,
args,
headers,
)
.await
.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.to_string(),
source,
})?;
Ok(Upstream::Ws(upstream))
} else {
let conn = client
.connect(url.to_string(), session_id, Some(headers))
.await
.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.to_string(),
source,
})?;
Ok(Upstream::Http(conn))
}
}
fn build_canonical_headers(
headers: &IndexMap<String, String>,
upstream_session_id: &str,
) -> IndexMap<String, String> {
let mut out: IndexMap<String, String> = headers.clone();
out.insert(MCP_SESSION_ID_KEY.to_string(), upstream_session_id.to_string());
out
}
fn parse_init_headers(
http_headers: &HeaderMap,
) -> Result<Vec<UpstreamSpec>, BadInit> {
let servers: Vec<String> = match http_headers.get(SERVERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: SERVERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: SERVERS_HEADER,
source,
})?
}
None => Vec::new(),
};
let mut per_url_headers: IndexMap<String, IndexMap<String, String>> =
match http_headers.get(HEADERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: HEADERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: HEADERS_HEADER,
source,
})?
}
None => IndexMap::new(),
};
for inner in per_url_headers.values_mut() {
for key in crate::session::Session::TRANSIENT_HEADER_KEYS {
inner.shift_remove(key);
}
}
let mut seen = std::collections::HashSet::new();
let mut specs = Vec::with_capacity(servers.len());
for url in servers {
if !seen.insert(url.clone()) {
tracing::debug!(url = %url, "ignoring duplicate X-MCP-Servers entry");
continue;
}
let headers = per_url_headers.shift_remove(&url).unwrap_or_default();
specs.push(UpstreamSpec { url, headers });
}
Ok(specs)
}