use axum::http::HeaderMap;
use futures::future::try_join_all;
use indexmap::IndexMap;
use objectiveai_sdk::mcp::{Client, Connection};
const SERVERS_HEADER: &str = "X-MCP-Servers";
const HEADERS_HEADER: &str = "X-MCP-Headers";
pub const LIST_FILTER_HEADER: &str = "X-List-Filter";
#[derive(Debug)]
struct UpstreamSpec {
url: String,
headers: IndexMap<String, String>,
}
#[derive(Debug, thiserror::Error)]
pub enum BadInit {
#[error("{header} is not valid UTF-8")]
NotUtf8 { header: &'static str },
#[error("{header} is not valid JSON: {source}")]
NotJson {
header: &'static str,
#[source]
source: serde_json::Error,
},
#[error("upstream connect failed for {url}: {source}")]
UpstreamConnectFailed {
url: String,
#[source]
source: objectiveai_sdk::mcp::Error,
},
}
pub const MCP_SESSION_ID_KEY: &str = "Mcp-Session-Id";
pub async fn connect_all_fresh(
client: &Client,
http_headers: &HeaderMap,
) -> Result<Vec<(Connection, IndexMap<String, String>)>, BadInit> {
let specs = parse_init_headers(http_headers)?;
let transient: IndexMap<String, String> = crate::session::Session::TRANSIENT_HEADER_KEYS
.iter()
.filter_map(|key| {
let v = http_headers.get(*key)?.to_str().ok()?;
Some((key.to_string(), v.to_string()))
})
.collect();
let attempts = specs.into_iter().map(|spec| {
let url = spec.url.clone();
let headers_for_payload = spec.headers.clone();
let transient = transient.clone();
async move {
let mut headers = spec.headers;
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
for (k, v) in transient {
headers.entry(k).or_insert(v);
}
let conn_result = client
.connect(spec.url, session_id, Some(headers))
.await;
let conn = conn_result.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.clone(),
source,
})?;
let payload_headers =
build_canonical_headers(&headers_for_payload, &conn.session_id);
Ok::<_, BadInit>((conn, payload_headers))
}
});
let connections = try_join_all(attempts).await?;
Ok(connections)
}
pub async fn reconnect_from_payload(
client: &Client,
payload: &crate::session_manager::SessionPayload,
) -> Result<Vec<(Connection, IndexMap<String, String>)>, BadInit> {
let attempts = payload.connections.iter().map(|(url, headers)| {
let url = url.clone();
let mut headers = headers.clone();
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
let payload_headers = headers.clone();
async move {
let conn_result = client
.connect(url.clone(), session_id, Some(headers))
.await;
let conn = conn_result.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.clone(),
source,
})?;
let canonical = build_canonical_headers(&payload_headers, &conn.session_id);
Ok::<_, BadInit>((conn, canonical))
}
});
try_join_all(attempts).await
}
fn build_canonical_headers(
headers: &IndexMap<String, String>,
upstream_session_id: &str,
) -> IndexMap<String, String> {
let mut out: IndexMap<String, String> = headers.clone();
out.insert(MCP_SESSION_ID_KEY.to_string(), upstream_session_id.to_string());
out
}
fn parse_init_headers(
http_headers: &HeaderMap,
) -> Result<Vec<UpstreamSpec>, BadInit> {
let servers: Vec<String> = match http_headers.get(SERVERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: SERVERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: SERVERS_HEADER,
source,
})?
}
None => Vec::new(),
};
let mut per_url_headers: IndexMap<String, IndexMap<String, String>> =
match http_headers.get(HEADERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: HEADERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: HEADERS_HEADER,
source,
})?
}
None => IndexMap::new(),
};
for inner in per_url_headers.values_mut() {
for key in crate::session::Session::TRANSIENT_HEADER_KEYS {
inner.shift_remove(key);
}
}
let mut seen = std::collections::HashSet::new();
let mut specs = Vec::with_capacity(servers.len());
for url in servers {
if !seen.insert(url.clone()) {
tracing::debug!(url = %url, "ignoring duplicate X-MCP-Servers entry");
continue;
}
let headers = per_url_headers.shift_remove(&url).unwrap_or_default();
specs.push(UpstreamSpec { url, headers });
}
Ok(specs)
}