use axum::body::{Body, Bytes};
use axum::extract::State;
use axum::http::{header, HeaderValue, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{Json, Router};
use futures::stream::{self, Stream};
use serde_json::{json, Value};
use std::convert::Infallible;
use std::time::Duration;
use crate::dispatch::dispatch_request;
use crate::initialize::ServerInfo;
use crate::jsonrpc::{
parse_envelope, ErrorObject, Request as RpcRequest, RequestEnvelope, Response as RpcResponse,
};
use crate::tools::ToolDispatch;
#[derive(Clone)]
pub struct McpState {
pub dispatch: ToolDispatch,
pub server_info: ServerInfo,
}
pub fn router(state: McpState) -> Router {
Router::new()
.route("/mcp/v1", post(handle_post).get(handle_get_sse))
.with_state(state)
}
async fn handle_post(State(state): State<McpState>, body: Bytes) -> Response {
let envelope = match parse_envelope(&body) {
Ok(env) => env,
Err(err) => return parse_error_response_with_obj(err),
};
match envelope {
RequestEnvelope::Single(req) => match dispatch_one(&state, req).await {
Some(resp) => Json(resp).into_response(),
None => StatusCode::ACCEPTED.into_response(),
},
RequestEnvelope::Batch(reqs) => {
let mut out: Vec<RpcResponse> = Vec::with_capacity(reqs.len());
for req in reqs {
if let Some(resp) = dispatch_one(&state, req).await {
out.push(resp);
}
}
if out.is_empty() {
StatusCode::ACCEPTED.into_response()
} else {
Json(out).into_response()
}
}
}
}
async fn dispatch_one(state: &McpState, req: RpcRequest) -> Option<RpcResponse> {
use tracing::Instrument;
let span = tracing::info_span!("mcp.dispatch", method = %req.method);
dispatch_request(&state.dispatch, &state.server_info, req)
.instrument(span)
.await
}
async fn handle_get_sse(
State(_state): State<McpState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let s = stream::pending::<Result<Event, Infallible>>();
Sse::new(s).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
}
fn parse_error_response_with_obj(err: ErrorObject) -> Response {
let body = json!({
"jsonrpc": "2.0",
"id": Value::Null,
"error": err,
});
let mut resp = Response::new(Body::from(body.to_string()));
resp.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
resp
}