use std::borrow::Cow;
use std::sync::Arc;
use dashmap::DashMap;
use futures::FutureExt;
use rmcp::serve_server;
use rmcp::transport::TransportAdapterIdentity;
use rmcp::transport::WorkerTransport;
use rmcp::transport::streamable_http_server::session::SessionId;
use rmcp::transport::streamable_http_server::session::local::{
LocalSessionManager, create_local_session,
};
use rmcp::{
Peer, RoleServer, ServerHandler,
handler::server::router::tool::{ToolRoute, ToolRouter},
handler::server::tool::ToolCallContext,
model::{
CallToolRequestParams, CallToolResult, ClientCapabilities, ClientJsonRpcMessage,
ClientNotification, ClientRequest, Content, Implementation, InitializeRequest,
InitializeRequestParams, InitializeResult, InitializedNotification, NumberOrString,
ProtocolVersion, RequestId, ServerCapabilities, ServerInfo, ServerNotification, Tool,
ToolListChangedNotification,
},
service::RequestContext,
transport::streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService},
};
use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::OnceCell;
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::{CancellationToken, DropGuard};
use objectiveai_sdk::functions::inventions::InventionTool;
#[derive(Clone)]
struct SessionState {
tool_router: Arc<RwLock<ToolRouter<InventionMcp>>>,
peers: Arc<Mutex<Vec<Peer<RoleServer>>>>,
}
impl SessionState {
fn new(tools: Vec<InventionTool>) -> Self {
Self {
tool_router: Arc::new(RwLock::new(build_router(tools))),
peers: Arc::new(Mutex::new(Vec::new())),
}
}
}
pub struct InventionServerSpawner {
cell: OnceCell<Arc<InventionServerHandle>>,
handle: Option<tokio::runtime::Handle>,
}
impl Default for InventionServerSpawner {
fn default() -> Self {
Self::new()
}
}
impl InventionServerSpawner {
pub fn new() -> Self {
Self {
cell: OnceCell::new(),
handle: None,
}
}
pub fn new_with_handle(handle: tokio::runtime::Handle) -> Self {
Self {
cell: OnceCell::new(),
handle: Some(handle),
}
}
pub async fn get(&self) -> std::io::Result<Arc<InventionServerHandle>> {
self.cell
.get_or_try_init(|| async {
InventionServerHandle::spawn(self.handle.clone()).await
})
.await
.map(Arc::clone)
}
}
pub struct InventionServerHandle {
url: String,
sessions: Arc<DashMap<SessionId, SessionState>>,
rmcp_session_manager: Arc<LocalSessionManager>,
runtime_handle: Option<tokio::runtime::Handle>,
_shutdown: DropGuard,
_server_handle: tokio::task::AbortHandle,
}
impl InventionServerHandle {
async fn spawn(
runtime_handle: Option<tokio::runtime::Handle>,
) -> std::io::Result<Arc<Self>> {
let ct = CancellationToken::new();
let sessions: Arc<DashMap<SessionId, SessionState>> = Arc::new(DashMap::new());
let rmcp_session_manager = Arc::new(LocalSessionManager::default());
let (port_rx, server_handle) = build_and_spawn_server(
Arc::clone(&sessions),
Arc::clone(&rmcp_session_manager),
ct.clone(),
runtime_handle.clone(),
);
let port = port_rx
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
Ok(Arc::new(Self {
url: format!("http://127.0.0.1:{}/mcp", port),
sessions,
rmcp_session_manager,
runtime_handle,
_shutdown: ct.drop_guard(),
_server_handle: server_handle,
}))
}
pub async fn register(
self: &Arc<Self>,
initial_tools: Vec<InventionTool>,
) -> InventionSession {
let id: SessionId = rmcp::transport::common::server_side_http::session_id();
self.sessions
.insert(id.clone(), SessionState::new(initial_tools));
let (handle, worker) = create_local_session(
id.clone(),
self.rmcp_session_manager.session_config.clone(),
);
self.rmcp_session_manager
.sessions
.write()
.await
.insert(id.clone(), handle.clone());
let mcp = InventionMcp {
sessions: Arc::clone(&self.sessions),
};
let transport = WorkerTransport::spawn(worker);
let task = async move {
if let Ok(service) =
serve_server::<_, _, _, TransportAdapterIdentity>(mcp, transport).await
{
let _ = service.waiting().await;
}
};
match &self.runtime_handle {
Some(h) => {
h.spawn(task);
}
None => {
tokio::spawn(task);
}
}
let init_req = ClientJsonRpcMessage::request(
ClientRequest::InitializeRequest(InitializeRequest {
method: Default::default(),
params: InitializeRequestParams {
meta: None,
protocol_version: ProtocolVersion::V_2025_06_18,
capabilities: ClientCapabilities::default(),
client_info: Implementation {
name: "objectiveai-invention-preseed".into(),
title: None,
version: env!("CARGO_PKG_VERSION").into(),
description: None,
icons: None,
website_url: None,
},
},
extensions: Default::default(),
}),
RequestId::Number(0),
);
let _ = handle.initialize(init_req).await;
let initialized = ClientJsonRpcMessage::notification(
ClientNotification::InitializedNotification(InitializedNotification {
method: Default::default(),
extensions: Default::default(),
}),
);
let _ = handle.push_message(initialized, None).await;
InventionSession {
id,
handle: Arc::clone(self),
}
}
}
pub struct InventionSession {
id: SessionId,
handle: Arc<InventionServerHandle>,
}
impl InventionSession {
pub fn url(&self) -> String {
self.handle.url.clone()
}
pub fn id(&self) -> &str {
&self.id
}
pub async fn set_tools(&self, tools: Vec<InventionTool>) {
let state = match self.handle.sessions.get(&self.id) {
Some(e) => e.value().clone(),
None => return, };
*state.tool_router.write().await = build_router(tools);
let peers: Vec<Peer<RoleServer>> = state.peers.lock().await.clone();
let results = futures::future::join_all(peers.iter().map(|peer| {
peer.send_notification(ServerNotification::ToolListChangedNotification(
ToolListChangedNotification::default(),
))
}))
.await;
let alive: Vec<Peer<RoleServer>> = peers
.into_iter()
.zip(results)
.filter_map(|(peer, result)| result.ok().map(|()| peer))
.collect();
*state.peers.lock().await = alive;
}
}
impl Drop for InventionSession {
fn drop(&mut self) {
self.handle.sessions.remove(&self.id);
let mgr = Arc::clone(&self.handle.rmcp_session_manager);
let id = self.id.clone();
let task = async move {
use rmcp::transport::streamable_http_server::session::SessionManager;
let _ = mgr.close_session(&id).await;
};
match &self.handle.runtime_handle {
Some(h) => {
h.spawn(task);
}
None => {
tokio::spawn(task);
}
}
}
}
#[derive(Clone)]
struct InventionMcp {
sessions: Arc<DashMap<SessionId, SessionState>>,
}
impl InventionMcp {
fn session_state_for(&self, context: &RequestContext<RoleServer>) -> Option<SessionState> {
let parts = context.extensions.get::<axum::http::request::Parts>()?;
let id_str = parts.headers.get("mcp-session-id")?.to_str().ok()?;
let id: SessionId = id_str.into();
self.sessions.get(&id).map(|e| e.value().clone())
}
}
#[inline(never)]
fn build_router(tools: Vec<InventionTool>) -> ToolRouter<InventionMcp> {
let mut tool_router = ToolRouter::<InventionMcp>::new();
for t in tools {
let input_schema: serde_json::Map<String, Value> = t.parameters.into_iter().collect();
let tool_def = Tool {
name: Cow::Owned(t.name.to_string()),
title: None,
description: Some(Cow::Owned(t.description.to_string())),
input_schema: Arc::new(input_schema),
output_schema: None,
annotations: None,
execution: None,
icons: None,
meta: None,
};
let call_fn = t.call.clone();
tool_router.add_route(ToolRoute::new_dyn(
tool_def,
move |ctx: ToolCallContext<'_, InventionMcp>| {
let call_fn = call_fn.clone();
let arguments = ctx
.arguments
.clone()
.map(Value::Object)
.unwrap_or(Value::Object(Default::default()));
async move {
let result = call_fn(arguments).await;
match result {
Ok(text) => Ok(CallToolResult::success(vec![Content::text(text)])),
Err(text) => Ok(CallToolResult::error(vec![Content::text(text)])),
}
}
.boxed()
},
));
}
tool_router
}
impl ServerHandler for InventionMcp {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2025_06_18,
capabilities: ServerCapabilities::builder()
.enable_tools()
.enable_tool_list_changed()
.build(),
server_info: Implementation {
name: "oaifi".into(),
title: None,
version: env!("CARGO_PKG_VERSION").into(),
description: None,
icons: None,
website_url: None,
},
instructions: None,
}
}
fn initialize(
&self,
request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<InitializeResult, rmcp::ErrorData>> + Send + '_
{
async move {
if context.peer.peer_info().is_none() {
context.peer.set_peer_info(request);
}
if let Some(state) = self.session_state_for(&context) {
state.peers.lock().await.push(context.peer.clone());
}
Ok(self.get_info())
}
}
async fn list_tools(
&self,
_request: Option<rmcp::model::PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> Result<rmcp::model::ListToolsResult, rmcp::ErrorData> {
let tools = match self.session_state_for(&context) {
Some(state) => state.tool_router.read().await.list_all(),
None => Vec::new(),
};
Ok(rmcp::model::ListToolsResult {
tools,
meta: None,
next_cursor: None,
})
}
async fn call_tool(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> Result<CallToolResult, rmcp::ErrorData> {
let state = match self.session_state_for(&context) {
Some(s) => s,
None => {
return Err(rmcp::ErrorData::invalid_params(
"no Mcp-Session-Id matches a registered invention session",
None,
));
}
};
let router = state.tool_router.read().await.clone();
let tcc = ToolCallContext::new(self, request, context);
router.call(tcc).await
}
fn get_tool(&self, name: &str) -> Option<Tool> {
let _ = name;
None
}
}
#[inline(never)]
fn build_and_spawn_server(
sessions: Arc<DashMap<SessionId, SessionState>>,
rmcp_session_manager: Arc<LocalSessionManager>,
ct: CancellationToken,
runtime_handle: Option<tokio::runtime::Handle>,
) -> (tokio::sync::oneshot::Receiver<u16>, tokio::task::AbortHandle) {
let (port_tx, port_rx) = tokio::sync::oneshot::channel();
let ct_child = ct.child_token();
let task = async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let _ = port_tx.send(port);
let factory_sessions = Arc::clone(&sessions);
let service: StreamableHttpService<InventionMcp, LocalSessionManager> =
StreamableHttpService::new(
move || {
Ok(InventionMcp {
sessions: Arc::clone(&factory_sessions),
})
},
Arc::clone(&rmcp_session_manager),
StreamableHttpServerConfig {
stateful_mode: true,
cancellation_token: ct_child,
..Default::default()
},
);
let router = axum::Router::new().fallback_service(service);
axum::serve(listener, router).await.ok();
};
let handle = match runtime_handle {
Some(h) => h.spawn(task).abort_handle(),
None => tokio::spawn(task).abort_handle(),
};
(port_rx, handle)
}
#[cfg(test)]
#[path = "invention_server_tests.rs"]
mod tests;