use dashmap::DashMap;
use futures::future::try_join_all;
use indexmap::IndexMap;
use std::sync::Arc;
use objectiveai_sdk::mcp::{
Connection, JsonRpcNotification,
resource::{ListResourcesResult, ReadResourceResult, Resource},
tool::{CallToolRequestParams, CallToolResult, ContentBlock, ListToolsResult, Tool},
};
use axum::http::HeaderMap;
use tokio::sync::{Mutex, RwLock, broadcast};
use tokio_util::sync::CancellationToken;
const OUTBOUND_CAPACITY: usize = 64;
fn request_id_key(id: &serde_json::Value) -> String {
serde_json::to_string(id).unwrap_or_default()
}
#[derive(Debug)]
pub struct Session {
pub connections: IndexMap<String, Connection>,
pub outbound: broadcast::Sender<JsonRpcNotification>,
in_flight: DashMap<String, CancellationToken>,
pub payload: crate::session_manager::SessionPayload,
pub transient_headers: RwLock<IndexMap<String, String>>,
}
impl Session {
pub(crate) fn new(
connections: IndexMap<String, Connection>,
payload: crate::session_manager::SessionPayload,
) -> Self {
let (outbound, _) = broadcast::channel(OUTBOUND_CAPACITY);
for connection in connections.values() {
let tx = outbound.clone();
connection.set_on_tools_list_changed(move || {
let _ = tx.send(JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "notifications/tools/list_changed".into(),
params: None,
});
});
let tx = outbound.clone();
connection.set_on_resources_list_changed(move || {
let _ = tx.send(JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "notifications/resources/list_changed".into(),
params: None,
});
});
}
Self {
connections,
outbound,
in_flight: DashMap::new(),
payload,
transient_headers: RwLock::new(IndexMap::new()),
}
}
pub const TRANSIENT_HEADER_KEYS: [&'static str; 6] = [
"X-OBJECTIVEAI-RESPONSE-ID",
"X-OBJECTIVEAI-RESPONSE-IDS",
"X-OBJECTIVEAI-AGENT-INSTANCE-HIERARCHY",
"X-OBJECTIVEAI-AGENT-ID",
"X-OBJECTIVEAI-AGENT-FULL-ID",
"X-OBJECTIVEAI-AGENT-REMOTE",
];
pub async fn apply_transient_headers(&self, src: &HeaderMap) {
let mut bag = IndexMap::new();
for key in Self::TRANSIENT_HEADER_KEYS {
if let Some(v) = src.get(key).and_then(|v| v.to_str().ok()) {
bag.insert(key.to_string(), v.to_string());
}
}
let snapshot = bag.clone();
*self.transient_headers.write().await = bag;
for connection in self.connections.values() {
connection.set_extra_headers(snapshot.clone()).await;
}
}
pub fn register_in_flight(&self, id: &serde_json::Value) -> CancellationToken {
let token = CancellationToken::new();
self.in_flight.insert(request_id_key(id), token.clone());
token
}
pub fn deregister_in_flight(&self, id: &serde_json::Value) {
self.in_flight.remove(&request_id_key(id));
}
pub fn cancel_in_flight(&self, id: &serde_json::Value) -> bool {
match self.in_flight.get(&request_id_key(id)) {
Some(entry) => {
entry.value().cancel();
true
}
None => false,
}
}
pub async fn list_tools(&self) -> Result<ListToolsResult, Arc<objectiveai_sdk::mcp::Error>> {
self.list_tools_filtered(None).await
}
pub async fn list_tools_filtered(
&self,
filter_url: Option<&str>,
) -> Result<ListToolsResult, Arc<objectiveai_sdk::mcp::Error>> {
let pairs: Vec<(&String, &Connection)> = match filter_url {
Some(url) => self
.connections
.iter()
.filter(|(_, c)| c.url == url)
.collect(),
None => self.connections.iter().collect(),
};
let results = try_join_all(
pairs
.iter()
.map(|(_, c)| async move {
let r = c.list_tools().await;
r
}),
)
.await?;
let mut tools: Vec<Tool> = Vec::new();
for ((server_name, _), arc) in pairs.into_iter().zip(results) {
for tool in arc.iter() {
let mut prefixed = tool.clone();
prefixed.name = prefix_name(server_name, &tool.name);
tools.push(prefixed);
}
}
tools.sort_by(|a, b| a.name.cmp(&b.name));
Ok(ListToolsResult {
tools,
next_cursor: None,
_meta: None,
})
}
pub async fn list_resources(&self) -> Result<ListResourcesResult, Arc<objectiveai_sdk::mcp::Error>> {
self.list_resources_filtered(None).await
}
pub async fn list_resources_filtered(
&self,
filter_url: Option<&str>,
) -> Result<ListResourcesResult, Arc<objectiveai_sdk::mcp::Error>> {
let pairs: Vec<(&String, &Connection)> = match filter_url {
Some(url) => self
.connections
.iter()
.filter(|(_, c)| c.url == url)
.collect(),
None => self.connections.iter().collect(),
};
let results = try_join_all(
pairs
.iter()
.map(|(_, c)| async move {
let r = c.list_resources().await;
r
}),
)
.await?;
let mut resources: Vec<Resource> = Vec::new();
for ((server_name, _), arc) in pairs.into_iter().zip(results) {
for resource in arc.iter() {
let mut prefixed = resource.clone();
prefixed.uri = prefix_name(server_name, &resource.uri);
resources.push(prefixed);
}
}
resources.sort_by(|a, b| a.uri.cmp(&b.uri));
Ok(ListResourcesResult {
resources,
next_cursor: None,
_meta: None,
})
}
pub async fn call_tool(
&self,
params: &CallToolRequestParams,
) -> Result<CallToolResult, CallToolError> {
let (connection, original_name) = self
.route(¶ms.name)
.ok_or_else(|| CallToolError::ToolNotFound(params.name.clone()))?;
let upstream_params = CallToolRequestParams {
name: original_name,
arguments: params.arguments.clone(),
task: params.task.clone(),
_meta: params._meta.clone(),
};
let r = connection.call_tool(&upstream_params).await;
Ok(r?)
}
pub async fn read_resource(
&self,
uri: &str,
) -> Result<ReadResourceResult, ReadResourceError> {
let (connection, original_uri) = self
.route(uri)
.ok_or_else(|| ReadResourceError::ResourceNotFound(uri.to_string()))?;
let r = connection.read_resource(&original_uri).await;
Ok(r?)
}
fn route<'a>(&'a self, prefixed: &str) -> Option<(&'a Connection, String)> {
let mut best: Option<(&'a str, &'a Connection)> = None;
for (name, conn) in &self.connections {
if prefixed.len() > name.len() + 1
&& prefixed.as_bytes()[name.len()] == b'_'
&& prefixed.starts_with(name.as_str())
{
if best.map(|(b, _)| name.len() > b.len()).unwrap_or(true) {
best = Some((name.as_str(), conn));
}
}
}
best.map(|(name, conn)| {
let original = prefixed[name.len() + 1..].to_string();
(conn, original)
})
}
}
fn prefix_name(server_name: &str, name: &str) -> String {
format!("{server_name}_{name}")
}
#[derive(Debug, thiserror::Error)]
pub enum CallToolError {
#[error("tool not found on any upstream: {0}")]
ToolNotFound(String),
#[error("upstream call_tool failed: {0}")]
Upstream(#[from] objectiveai_sdk::mcp::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ReadResourceError {
#[error("resource not found on any upstream: {0}")]
ResourceNotFound(String),
#[error("upstream read_resource failed: {0}")]
Upstream(#[from] objectiveai_sdk::mcp::Error),
}