agenterra_rmcp/transport/streamable_http_server/
tower.rs1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::UnsyncBoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9
10use super::session::SessionManager;
11use crate::{
12 RoleServer,
13 model::{ClientJsonRpcMessage, ClientRequest, GetExtensions},
14 serve_server,
15 service::serve_directly,
16 transport::{
17 OneshotTransport, TransportAdapterIdentity,
18 common::{
19 http_header::{
20 EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
21 },
22 server_side_http::{
23 BoxResponse, ServerSseMessage, accepted_response, expect_json,
24 internal_error_response, sse_stream_response, unexpected_message_response,
25 },
26 },
27 },
28};
29
30#[derive(Debug, Clone)]
31pub struct StreamableHttpServerConfig {
32 pub sse_keep_alive: Option<Duration>,
34 pub stateful_mode: bool,
36}
37
38impl Default for StreamableHttpServerConfig {
39 fn default() -> Self {
40 Self {
41 sse_keep_alive: Some(Duration::from_secs(15)),
42 stateful_mode: true,
43 }
44 }
45}
46
47pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
48 pub config: StreamableHttpServerConfig,
49 session_manager: Arc<M>,
50 service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
51}
52
53impl<S, M> Clone for StreamableHttpService<S, M> {
54 fn clone(&self) -> Self {
55 Self {
56 config: self.config.clone(),
57 session_manager: self.session_manager.clone(),
58 service_factory: self.service_factory.clone(),
59 }
60 }
61}
62
63impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
64where
65 RequestBody: Body + Send + 'static,
66 S: crate::Service<RoleServer>,
67 M: SessionManager,
68 RequestBody::Error: Display,
69 RequestBody::Data: Send + 'static,
70{
71 type Response = BoxResponse;
72 type Error = Infallible;
73 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74 fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
75 let service = self.clone();
76 Box::pin(async move {
77 let response = service.handle(req).await;
78 Ok(response)
79 })
80 }
81 fn poll_ready(
82 &mut self,
83 _cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<Result<(), Self::Error>> {
85 std::task::Poll::Ready(Ok(()))
86 }
87}
88
89impl<S, M> StreamableHttpService<S, M>
90where
91 S: crate::Service<RoleServer> + Send + 'static,
92 M: SessionManager,
93{
94 pub fn new(
95 service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
96 session_manager: Arc<M>,
97 config: StreamableHttpServerConfig,
98 ) -> Self {
99 Self {
100 config,
101 session_manager,
102 service_factory: Arc::new(service_factory),
103 }
104 }
105 fn get_service(&self) -> Result<S, std::io::Error> {
106 (self.service_factory)()
107 }
108 pub async fn handle<B>(&self, request: Request<B>) -> Response<UnsyncBoxBody<Bytes, Infallible>>
109 where
110 B: Body + Send + 'static,
111 B::Error: Display,
112 {
113 let method = request.method().clone();
114 let result = match method {
115 Method::GET => self.handle_get(request).await,
116 Method::POST => self.handle_post(request).await,
117 Method::DELETE => self.handle_delete(request).await,
118 _ => {
119 let response = Response::builder()
121 .status(http::StatusCode::METHOD_NOT_ALLOWED)
122 .header(ALLOW, "GET, POST, DELETE")
123 .body(Full::new(Bytes::from("Method Not Allowed")).boxed_unsync())
124 .expect("valid response");
125 return response;
126 }
127 };
128 match result {
129 Ok(response) => response,
130 Err(response) => response,
131 }
132 }
133 async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
134 where
135 B: Body + Send + 'static,
136 B::Error: Display,
137 {
138 if !request
140 .headers()
141 .get(http::header::ACCEPT)
142 .and_then(|header| header.to_str().ok())
143 .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
144 {
145 return Ok(Response::builder()
146 .status(http::StatusCode::NOT_ACCEPTABLE)
147 .body(
148 Full::new(Bytes::from(
149 "Not Acceptable: Client must accept text/event-stream",
150 ))
151 .boxed_unsync(),
152 )
153 .expect("valid response"));
154 }
155 let session_id = request
157 .headers()
158 .get(HEADER_SESSION_ID)
159 .and_then(|v| v.to_str().ok())
160 .map(|s| s.to_owned().into());
161 let Some(session_id) = session_id else {
162 return Ok(Response::builder()
164 .status(http::StatusCode::UNAUTHORIZED)
165 .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed_unsync())
166 .expect("valid response"));
167 };
168 let has_session = self
170 .session_manager
171 .has_session(&session_id)
172 .await
173 .map_err(internal_error_response("check session"))?;
174 if !has_session {
175 return Ok(Response::builder()
177 .status(http::StatusCode::UNAUTHORIZED)
178 .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed_unsync())
179 .expect("valid response"));
180 }
181 let last_event_id = request
183 .headers()
184 .get(HEADER_LAST_EVENT_ID)
185 .and_then(|v| v.to_str().ok())
186 .map(|s| s.to_owned());
187 if let Some(last_event_id) = last_event_id {
188 let stream = self
190 .session_manager
191 .resume(&session_id, last_event_id)
192 .await
193 .map_err(internal_error_response("resume session"))?;
194 Ok(sse_stream_response(stream, self.config.sse_keep_alive))
195 } else {
196 let stream = self
198 .session_manager
199 .create_standalone_stream(&session_id)
200 .await
201 .map_err(internal_error_response("create standalone stream"))?;
202 Ok(sse_stream_response(stream, self.config.sse_keep_alive))
203 }
204 }
205
206 async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
207 where
208 B: Body + Send + 'static,
209 B::Error: Display,
210 {
211 if !request
213 .headers()
214 .get(http::header::ACCEPT)
215 .and_then(|header| header.to_str().ok())
216 .is_some_and(|header| {
217 header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
218 })
219 {
220 return Ok(Response::builder()
221 .status(http::StatusCode::NOT_ACCEPTABLE)
222 .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed_unsync())
223 .expect("valid response"));
224 }
225
226 if !request
228 .headers()
229 .get(http::header::CONTENT_TYPE)
230 .and_then(|header| header.to_str().ok())
231 .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
232 {
233 return Ok(Response::builder()
234 .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
235 .body(
236 Full::new(Bytes::from(
237 "Unsupported Media Type: Content-Type must be application/json",
238 ))
239 .boxed_unsync(),
240 )
241 .expect("valid response"));
242 }
243
244 let (part, body) = request.into_parts();
246 let mut message = match expect_json(body).await {
247 Ok(message) => message,
248 Err(response) => return Ok(response),
249 };
250
251 if self.config.stateful_mode {
252 let session_id = part
254 .headers
255 .get(HEADER_SESSION_ID)
256 .and_then(|v| v.to_str().ok());
257 if let Some(session_id) = session_id {
258 let session_id = session_id.to_owned().into();
259 let has_session = self
260 .session_manager
261 .has_session(&session_id)
262 .await
263 .map_err(internal_error_response("check session"))?;
264 if !has_session {
265 return Ok(Response::builder()
267 .status(http::StatusCode::UNAUTHORIZED)
268 .body(
269 Full::new(Bytes::from("Unauthorized: Session not found"))
270 .boxed_unsync(),
271 )
272 .expect("valid response"));
273 }
274
275 match &mut message {
277 ClientJsonRpcMessage::Request(req) => {
278 req.request.extensions_mut().insert(part);
279 }
280 ClientJsonRpcMessage::Notification(not) => {
281 not.notification.extensions_mut().insert(part);
282 }
283 _ => {
284 }
286 }
287
288 match message {
289 ClientJsonRpcMessage::Request(_) => {
290 let stream = self
291 .session_manager
292 .create_stream(&session_id, message)
293 .await
294 .map_err(internal_error_response("get session"))?;
295 Ok(sse_stream_response(stream, self.config.sse_keep_alive))
296 }
297 ClientJsonRpcMessage::Notification(_)
298 | ClientJsonRpcMessage::Response(_)
299 | ClientJsonRpcMessage::Error(_) => {
300 self.session_manager
302 .accept_message(&session_id, message)
303 .await
304 .map_err(internal_error_response("accept message"))?;
305 Ok(accepted_response())
306 }
307 _ => Ok(Response::builder()
308 .status(http::StatusCode::NOT_IMPLEMENTED)
309 .body(
310 Full::new(Bytes::from("Batch requests are not supported yet"))
311 .boxed_unsync(),
312 )
313 .expect("valid response")),
314 }
315 } else {
316 let (session_id, transport) = self
317 .session_manager
318 .create_session()
319 .await
320 .map_err(internal_error_response("create session"))?;
321 if let ClientJsonRpcMessage::Request(req) = &mut message {
322 if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
323 return Err(unexpected_message_response("initialize request"));
324 }
325 req.request.extensions_mut().insert(part);
327 } else {
328 return Err(unexpected_message_response("initialize request"));
329 }
330 let service = self
331 .get_service()
332 .map_err(internal_error_response("get service"))?;
333 tokio::spawn({
335 let session_manager = self.session_manager.clone();
336 let session_id = session_id.clone();
337 async move {
338 let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
339 service, transport,
340 )
341 .await;
342 match service {
343 Ok(service) => {
344 let _ = service.waiting().await;
346 }
347 Err(e) => {
348 tracing::error!("Failed to create service: {e}");
349 }
350 }
351 let _ = session_manager
352 .close_session(&session_id)
353 .await
354 .inspect_err(|e| {
355 tracing::error!("Failed to close session {session_id}: {e}");
356 });
357 }
358 });
359 let response = self
361 .session_manager
362 .initialize_session(&session_id, message)
363 .await
364 .map_err(internal_error_response("create stream"))?;
365 let mut response = sse_stream_response(
366 futures::stream::once({
367 async move {
368 ServerSseMessage {
369 event_id: None,
370 message: response.into(),
371 }
372 }
373 }),
374 self.config.sse_keep_alive,
375 );
376
377 response.headers_mut().insert(
378 HEADER_SESSION_ID,
379 session_id
380 .parse()
381 .map_err(internal_error_response("create session id header"))?,
382 );
383 Ok(response)
384 }
385 } else {
386 let service = self
387 .get_service()
388 .map_err(internal_error_response("get service"))?;
389 match message {
390 ClientJsonRpcMessage::Request(mut request) => {
391 request.request.extensions_mut().insert(part);
392 let (transport, receiver) =
393 OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
394 let service = serve_directly(service, transport, None);
395 tokio::spawn(async move {
396 let _ = service.waiting().await;
398 });
399 Ok(sse_stream_response(
400 ReceiverStream::new(receiver).map(|message| {
401 tracing::info!(?message);
402 ServerSseMessage {
403 event_id: None,
404 message: message.into(),
405 }
406 }),
407 self.config.sse_keep_alive,
408 ))
409 }
410 ClientJsonRpcMessage::Notification(_notification) => {
411 Ok(accepted_response())
413 }
414 ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
415 ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
416 _ => Ok(Response::builder()
417 .status(http::StatusCode::NOT_IMPLEMENTED)
418 .body(
419 Full::new(Bytes::from("Batch requests are not supported yet"))
420 .boxed_unsync(),
421 )
422 .expect("valid response")),
423 }
424 }
425 }
426
427 async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
428 where
429 B: Body + Send + 'static,
430 B::Error: Display,
431 {
432 let session_id = request
434 .headers()
435 .get(HEADER_SESSION_ID)
436 .and_then(|v| v.to_str().ok())
437 .map(|s| s.to_owned().into());
438 let Some(session_id) = session_id else {
439 return Ok(Response::builder()
441 .status(http::StatusCode::UNAUTHORIZED)
442 .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed_unsync())
443 .expect("valid response"));
444 };
445 self.session_manager
447 .close_session(&session_id)
448 .await
449 .map_err(internal_error_response("close session"))?;
450 Ok(accepted_response())
451 }
452}