Skip to main content

agenterra_rmcp/transport/streamable_http_server/
tower.rs

1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::UnsyncBoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9
10use super::session::SessionManager;
11use crate::{
12    RoleServer,
13    model::{ClientJsonRpcMessage, ClientRequest, GetExtensions},
14    serve_server,
15    service::serve_directly,
16    transport::{
17        OneshotTransport, TransportAdapterIdentity,
18        common::{
19            http_header::{
20                EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
21            },
22            server_side_http::{
23                BoxResponse, ServerSseMessage, accepted_response, expect_json,
24                internal_error_response, sse_stream_response, unexpected_message_response,
25            },
26        },
27    },
28};
29
30#[derive(Debug, Clone)]
31pub struct StreamableHttpServerConfig {
32    /// The ping message duration for SSE connections.
33    pub sse_keep_alive: Option<Duration>,
34    /// If true, the server will create a session for each request and keep it alive.
35    pub stateful_mode: bool,
36}
37
38impl Default for StreamableHttpServerConfig {
39    fn default() -> Self {
40        Self {
41            sse_keep_alive: Some(Duration::from_secs(15)),
42            stateful_mode: true,
43        }
44    }
45}
46
47pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
48    pub config: StreamableHttpServerConfig,
49    session_manager: Arc<M>,
50    service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
51}
52
53impl<S, M> Clone for StreamableHttpService<S, M> {
54    fn clone(&self) -> Self {
55        Self {
56            config: self.config.clone(),
57            session_manager: self.session_manager.clone(),
58            service_factory: self.service_factory.clone(),
59        }
60    }
61}
62
63impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
64where
65    RequestBody: Body + Send + 'static,
66    S: crate::Service<RoleServer>,
67    M: SessionManager,
68    RequestBody::Error: Display,
69    RequestBody::Data: Send + 'static,
70{
71    type Response = BoxResponse;
72    type Error = Infallible;
73    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74    fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
75        let service = self.clone();
76        Box::pin(async move {
77            let response = service.handle(req).await;
78            Ok(response)
79        })
80    }
81    fn poll_ready(
82        &mut self,
83        _cx: &mut std::task::Context<'_>,
84    ) -> std::task::Poll<Result<(), Self::Error>> {
85        std::task::Poll::Ready(Ok(()))
86    }
87}
88
89impl<S, M> StreamableHttpService<S, M>
90where
91    S: crate::Service<RoleServer> + Send + 'static,
92    M: SessionManager,
93{
94    pub fn new(
95        service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
96        session_manager: Arc<M>,
97        config: StreamableHttpServerConfig,
98    ) -> Self {
99        Self {
100            config,
101            session_manager,
102            service_factory: Arc::new(service_factory),
103        }
104    }
105    fn get_service(&self) -> Result<S, std::io::Error> {
106        (self.service_factory)()
107    }
108    pub async fn handle<B>(&self, request: Request<B>) -> Response<UnsyncBoxBody<Bytes, Infallible>>
109    where
110        B: Body + Send + 'static,
111        B::Error: Display,
112    {
113        let method = request.method().clone();
114        let result = match method {
115            Method::GET => self.handle_get(request).await,
116            Method::POST => self.handle_post(request).await,
117            Method::DELETE => self.handle_delete(request).await,
118            _ => {
119                // Handle other methods or return an error
120                let response = Response::builder()
121                    .status(http::StatusCode::METHOD_NOT_ALLOWED)
122                    .header(ALLOW, "GET, POST, DELETE")
123                    .body(Full::new(Bytes::from("Method Not Allowed")).boxed_unsync())
124                    .expect("valid response");
125                return response;
126            }
127        };
128        match result {
129            Ok(response) => response,
130            Err(response) => response,
131        }
132    }
133    async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
134    where
135        B: Body + Send + 'static,
136        B::Error: Display,
137    {
138        // check accept header
139        if !request
140            .headers()
141            .get(http::header::ACCEPT)
142            .and_then(|header| header.to_str().ok())
143            .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
144        {
145            return Ok(Response::builder()
146                .status(http::StatusCode::NOT_ACCEPTABLE)
147                .body(
148                    Full::new(Bytes::from(
149                        "Not Acceptable: Client must accept text/event-stream",
150                    ))
151                    .boxed_unsync(),
152                )
153                .expect("valid response"));
154        }
155        // check session id
156        let session_id = request
157            .headers()
158            .get(HEADER_SESSION_ID)
159            .and_then(|v| v.to_str().ok())
160            .map(|s| s.to_owned().into());
161        let Some(session_id) = session_id else {
162            // unauthorized
163            return Ok(Response::builder()
164                .status(http::StatusCode::UNAUTHORIZED)
165                .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed_unsync())
166                .expect("valid response"));
167        };
168        // check if session exists
169        let has_session = self
170            .session_manager
171            .has_session(&session_id)
172            .await
173            .map_err(internal_error_response("check session"))?;
174        if !has_session {
175            // unauthorized
176            return Ok(Response::builder()
177                .status(http::StatusCode::UNAUTHORIZED)
178                .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed_unsync())
179                .expect("valid response"));
180        }
181        // check if last event id is provided
182        let last_event_id = request
183            .headers()
184            .get(HEADER_LAST_EVENT_ID)
185            .and_then(|v| v.to_str().ok())
186            .map(|s| s.to_owned());
187        if let Some(last_event_id) = last_event_id {
188            // check if session has this event id
189            let stream = self
190                .session_manager
191                .resume(&session_id, last_event_id)
192                .await
193                .map_err(internal_error_response("resume session"))?;
194            Ok(sse_stream_response(stream, self.config.sse_keep_alive))
195        } else {
196            // create standalone stream
197            let stream = self
198                .session_manager
199                .create_standalone_stream(&session_id)
200                .await
201                .map_err(internal_error_response("create standalone stream"))?;
202            Ok(sse_stream_response(stream, self.config.sse_keep_alive))
203        }
204    }
205
206    async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
207    where
208        B: Body + Send + 'static,
209        B::Error: Display,
210    {
211        // check accept header
212        if !request
213            .headers()
214            .get(http::header::ACCEPT)
215            .and_then(|header| header.to_str().ok())
216            .is_some_and(|header| {
217                header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
218            })
219        {
220            return Ok(Response::builder()
221                .status(http::StatusCode::NOT_ACCEPTABLE)
222                .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed_unsync())
223                .expect("valid response"));
224        }
225
226        // check content type
227        if !request
228            .headers()
229            .get(http::header::CONTENT_TYPE)
230            .and_then(|header| header.to_str().ok())
231            .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
232        {
233            return Ok(Response::builder()
234                .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
235                .body(
236                    Full::new(Bytes::from(
237                        "Unsupported Media Type: Content-Type must be application/json",
238                    ))
239                    .boxed_unsync(),
240                )
241                .expect("valid response"));
242        }
243
244        // json deserialize request body
245        let (part, body) = request.into_parts();
246        let mut message = match expect_json(body).await {
247            Ok(message) => message,
248            Err(response) => return Ok(response),
249        };
250
251        if self.config.stateful_mode {
252            // do we have a session id?
253            let session_id = part
254                .headers
255                .get(HEADER_SESSION_ID)
256                .and_then(|v| v.to_str().ok());
257            if let Some(session_id) = session_id {
258                let session_id = session_id.to_owned().into();
259                let has_session = self
260                    .session_manager
261                    .has_session(&session_id)
262                    .await
263                    .map_err(internal_error_response("check session"))?;
264                if !has_session {
265                    // unauthorized
266                    return Ok(Response::builder()
267                        .status(http::StatusCode::UNAUTHORIZED)
268                        .body(
269                            Full::new(Bytes::from("Unauthorized: Session not found"))
270                                .boxed_unsync(),
271                        )
272                        .expect("valid response"));
273                }
274
275                // inject request part to extensions
276                match &mut message {
277                    ClientJsonRpcMessage::Request(req) => {
278                        req.request.extensions_mut().insert(part);
279                    }
280                    ClientJsonRpcMessage::Notification(not) => {
281                        not.notification.extensions_mut().insert(part);
282                    }
283                    _ => {
284                        // skip
285                    }
286                }
287
288                match message {
289                    ClientJsonRpcMessage::Request(_) => {
290                        let stream = self
291                            .session_manager
292                            .create_stream(&session_id, message)
293                            .await
294                            .map_err(internal_error_response("get session"))?;
295                        Ok(sse_stream_response(stream, self.config.sse_keep_alive))
296                    }
297                    ClientJsonRpcMessage::Notification(_)
298                    | ClientJsonRpcMessage::Response(_)
299                    | ClientJsonRpcMessage::Error(_) => {
300                        // handle notification
301                        self.session_manager
302                            .accept_message(&session_id, message)
303                            .await
304                            .map_err(internal_error_response("accept message"))?;
305                        Ok(accepted_response())
306                    }
307                    _ => Ok(Response::builder()
308                        .status(http::StatusCode::NOT_IMPLEMENTED)
309                        .body(
310                            Full::new(Bytes::from("Batch requests are not supported yet"))
311                                .boxed_unsync(),
312                        )
313                        .expect("valid response")),
314                }
315            } else {
316                let (session_id, transport) = self
317                    .session_manager
318                    .create_session()
319                    .await
320                    .map_err(internal_error_response("create session"))?;
321                if let ClientJsonRpcMessage::Request(req) = &mut message {
322                    if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
323                        return Err(unexpected_message_response("initialize request"));
324                    }
325                    // inject request part to extensions
326                    req.request.extensions_mut().insert(part);
327                } else {
328                    return Err(unexpected_message_response("initialize request"));
329                }
330                let service = self
331                    .get_service()
332                    .map_err(internal_error_response("get service"))?;
333                // spawn a task to serve the session
334                tokio::spawn({
335                    let session_manager = self.session_manager.clone();
336                    let session_id = session_id.clone();
337                    async move {
338                        let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
339                            service, transport,
340                        )
341                        .await;
342                        match service {
343                            Ok(service) => {
344                                // on service created
345                                let _ = service.waiting().await;
346                            }
347                            Err(e) => {
348                                tracing::error!("Failed to create service: {e}");
349                            }
350                        }
351                        let _ = session_manager
352                            .close_session(&session_id)
353                            .await
354                            .inspect_err(|e| {
355                                tracing::error!("Failed to close session {session_id}: {e}");
356                            });
357                    }
358                });
359                // get initialize response
360                let response = self
361                    .session_manager
362                    .initialize_session(&session_id, message)
363                    .await
364                    .map_err(internal_error_response("create stream"))?;
365                let mut response = sse_stream_response(
366                    futures::stream::once({
367                        async move {
368                            ServerSseMessage {
369                                event_id: None,
370                                message: response.into(),
371                            }
372                        }
373                    }),
374                    self.config.sse_keep_alive,
375                );
376
377                response.headers_mut().insert(
378                    HEADER_SESSION_ID,
379                    session_id
380                        .parse()
381                        .map_err(internal_error_response("create session id header"))?,
382                );
383                Ok(response)
384            }
385        } else {
386            let service = self
387                .get_service()
388                .map_err(internal_error_response("get service"))?;
389            match message {
390                ClientJsonRpcMessage::Request(mut request) => {
391                    request.request.extensions_mut().insert(part);
392                    let (transport, receiver) =
393                        OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
394                    let service = serve_directly(service, transport, None);
395                    tokio::spawn(async move {
396                        // on service created
397                        let _ = service.waiting().await;
398                    });
399                    Ok(sse_stream_response(
400                        ReceiverStream::new(receiver).map(|message| {
401                            tracing::info!(?message);
402                            ServerSseMessage {
403                                event_id: None,
404                                message: message.into(),
405                            }
406                        }),
407                        self.config.sse_keep_alive,
408                    ))
409                }
410                ClientJsonRpcMessage::Notification(_notification) => {
411                    // ignore
412                    Ok(accepted_response())
413                }
414                ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
415                ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
416                _ => Ok(Response::builder()
417                    .status(http::StatusCode::NOT_IMPLEMENTED)
418                    .body(
419                        Full::new(Bytes::from("Batch requests are not supported yet"))
420                            .boxed_unsync(),
421                    )
422                    .expect("valid response")),
423            }
424        }
425    }
426
427    async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
428    where
429        B: Body + Send + 'static,
430        B::Error: Display,
431    {
432        // check session id
433        let session_id = request
434            .headers()
435            .get(HEADER_SESSION_ID)
436            .and_then(|v| v.to_str().ok())
437            .map(|s| s.to_owned().into());
438        let Some(session_id) = session_id else {
439            // unauthorized
440            return Ok(Response::builder()
441                .status(http::StatusCode::UNAUTHORIZED)
442                .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed_unsync())
443                .expect("valid response"));
444        };
445        // close session
446        self.session_manager
447            .close_session(&session_id)
448            .await
449            .map_err(internal_error_response("close session"))?;
450        Ok(accepted_response())
451    }
452}