use dashmap::DashMap;
use futures::future::try_join_all;
use indexmap::IndexMap;
use std::sync::Arc;
use objectiveai_sdk::mcp::{
Connection, JsonRpcNotification,
resource::{ListResourcesResult, ReadResourceResult, Resource},
tool::{CallToolRequestParams, CallToolResult, ContentBlock, ListToolsResult, Tool},
};
use tokio::sync::{Mutex, broadcast};
use tokio_util::sync::CancellationToken;
const OUTBOUND_CAPACITY: usize = 64;
fn request_id_key(id: &serde_json::Value) -> String {
serde_json::to_string(id).unwrap_or_default()
}
#[derive(Debug)]
pub struct Session {
pub connections: IndexMap<String, Connection>,
pub outbound: broadcast::Sender<JsonRpcNotification>,
in_flight: DashMap<String, CancellationToken>,
pending_notifications: Mutex<Vec<ContentBlock>>,
pub payload: crate::session_manager::SessionPayload,
tool_allowlists_by_server: IndexMap<String, Vec<String>>,
}
impl Session {
pub(crate) fn new(
connections: IndexMap<String, Connection>,
payload: crate::session_manager::SessionPayload,
) -> Self {
let (outbound, _) = broadcast::channel(OUTBOUND_CAPACITY);
for connection in connections.values() {
let tx = outbound.clone();
connection.set_on_tools_list_changed(move || {
let _ = tx.send(JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "notifications/tools/list_changed".into(),
params: None,
});
});
let tx = outbound.clone();
connection.set_on_resources_list_changed(move || {
let _ = tx.send(JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "notifications/resources/list_changed".into(),
params: None,
});
});
}
let mut tool_allowlists_by_server: IndexMap<String, Vec<String>> = IndexMap::new();
for (server_name, connection) in &connections {
if let Some(names) = payload.tool_allowlists.get(&connection.url) {
tool_allowlists_by_server
.insert(server_name.clone(), names.clone());
}
}
Self {
connections,
outbound,
in_flight: DashMap::new(),
pending_notifications: Mutex::new(Vec::new()),
payload,
tool_allowlists_by_server,
}
}
pub fn agent_id(&self) -> Option<&str> {
self.payload.agent_id.as_deref()
}
pub async fn enqueue_notifications(&self, blocks: Vec<ContentBlock>) {
if blocks.is_empty() {
return;
}
self.pending_notifications.lock().await.extend(blocks);
}
pub async fn drain_notifications(&self) -> Vec<ContentBlock> {
std::mem::take(&mut *self.pending_notifications.lock().await)
}
pub async fn has_pending_notifications(&self) -> bool {
!self.pending_notifications.lock().await.is_empty()
}
pub fn register_in_flight(&self, id: &serde_json::Value) -> CancellationToken {
let token = CancellationToken::new();
self.in_flight.insert(request_id_key(id), token.clone());
token
}
pub fn deregister_in_flight(&self, id: &serde_json::Value) {
self.in_flight.remove(&request_id_key(id));
}
pub fn cancel_in_flight(&self, id: &serde_json::Value) -> bool {
match self.in_flight.get(&request_id_key(id)) {
Some(entry) => {
entry.value().cancel();
true
}
None => false,
}
}
pub async fn list_tools(&self) -> Result<ListToolsResult, Arc<objectiveai_sdk::mcp::Error>> {
self.list_tools_filtered(None).await
}
pub async fn list_tools_filtered(
&self,
filter_url: Option<&str>,
) -> Result<ListToolsResult, Arc<objectiveai_sdk::mcp::Error>> {
let pairs: Vec<(&String, &Connection)> = match filter_url {
Some(url) => self
.connections
.iter()
.filter(|(_, c)| c.url == url)
.collect(),
None => self.connections.iter().collect(),
};
let results = try_join_all(
pairs
.iter()
.map(|(_, c)| async move {
let r = c.list_tools().await;
r
}),
)
.await?;
let mut tools: Vec<Tool> = Vec::new();
for ((server_name, _), arc) in pairs.into_iter().zip(results) {
let allowlist = self.tool_allowlists_by_server.get(server_name);
for tool in arc.iter() {
if let Some(names) = allowlist {
if !names.iter().any(|n| n == &tool.name) {
continue;
}
}
let mut prefixed = tool.clone();
prefixed.name = prefix_name(server_name, &tool.name);
tools.push(prefixed);
}
}
tools.sort_by(|a, b| a.name.cmp(&b.name));
Ok(ListToolsResult {
tools,
next_cursor: None,
_meta: None,
})
}
pub async fn list_resources(&self) -> Result<ListResourcesResult, Arc<objectiveai_sdk::mcp::Error>> {
self.list_resources_filtered(None).await
}
pub async fn list_resources_filtered(
&self,
filter_url: Option<&str>,
) -> Result<ListResourcesResult, Arc<objectiveai_sdk::mcp::Error>> {
let pairs: Vec<(&String, &Connection)> = match filter_url {
Some(url) => self
.connections
.iter()
.filter(|(_, c)| c.url == url)
.collect(),
None => self.connections.iter().collect(),
};
let results = try_join_all(
pairs
.iter()
.map(|(_, c)| async move {
let r = c.list_resources().await;
r
}),
)
.await?;
let mut resources: Vec<Resource> = Vec::new();
for ((server_name, _), arc) in pairs.into_iter().zip(results) {
for resource in arc.iter() {
let mut prefixed = resource.clone();
prefixed.uri = prefix_name(server_name, &resource.uri);
resources.push(prefixed);
}
}
resources.sort_by(|a, b| a.uri.cmp(&b.uri));
Ok(ListResourcesResult {
resources,
next_cursor: None,
_meta: None,
})
}
pub async fn call_tool(
&self,
params: &CallToolRequestParams,
) -> Result<CallToolResult, CallToolError> {
let (connection, original_name) = self
.route(¶ms.name)
.ok_or_else(|| CallToolError::ToolNotFound(params.name.clone()))?;
let upstream_params = CallToolRequestParams {
name: original_name,
arguments: params.arguments.clone(),
task: params.task.clone(),
_meta: params._meta.clone(),
};
let r = connection.call_tool(&upstream_params).await;
Ok(r?)
}
pub async fn read_resource(
&self,
uri: &str,
) -> Result<ReadResourceResult, ReadResourceError> {
let (connection, original_uri) = self
.route(uri)
.ok_or_else(|| ReadResourceError::ResourceNotFound(uri.to_string()))?;
let r = connection.read_resource(&original_uri).await;
Ok(r?)
}
fn route<'a>(&'a self, prefixed: &str) -> Option<(&'a Connection, String)> {
let mut best: Option<(&'a str, &'a Connection)> = None;
for (name, conn) in &self.connections {
if prefixed.len() > name.len() + 1
&& prefixed.as_bytes()[name.len()] == b'_'
&& prefixed.starts_with(name.as_str())
{
if best.map(|(b, _)| name.len() > b.len()).unwrap_or(true) {
best = Some((name.as_str(), conn));
}
}
}
best.map(|(name, conn)| {
let original = prefixed[name.len() + 1..].to_string();
(conn, original)
})
}
}
fn prefix_name(server_name: &str, name: &str) -> String {
format!("{server_name}_{name}")
}
#[derive(Debug, thiserror::Error)]
pub enum CallToolError {
#[error("tool not found on any upstream: {0}")]
ToolNotFound(String),
#[error("upstream call_tool failed: {0}")]
Upstream(#[from] objectiveai_sdk::mcp::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ReadResourceError {
#[error("resource not found on any upstream: {0}")]
ResourceNotFound(String),
#[error("upstream read_resource failed: {0}")]
Upstream(#[from] objectiveai_sdk::mcp::Error),
}