rmcp-actix-web 0.6.1

actix-web transport implementations for RMCP (Rust Model Context Protocol)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
//! Server-Sent Events (SSE) transport implementation for MCP.
//!
//! This module provides a unidirectional transport using the SSE protocol,
//! allowing servers to push real-time updates to clients over a standard HTTP connection.
//!
//! ## Architecture
//!
//! The SSE transport consists of two HTTP endpoints:
//! - **SSE endpoint** (`/sse` by default): Clients connect here to receive server-sent events
//! - **POST endpoint** (`/message` by default): Clients send JSON-RPC messages here
//!
//! ## Connection Flow
//!
//! 1. Client connects to the SSE endpoint with a session ID
//! 2. Server establishes an event stream for real-time messages
//! 3. Client sends requests to the POST endpoint with the same session ID
//! 4. Server processes requests and sends responses via the SSE stream
//!
//! ## Features
//!
//! - Automatic keep-alive pings to maintain connections
//! - Session management for multiple concurrent clients
//! - Builder pattern for configuration
//! - Compatible with proxies and firewalls
//!
//! ## Example
//!
//! ```rust,no_run
//! use rmcp_actix_web::SseService;
//! use actix_web::{App, web};
//! use std::time::Duration;
//!
//! # struct MyService;
//! # use rmcp::{ServerHandler, model::ServerInfo};
//! # impl ServerHandler for MyService {
//! #     fn get_info(&self) -> ServerInfo { ServerInfo::default() }
//! # }
//! # impl MyService {
//! #     fn new() -> Self { Self }
//! # }
//! #[actix_web::main]
//! async fn main() -> std::io::Result<()> {
//!     let sse_service = SseService::builder()
//!         .service_factory(std::sync::Arc::new(|| Ok(MyService::new())))
//!         .sse_path("/events".to_string())
//!         .post_path("/messages".to_string())
//!         .sse_keep_alive(Duration::from_secs(30))
//!         .build();
//!     
//!     let app = App::new()
//!         .service(web::scope("/api").service(sse_service.scope()));
//!     
//!     Ok(())
//! }
//! ```

use std::{collections::HashMap, sync::Arc, time::Duration};

use actix_web::{
    HttpRequest, HttpResponse, Result, Scope,
    error::ErrorInternalServerError,
    http::header::{self, CACHE_CONTROL, CONTENT_TYPE},
    middleware,
    web::{self, Bytes, Data, Json, Query},
};
use futures::{Sink, SinkExt, Stream, StreamExt};
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::PollSender;

use crate::transport::AuthorizationHeader;
use rmcp::{
    RoleServer,
    model::{ClientJsonRpcMessage, GetExtensions},
    service::{RxJsonRpcMessage, TxJsonRpcMessage, serve_directly_with_ct},
    transport::common::server_side_http::{DEFAULT_AUTO_PING_INTERVAL, SessionId, session_id},
};

const HEADER_X_ACCEL_BUFFERING: &str = "X-Accel-Buffering";

type TxStore =
    Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::mpsc::Sender<ClientJsonRpcMessage>>>>;

#[derive(Clone, Debug)]
struct AppData {
    txs: TxStore,
    transport_tx: tokio::sync::mpsc::UnboundedSender<SseServerTransport>,
    post_path: Arc<str>,
    sse_path: Arc<str>,
    sse_ping_interval: Duration,
}

// AppData::new is no longer used since we create AppData directly
// in the scope method with shared session storage

#[doc(hidden)]
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostEventQuery {
    /// The session ID from the query string
    pub session_id: String,
}

async fn post_event_handler(
    app_data: Data<AppData>,
    query: Query<PostEventQuery>,
    req: HttpRequest,
    mut message: Json<ClientJsonRpcMessage>,
) -> Result<HttpResponse> {
    let session_id = &query.session_id;
    tracing::debug!(session_id, ?message, "new client message");

    // Extract and inject Authorization header if present (Bearer tokens only)
    if let ClientJsonRpcMessage::Request(request_msg) = &mut message.0
        && let Some(auth_value) = req.headers().get(header::AUTHORIZATION)
        && let Ok(auth_str) = auth_value.to_str()
        && auth_str.starts_with("Bearer ")
    {
        request_msg
            .request
            .extensions_mut()
            .insert(AuthorizationHeader(auth_str.to_string()));
        tracing::debug!("Forwarding Authorization header for MCP proxy scenario");
    }

    let tx = {
        let rg = app_data.txs.read().await;
        rg.get(session_id.as_str())
            .ok_or_else(|| actix_web::error::ErrorNotFound("Session not found"))?
            .clone()
    };

    if tx.send(message.0).await.is_err() {
        tracing::error!("send message error");
        return Err(actix_web::error::ErrorGone("Session closed"));
    }

    Ok(HttpResponse::Accepted().finish())
}

async fn sse_handler(app_data: Data<AppData>, req: HttpRequest) -> Result<HttpResponse> {
    let session = session_id();
    tracing::info!(%session, "sse connection");

    let (from_client_tx, from_client_rx) = tokio::sync::mpsc::channel(64);
    let (to_client_tx, to_client_rx) = tokio::sync::mpsc::channel(64);
    let to_client_tx_clone = to_client_tx.clone();

    app_data
        .txs
        .write()
        .await
        .insert(session.clone(), from_client_tx);

    let stream = ReceiverStream::new(from_client_rx);
    let sink = PollSender::new(to_client_tx);
    let transport = SseServerTransport {
        stream,
        sink,
        session_id: session.clone(),
        tx_store: app_data.txs.clone(),
    };

    let transport_send_result = app_data.transport_tx.send(transport);
    if transport_send_result.is_err() {
        tracing::warn!("send transport out error");
        return Err(ErrorInternalServerError(
            "Failed to send transport, server is closed",
        ));
    }

    let post_path = app_data.post_path.clone();
    let ping_interval = app_data.sse_ping_interval;
    let session_for_stream = session.clone();

    // Get the current path prefix from the request (remove the SSE endpoint part)
    let current_path = req.path();
    let sse_endpoint = &app_data.sse_path;
    let path_prefix = if current_path.ends_with(sse_endpoint.as_ref()) {
        &current_path[..current_path.len() - sse_endpoint.len()]
    } else {
        current_path
    };
    let relative_post_path = format!("{}{}", path_prefix, post_path);

    // Create SSE response stream
    let sse_stream = async_stream::stream! {
        // Send initial endpoint message
        yield Ok::<_, actix_web::Error>(Bytes::from(format!(
            "event: endpoint\ndata: {}?sessionId={}\n\n", relative_post_path, session_for_stream
        )));

        // Set up ping interval
        let mut ping_interval = tokio::time::interval(ping_interval);
        ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

        let mut rx = ReceiverStream::new(to_client_rx);

        loop {
            tokio::select! {
                Some(message) = rx.next() => {
                    match serde_json::to_string(&message) {
                        Ok(json) => {
                            yield Ok(Bytes::from(format!("event: message\ndata: {json}\n\n")));
                        }
                        Err(e) => {
                            tracing::error!("Failed to serialize message: {}", e);
                        }
                    }
                }
                _ = ping_interval.tick() => {
                    yield Ok(Bytes::from(": ping\n\n"));
                }
                else => break,
            }
        }
    };

    // Clean up on disconnect
    let app_data_clone = app_data.clone();
    let session_for_cleanup = session.clone();
    actix_rt::spawn(async move {
        to_client_tx_clone.closed().await;

        let mut txs = app_data_clone.txs.write().await;
        txs.remove(&session_for_cleanup);
        tracing::debug!(%session_for_cleanup, "Closed session and cleaned up resources");
    });

    Ok(HttpResponse::Ok()
        .insert_header((CONTENT_TYPE, "text/event-stream"))
        .insert_header((CACHE_CONTROL, "no-cache"))
        .insert_header((HEADER_X_ACCEL_BUFFERING, "no"))
        .streaming(sse_stream))
}

/// Transport handle for an individual SSE client connection.
///
/// Implements both `Sink` and `Stream` traits to provide bidirectional communication
/// for a single client session. This is created internally for each client connection
/// and passed to the MCP service.
pub struct SseServerTransport {
    stream: ReceiverStream<RxJsonRpcMessage<RoleServer>>,
    sink: PollSender<TxJsonRpcMessage<RoleServer>>,
    session_id: SessionId,
    tx_store: TxStore,
}

impl Sink<TxJsonRpcMessage<RoleServer>> for SseServerTransport {
    type Error = std::io::Error;

    fn poll_ready(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.sink
            .poll_ready_unpin(cx)
            .map_err(std::io::Error::other)
    }

    fn start_send(
        mut self: std::pin::Pin<&mut Self>,
        item: TxJsonRpcMessage<RoleServer>,
    ) -> Result<(), Self::Error> {
        self.sink
            .start_send_unpin(item)
            .map_err(std::io::Error::other)
    }

    fn poll_flush(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.sink
            .poll_flush_unpin(cx)
            .map_err(std::io::Error::other)
    }

    fn poll_close(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        let inner_close_result = self
            .sink
            .poll_close_unpin(cx)
            .map_err(std::io::Error::other);
        if inner_close_result.is_ready() {
            let session_id = self.session_id.clone();
            let tx_store = self.tx_store.clone();
            tokio::spawn(async move {
                tx_store.write().await.remove(&session_id);
            });
        }
        inner_close_result
    }
}

impl Stream for SseServerTransport {
    type Item = RxJsonRpcMessage<RoleServer>;

    fn poll_next(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        self.stream.poll_next_unpin(cx)
    }
}

/// Server-Sent Events transport service for MCP.
///
/// Provides a unidirectional streaming transport from server to client using the SSE protocol.
/// Clients connect to the SSE endpoint to receive events and send requests via a separate POST endpoint.
/// Uses a builder pattern for configuration.
///
/// # Architecture
///
/// The service manages two endpoints:
/// - SSE endpoint for server-to-client streaming
/// - POST endpoint for client-to-server messages
///
/// Each client connection is identified by a unique session ID that must be provided
/// in both the SSE connection and POST requests.
///
/// # Example
///
/// ```rust,no_run
/// use rmcp_actix_web::SseService;
/// use actix_web::{App, web};
/// use std::time::Duration;
///
/// # use rmcp::{ServerHandler, model::ServerInfo};
/// # struct MyService;
/// # impl ServerHandler for MyService {
/// #     fn get_info(&self) -> ServerInfo { ServerInfo::default() }
/// # }
/// # impl MyService { fn new() -> Self { Self } }
///
/// let sse_service = SseService::builder()
///     .service_factory(std::sync::Arc::new(|| Ok(MyService::new())))
///     .sse_path("/events".to_string())
///     .post_path("/messages".to_string())
///     .sse_keep_alive(Duration::from_secs(30))
///     .build();
///     
/// let app = App::new()
///     .service(web::scope("/api").service(sse_service.scope()));
/// ```
#[derive(Clone, bon::Builder)]
pub struct SseService<S> {
    /// The service factory function that creates new MCP service instances
    service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,

    /// The path for the SSE endpoint
    #[builder(default = "/sse".to_string())]
    sse_path: String,

    /// The path for the POST message endpoint
    #[builder(default = "/message".to_string())]
    post_path: String,

    /// Optional keep-alive interval for SSE connections
    sse_keep_alive: Option<Duration>,

    /// Shared session storage across workers
    #[builder(skip = Default::default())]
    shared_txs: TxStore,
}

impl<S> SseService<S>
where
    S: rmcp::ServerHandler + Send + 'static,
{
    /// Creates a new scope configured with this service for framework-level composition.
    ///
    /// This method provides framework-level composition aligned with RMCP patterns,
    /// similar to how `StreamableHttpService::scope()` works. This allows mounting the
    /// SSE service at custom paths using actix-web's routing.
    ///
    /// # Returns
    ///
    /// Returns an actix-web `Scope` configured with the SSE routes
    ///
    /// # Example
    ///
    /// ```rust,no_run
    /// use rmcp_actix_web::SseService;
    /// use actix_web::{App, HttpServer, web};
    /// use std::time::Duration;
    ///
    /// # use rmcp::{ServerHandler, model::ServerInfo};
    /// # struct MyService;
    /// # impl ServerHandler for MyService {
    /// #     fn get_info(&self) -> ServerInfo { ServerInfo::default() }
    /// # }
    /// # impl MyService { fn new() -> Self { Self } }
    /// let service = SseService::builder()
    ///     .service_factory(std::sync::Arc::new(|| Ok(MyService::new())))
    ///     .sse_path("/events".to_string())
    ///     .post_path("/messages".to_string())
    ///     .build();
    ///     
    /// // Mount into existing app at a custom path
    /// let app = App::new()
    ///     .service(web::scope("/api/v1/mcp").service(service.scope()));
    /// ```
    pub fn scope(
        self,
    ) -> Scope<
        impl actix_web::dev::ServiceFactory<
            actix_web::dev::ServiceRequest,
            Config = (),
            Response = actix_web::dev::ServiceResponse,
            Error = actix_web::Error,
            InitError = (),
        >,
    > {
        let transport_rx = Arc::new(Mutex::new(None));
        let (transport_tx, rx) = tokio::sync::mpsc::unbounded_channel();
        *transport_rx
            .try_lock()
            .expect("Failed to acquire transport_rx lock") = Some(rx);

        // Create AppData with shared session storage
        let app_data = AppData {
            txs: self.shared_txs.clone(),
            transport_tx,
            post_path: self.post_path.clone().into(),
            sse_path: self.sse_path.clone().into(),
            sse_ping_interval: self.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL),
        };

        let sse_path = self.sse_path.clone();
        let post_path = self.post_path.clone();

        let app_data = Data::new(app_data);
        let service_factory = self.service_factory.clone();
        let transport_rx_clone = transport_rx.clone();

        // Start the service handler task
        actix_rt::spawn(async move {
            let mut transport_rx = transport_rx_clone.lock().await.take();
            if let Some(mut rx) = transport_rx.take() {
                while let Some(transport) = rx.recv().await {
                    let service = match service_factory() {
                        Ok(service) => service,
                        Err(e) => {
                            tracing::error!("Failed to create service: {}", e);
                            continue;
                        }
                    };

                    tokio::spawn(async move {
                        let server = serve_directly_with_ct(
                            service,
                            transport,
                            None,
                            tokio_util::sync::CancellationToken::new(),
                        );
                        if let Err(e) = server.waiting().await {
                            tracing::error!("Service error: {}", e);
                        }
                    });
                }
            }
        });

        web::scope("")
            .app_data(app_data.clone())
            .wrap(middleware::NormalizePath::trim())
            .route(&sse_path, web::get().to(sse_handler))
            .route(&post_path, web::post().to(post_event_handler))
    }
}