1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4
5use axum::body::Body;
6use axum::extract::{Path, Query, State};
7use axum::http::{Response, StatusCode, header};
8use axum::response::IntoResponse;
9use axum::response::sse::{Event, KeepAlive, Sse};
10use axum::routing::{get, post};
11use axum::{Json, Router};
12use bytes::Bytes;
13use dashmap::DashMap;
14use futures::Stream;
15use parking_lot::Mutex;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use tokio::sync::{broadcast, mpsc, oneshot};
19use tokio::task::JoinHandle;
20use tokio_stream::StreamExt;
21use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
22use uuid::Uuid;
23
24use crate::command::{CommandResult, ServerCommand, execute_command};
25use crate::error::FastMcpError;
26use crate::prompt::PromptMessage;
27use crate::resource::ResourceContent;
28use crate::server::FastMcpServer;
29use crate::tool::ToolResponse;
30
31#[derive(Clone)]
32struct AppState {
33 server: Arc<FastMcpServer>,
34 hub: SseHub,
35 stream_hub: StreamHub,
36}
37
38impl AppState {
39 fn broadcast(&self, result: &CommandResult) {
40 let payload = match serde_json::to_value(result) {
41 Ok(value) => value,
42 Err(err) => {
43 tracing::error!("failed to encode SSE payload: {}", err);
44 return;
45 }
46 };
47 self.hub.publish(result.event_kind(), payload);
48 }
49
50 async fn send_to_session(
51 &self,
52 session_id: Uuid,
53 result: &CommandResult,
54 ) -> Result<(), FastMcpError> {
55 self.stream_hub.send(&session_id, result).await
56 }
57}
58
59#[derive(Clone)]
60struct SseHub {
61 tx: broadcast::Sender<ServerEvent>,
62}
63
64impl SseHub {
65 fn new() -> Self {
66 let (tx, _) = broadcast::channel(256);
67 Self { tx }
68 }
69
70 fn publish(&self, kind: &str, payload: Value) {
71 let event = ServerEvent {
72 kind: kind.to_string(),
73 payload,
74 };
75 if let Err(err) = self.tx.send(event) {
76 tracing::debug!("no active SSE listeners to receive event: {err}");
77 }
78 }
79
80 fn subscribe(&self) -> broadcast::Receiver<ServerEvent> {
81 self.tx.subscribe()
82 }
83}
84
85#[derive(Clone, Debug)]
86struct ServerEvent {
87 kind: String,
88 payload: Value,
89}
90
91#[derive(Clone)]
92struct StreamHub {
93 sessions: Arc<DashMap<Uuid, Arc<StreamSession>>>,
94}
95
96struct StreamSession {
97 sender: mpsc::Sender<Value>,
98 receiver: Mutex<Option<mpsc::Receiver<Value>>>,
99}
100
101impl StreamHub {
102 fn new() -> Self {
103 Self {
104 sessions: Arc::new(DashMap::new()),
105 }
106 }
107
108 fn create_session(&self) -> Uuid {
109 let (sender, receiver) = mpsc::channel(64);
110 let entry = Arc::new(StreamSession {
111 sender,
112 receiver: Mutex::new(Some(receiver)),
113 });
114 let id = Uuid::new_v4();
115 self.sessions.insert(id, entry);
116 id
117 }
118
119 fn take_receiver(&self, id: &Uuid) -> Option<mpsc::Receiver<Value>> {
120 self.sessions
121 .get(id)
122 .and_then(|entry| entry.receiver.lock().take())
123 }
124
125 async fn send(&self, id: &Uuid, result: &CommandResult) -> Result<(), FastMcpError> {
126 let payload = serde_json::to_value(result)?;
127 match self.sessions.get(id) {
128 Some(entry) => {
129 if let Err(err) = entry.sender.send(payload).await {
130 tracing::debug!("stream session {id} closed {:?}", err);
131 self.sessions.remove(id);
132 return Err(FastMcpError::InvalidInvocation(format!(
133 "session {id} closed"
134 )));
135 }
136 Ok(())
137 }
138 None => Err(FastMcpError::InvalidInvocation(format!(
139 "session {id} not found"
140 ))),
141 }
142 }
143
144 fn close(&self, id: &Uuid) {
145 self.sessions.remove(id);
146 }
147}
148
149pub struct HttpServerHandle {
150 addr: SocketAddr,
151 shutdown: Option<oneshot::Sender<()>>,
152 task: JoinHandle<()>,
153}
154
155impl HttpServerHandle {
156 pub fn addr(&self) -> SocketAddr {
157 self.addr
158 }
159
160 pub async fn shutdown(mut self) {
161 if let Some(tx) = self.shutdown.take() {
162 let _ = tx.send(());
163 }
164 let _ = self.task.await;
165 }
166}
167
168pub async fn start_http(
169 server: Arc<FastMcpServer>,
170 addr: SocketAddr,
171) -> std::io::Result<HttpServerHandle> {
172 let state = AppState {
173 server: Arc::clone(&server),
174 hub: SseHub::new(),
175 stream_hub: StreamHub::new(),
176 };
177 let router = Router::new()
178 .route("/healthz", get(health))
179 .route("/metadata", get(metadata))
180 .route("/tools", get(list_tools))
181 .route("/tools/:name/call", post(call_tool))
182 .route("/resources", get(list_resources))
183 .route("/resource", get(read_resource))
184 .route("/prompts", get(list_prompts))
185 .route("/prompts/:name/instantiate", post(instantiate_prompt))
186 .route("/sse", get(sse_stream))
187 .route("/streamable/session", post(create_stream_session))
188 .route("/streamable/session/:id", get(stream_session))
189 .route(
190 "/streamable/session/:id/messages",
191 post(stream_session_message),
192 )
193 .route("/messages", post(message_gateway))
194 .with_state(state);
195
196 let listener = tokio::net::TcpListener::bind(addr).await?;
197 let local_addr = listener.local_addr()?;
198 log_http_startup(&server, &local_addr);
199 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
200
201 let task = tokio::spawn(async move {
202 let server = axum::serve(listener, router);
203 let graceful = server.with_graceful_shutdown(async {
204 let _ = shutdown_rx.await;
205 });
206 if let Err(err) = graceful.await {
207 tracing::error!("HTTP server error: {}", err);
208 }
209 });
210
211 Ok(HttpServerHandle {
212 addr: local_addr,
213 shutdown: Some(shutdown_tx),
214 task,
215 })
216}
217
218async fn health() -> impl IntoResponse {
219 StatusCode::OK
220}
221
222async fn metadata(State(state): State<AppState>) -> impl IntoResponse {
223 Json(state.server.metadata())
224}
225
226async fn list_tools(State(state): State<AppState>) -> impl IntoResponse {
227 Json(state.server.list_tools())
228}
229
230#[derive(Deserialize)]
231struct CallToolRequest {
232 #[serde(default)]
233 arguments: Value,
234}
235
236async fn call_tool(
237 State(state): State<AppState>,
238 Path(name): Path<String>,
239 Json(payload): Json<CallToolRequest>,
240) -> Result<Json<ToolResponse>, HttpError> {
241 let response = state
242 .server
243 .call_tool(&name, payload.arguments)
244 .await
245 .map_err(HttpError::from)?;
246
247 state.broadcast(&CommandResult::ToolInvocation {
248 data: response.clone(),
249 });
250
251 Ok(Json(response))
252}
253
254async fn list_resources(State(state): State<AppState>) -> impl IntoResponse {
255 Json(state.server.list_resources())
256}
257
258#[derive(Deserialize)]
259struct ResourceQuery {
260 uri: String,
261}
262
263async fn read_resource(
264 State(state): State<AppState>,
265 Query(query): Query<ResourceQuery>,
266) -> Result<Json<ResourceContent>, HttpError> {
267 let content = state
268 .server
269 .read_resource(&query.uri)
270 .await
271 .map_err(HttpError::from)?;
272
273 state.broadcast(&CommandResult::Resource {
274 data: content.clone(),
275 });
276
277 Ok(Json(content))
278}
279
280async fn list_prompts(State(state): State<AppState>) -> impl IntoResponse {
281 Json(state.server.list_prompts())
282}
283
284#[derive(Deserialize)]
285struct InstantiatePromptRequest {
286 #[serde(default)]
287 arguments: Option<Value>,
288}
289
290#[derive(Serialize)]
291struct InstantiatePromptResponse {
292 messages: Vec<PromptMessage>,
293}
294
295async fn instantiate_prompt(
296 State(state): State<AppState>,
297 Path(name): Path<String>,
298 Json(payload): Json<InstantiatePromptRequest>,
299) -> Result<Json<InstantiatePromptResponse>, HttpError> {
300 let messages = state
301 .server
302 .instantiate_prompt(&name, payload.arguments.as_ref())
303 .map_err(HttpError::from)?;
304
305 state.broadcast(&CommandResult::PromptInstantiation {
306 data: messages.clone(),
307 });
308
309 Ok(Json(InstantiatePromptResponse { messages }))
310}
311
312#[derive(Serialize)]
313struct CreateSessionResponse {
314 session_id: Uuid,
315}
316
317async fn create_stream_session(State(state): State<AppState>) -> impl IntoResponse {
318 let session_id = state.stream_hub.create_session();
319 (
320 StatusCode::CREATED,
321 Json(CreateSessionResponse { session_id }),
322 )
323}
324
325async fn stream_session(
326 State(state): State<AppState>,
327 Path(id): Path<String>,
328) -> Result<Response<Body>, HttpError> {
329 let session_id = parse_session_id(&id)?;
330 let receiver = state
331 .stream_hub
332 .take_receiver(&session_id)
333 .ok_or_else(|| HttpError::not_found("stream session not found"))?;
334
335 let stream = ReceiverStream::new(receiver).map(|value| {
336 let mut bytes =
337 serde_json::to_vec(&value).expect("serializing serde_json::Value should be infallible");
338 bytes.push(b'\n');
339 Ok::<Bytes, Infallible>(Bytes::from(bytes))
340 });
341
342 Response::builder()
343 .status(StatusCode::OK)
344 .header(header::CONTENT_TYPE, "application/jsonl")
345 .body(Body::from_stream(stream))
346 .map_err(|err| HttpError::internal(err.to_string()))
347}
348
349async fn stream_session_message(
350 State(state): State<AppState>,
351 Path(id): Path<String>,
352 Json(command): Json<ServerCommand>,
353) -> Result<Json<CommandResult>, HttpError> {
354 let session_id = parse_session_id(&id)?;
355
356 let (result, shutdown) = execute_command(&state.server, command)
357 .await
358 .map_err(HttpError::from)?;
359
360 state
361 .send_to_session(session_id, &result)
362 .await
363 .map_err(HttpError::from)?;
364 state.broadcast(&result);
365
366 if shutdown {
367 tracing::info!("stream session {session_id} requested shutdown");
368 state.stream_hub.close(&session_id);
369 }
370
371 Ok(Json(result))
372}
373
374fn parse_session_id(raw: &str) -> Result<Uuid, HttpError> {
375 Uuid::parse_str(raw).map_err(|_| HttpError::bad_request("invalid stream session id"))
376}
377
378async fn sse_stream(
379 State(state): State<AppState>,
380) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
381 let stream = BroadcastStream::new(state.hub.subscribe()).filter_map(|result| {
382 let event = match result {
383 Ok(event) => event,
384 Err(_) => return None,
385 };
386
387 match serde_json::to_string(&event.payload) {
388 Ok(data) => {
389 let mut sse_event = Event::default().event(event.kind);
390 sse_event = sse_event.data(data);
391 Some(Ok(sse_event))
392 }
393 Err(err) => {
394 tracing::error!("failed to serialize SSE event: {}", err);
395 None
396 }
397 }
398 });
399
400 Sse::new(stream).keep_alive(
401 KeepAlive::new()
402 .interval(std::time::Duration::from_secs(15))
403 .text("ping"),
404 )
405}
406
407async fn message_gateway(
408 State(state): State<AppState>,
409 Json(command): Json<ServerCommand>,
410) -> Result<Json<CommandResult>, HttpError> {
411 let (result, shutdown) = execute_command(&state.server, command)
412 .await
413 .map_err(HttpError::from)?;
414
415 state.broadcast(&result);
416
417 if shutdown {
418 tracing::info!("received shutdown command via message gateway");
419 }
420
421 Ok(Json(result))
422}
423
424#[derive(Debug, Serialize)]
425struct ErrorBody {
426 error: String,
427}
428
429pub struct HttpError {
430 status: StatusCode,
431 message: String,
432}
433
434impl HttpError {
435 fn new(status: StatusCode, message: impl Into<String>) -> Self {
436 Self {
437 status,
438 message: message.into(),
439 }
440 }
441
442 fn bad_request(message: impl Into<String>) -> Self {
443 Self::new(StatusCode::BAD_REQUEST, message)
444 }
445
446 fn not_found(message: impl Into<String>) -> Self {
447 Self::new(StatusCode::NOT_FOUND, message)
448 }
449
450 fn internal(message: impl Into<String>) -> Self {
451 Self::new(StatusCode::INTERNAL_SERVER_ERROR, message)
452 }
453}
454
455impl From<FastMcpError> for HttpError {
456 fn from(err: FastMcpError) -> Self {
457 match err {
458 FastMcpError::ToolNotFound(_)
459 | FastMcpError::ResourceNotFound(_)
460 | FastMcpError::PromptNotFound(_) => Self {
461 status: StatusCode::NOT_FOUND,
462 message: err.to_string(),
463 },
464 FastMcpError::DuplicateTool(_)
465 | FastMcpError::DuplicateResource(_)
466 | FastMcpError::DuplicatePrompt(_) => Self {
467 status: StatusCode::CONFLICT,
468 message: err.to_string(),
469 },
470 FastMcpError::InvalidInvocation(_) => Self {
471 status: StatusCode::UNPROCESSABLE_ENTITY,
472 message: err.to_string(),
473 },
474 FastMcpError::HandlerError(_) | FastMcpError::Serialization(_) => Self {
475 status: StatusCode::INTERNAL_SERVER_ERROR,
476 message: err.to_string(),
477 },
478 FastMcpError::Io(_) => Self {
479 status: StatusCode::INTERNAL_SERVER_ERROR,
480 message: err.to_string(),
481 },
482 }
483 }
484}
485
486impl IntoResponse for HttpError {
487 fn into_response(self) -> axum::response::Response {
488 let body = Json(ErrorBody {
489 error: self.message,
490 });
491 (self.status, body).into_response()
492 }
493}
494fn log_http_startup(server: &FastMcpServer, addr: &SocketAddr) {
495 let metadata = server.metadata();
496 let base_url = format!("http://{}", addr);
497 let host = addr.ip();
498 let port = addr.port();
499 let tools = server
500 .list_tools()
501 .into_iter()
502 .map(|tool| tool.name)
503 .collect::<Vec<_>>();
504 let resources = server
505 .list_resources()
506 .into_iter()
507 .map(|resource| resource.uri)
508 .collect::<Vec<_>>();
509 let prompts = server
510 .list_prompts()
511 .into_iter()
512 .map(|prompt| prompt.name)
513 .collect::<Vec<_>>();
514
515 let mut lines = Vec::new();
516 lines.push(format!(
517 "FastMCP '{}' (id: {}) listening on {}",
518 metadata.name, metadata.id, base_url
519 ));
520 lines.push(format!(" Host: {}", host));
521 lines.push(format!(" Port: {}", port));
522 lines.push(format!(
523 " Instructions: {}",
524 metadata
525 .instructions
526 .as_deref()
527 .unwrap_or("No instructions configured")
528 ));
529 lines.push(format!(
530 " Registered tools: {}",
531 if tools.is_empty() {
532 "none".into()
533 } else {
534 tools.join(", ")
535 }
536 ));
537 lines.push(format!(
538 " Registered resources: {}",
539 if resources.is_empty() {
540 "none".into()
541 } else {
542 resources.join(", ")
543 }
544 ));
545 lines.push(format!(
546 " Registered prompts: {}",
547 if prompts.is_empty() {
548 "none".into()
549 } else {
550 prompts.join(", ")
551 }
552 ));
553 lines.push(format!(" HTTP base URL: {}", base_url));
554 lines.push(format!(" HTTP base URI: mcp+http://{}", addr));
555 lines.push(format!(" SSE endpoint: {}/sse", base_url));
556 lines.push(format!(" SSE URI: mcp+sse://{}/sse", addr));
557 lines.push(" Streamable HTTP endpoints:".to_string());
558 lines.push(format!(
559 " session: {}/streamable/session (URI: mcp+streamable-http://{}/streamable/session)",
560 base_url, addr
561 ));
562 lines.push(format!(
563 " session/{{id}}: {}/streamable/session/{{id}} (URI: mcp+streamable-http://{}/streamable/session/{{id}})",
564 base_url, addr
565 ));
566 lines.push(format!(
567 " session/{{id}}/messages: {}/streamable/session/{{id}}/messages (URI: mcp+streamable-http://{}/streamable/session/{{id}}/messages)",
568 base_url, addr
569 ));
570 lines.push(format!(
571 " Message gateway: {}/messages (URI: mcp+http://{}/messages)",
572 base_url, addr
573 ));
574
575 emit_startup_lines(lines);
576}
577
578fn emit_startup_lines(lines: Vec<String>) {
579 for line in lines {
580 tracing::info!("{}", line);
581 println!("{line}");
582 }
583}