Skip to main content

fastmcp_rs/http/
mod.rs

1use std::convert::Infallible;
2use std::net::SocketAddr;
3use std::sync::Arc;
4
5use axum::body::Body;
6use axum::extract::{Path, Query, State};
7use axum::http::{Response, StatusCode, header};
8use axum::response::IntoResponse;
9use axum::response::sse::{Event, KeepAlive, Sse};
10use axum::routing::{get, post};
11use axum::{Json, Router};
12use bytes::Bytes;
13use dashmap::DashMap;
14use futures::Stream;
15use parking_lot::Mutex;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use tokio::sync::{broadcast, mpsc, oneshot};
19use tokio::task::JoinHandle;
20use tokio_stream::StreamExt;
21use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
22use uuid::Uuid;
23
24use crate::command::{CommandResult, ServerCommand, execute_command};
25use crate::error::FastMcpError;
26use crate::prompt::PromptMessage;
27use crate::resource::ResourceContent;
28use crate::server::FastMcpServer;
29use crate::tool::ToolResponse;
30
31#[derive(Clone)]
32struct AppState {
33    server: Arc<FastMcpServer>,
34    hub: SseHub,
35    stream_hub: StreamHub,
36}
37
38impl AppState {
39    fn broadcast(&self, result: &CommandResult) {
40        let payload = match serde_json::to_value(result) {
41            Ok(value) => value,
42            Err(err) => {
43                tracing::error!("failed to encode SSE payload: {}", err);
44                return;
45            }
46        };
47        self.hub.publish(result.event_kind(), payload);
48    }
49
50    async fn send_to_session(
51        &self,
52        session_id: Uuid,
53        result: &CommandResult,
54    ) -> Result<(), FastMcpError> {
55        self.stream_hub.send(&session_id, result).await
56    }
57}
58
59#[derive(Clone)]
60struct SseHub {
61    tx: broadcast::Sender<ServerEvent>,
62}
63
64impl SseHub {
65    fn new() -> Self {
66        let (tx, _) = broadcast::channel(256);
67        Self { tx }
68    }
69
70    fn publish(&self, kind: &str, payload: Value) {
71        let event = ServerEvent {
72            kind: kind.to_string(),
73            payload,
74        };
75        if let Err(err) = self.tx.send(event) {
76            tracing::debug!("no active SSE listeners to receive event: {err}");
77        }
78    }
79
80    fn subscribe(&self) -> broadcast::Receiver<ServerEvent> {
81        self.tx.subscribe()
82    }
83}
84
85#[derive(Clone, Debug)]
86struct ServerEvent {
87    kind: String,
88    payload: Value,
89}
90
91#[derive(Clone)]
92struct StreamHub {
93    sessions: Arc<DashMap<Uuid, Arc<StreamSession>>>,
94}
95
96struct StreamSession {
97    sender: mpsc::Sender<Value>,
98    receiver: Mutex<Option<mpsc::Receiver<Value>>>,
99}
100
101impl StreamHub {
102    fn new() -> Self {
103        Self {
104            sessions: Arc::new(DashMap::new()),
105        }
106    }
107
108    fn create_session(&self) -> Uuid {
109        let (sender, receiver) = mpsc::channel(64);
110        let entry = Arc::new(StreamSession {
111            sender,
112            receiver: Mutex::new(Some(receiver)),
113        });
114        let id = Uuid::new_v4();
115        self.sessions.insert(id, entry);
116        id
117    }
118
119    fn take_receiver(&self, id: &Uuid) -> Option<mpsc::Receiver<Value>> {
120        self.sessions
121            .get(id)
122            .and_then(|entry| entry.receiver.lock().take())
123    }
124
125    async fn send(&self, id: &Uuid, result: &CommandResult) -> Result<(), FastMcpError> {
126        let payload = serde_json::to_value(result)?;
127        match self.sessions.get(id) {
128            Some(entry) => {
129                if let Err(err) = entry.sender.send(payload).await {
130                    tracing::debug!("stream session {id} closed {:?}", err);
131                    self.sessions.remove(id);
132                    return Err(FastMcpError::InvalidInvocation(format!(
133                        "session {id} closed"
134                    )));
135                }
136                Ok(())
137            }
138            None => Err(FastMcpError::InvalidInvocation(format!(
139                "session {id} not found"
140            ))),
141        }
142    }
143
144    fn close(&self, id: &Uuid) {
145        self.sessions.remove(id);
146    }
147}
148
149pub struct HttpServerHandle {
150    addr: SocketAddr,
151    shutdown: Option<oneshot::Sender<()>>,
152    task: JoinHandle<()>,
153}
154
155impl HttpServerHandle {
156    pub fn addr(&self) -> SocketAddr {
157        self.addr
158    }
159
160    pub async fn shutdown(mut self) {
161        if let Some(tx) = self.shutdown.take() {
162            let _ = tx.send(());
163        }
164        let _ = self.task.await;
165    }
166}
167
168pub async fn start_http(
169    server: Arc<FastMcpServer>,
170    addr: SocketAddr,
171) -> std::io::Result<HttpServerHandle> {
172    let state = AppState {
173        server: Arc::clone(&server),
174        hub: SseHub::new(),
175        stream_hub: StreamHub::new(),
176    };
177    let router = Router::new()
178        .route("/healthz", get(health))
179        .route("/metadata", get(metadata))
180        .route("/tools", get(list_tools))
181        .route("/tools/:name/call", post(call_tool))
182        .route("/resources", get(list_resources))
183        .route("/resource", get(read_resource))
184        .route("/prompts", get(list_prompts))
185        .route("/prompts/:name/instantiate", post(instantiate_prompt))
186        .route("/sse", get(sse_stream))
187        .route("/streamable/session", post(create_stream_session))
188        .route("/streamable/session/:id", get(stream_session))
189        .route(
190            "/streamable/session/:id/messages",
191            post(stream_session_message),
192        )
193        .route("/messages", post(message_gateway))
194        .with_state(state);
195
196    let listener = tokio::net::TcpListener::bind(addr).await?;
197    let local_addr = listener.local_addr()?;
198    log_http_startup(&server, &local_addr);
199    let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
200
201    let task = tokio::spawn(async move {
202        let server = axum::serve(listener, router);
203        let graceful = server.with_graceful_shutdown(async {
204            let _ = shutdown_rx.await;
205        });
206        if let Err(err) = graceful.await {
207            tracing::error!("HTTP server error: {}", err);
208        }
209    });
210
211    Ok(HttpServerHandle {
212        addr: local_addr,
213        shutdown: Some(shutdown_tx),
214        task,
215    })
216}
217
218async fn health() -> impl IntoResponse {
219    StatusCode::OK
220}
221
222async fn metadata(State(state): State<AppState>) -> impl IntoResponse {
223    Json(state.server.metadata())
224}
225
226async fn list_tools(State(state): State<AppState>) -> impl IntoResponse {
227    Json(state.server.list_tools())
228}
229
230#[derive(Deserialize)]
231struct CallToolRequest {
232    #[serde(default)]
233    arguments: Value,
234}
235
236async fn call_tool(
237    State(state): State<AppState>,
238    Path(name): Path<String>,
239    Json(payload): Json<CallToolRequest>,
240) -> Result<Json<ToolResponse>, HttpError> {
241    let response = state
242        .server
243        .call_tool(&name, payload.arguments)
244        .await
245        .map_err(HttpError::from)?;
246
247    state.broadcast(&CommandResult::ToolInvocation {
248        data: response.clone(),
249    });
250
251    Ok(Json(response))
252}
253
254async fn list_resources(State(state): State<AppState>) -> impl IntoResponse {
255    Json(state.server.list_resources())
256}
257
258#[derive(Deserialize)]
259struct ResourceQuery {
260    uri: String,
261}
262
263async fn read_resource(
264    State(state): State<AppState>,
265    Query(query): Query<ResourceQuery>,
266) -> Result<Json<ResourceContent>, HttpError> {
267    let content = state
268        .server
269        .read_resource(&query.uri)
270        .await
271        .map_err(HttpError::from)?;
272
273    state.broadcast(&CommandResult::Resource {
274        data: content.clone(),
275    });
276
277    Ok(Json(content))
278}
279
280async fn list_prompts(State(state): State<AppState>) -> impl IntoResponse {
281    Json(state.server.list_prompts())
282}
283
284#[derive(Deserialize)]
285struct InstantiatePromptRequest {
286    #[serde(default)]
287    arguments: Option<Value>,
288}
289
290#[derive(Serialize)]
291struct InstantiatePromptResponse {
292    messages: Vec<PromptMessage>,
293}
294
295async fn instantiate_prompt(
296    State(state): State<AppState>,
297    Path(name): Path<String>,
298    Json(payload): Json<InstantiatePromptRequest>,
299) -> Result<Json<InstantiatePromptResponse>, HttpError> {
300    let messages = state
301        .server
302        .instantiate_prompt(&name, payload.arguments.as_ref())
303        .map_err(HttpError::from)?;
304
305    state.broadcast(&CommandResult::PromptInstantiation {
306        data: messages.clone(),
307    });
308
309    Ok(Json(InstantiatePromptResponse { messages }))
310}
311
312#[derive(Serialize)]
313struct CreateSessionResponse {
314    session_id: Uuid,
315}
316
317async fn create_stream_session(State(state): State<AppState>) -> impl IntoResponse {
318    let session_id = state.stream_hub.create_session();
319    (
320        StatusCode::CREATED,
321        Json(CreateSessionResponse { session_id }),
322    )
323}
324
325async fn stream_session(
326    State(state): State<AppState>,
327    Path(id): Path<String>,
328) -> Result<Response<Body>, HttpError> {
329    let session_id = parse_session_id(&id)?;
330    let receiver = state
331        .stream_hub
332        .take_receiver(&session_id)
333        .ok_or_else(|| HttpError::not_found("stream session not found"))?;
334
335    let stream = ReceiverStream::new(receiver).map(|value| {
336        let mut bytes =
337            serde_json::to_vec(&value).expect("serializing serde_json::Value should be infallible");
338        bytes.push(b'\n');
339        Ok::<Bytes, Infallible>(Bytes::from(bytes))
340    });
341
342    Response::builder()
343        .status(StatusCode::OK)
344        .header(header::CONTENT_TYPE, "application/jsonl")
345        .body(Body::from_stream(stream))
346        .map_err(|err| HttpError::internal(err.to_string()))
347}
348
349async fn stream_session_message(
350    State(state): State<AppState>,
351    Path(id): Path<String>,
352    Json(command): Json<ServerCommand>,
353) -> Result<Json<CommandResult>, HttpError> {
354    let session_id = parse_session_id(&id)?;
355
356    let (result, shutdown) = execute_command(&state.server, command)
357        .await
358        .map_err(HttpError::from)?;
359
360    state
361        .send_to_session(session_id, &result)
362        .await
363        .map_err(HttpError::from)?;
364    state.broadcast(&result);
365
366    if shutdown {
367        tracing::info!("stream session {session_id} requested shutdown");
368        state.stream_hub.close(&session_id);
369    }
370
371    Ok(Json(result))
372}
373
374fn parse_session_id(raw: &str) -> Result<Uuid, HttpError> {
375    Uuid::parse_str(raw).map_err(|_| HttpError::bad_request("invalid stream session id"))
376}
377
378async fn sse_stream(
379    State(state): State<AppState>,
380) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
381    let stream = BroadcastStream::new(state.hub.subscribe()).filter_map(|result| {
382        let event = match result {
383            Ok(event) => event,
384            Err(_) => return None,
385        };
386
387        match serde_json::to_string(&event.payload) {
388            Ok(data) => {
389                let mut sse_event = Event::default().event(event.kind);
390                sse_event = sse_event.data(data);
391                Some(Ok(sse_event))
392            }
393            Err(err) => {
394                tracing::error!("failed to serialize SSE event: {}", err);
395                None
396            }
397        }
398    });
399
400    Sse::new(stream).keep_alive(
401        KeepAlive::new()
402            .interval(std::time::Duration::from_secs(15))
403            .text("ping"),
404    )
405}
406
407async fn message_gateway(
408    State(state): State<AppState>,
409    Json(command): Json<ServerCommand>,
410) -> Result<Json<CommandResult>, HttpError> {
411    let (result, shutdown) = execute_command(&state.server, command)
412        .await
413        .map_err(HttpError::from)?;
414
415    state.broadcast(&result);
416
417    if shutdown {
418        tracing::info!("received shutdown command via message gateway");
419    }
420
421    Ok(Json(result))
422}
423
424#[derive(Debug, Serialize)]
425struct ErrorBody {
426    error: String,
427}
428
429pub struct HttpError {
430    status: StatusCode,
431    message: String,
432}
433
434impl HttpError {
435    fn new(status: StatusCode, message: impl Into<String>) -> Self {
436        Self {
437            status,
438            message: message.into(),
439        }
440    }
441
442    fn bad_request(message: impl Into<String>) -> Self {
443        Self::new(StatusCode::BAD_REQUEST, message)
444    }
445
446    fn not_found(message: impl Into<String>) -> Self {
447        Self::new(StatusCode::NOT_FOUND, message)
448    }
449
450    fn internal(message: impl Into<String>) -> Self {
451        Self::new(StatusCode::INTERNAL_SERVER_ERROR, message)
452    }
453}
454
455impl From<FastMcpError> for HttpError {
456    fn from(err: FastMcpError) -> Self {
457        match err {
458            FastMcpError::ToolNotFound(_)
459            | FastMcpError::ResourceNotFound(_)
460            | FastMcpError::PromptNotFound(_) => Self {
461                status: StatusCode::NOT_FOUND,
462                message: err.to_string(),
463            },
464            FastMcpError::DuplicateTool(_)
465            | FastMcpError::DuplicateResource(_)
466            | FastMcpError::DuplicatePrompt(_) => Self {
467                status: StatusCode::CONFLICT,
468                message: err.to_string(),
469            },
470            FastMcpError::InvalidInvocation(_) => Self {
471                status: StatusCode::UNPROCESSABLE_ENTITY,
472                message: err.to_string(),
473            },
474            FastMcpError::HandlerError(_) | FastMcpError::Serialization(_) => Self {
475                status: StatusCode::INTERNAL_SERVER_ERROR,
476                message: err.to_string(),
477            },
478            FastMcpError::Io(_) => Self {
479                status: StatusCode::INTERNAL_SERVER_ERROR,
480                message: err.to_string(),
481            },
482        }
483    }
484}
485
486impl IntoResponse for HttpError {
487    fn into_response(self) -> axum::response::Response {
488        let body = Json(ErrorBody {
489            error: self.message,
490        });
491        (self.status, body).into_response()
492    }
493}
494fn log_http_startup(server: &FastMcpServer, addr: &SocketAddr) {
495    let metadata = server.metadata();
496    let base_url = format!("http://{}", addr);
497    let host = addr.ip();
498    let port = addr.port();
499    let tools = server
500        .list_tools()
501        .into_iter()
502        .map(|tool| tool.name)
503        .collect::<Vec<_>>();
504    let resources = server
505        .list_resources()
506        .into_iter()
507        .map(|resource| resource.uri)
508        .collect::<Vec<_>>();
509    let prompts = server
510        .list_prompts()
511        .into_iter()
512        .map(|prompt| prompt.name)
513        .collect::<Vec<_>>();
514
515    let mut lines = Vec::new();
516    lines.push(format!(
517        "FastMCP '{}' (id: {}) listening on {}",
518        metadata.name, metadata.id, base_url
519    ));
520    lines.push(format!("  Host: {}", host));
521    lines.push(format!("  Port: {}", port));
522    lines.push(format!(
523        "  Instructions: {}",
524        metadata
525            .instructions
526            .as_deref()
527            .unwrap_or("No instructions configured")
528    ));
529    lines.push(format!(
530        "  Registered tools: {}",
531        if tools.is_empty() {
532            "none".into()
533        } else {
534            tools.join(", ")
535        }
536    ));
537    lines.push(format!(
538        "  Registered resources: {}",
539        if resources.is_empty() {
540            "none".into()
541        } else {
542            resources.join(", ")
543        }
544    ));
545    lines.push(format!(
546        "  Registered prompts: {}",
547        if prompts.is_empty() {
548            "none".into()
549        } else {
550            prompts.join(", ")
551        }
552    ));
553    lines.push(format!("  HTTP base URL: {}", base_url));
554    lines.push(format!("  HTTP base URI: mcp+http://{}", addr));
555    lines.push(format!("  SSE endpoint: {}/sse", base_url));
556    lines.push(format!("  SSE URI: mcp+sse://{}/sse", addr));
557    lines.push("  Streamable HTTP endpoints:".to_string());
558    lines.push(format!(
559        "    session: {}/streamable/session (URI: mcp+streamable-http://{}/streamable/session)",
560        base_url, addr
561    ));
562    lines.push(format!(
563        "    session/{{id}}: {}/streamable/session/{{id}} (URI: mcp+streamable-http://{}/streamable/session/{{id}})",
564        base_url, addr
565    ));
566    lines.push(format!(
567        "    session/{{id}}/messages: {}/streamable/session/{{id}}/messages (URI: mcp+streamable-http://{}/streamable/session/{{id}}/messages)",
568        base_url, addr
569    ));
570    lines.push(format!(
571        "  Message gateway: {}/messages (URI: mcp+http://{}/messages)",
572        base_url, addr
573    ));
574
575    emit_startup_lines(lines);
576}
577
578fn emit_startup_lines(lines: Vec<String>) {
579    for line in lines {
580        tracing::info!("{}", line);
581        println!("{line}");
582    }
583}