use dashmap::DashMap;
use futures::future::try_join_all;
use indexmap::IndexMap;
use std::sync::Arc;
use objectiveai_sdk::mcp::{
Connection, JsonRpcNotification,
resource::{ListResourcesResult, ReadResourceResult, Resource},
tool::{CallToolRequestParams, CallToolResult, ContentBlock, ListToolsResult, Tool},
};
use tokio::sync::{Mutex, broadcast};
use tokio_util::sync::CancellationToken;
const OUTBOUND_CAPACITY: usize = 64;
fn request_id_key(id: &serde_json::Value) -> String {
serde_json::to_string(id).unwrap_or_default()
}
#[derive(Debug)]
pub struct Session {
pub connections: IndexMap<String, Connection>,
pub outbound: broadcast::Sender<JsonRpcNotification>,
in_flight: DashMap<String, CancellationToken>,
pending_notifications: Mutex<Vec<ContentBlock>>,
pub payload: crate::session_manager::SessionPayload,
}
impl Session {
pub(crate) fn new(
connections: IndexMap<String, Connection>,
payload: crate::session_manager::SessionPayload,
) -> Self {
let (outbound, _) = broadcast::channel(OUTBOUND_CAPACITY);
for connection in connections.values() {
let tx = outbound.clone();
connection.set_on_tools_list_changed(move || {
let _ = tx.send(JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "notifications/tools/list_changed".into(),
params: None,
});
});
let tx = outbound.clone();
connection.set_on_resources_list_changed(move || {
let _ = tx.send(JsonRpcNotification {
jsonrpc: "2.0".into(),
method: "notifications/resources/list_changed".into(),
params: None,
});
});
}
Self {
connections,
outbound,
in_flight: DashMap::new(),
pending_notifications: Mutex::new(Vec::new()),
payload,
}
}
pub async fn enqueue_notifications(&self, blocks: Vec<ContentBlock>) {
if blocks.is_empty() {
return;
}
self.pending_notifications.lock().await.extend(blocks);
}
pub async fn drain_notifications(&self) -> Vec<ContentBlock> {
std::mem::take(&mut *self.pending_notifications.lock().await)
}
pub fn register_in_flight(&self, id: &serde_json::Value) -> CancellationToken {
let token = CancellationToken::new();
self.in_flight.insert(request_id_key(id), token.clone());
token
}
pub fn deregister_in_flight(&self, id: &serde_json::Value) {
self.in_flight.remove(&request_id_key(id));
}
pub fn cancel_in_flight(&self, id: &serde_json::Value) -> bool {
match self.in_flight.get(&request_id_key(id)) {
Some(entry) => {
entry.value().cancel();
true
}
None => false,
}
}
pub async fn list_tools(&self) -> Result<ListToolsResult, Arc<objectiveai_sdk::mcp::Error>> {
let names: Vec<&String> = self.connections.keys().collect();
let results = try_join_all(
self.connections
.values()
.map(|c| async move { c.list_tools().await }),
)
.await?;
let mut tools: Vec<Tool> = Vec::new();
for (server_name, arc) in names.into_iter().zip(results) {
for tool in arc.iter() {
let mut prefixed = tool.clone();
prefixed.name = prefix_name(server_name, &tool.name);
tools.push(prefixed);
}
}
tools.sort_by(|a, b| a.name.cmp(&b.name));
Ok(ListToolsResult {
tools,
next_cursor: None,
_meta: None,
})
}
pub async fn list_resources(&self) -> Result<ListResourcesResult, Arc<objectiveai_sdk::mcp::Error>> {
let names: Vec<&String> = self.connections.keys().collect();
let results = try_join_all(
self.connections
.values()
.map(|c| async move { c.list_resources().await }),
)
.await?;
let mut resources: Vec<Resource> = Vec::new();
for (server_name, arc) in names.into_iter().zip(results) {
for resource in arc.iter() {
let mut prefixed = resource.clone();
prefixed.uri = prefix_name(server_name, &resource.uri);
resources.push(prefixed);
}
}
resources.sort_by(|a, b| a.uri.cmp(&b.uri));
Ok(ListResourcesResult {
resources,
next_cursor: None,
_meta: None,
})
}
pub async fn call_tool(
&self,
params: &CallToolRequestParams,
) -> Result<CallToolResult, CallToolError> {
let (connection, original_name) = self
.route(¶ms.name)
.ok_or_else(|| CallToolError::ToolNotFound(params.name.clone()))?;
let upstream_params = CallToolRequestParams {
name: original_name,
arguments: params.arguments.clone(),
task: params.task.clone(),
_meta: params._meta.clone(),
};
Ok(connection.call_tool(&upstream_params).await?)
}
pub async fn read_resource(
&self,
uri: &str,
) -> Result<ReadResourceResult, ReadResourceError> {
let (connection, original_uri) = self
.route(uri)
.ok_or_else(|| ReadResourceError::ResourceNotFound(uri.to_string()))?;
Ok(connection.read_resource(&original_uri).await?)
}
fn route<'a>(&'a self, prefixed: &str) -> Option<(&'a Connection, String)> {
let mut best: Option<(&'a str, &'a Connection)> = None;
for (name, conn) in &self.connections {
if prefixed.len() > name.len() + 1
&& prefixed.as_bytes()[name.len()] == b'_'
&& prefixed.starts_with(name.as_str())
{
if best.map(|(b, _)| name.len() > b.len()).unwrap_or(true) {
best = Some((name.as_str(), conn));
}
}
}
best.map(|(name, conn)| {
let original = prefixed[name.len() + 1..].to_string();
(conn, original)
})
}
}
fn prefix_name(server_name: &str, name: &str) -> String {
format!("{server_name}_{name}")
}
#[derive(Debug, thiserror::Error)]
pub enum CallToolError {
#[error("tool not found on any upstream: {0}")]
ToolNotFound(String),
#[error("upstream call_tool failed: {0}")]
Upstream(#[from] objectiveai_sdk::mcp::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ReadResourceError {
#[error("resource not found on any upstream: {0}")]
ResourceNotFound(String),
#[error("upstream read_resource failed: {0}")]
Upstream(#[from] objectiveai_sdk::mcp::Error),
}