use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::{Response, StatusCode, header};
use axum::response::IntoResponse;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::routing::{get, post};
use axum::{Json, Router};
use bytes::Bytes;
use dashmap::DashMap;
use futures::Stream;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use uuid::Uuid;
use crate::command::{CommandResult, ServerCommand, execute_command};
use crate::error::FastMcpError;
use crate::prompt::PromptMessage;
use crate::resource::ResourceContent;
use crate::server::FastMcpServer;
use crate::tool::ToolResponse;
#[derive(Clone)]
struct AppState {
server: Arc<FastMcpServer>,
hub: SseHub,
stream_hub: StreamHub,
}
impl AppState {
fn broadcast(&self, result: &CommandResult) {
let payload = match serde_json::to_value(result) {
Ok(value) => value,
Err(err) => {
tracing::error!("failed to encode SSE payload: {}", err);
return;
}
};
self.hub.publish(result.event_kind(), payload);
}
async fn send_to_session(
&self,
session_id: Uuid,
result: &CommandResult,
) -> Result<(), FastMcpError> {
self.stream_hub.send(&session_id, result).await
}
}
#[derive(Clone)]
struct SseHub {
tx: broadcast::Sender<ServerEvent>,
}
impl SseHub {
fn new() -> Self {
let (tx, _) = broadcast::channel(256);
Self { tx }
}
fn publish(&self, kind: &str, payload: Value) {
let event = ServerEvent {
kind: kind.to_string(),
payload,
};
if let Err(err) = self.tx.send(event) {
tracing::debug!("no active SSE listeners to receive event: {err}");
}
}
fn subscribe(&self) -> broadcast::Receiver<ServerEvent> {
self.tx.subscribe()
}
}
#[derive(Clone, Debug)]
struct ServerEvent {
kind: String,
payload: Value,
}
#[derive(Clone)]
struct StreamHub {
sessions: Arc<DashMap<Uuid, Arc<StreamSession>>>,
}
struct StreamSession {
sender: mpsc::Sender<Value>,
receiver: Mutex<Option<mpsc::Receiver<Value>>>,
}
impl StreamHub {
fn new() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
fn create_session(&self) -> Uuid {
let (sender, receiver) = mpsc::channel(64);
let entry = Arc::new(StreamSession {
sender,
receiver: Mutex::new(Some(receiver)),
});
let id = Uuid::new_v4();
self.sessions.insert(id, entry);
id
}
fn take_receiver(&self, id: &Uuid) -> Option<mpsc::Receiver<Value>> {
self.sessions
.get(id)
.and_then(|entry| entry.receiver.lock().take())
}
async fn send(&self, id: &Uuid, result: &CommandResult) -> Result<(), FastMcpError> {
let payload = serde_json::to_value(result)?;
match self.sessions.get(id) {
Some(entry) => {
if let Err(err) = entry.sender.send(payload).await {
tracing::debug!("stream session {id} closed {:?}", err);
self.sessions.remove(id);
return Err(FastMcpError::InvalidInvocation(format!(
"session {id} closed"
)));
}
Ok(())
}
None => Err(FastMcpError::InvalidInvocation(format!(
"session {id} not found"
))),
}
}
fn close(&self, id: &Uuid) {
self.sessions.remove(id);
}
}
pub struct HttpServerHandle {
addr: SocketAddr,
shutdown: Option<oneshot::Sender<()>>,
task: JoinHandle<()>,
}
impl HttpServerHandle {
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
let _ = self.task.await;
}
}
pub async fn start_http(
server: Arc<FastMcpServer>,
addr: SocketAddr,
) -> std::io::Result<HttpServerHandle> {
let state = AppState {
server: Arc::clone(&server),
hub: SseHub::new(),
stream_hub: StreamHub::new(),
};
let router = Router::new()
.route("/healthz", get(health))
.route("/metadata", get(metadata))
.route("/tools", get(list_tools))
.route("/tools/:name/call", post(call_tool))
.route("/resources", get(list_resources))
.route("/resource", get(read_resource))
.route("/prompts", get(list_prompts))
.route("/prompts/:name/instantiate", post(instantiate_prompt))
.route("/sse", get(sse_stream))
.route("/streamable/session", post(create_stream_session))
.route("/streamable/session/:id", get(stream_session))
.route(
"/streamable/session/:id/messages",
post(stream_session_message),
)
.route("/messages", post(message_gateway))
.with_state(state);
let listener = tokio::net::TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
log_http_startup(&server, &local_addr);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let task = tokio::spawn(async move {
let server = axum::serve(listener, router);
let graceful = server.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
});
if let Err(err) = graceful.await {
tracing::error!("HTTP server error: {}", err);
}
});
Ok(HttpServerHandle {
addr: local_addr,
shutdown: Some(shutdown_tx),
task,
})
}
async fn health() -> impl IntoResponse {
StatusCode::OK
}
async fn metadata(State(state): State<AppState>) -> impl IntoResponse {
Json(state.server.metadata())
}
async fn list_tools(State(state): State<AppState>) -> impl IntoResponse {
Json(state.server.list_tools())
}
#[derive(Deserialize)]
struct CallToolRequest {
#[serde(default)]
arguments: Value,
}
async fn call_tool(
State(state): State<AppState>,
Path(name): Path<String>,
Json(payload): Json<CallToolRequest>,
) -> Result<Json<ToolResponse>, HttpError> {
let response = state
.server
.call_tool(&name, payload.arguments)
.await
.map_err(HttpError::from)?;
state.broadcast(&CommandResult::ToolInvocation {
data: response.clone(),
});
Ok(Json(response))
}
async fn list_resources(State(state): State<AppState>) -> impl IntoResponse {
Json(state.server.list_resources())
}
#[derive(Deserialize)]
struct ResourceQuery {
uri: String,
}
async fn read_resource(
State(state): State<AppState>,
Query(query): Query<ResourceQuery>,
) -> Result<Json<ResourceContent>, HttpError> {
let content = state
.server
.read_resource(&query.uri)
.await
.map_err(HttpError::from)?;
state.broadcast(&CommandResult::Resource {
data: content.clone(),
});
Ok(Json(content))
}
async fn list_prompts(State(state): State<AppState>) -> impl IntoResponse {
Json(state.server.list_prompts())
}
#[derive(Deserialize)]
struct InstantiatePromptRequest {
#[serde(default)]
arguments: Option<Value>,
}
#[derive(Serialize)]
struct InstantiatePromptResponse {
messages: Vec<PromptMessage>,
}
async fn instantiate_prompt(
State(state): State<AppState>,
Path(name): Path<String>,
Json(payload): Json<InstantiatePromptRequest>,
) -> Result<Json<InstantiatePromptResponse>, HttpError> {
let messages = state
.server
.instantiate_prompt(&name, payload.arguments.as_ref())
.map_err(HttpError::from)?;
state.broadcast(&CommandResult::PromptInstantiation {
data: messages.clone(),
});
Ok(Json(InstantiatePromptResponse { messages }))
}
#[derive(Serialize)]
struct CreateSessionResponse {
session_id: Uuid,
}
async fn create_stream_session(State(state): State<AppState>) -> impl IntoResponse {
let session_id = state.stream_hub.create_session();
(
StatusCode::CREATED,
Json(CreateSessionResponse { session_id }),
)
}
async fn stream_session(
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<Response<Body>, HttpError> {
let session_id = parse_session_id(&id)?;
let receiver = state
.stream_hub
.take_receiver(&session_id)
.ok_or_else(|| HttpError::not_found("stream session not found"))?;
let stream = ReceiverStream::new(receiver).map(|value| {
let mut bytes =
serde_json::to_vec(&value).expect("serializing serde_json::Value should be infallible");
bytes.push(b'\n');
Ok::<Bytes, Infallible>(Bytes::from(bytes))
});
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/jsonl")
.body(Body::from_stream(stream))
.map_err(|err| HttpError::internal(err.to_string()))
}
async fn stream_session_message(
State(state): State<AppState>,
Path(id): Path<String>,
Json(command): Json<ServerCommand>,
) -> Result<Json<CommandResult>, HttpError> {
let session_id = parse_session_id(&id)?;
let (result, shutdown) = execute_command(&state.server, command)
.await
.map_err(HttpError::from)?;
state
.send_to_session(session_id, &result)
.await
.map_err(HttpError::from)?;
state.broadcast(&result);
if shutdown {
tracing::info!("stream session {session_id} requested shutdown");
state.stream_hub.close(&session_id);
}
Ok(Json(result))
}
fn parse_session_id(raw: &str) -> Result<Uuid, HttpError> {
Uuid::parse_str(raw).map_err(|_| HttpError::bad_request("invalid stream session id"))
}
async fn sse_stream(
State(state): State<AppState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let stream = BroadcastStream::new(state.hub.subscribe()).filter_map(|result| {
let event = match result {
Ok(event) => event,
Err(_) => return None,
};
match serde_json::to_string(&event.payload) {
Ok(data) => {
let mut sse_event = Event::default().event(event.kind);
sse_event = sse_event.data(data);
Some(Ok(sse_event))
}
Err(err) => {
tracing::error!("failed to serialize SSE event: {}", err);
None
}
}
});
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text("ping"),
)
}
async fn message_gateway(
State(state): State<AppState>,
Json(command): Json<ServerCommand>,
) -> Result<Json<CommandResult>, HttpError> {
let (result, shutdown) = execute_command(&state.server, command)
.await
.map_err(HttpError::from)?;
state.broadcast(&result);
if shutdown {
tracing::info!("received shutdown command via message gateway");
}
Ok(Json(result))
}
#[derive(Debug, Serialize)]
struct ErrorBody {
error: String,
}
pub struct HttpError {
status: StatusCode,
message: String,
}
impl HttpError {
fn new(status: StatusCode, message: impl Into<String>) -> Self {
Self {
status,
message: message.into(),
}
}
fn bad_request(message: impl Into<String>) -> Self {
Self::new(StatusCode::BAD_REQUEST, message)
}
fn not_found(message: impl Into<String>) -> Self {
Self::new(StatusCode::NOT_FOUND, message)
}
fn internal(message: impl Into<String>) -> Self {
Self::new(StatusCode::INTERNAL_SERVER_ERROR, message)
}
}
impl From<FastMcpError> for HttpError {
fn from(err: FastMcpError) -> Self {
match err {
FastMcpError::ToolNotFound(_)
| FastMcpError::ResourceNotFound(_)
| FastMcpError::PromptNotFound(_) => Self {
status: StatusCode::NOT_FOUND,
message: err.to_string(),
},
FastMcpError::DuplicateTool(_)
| FastMcpError::DuplicateResource(_)
| FastMcpError::DuplicatePrompt(_) => Self {
status: StatusCode::CONFLICT,
message: err.to_string(),
},
FastMcpError::InvalidInvocation(_) => Self {
status: StatusCode::UNPROCESSABLE_ENTITY,
message: err.to_string(),
},
FastMcpError::HandlerError(_) | FastMcpError::Serialization(_) => Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: err.to_string(),
},
FastMcpError::Io(_) => Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: err.to_string(),
},
}
}
}
impl IntoResponse for HttpError {
fn into_response(self) -> axum::response::Response {
let body = Json(ErrorBody {
error: self.message,
});
(self.status, body).into_response()
}
}
fn log_http_startup(server: &FastMcpServer, addr: &SocketAddr) {
let metadata = server.metadata();
let base_url = format!("http://{}", addr);
let host = addr.ip();
let port = addr.port();
let tools = server
.list_tools()
.into_iter()
.map(|tool| tool.name)
.collect::<Vec<_>>();
let resources = server
.list_resources()
.into_iter()
.map(|resource| resource.uri)
.collect::<Vec<_>>();
let prompts = server
.list_prompts()
.into_iter()
.map(|prompt| prompt.name)
.collect::<Vec<_>>();
let mut lines = Vec::new();
lines.push(format!(
"FastMCP '{}' (id: {}) listening on {}",
metadata.name, metadata.id, base_url
));
lines.push(format!(" Host: {}", host));
lines.push(format!(" Port: {}", port));
lines.push(format!(
" Instructions: {}",
metadata
.instructions
.as_deref()
.unwrap_or("No instructions configured")
));
lines.push(format!(
" Registered tools: {}",
if tools.is_empty() {
"none".into()
} else {
tools.join(", ")
}
));
lines.push(format!(
" Registered resources: {}",
if resources.is_empty() {
"none".into()
} else {
resources.join(", ")
}
));
lines.push(format!(
" Registered prompts: {}",
if prompts.is_empty() {
"none".into()
} else {
prompts.join(", ")
}
));
lines.push(format!(" HTTP base URL: {}", base_url));
lines.push(format!(" HTTP base URI: mcp+http://{}", addr));
lines.push(format!(" SSE endpoint: {}/sse", base_url));
lines.push(format!(" SSE URI: mcp+sse://{}/sse", addr));
lines.push(" Streamable HTTP endpoints:".to_string());
lines.push(format!(
" session: {}/streamable/session (URI: mcp+streamable-http://{}/streamable/session)",
base_url, addr
));
lines.push(format!(
" session/{{id}}: {}/streamable/session/{{id}} (URI: mcp+streamable-http://{}/streamable/session/{{id}})",
base_url, addr
));
lines.push(format!(
" session/{{id}}/messages: {}/streamable/session/{{id}}/messages (URI: mcp+streamable-http://{}/streamable/session/{{id}}/messages)",
base_url, addr
));
lines.push(format!(
" Message gateway: {}/messages (URI: mcp+http://{}/messages)",
base_url, addr
));
emit_startup_lines(lines);
}
fn emit_startup_lines(lines: Vec<String>) {
for line in lines {
tracing::info!("{}", line);
println!("{line}");
}
}