use crate::{
error::SdkResult,
mcp_server::{
error::{TransportServerError, TransportServerResult},
ServerRuntime,
},
session_store::SessionStore,
task_store::{ClientTaskStore, ServerTaskStore, TaskStatusPoller},
};
use crate::{
mcp_http::McpAppState,
mcp_server::HyperServer,
schema::{
schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient},
CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams, ListRootsResult,
LoggingMessageNotificationParams, NotificationParams, RequestParams,
ResourceUpdatedNotificationParams,
},
McpServer,
};
use axum_server::Handle;
use futures::StreamExt;
use rust_mcp_schema::{
schema_utils::{ClientTaskResult, CustomNotification, CustomRequest},
CancelTaskParams, CancelTaskResult, CancelledNotificationParams, CreateTaskResult,
ElicitCompleteParams, ElicitRequestParams, ElicitResult, GenericResult, GetTaskParams,
GetTaskPayloadParams, GetTaskResult, ProgressNotificationParams, RpcError,
TaskStatusNotificationParams,
};
use rust_mcp_transport::SessionId;
use std::{sync::Arc, time::Duration};
use tokio::task::JoinHandle;
pub struct HyperRuntime {
pub(crate) state: Arc<McpAppState>,
pub(crate) server_task: JoinHandle<Result<(), TransportServerError>>,
pub(crate) server_handle: Handle,
}
impl HyperRuntime {
fn task_poller_callback(
client_task_store: Arc<ClientTaskStore>,
session_store: Arc<dyn SessionStore>,
) -> TaskStatusPoller {
let session_store = session_store.clone();
let task_store_clone = client_task_store.clone();
let callback: TaskStatusPoller = Box::new(move |task_id, session_id| {
let session_store_clone = session_store.clone();
let task_store_clone = task_store_clone.clone();
Box::pin(async move {
let Some(session) = session_id.as_ref() else {
return Err(RpcError::invalid_request()
.with_message("No session id provided!".to_string())
.into());
};
let Some(runtime) = session_store_clone.get(session).await else {
return Err(RpcError::invalid_request()
.with_message("Invalid or broken session!".to_string())
.into());
};
runtime
.poll_task_status(task_id, session_id, task_store_clone)
.await
})
});
callback
}
pub async fn create(server: HyperServer) -> SdkResult<Self> {
let addr = server.options.resolve_server_address().await?;
let state = server.state();
let server_handle = server.server_handle();
let server_task = tokio::spawn(async move {
#[cfg(feature = "ssl")]
if server.options.enable_ssl {
server.start_ssl(addr).await
} else {
server.start_http(addr).await
}
#[cfg(not(feature = "ssl"))]
if server.options.enable_ssl {
panic!("SSL requested but the 'ssl' feature is not enabled");
} else {
server.start_http(addr).await
}
});
let state_clone = state.clone();
if let Some(task_store) = state_clone.task_store.clone() {
if let Some(mut stream) = task_store.subscribe() {
tokio::spawn(async move {
while let Some((params, session_id_opt)) = stream.next().await {
if let Some(session_id) = session_id_opt.as_ref() {
if let Some(transport) = state_clone.session_store.get(session_id).await
{
let _ = transport.notify_task_status(params).await;
}
}
}
});
}
}
if let Some(client_task_store) = state.client_task_store.clone() {
let session_store = state.session_store.clone();
let callback: TaskStatusPoller =
Self::task_poller_callback(Arc::clone(&client_task_store), session_store);
client_task_store.start_task_polling(callback)?;
}
Ok(Self {
state,
server_task,
server_handle,
})
}
pub fn graceful_shutdown(&self, timeout: Option<Duration>) {
self.server_handle.graceful_shutdown(timeout);
}
pub async fn await_server(self) -> SdkResult<()> {
let result = self.server_task.await?;
result.map_err(|err| err.into())
}
pub async fn sessions(&self) -> Vec<String> {
self.state.session_store.keys().await
}
pub async fn runtime_by_session(
&self,
session_id: &SessionId,
) -> TransportServerResult<Arc<ServerRuntime>> {
self.state.session_store.get(session_id).await.ok_or(
TransportServerError::SessionIdInvalid(session_id.to_string()),
)
}
pub async fn send_request(
&self,
session_id: &SessionId,
request: RequestFromServer,
timeout: Option<Duration>,
) -> SdkResult<ResultFromClient> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request(request, timeout).await
}
pub async fn send_notification(
&self,
session_id: &SessionId,
notification: NotificationFromServer,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.send_notification(notification).await
}
pub async fn client_info(
&self,
session_id: &SessionId,
) -> SdkResult<Option<InitializeRequestParams>> {
let runtime = self.runtime_by_session(session_id).await?;
Ok(runtime.client_info())
}
pub async fn request_elicitation(
&self,
session_id: &SessionId,
params: ElicitRequestParams,
) -> SdkResult<ElicitResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_elicitation(params).await
}
pub async fn request_elicitation_task(
&self,
session_id: &SessionId,
params: ElicitRequestParams,
) -> SdkResult<CreateTaskResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_elicitation_task(params).await
}
pub async fn request_root_list(
&self,
session_id: &SessionId,
params: Option<RequestParams>,
) -> SdkResult<ListRootsResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_root_list(params).await
}
pub async fn ping(
&self,
session_id: &SessionId,
params: Option<RequestParams>,
timeout: Option<Duration>,
) -> SdkResult<crate::schema::Result> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.ping(params, timeout).await
}
pub async fn request_message_creation(
&self,
session_id: &SessionId,
params: CreateMessageRequestParams,
) -> SdkResult<CreateMessageResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_message_creation(params).await
}
pub async fn request_get_task(
&self,
session_id: &SessionId,
params: GetTaskParams,
) -> SdkResult<GetTaskResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_get_task(params).await
}
pub async fn request_get_task_payload(
&self,
session_id: &SessionId,
params: GetTaskPayloadParams,
) -> SdkResult<ClientTaskResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_get_task_payload(params).await
}
pub async fn request_task_cancellation(
&self,
session_id: &SessionId,
params: CancelTaskParams,
) -> SdkResult<CancelTaskResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_task_cancellation(params).await
}
pub async fn request_custom(
&self,
session_id: &SessionId,
params: CustomRequest,
) -> SdkResult<GenericResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_custom(params).await
}
pub async fn notify_log_message(
&self,
session_id: &SessionId,
params: LoggingMessageNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_log_message(params).await
}
pub async fn notify_prompt_list_changed(
&self,
session_id: &SessionId,
params: Option<NotificationParams>,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_prompt_list_changed(params).await
}
pub async fn notify_resource_list_changed(
&self,
session_id: &SessionId,
params: Option<NotificationParams>,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_resource_list_changed(params).await
}
pub async fn notify_resource_updated(
&self,
session_id: &SessionId,
params: ResourceUpdatedNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_resource_updated(params).await
}
pub async fn notify_tool_list_changed(
&self,
session_id: &SessionId,
params: Option<NotificationParams>,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_tool_list_changed(params).await
}
pub async fn notify_cancellation(
&self,
session_id: &SessionId,
params: CancelledNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_cancellation(params).await
}
pub async fn notify_progress(
&self,
session_id: &SessionId,
params: ProgressNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_progress(params).await
}
pub async fn notify_task_status(
&self,
session_id: &SessionId,
params: TaskStatusNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_task_status(params).await
}
pub async fn notify_elicitation_completed(
&self,
session_id: &SessionId,
params: ElicitCompleteParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_elicitation_completed(params).await
}
pub async fn notify_custom(
&self,
session_id: &SessionId,
params: CustomNotification,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_custom(params).await
}
#[deprecated(since = "0.8.0", note = "Use `request_root_list()` instead.")]
pub async fn list_roots(
&self,
session_id: &SessionId,
params: Option<RequestParams>,
) -> SdkResult<ListRootsResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_root_list(params).await
}
#[deprecated(since = "0.8.0", note = "Use `request_elicitation()` instead.")]
pub async fn elicit_input(
&self,
session_id: &SessionId,
params: ElicitRequestParams,
) -> SdkResult<ElicitResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_elicitation(params).await
}
#[deprecated(since = "0.8.0", note = "Use `request_message_creation()` instead.")]
pub async fn create_message(
&self,
session_id: &SessionId,
params: CreateMessageRequestParams,
) -> SdkResult<CreateMessageResult> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.request_message_creation(params).await
}
#[deprecated(since = "0.8.0", note = "Use `notify_tool_list_changed()` instead.")]
pub async fn send_tool_list_changed(
&self,
session_id: &SessionId,
params: Option<NotificationParams>,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_tool_list_changed(params).await
}
#[deprecated(since = "0.8.0", note = "Use `notify_resource_updated()` instead.")]
pub async fn send_resource_updated(
&self,
session_id: &SessionId,
params: ResourceUpdatedNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_resource_updated(params).await
}
#[deprecated(
since = "0.8.0",
note = "Use `notify_resource_list_changed()` instead."
)]
pub async fn send_resource_list_changed(
&self,
session_id: &SessionId,
params: Option<NotificationParams>,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_resource_list_changed(params).await
}
#[deprecated(since = "0.8.0", note = "Use `notify_prompt_list_changed()` instead.")]
pub async fn send_prompt_list_changed(
&self,
session_id: &SessionId,
params: Option<NotificationParams>,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_prompt_list_changed(params).await
}
#[deprecated(since = "0.8.0", note = "Use `notify_log_message()` instead.")]
pub async fn send_logging_message(
&self,
session_id: &SessionId,
params: LoggingMessageNotificationParams,
) -> SdkResult<()> {
let runtime = self.runtime_by_session(session_id).await?;
runtime.notify_log_message(params).await
}
pub fn task_store(&self) -> Option<Arc<ServerTaskStore>> {
self.state.task_store.clone()
}
pub fn client_task_store(&self) -> Option<Arc<ClientTaskStore>> {
self.state.client_task_store.clone()
}
}