use axum::http::HeaderMap;
use futures::future::try_join_all;
use indexmap::IndexMap;
use objectiveai_sdk::mcp::{Client, Connection};
const SERVERS_HEADER: &str = "X-MCP-Servers";
const HEADERS_HEADER: &str = "X-MCP-Headers";
const TOOLS_ALLOW_HEADER: &str = "X-MCP-Tools-Allow";
pub const LIST_FILTER_HEADER: &str = "X-List-Filter";
#[derive(Debug)]
struct UpstreamSpec {
url: String,
headers: IndexMap<String, String>,
}
#[derive(Debug, thiserror::Error)]
pub enum BadInit {
#[error("{header} is not valid UTF-8")]
NotUtf8 { header: &'static str },
#[error("{header} is not valid JSON: {source}")]
NotJson {
header: &'static str,
#[source]
source: serde_json::Error,
},
#[error("upstream connect failed for {url}: {source}")]
UpstreamConnectFailed {
url: String,
#[source]
source: objectiveai_sdk::mcp::Error,
},
}
pub const MCP_SESSION_ID_KEY: &str = "Mcp-Session-Id";
pub async fn connect_all_fresh(
client: &Client,
http_headers: &HeaderMap,
agent_id: Option<&str>,
) -> Result<(Vec<(Connection, IndexMap<String, String>)>, IndexMap<String, Vec<String>>), BadInit> {
let (specs, tool_allowlists) = parse_init_headers(http_headers)?;
let agent_id_owned: Option<String> =
agent_id.filter(|s| !s.is_empty()).map(str::to_owned);
let attempts = specs.into_iter().map(|spec| {
let url = spec.url.clone();
let headers_for_payload = spec.headers.clone();
let agent_id_owned = agent_id_owned.clone();
async move {
let mut headers = spec.headers;
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
if let Some(id) = &agent_id_owned {
headers.insert("X-OBJECTIVEAI-AGENT-ID".to_string(), id.clone());
}
let conn_result = client
.connect(spec.url, session_id, Some(headers))
.await;
let conn = conn_result.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.clone(),
source,
})?;
let payload_headers =
build_canonical_headers(&headers_for_payload, &conn.session_id);
Ok::<_, BadInit>((conn, payload_headers))
}
});
let connections = try_join_all(attempts).await?;
Ok((connections, tool_allowlists))
}
pub async fn reconnect_from_payload(
client: &Client,
payload: &crate::session_manager::SessionPayload,
) -> Result<Vec<(Connection, IndexMap<String, String>)>, BadInit> {
let agent_id_owned: Option<String> = payload
.agent_id
.as_deref()
.filter(|s| !s.is_empty())
.map(str::to_owned);
let attempts = payload.connections.iter().map(|(url, headers)| {
let url = url.clone();
let mut headers = headers.clone();
let session_id = headers.shift_remove(MCP_SESSION_ID_KEY);
if let Some(id) = &agent_id_owned {
headers.insert("X-OBJECTIVEAI-AGENT-ID".to_string(), id.clone());
}
let payload_headers = headers.clone();
async move {
let conn_result = client
.connect(url.clone(), session_id, Some(headers))
.await;
let conn = conn_result.map_err(|source| BadInit::UpstreamConnectFailed {
url: url.clone(),
source,
})?;
let canonical = build_canonical_headers(&payload_headers, &conn.session_id);
Ok::<_, BadInit>((conn, canonical))
}
});
try_join_all(attempts).await
}
fn build_canonical_headers(
headers: &IndexMap<String, String>,
upstream_session_id: &str,
) -> IndexMap<String, String> {
let mut out: IndexMap<String, String> = headers.clone();
out.insert(MCP_SESSION_ID_KEY.to_string(), upstream_session_id.to_string());
out
}
fn parse_init_headers(
http_headers: &HeaderMap,
) -> Result<(Vec<UpstreamSpec>, IndexMap<String, Vec<String>>), BadInit> {
let servers: Vec<String> = match http_headers.get(SERVERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: SERVERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: SERVERS_HEADER,
source,
})?
}
None => Vec::new(),
};
let mut per_url_headers: IndexMap<String, IndexMap<String, String>> =
match http_headers.get(HEADERS_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 { header: HEADERS_HEADER })?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: HEADERS_HEADER,
source,
})?
}
None => IndexMap::new(),
};
let tool_allowlists: IndexMap<String, Vec<String>> =
match http_headers.get(TOOLS_ALLOW_HEADER) {
Some(v) => {
let s = v.to_str().map_err(|_| BadInit::NotUtf8 {
header: TOOLS_ALLOW_HEADER,
})?;
serde_json::from_str(s).map_err(|source| BadInit::NotJson {
header: TOOLS_ALLOW_HEADER,
source,
})?
}
None => IndexMap::new(),
};
let mut seen = std::collections::HashSet::new();
let mut specs = Vec::with_capacity(servers.len());
for url in servers {
if !seen.insert(url.clone()) {
tracing::debug!(url = %url, "ignoring duplicate X-MCP-Servers entry");
continue;
}
let headers = per_url_headers.shift_remove(&url).unwrap_or_default();
specs.push(UpstreamSpec { url, headers });
}
Ok((specs, tool_allowlists))
}