use crate::error::ExtensionError;
use crate::state::{HasServerInfo, McpState};
use crate::{SUPPORTED_VERSIONS, is_supported_version};
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::stream::Stream;
use mcpkit_core::protocol::Message;
use mcpkit_server::ServerHandler;
use std::convert::Infallible;
use tracing::{debug, info, warn};
pub async fn handle_mcp_post<H>(
State(state): State<McpState<H>>,
headers: HeaderMap,
body: String,
) -> impl IntoResponse
where
H: ServerHandler + Send + Sync + 'static,
{
let version = headers
.get("mcp-protocol-version")
.and_then(|v| v.to_str().ok());
if !is_supported_version(version) {
let provided = version.unwrap_or("none");
warn!(version = provided, "Unsupported protocol version");
return ExtensionError::UnsupportedVersion(format!(
"{} (supported: {})",
provided,
SUPPORTED_VERSIONS.join(", ")
))
.into_response();
}
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let session_id = match session_id {
Some(id) => {
state.sessions.touch(&id);
id
}
None => state.sessions.create(),
};
debug!(session_id = %session_id, "Processing MCP request");
let msg: Message = match serde_json::from_str(&body) {
Ok(m) => m,
Err(e) => {
warn!(error = %e, "Failed to parse JSON-RPC message");
return ExtensionError::InvalidMessage(e.to_string()).into_response();
}
};
match msg {
Message::Request(request) => {
info!(
method = %request.method,
id = ?request.id,
session_id = %session_id,
"Handling MCP request"
);
let response = create_response_for_request(&state, &request).await;
match serde_json::to_string(&Message::Response(response)) {
Ok(body) => (
StatusCode::OK,
[
("content-type", "application/json"),
("mcp-session-id", session_id.as_str()),
],
body,
)
.into_response(),
Err(e) => ExtensionError::Serialization(e).into_response(),
}
}
Message::Notification(notification) => {
debug!(
method = %notification.method,
session_id = %session_id,
"Received notification"
);
(
StatusCode::ACCEPTED,
[("mcp-session-id", session_id.as_str())],
)
.into_response()
}
_ => {
warn!("Unexpected message type received");
ExtensionError::InvalidMessage("Expected request or notification".to_string())
.into_response()
}
}
}
async fn create_response_for_request<H>(
state: &McpState<H>,
request: &mcpkit_core::protocol::Request,
) -> mcpkit_core::protocol::Response
where
H: ServerHandler + Send + Sync + 'static,
{
use mcpkit_core::error::JsonRpcError;
use mcpkit_core::protocol::Response;
let method = request.method.as_ref();
match method {
"ping" => Response::success(request.id.clone(), serde_json::json!({})),
"initialize" => {
let init_result = serde_json::json!({
"protocolVersion": "2025-06-18",
"serverInfo": state.server_info,
"capabilities": state.handler.capabilities(),
});
Response::success(request.id.clone(), init_result)
}
_ => {
Response::error(
request.id.clone(),
JsonRpcError::method_not_found(format!("Method '{method}' not found")),
)
}
}
}
pub async fn handle_sse<H>(
State(state): State<McpState<H>>,
headers: HeaderMap,
) -> impl IntoResponse
where
H: HasServerInfo + Send + Sync + 'static,
{
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(String::from);
let (id, rx) = if let Some(id) = session_id {
if let Some(rx) = state.sse_sessions.get_receiver(&id) {
info!(session_id = %id, "Reconnected to SSE session");
(id, rx)
} else {
let (new_id, rx) = state.sse_sessions.create_session();
info!(session_id = %new_id, "Created new SSE session (requested not found)");
(new_id, rx)
}
} else {
let (id, rx) = state.sse_sessions.create_session();
info!(session_id = %id, "Created new SSE session");
(id, rx)
};
let stream = create_sse_stream(id, rx);
Sse::new(stream).keep_alive(KeepAlive::default())
}
fn create_sse_stream(
session_id: String,
mut rx: tokio::sync::broadcast::Receiver<String>,
) -> impl Stream<Item = Result<Event, Infallible>> {
async_stream::stream! {
yield Ok(Event::default()
.event("connected")
.data(&session_id));
loop {
match rx.recv().await {
Ok(msg) => {
yield Ok(Event::default()
.event("message")
.data(msg));
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
warn!(skipped = n, "SSE client lagged, skipped messages");
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
debug!("SSE channel closed");
break;
}
}
}
}
}