use crate::error::ExtensionError;
use crate::session::{EventStore, StoredEvent};
use crate::state::{HasServerInfo, McpState, OAuthState};
use crate::{SUPPORTED_VERSIONS, is_supported_version};
use axum::Json;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::stream::Stream;
use mcpkit_core::capability::ClientCapabilities;
use mcpkit_core::protocol::Message;
use mcpkit_core::protocol_version::ProtocolVersion;
use mcpkit_server::context::{Context, NoOpPeer};
use mcpkit_server::{
PromptHandler, ResourceHandler, ServerHandler, ToolHandler, route_prompts, route_resources,
route_tools,
};
use std::convert::Infallible;
use std::sync::Arc;
use tracing::{debug, info, warn};
pub async fn handle_mcp_post<H>(
State(state): State<McpState<H>>,
headers: HeaderMap,
body: String,
) -> impl IntoResponse
where
H: ServerHandler + ToolHandler + ResourceHandler + PromptHandler + Send + Sync + 'static,
{
let version = headers
.get("mcp-protocol-version")
.and_then(|v| v.to_str().ok());
if !is_supported_version(version) {
let provided = version.unwrap_or("none");
warn!(version = provided, "Unsupported protocol version");
return ExtensionError::UnsupportedVersion(format!(
"{} (supported: {})",
provided,
SUPPORTED_VERSIONS.join(", ")
))
.into_response();
}
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let session_id = match session_id {
Some(id) => {
state.sessions.touch(&id);
id
}
None => state.sessions.create(),
};
debug!(session_id = %session_id, "Processing MCP request");
let msg: Message = match serde_json::from_str(&body) {
Ok(m) => m,
Err(e) => {
warn!(error = %e, "Failed to parse JSON-RPC message");
return ExtensionError::InvalidMessage(e.to_string()).into_response();
}
};
match msg {
Message::Request(request) => {
info!(
method = %request.method,
id = ?request.id,
session_id = %session_id,
"Handling MCP request"
);
let response = create_response_for_request(&state, &request).await;
match serde_json::to_string(&Message::Response(response)) {
Ok(body) => (
StatusCode::OK,
[
("content-type", "application/json"),
("mcp-session-id", session_id.as_str()),
],
body,
)
.into_response(),
Err(e) => ExtensionError::Serialization(e).into_response(),
}
}
Message::Notification(notification) => {
debug!(
method = %notification.method,
session_id = %session_id,
"Received notification"
);
(
StatusCode::ACCEPTED,
[("mcp-session-id", session_id.as_str())],
)
.into_response()
}
_ => {
warn!("Unexpected message type received");
ExtensionError::InvalidMessage("Expected request or notification".to_string())
.into_response()
}
}
}
async fn create_response_for_request<H>(
state: &McpState<H>,
request: &mcpkit_core::protocol::Request,
) -> mcpkit_core::protocol::Response
where
H: ServerHandler + ToolHandler + ResourceHandler + PromptHandler + Send + Sync + 'static,
{
use mcpkit_core::error::JsonRpcError;
use mcpkit_core::protocol::Response;
let method = request.method.as_ref();
let params = request.params.as_ref();
let req_id = request.id.clone();
let client_caps = ClientCapabilities::default();
let server_caps = state.handler.capabilities();
let protocol_version = ProtocolVersion::LATEST;
let peer = NoOpPeer;
let ctx = Context::new(
&req_id,
None,
&client_caps,
&server_caps,
protocol_version,
&peer,
);
match method {
"ping" => Response::success(request.id.clone(), serde_json::json!({})),
"initialize" => {
let init_result = serde_json::json!({
"protocolVersion": ProtocolVersion::LATEST.as_str(),
"serverInfo": state.server_info,
"capabilities": state.handler.capabilities(),
});
Response::success(request.id.clone(), init_result)
}
_ => {
if let Some(result) = route_tools(state.handler.as_ref(), method, params, &ctx).await {
return match result {
Ok(value) => Response::success(request.id.clone(), value),
Err(e) => Response::error(request.id.clone(), e.into()),
};
}
if let Some(result) =
route_resources(state.handler.as_ref(), method, params, &ctx).await
{
return match result {
Ok(value) => Response::success(request.id.clone(), value),
Err(e) => Response::error(request.id.clone(), e.into()),
};
}
if let Some(result) = route_prompts(state.handler.as_ref(), method, params, &ctx).await
{
return match result {
Ok(value) => Response::success(request.id.clone(), value),
Err(e) => Response::error(request.id.clone(), e.into()),
};
}
Response::error(
request.id.clone(),
JsonRpcError::method_not_found(format!("Method '{method}' not found")),
)
}
}
}
pub async fn handle_sse<H>(
State(state): State<McpState<H>>,
headers: HeaderMap,
) -> impl IntoResponse
where
H: HasServerInfo + Send + Sync + 'static,
{
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let last_event_id = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let (id, rx, replay_events) = if let Some(id) = session_id {
if let Some(rx) = state.sse_sessions.get_receiver(&id) {
let replay = if let Some(last_id) = &last_event_id {
info!(session_id = %id, last_event_id = %last_id, "Reconnecting with Last-Event-ID");
state
.sse_sessions
.get_events_for_replay(&id, last_id)
.await
.unwrap_or_default()
} else {
Vec::new()
};
info!(
session_id = %id,
replay_count = replay.len(),
"Reconnected to SSE session"
);
(id, rx, replay)
} else {
let (new_id, rx) = state.sse_sessions.create_session();
info!(session_id = %new_id, "Created new SSE session (requested not found)");
(new_id, rx, Vec::new())
}
} else {
let (id, rx) = state.sse_sessions.create_session();
info!(session_id = %id, "Created new SSE session");
(id, rx, Vec::new())
};
let event_store = state.sse_sessions.get_event_store(&id);
let stream = create_sse_stream_with_replay(id, rx, replay_events, event_store);
Sse::new(stream).keep_alive(KeepAlive::default())
}
fn create_sse_stream_with_replay(
session_id: String,
mut rx: tokio::sync::broadcast::Receiver<String>,
replay_events: Vec<StoredEvent>,
event_store: Option<Arc<EventStore>>,
) -> impl Stream<Item = Result<Event, Infallible>> {
async_stream::stream! {
for stored in replay_events {
debug!(event_id = %stored.id, "Replaying missed event");
yield Ok(Event::default()
.id(&stored.id)
.event(&stored.event_type)
.data(&stored.data));
}
let connected_event_id = event_store
.as_ref()
.map_or_else(|| "evt-connected".to_string(), |store| store.next_event_id());
yield Ok(Event::default()
.id(&connected_event_id)
.event("connected")
.data(&session_id));
loop {
match rx.recv().await {
Ok(msg) => {
let event_id = event_store
.as_ref()
.map_or_else(|| format!("evt-{}", uuid::Uuid::new_v4()), |store| store.next_event_id());
yield Ok(Event::default()
.id(&event_id)
.event("message")
.data(msg));
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
warn!(skipped = n, "SSE client lagged, skipped messages");
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
debug!("SSE channel closed");
break;
}
}
}
}
}
pub async fn handle_oauth_protected_resource(State(state): State<OAuthState>) -> impl IntoResponse {
debug!("Serving OAuth protected resource metadata");
(
StatusCode::OK,
[("content-type", "application/json")],
Json(state.metadata),
)
}