use axum::http::HeaderMap;
use futures::future::try_join_all;
use indexmap::IndexMap;
use objectiveai_sdk::mcp::{Client, Connection};
const SERVERS_HEADER: &str = "X-MCP-Servers";
const HEADERS_HEADER: &str = "X-MCP-Headers";
#[derive(Debug)]
struct UpstreamSpec {
url: String,
headers: IndexMap<String, String>,
}
#[derive(Debug, thiserror::Error)]
pub enum BadInit {
#[error("{header} is not valid UTF-8")]
NotUtf8 { header: &'static str },
#[error("{header} is not valid JSON: {source}")]
NotJson {
header: &'static str,
#[source]
source: serde_json::Error,
},
#[error("upstream connect failed for {url}: {source}")]
UpstreamConnectFailed {
url: String,
#[source]
source: objectiveai_sdk::mcp::Error,
},
}
pub const MCP_SESSION_ID_KEY: &str = "Mcp-Session-Id";
pub async fn connect_all_fresh(
client: &Client,
http_headers: &HeaderMap,
) -> Result<Vec<(Connection, IndexMap<String, String>)>, BadInit> {
let specs = parse_init_headers(http_headers)?;
let attempts = specs.into_iter().map(|spec| {
let url = spec.url.clone();
let headers_for_payload = spec.headers.clone();
async move {
let mut headers = spec.headers;
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
let conn = client
.connect(spec.url, session_id, Some(headers))
.await
.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.clone(),
source,
})?;
let payload_headers =
build_canonical_headers(&headers_for_payload, &conn.session_id);
Ok::<_, BadInit>((conn, payload_headers))
}
});
try_join_all(attempts).await
}
pub async fn reconnect_from_payload(
client: &Client,
payload: &crate::session_manager::SessionPayload,
) -> Result<Vec<(Connection, IndexMap<String, String>)>, BadInit> {
let attempts = payload.iter().map(|(url, headers)| {
let url = url.clone();
let mut headers = headers.clone();
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
let payload_headers = headers.clone();
async move {
let conn = client
.connect(url.clone(), session_id, Some(headers))
.await
.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.clone(),
source,
})?;
let canonical = build_canonical_headers(&payload_headers, &conn.session_id);
Ok::<_, BadInit>((conn, canonical))
}
});
try_join_all(attempts).await
}
fn build_canonical_headers(
headers: &IndexMap<String, String>,
upstream_session_id: &str,
) -> IndexMap<String, String> {
let mut out: IndexMap<String, String> = headers.clone();
out.insert(MCP_SESSION_ID_KEY.to_string(), upstream_session_id.to_string());
out
}
fn parse_init_headers(http_headers: &HeaderMap) -> Result<Vec<UpstreamSpec>, BadInit> {
let servers: Vec<String> = match http_headers.get(SERVERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: SERVERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: SERVERS_HEADER,
source,
})?
}
None => Vec::new(),
};
let mut per_url_headers: IndexMap<String, IndexMap<String, String>> =
match http_headers.get(HEADERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: HEADERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: HEADERS_HEADER,
source,
})?
}
None => IndexMap::new(),
};
let mut seen = std::collections::HashSet::new();
let mut specs = Vec::with_capacity(servers.len());
for url in servers {
if !seen.insert(url.clone()) {
tracing::debug!(url = %url, "ignoring duplicate X-MCP-Servers entry");
continue;
}
let headers = per_url_headers.shift_remove(&url).unwrap_or_default();
specs.push(UpstreamSpec { url, headers });
}
Ok(specs)
}