pub mod mcp_server_runtime;
pub mod mcp_server_runtime_core;
use crate::auth::AuthInfo;
use crate::error::SdkResult;
use crate::mcp_traits::{
McpObserver, McpServer, McpServerHandler, RequestIdGen, RequestIdGenNumeric,
};
use crate::schema::{
schema_utils::{
ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
ServerMessages,
},
InitializeRequestParams, InitializeResult, RequestId, RpcError,
};
use crate::task_store::{ClientTaskStore, ServerTaskStore, TaskStatusPoller, TaskStatusUpdate};
use crate::utils::AbortTaskOnDrop;
use async_trait::async_trait;
use futures::future::try_join_all;
use futures::{StreamExt, TryFutureExt};
use rust_mcp_schema::{GetTaskParams, GetTaskPayloadParams};
#[cfg(feature = "hyper-server")]
use rust_mcp_transport::SessionId;
use rust_mcp_transport::{IoStream, TaskId, TransportDispatcher};
use std::panic;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockReadGuard};
pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
const TASK_CHANNEL_CAPACITY: usize = 500;
type TransportType = Arc<
dyn TransportDispatcher<
ClientMessages,
MessageFromServer,
ClientMessage,
ServerMessages,
ServerMessage,
>,
>;
pub struct ServerRuntime {
handler: Arc<dyn McpServerHandler>,
server_details: Arc<InitializeResult>,
#[cfg(feature = "hyper-server")]
session_id: Option<SessionId>,
transport_map: tokio::sync::RwLock<Option<TransportType>>,
request_id_gen: Box<dyn RequestIdGen>,
client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
auth_info: tokio::sync::RwLock<Option<AuthInfo>>,
task_store: Option<Arc<ServerTaskStore>>,
client_task_store: Option<Arc<ClientTaskStore>>,
message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
}
pub struct McpServerOptions<T>
where
T: TransportDispatcher<
ClientMessages,
MessageFromServer,
ClientMessage,
ServerMessages,
ServerMessage,
>,
{
pub server_details: InitializeResult,
pub transport: T,
pub handler: Arc<dyn McpServerHandler>,
pub task_store: Option<Arc<ServerTaskStore>>,
pub client_task_store: Option<Arc<ClientTaskStore>>,
pub message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
}
#[async_trait]
impl McpServer for ServerRuntime {
fn task_store(&self) -> Option<Arc<ServerTaskStore>> {
self.task_store.clone()
}
fn client_task_store(&self) -> Option<Arc<ClientTaskStore>> {
self.client_task_store.clone()
}
async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
self.client_details_tx
.send(Some(client_details))
.map_err(|_| {
RpcError::internal_error()
.with_message("Failed to set client details".to_string())
.into()
})
}
async fn update_auth_info(&self, new_auth_info: Option<AuthInfo>) {
let should_update = {
let current = self.auth_info.read().await;
match (&*current, &new_auth_info) {
(None, Some(_)) => true,
(Some(old), Some(new)) => old.token_unique_id != new.token_unique_id,
(Some(_), None) => true,
(None, None) => false,
}
};
if should_update {
*self.auth_info.write().await = new_auth_info;
}
}
async fn auth_info(&self) -> RwLockReadGuard<'_, Option<AuthInfo>> {
self.auth_info.read().await
}
async fn auth_info_cloned(&self) -> Option<AuthInfo> {
let guard = self.auth_info.read().await;
guard.clone()
}
async fn wait_for_initialization(&self) {
loop {
if self.client_details_rx.borrow().is_some() {
return;
}
let mut rx = self.client_details_rx.clone();
rx.changed().await.ok();
}
}
async fn send(
&self,
message: MessageFromServer,
request_id: Option<RequestId>,
request_timeout: Option<Duration>,
) -> SdkResult<Option<ClientMessage>> {
let transport_map = self.transport_map.read().await;
let transport = transport_map.as_ref().ok_or(
RpcError::internal_error()
.with_message("transport stream does not exists or is closed!".to_string()),
)?;
let outgoing_request_id = self
.request_id_gen
.request_id_for_message(&message, request_id);
let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
if let Some(observer) = self.message_observer.as_ref() {
observer.on_send(&mcp_message);
}
let response = transport
.send_message(ServerMessages::Single(mcp_message), request_timeout)
.await?
.map(|res| res.as_single())
.transpose()?;
Ok(response)
}
async fn send_batch(
&self,
messages: Vec<ServerMessage>,
request_timeout: Option<Duration>,
) -> SdkResult<Option<Vec<ClientMessage>>> {
let transport_map = self.transport_map.read().await;
let transport = transport_map.as_ref().ok_or(
RpcError::internal_error()
.with_message("transport stream does not exists or is closed!".to_string()),
)?;
if let Some(observer) = self.message_observer.as_ref() {
messages.iter().for_each(|msg| observer.on_send(msg));
}
transport
.send_batch(messages, request_timeout)
.map_err(|err| err.into())
.await
}
fn server_info(&self) -> &InitializeResult {
&self.server_details
}
fn client_info(&self) -> Option<InitializeRequestParams> {
self.client_details_rx.borrow().clone()
}
async fn start(self: Arc<Self>) -> SdkResult<()> {
let self_clone = self.clone();
let transport_map = self_clone.transport_map.read().await;
let transport = transport_map.as_ref().ok_or(
RpcError::internal_error()
.with_message("transport stream does not exists or is closed!".to_string()),
)?;
let mut stream = transport.start().await?;
let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
while let Some(mcp_messages) = stream.next().await {
match mcp_messages {
ClientMessages::Single(client_message) => {
let transport = transport.clone();
let self = self.clone();
let tx = tx.clone();
tokio::spawn(async move {
let result = self.handle_message(client_message, &transport).await;
let send_result: SdkResult<_> = match result {
Ok(result) => {
if let Some(result) = result {
transport
.send_message(ServerMessages::Single(result), None)
.map_err(|e| e.into())
.await
} else {
Ok(None)
}
}
Err(error) => {
tracing::error!("Error handling message : {}", error);
Ok(None)
}
};
if let Err(error) = tx.send(send_result).await {
tracing::error!("Failed to send result to channel: {}", error);
}
});
}
ClientMessages::Batch(client_messages) => {
let transport = transport.clone();
let self = self_clone.clone();
let tx = tx.clone();
tokio::spawn(async move {
let handling_tasks: Vec<_> = client_messages
.into_iter()
.map(|client_message| self.handle_message(client_message, &transport))
.collect();
let send_result = match try_join_all(handling_tasks).await {
Ok(results) => {
let results: Vec<_> = results.into_iter().flatten().collect();
if !results.is_empty() {
transport
.send_message(ServerMessages::Batch(results), None)
.map_err(|e| e.into())
.await
} else {
Ok(None)
}
}
Err(error) => Err(error),
};
if let Err(error) = tx.send(send_result).await {
tracing::error!("Failed to send batch result to channel: {}", error);
}
});
}
}
while let Ok(result) = rx.try_recv() {
result?; }
}
drop(tx);
while let Some(result) = rx.recv().await {
result?; }
return Ok(());
}
async fn stderr_message(&self, message: String) -> SdkResult<()> {
let transport_map = self.transport_map.read().await;
let transport = transport_map.as_ref().ok_or(
RpcError::internal_error()
.with_message("transport stream does not exists or is closed!".to_string()),
)?;
let mut lock = transport.error_stream().write().await;
if let Some(IoStream::Writable(stderr)) = lock.as_mut() {
stderr.write_all(message.as_bytes()).await?;
stderr.write_all(b"\n").await?;
stderr.flush().await?;
}
Ok(())
}
#[cfg(feature = "hyper-server")]
fn session_id(&self) -> Option<SessionId> {
self.session_id.to_owned()
}
}
impl ServerRuntime {
pub(crate) async fn consume_payload_string(&self, payload: &str) -> SdkResult<()> {
let transport_map = self.transport_map.read().await;
let transport = transport_map.as_ref().ok_or(
RpcError::internal_error()
.with_message("stream id does not exists or is closed!".to_string()),
)?;
transport.consume_string_payload(payload).await?;
Ok(())
}
pub(crate) async fn handle_message(
self: &Arc<Self>,
message: ClientMessage,
transport: &Arc<
dyn TransportDispatcher<
ClientMessages,
MessageFromServer,
ClientMessage,
ServerMessages,
ServerMessage,
>,
>,
) -> SdkResult<Option<ServerMessage>> {
if let Some(observer) = self.message_observer.as_ref() {
observer.on_receive(&message);
}
let response = match message {
ClientMessage::Request(client_jsonrpc_request) => {
let request_id = client_jsonrpc_request.request_id().clone();
let result = self
.handler
.handle_request(client_jsonrpc_request, self.clone())
.await;
let response: MessageFromServer = match result {
Ok(success_value) => success_value.into(),
Err(error_value) => {
if !self.is_initialized() {
return Err(error_value.into());
}
MessageFromServer::Error(error_value)
}
};
let mpc_message: ServerMessage =
ServerMessage::from_message(response, Some(request_id))?;
Some(mpc_message)
}
ClientMessage::Notification(client_jsonrpc_notification) => {
self.handler
.handle_notification(client_jsonrpc_notification, self.clone())
.await?;
None
}
ClientMessage::Error(jsonrpc_error) => {
self.handler
.handle_error(&jsonrpc_error.error, self.clone())
.await?;
if let Some(request_id) = jsonrpc_error.id.as_ref() {
if let Some(tx_response) = transport.pending_request_tx(request_id).await {
tx_response
.send(ClientMessage::Error(jsonrpc_error))
.map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
} else {
tracing::warn!(
"Received an error response with no corresponding request {:?}",
&jsonrpc_error.id
);
}
}
None
}
ClientMessage::Response(response) => {
if let Some(tx_response) = transport.pending_request_tx(&response.id).await {
tx_response
.send(ClientMessage::Response(response))
.map_err(|e| RpcError::internal_error().with_message(e.to_string()))?;
} else {
tracing::warn!(
"Received a response with no corresponding request: {:?}",
&response.id
);
}
None
}
};
Ok(response)
}
pub(crate) async fn store_transport(
&self,
stream_id: &str,
transport: Arc<
dyn TransportDispatcher<
ClientMessages,
MessageFromServer,
ClientMessage,
ServerMessages,
ServerMessage,
>,
>,
) -> SdkResult<()> {
if stream_id != DEFAULT_STREAM_ID {
return Ok(());
}
let mut transport_map = self.transport_map.write().await;
tracing::trace!("save transport for stream id : {}", stream_id);
*transport_map = Some(transport);
Ok(())
}
pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> {
if stream_id != DEFAULT_STREAM_ID {
return Ok(());
}
let transport_map = self.transport_map.read().await;
tracing::trace!("removing transport for stream id : {}", stream_id);
if let Some(transport) = transport_map.as_ref() {
transport.shut_down().await?;
}
Ok(())
}
pub(crate) async fn shutdown(&self) {
let mut transport_map = self.transport_map.write().await;
let transport_option = transport_map.take();
drop(transport_map);
if let Some(transport) = transport_option {
let _ = transport.shut_down().await;
}
}
pub(crate) async fn default_stream_exists(&self) -> bool {
let transport_map = self.transport_map.read().await;
let live_transport = if let Some(t) = transport_map.as_ref() {
!t.is_shut_down().await
} else {
false
};
live_transport
}
pub(crate) async fn start_stream(
self: Arc<Self>,
transport: Arc<
dyn TransportDispatcher<
ClientMessages,
MessageFromServer,
ClientMessage,
ServerMessages,
ServerMessage,
>,
>,
stream_id: &str,
ping_interval: Duration,
payload: Option<String>,
) -> SdkResult<()> {
let mut stream = transport.start().await?;
if stream_id == DEFAULT_STREAM_ID {
self.store_transport(stream_id, transport.clone()).await?;
}
let self_clone = self.clone();
let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
let abort_alive_task = transport
.keep_alive(ping_interval, disconnect_tx)
.await?
.abort_handle();
let _abort_guard = AbortTaskOnDrop {
handle: abort_alive_task,
};
if let Some(payload) = payload {
if let Err(err) = transport.consume_string_payload(&payload).await {
let _ = self.remove_transport(stream_id).await;
return Err(err.into());
}
}
let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
loop {
tokio::select! {
Some(mcp_messages) = stream.next() =>{
match mcp_messages {
ClientMessages::Single(client_message) => {
let transport = transport.clone();
let self_clone = self.clone();
let tx = tx.clone();
tokio::spawn(async move {
let result = self_clone.handle_message(client_message, &transport).await;
let send_result: SdkResult<_> = match result {
Ok(result) => {
if let Some(result) = result {
transport
.send_message(ServerMessages::Single(result), None)
.map_err(|e| e.into())
.await
} else {
Ok(None)
}
}
Err(error) => {
tracing::error!("Error handling message : {}", error);
Ok(None)
}
};
if let Err(error) = tx.send(send_result).await {
tracing::error!("Failed to send batch result to channel: {}", error);
}
});
}
ClientMessages::Batch(client_messages) => {
let transport = transport.clone();
let self_clone = self_clone.clone();
let tx = tx.clone();
tokio::spawn(async move {
let handling_tasks: Vec<_> = client_messages
.into_iter()
.map(|client_message| self_clone.handle_message(client_message, &transport))
.collect();
let send_result = match try_join_all(handling_tasks).await {
Ok(results) => {
let results: Vec<_> = results.into_iter().flatten().collect();
if !results.is_empty() {
transport.send_message(ServerMessages::Batch(results), None)
.map_err(|e| e.into())
.await
}else {
Ok(None)
}
},
Err(error) => Err(error),
};
if let Err(error) = tx.send(send_result).await {
tracing::error!("Failed to send batch result to channel: {}", error);
}
});
}
}
while let Ok(result) = rx.try_recv() {
result?; }
if !stream_id.eq(DEFAULT_STREAM_ID){
drop(tx);
while let Some(result) = rx.recv().await {
result?; }
return Ok(());
}
}
_ = &mut disconnect_rx => {
drop(tx);
while let Some(result) = rx.recv().await {
result?; }
self.remove_transport(stream_id).await?;
return Err(SdkError::connection_closed().into());
}
}
}
}
#[cfg(feature = "hyper-server")]
pub(crate) fn new_instance(
server_details: Arc<InitializeResult>,
handler: Arc<dyn McpServerHandler>,
session_id: SessionId,
auth_info: Option<AuthInfo>,
task_store: Option<Arc<ServerTaskStore>>,
client_task_store: Option<Arc<ClientTaskStore>>,
message_observer: Option<Arc<dyn McpObserver<ClientMessage, ServerMessage>>>,
) -> Arc<Self> {
use tokio::sync::RwLock;
let (client_details_tx, client_details_rx) =
watch::channel::<Option<InitializeRequestParams>>(None);
Arc::new(Self {
server_details,
handler,
session_id: Some(session_id),
transport_map: tokio::sync::RwLock::new(None),
client_details_tx,
client_details_rx,
request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
auth_info: RwLock::new(auth_info),
task_store,
client_task_store,
message_observer,
})
}
pub(crate) async fn poll_task_status(
self: Arc<ServerRuntime>,
task_id: TaskId,
session_id: Option<String>,
task_store: Arc<ClientTaskStore>,
) -> SdkResult<TaskStatusUpdate> {
let result = self
.request_get_task(GetTaskParams {
task_id: task_id.to_string(),
})
.await?;
if result.is_terminal() {
let task_payload = self
.request_get_task_payload(GetTaskPayloadParams {
task_id: task_id.clone(),
})
.await?;
task_store
.store_task_result(
task_id.as_str(),
result.status,
task_payload.into(),
session_id.as_ref(),
)
.await;
}
Ok((result.status, result.poll_interval))
}
pub(crate) fn new<T>(options: McpServerOptions<T>) -> Arc<Self>
where
T: TransportDispatcher<
ClientMessages,
MessageFromServer,
ClientMessage,
ServerMessages,
ServerMessage,
>,
{
let (client_details_tx, client_details_rx) =
watch::channel::<Option<InitializeRequestParams>>(None);
let runtime = Arc::new(Self {
server_details: Arc::new(options.server_details),
handler: options.handler,
#[cfg(feature = "hyper-server")]
session_id: None,
transport_map: tokio::sync::RwLock::new(Some(Arc::new(options.transport))),
client_details_tx,
client_details_rx,
request_id_gen: Box::new(RequestIdGenNumeric::new(None)),
auth_info: RwLock::new(None),
task_store: options.task_store,
client_task_store: options.client_task_store,
message_observer: options.message_observer,
});
let runtime_clone = runtime.clone();
if let Some(task_store) = runtime_clone.task_store() {
if let Some(mut stream) = task_store.subscribe() {
tokio::spawn(async move {
while let Some((params, _)) = stream.next().await {
let _ = runtime_clone.notify_task_status(params).await;
}
});
}
}
if let Some(client_task_store) = runtime.client_task_store.clone() {
let task_store_clone = client_task_store.clone();
let runtime_clone = runtime.clone();
let callback: TaskStatusPoller = Box::new(move |task_id, session_id| {
let task_store_clone = client_task_store.clone();
let runtime_clone = runtime_clone.clone();
Box::pin(async move {
runtime_clone
.poll_task_status(task_id, session_id, task_store_clone)
.await
})
});
if let Err(error) = task_store_clone.start_task_polling(callback) {
tracing::error!("Failed to start task polling: {error}");
}
}
runtime
}
}