1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::SessionManager;
12use crate::{
13 RoleServer,
14 model::{ClientJsonRpcMessage, ClientRequest, GetExtensions},
15 serve_server,
16 service::serve_directly,
17 transport::{
18 OneshotTransport, TransportAdapterIdentity,
19 common::{
20 http_header::{
21 EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
22 },
23 server_side_http::{
24 BoxResponse, ServerSseMessage, accepted_response, expect_json,
25 internal_error_response, sse_stream_response, unexpected_message_response,
26 },
27 },
28 },
29};
30
31#[derive(Debug, Clone)]
32pub struct StreamableHttpServerConfig {
33 pub sse_keep_alive: Option<Duration>,
35 pub sse_retry: Option<Duration>,
37 pub stateful_mode: bool,
40 pub cancellation_token: CancellationToken,
45}
46
47impl Default for StreamableHttpServerConfig {
48 fn default() -> Self {
49 Self {
50 sse_keep_alive: Some(Duration::from_secs(15)),
51 sse_retry: Some(Duration::from_secs(3)),
52 stateful_mode: true,
53 cancellation_token: CancellationToken::new(),
54 }
55 }
56}
57
58pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
72 pub config: StreamableHttpServerConfig,
73 session_manager: Arc<M>,
74 service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
75}
76
77impl<S, M> Clone for StreamableHttpService<S, M> {
78 fn clone(&self) -> Self {
79 Self {
80 config: self.config.clone(),
81 session_manager: self.session_manager.clone(),
82 service_factory: self.service_factory.clone(),
83 }
84 }
85}
86
87impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
88where
89 RequestBody: Body + Send + 'static,
90 S: crate::Service<RoleServer>,
91 M: SessionManager,
92 RequestBody::Error: Display,
93 RequestBody::Data: Send + 'static,
94{
95 type Response = BoxResponse;
96 type Error = Infallible;
97 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
98 fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
99 let service = self.clone();
100 Box::pin(async move {
101 let response = service.handle(req).await;
102 Ok(response)
103 })
104 }
105 fn poll_ready(
106 &mut self,
107 _cx: &mut std::task::Context<'_>,
108 ) -> std::task::Poll<Result<(), Self::Error>> {
109 std::task::Poll::Ready(Ok(()))
110 }
111}
112
113impl<S, M> StreamableHttpService<S, M>
114where
115 S: crate::Service<RoleServer> + Send + 'static,
116 M: SessionManager,
117{
118 pub fn new(
119 service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
120 session_manager: Arc<M>,
121 config: StreamableHttpServerConfig,
122 ) -> Self {
123 Self {
124 config,
125 session_manager,
126 service_factory: Arc::new(service_factory),
127 }
128 }
129 fn get_service(&self) -> Result<S, std::io::Error> {
130 (self.service_factory)()
131 }
132 pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
133 where
134 B: Body + Send + 'static,
135 B::Error: Display,
136 {
137 let method = request.method().clone();
138 let allowed_methods = match self.config.stateful_mode {
139 true => "GET, POST, DELETE, OPTIONS",
140 false => "POST, OPTIONS",
141 };
142
143 if method == Method::OPTIONS {
145 let response = Response::builder()
146 .status(http::StatusCode::NO_CONTENT)
147 .header(ALLOW, allowed_methods)
148 .header("Access-Control-Allow-Origin", "*")
149 .header("Access-Control-Allow-Methods", allowed_methods)
150 .header(
151 "Access-Control-Allow-Headers",
152 "Content-Type, Accept, X-Session-Id, X-Last-Event-Id",
153 )
154 .header("Access-Control-Max-Age", "3600")
155 .body(Full::new(Bytes::new()).boxed())
156 .expect("valid response");
157 return response;
158 }
159
160 let result = match (method, self.config.stateful_mode) {
161 (Method::POST, _) => self.handle_post(request).await,
162 (Method::GET, true) => self.handle_get(request).await,
164 (Method::DELETE, true) => self.handle_delete(request).await,
165 _ => {
166 let response = Response::builder()
168 .status(http::StatusCode::METHOD_NOT_ALLOWED)
169 .header(ALLOW, allowed_methods)
170 .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
171 .expect("valid response");
172 return response;
173 }
174 };
175 match result {
176 Ok(response) => response,
177 Err(response) => response,
178 }
179 }
180 async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
181 where
182 B: Body + Send + 'static,
183 B::Error: Display,
184 {
185 if !request
187 .headers()
188 .get(http::header::ACCEPT)
189 .and_then(|header| header.to_str().ok())
190 .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
191 {
192 return Ok(Response::builder()
193 .status(http::StatusCode::NOT_ACCEPTABLE)
194 .body(
195 Full::new(Bytes::from(
196 "Not Acceptable: Client must accept text/event-stream",
197 ))
198 .boxed(),
199 )
200 .expect("valid response"));
201 }
202 let session_id = request
204 .headers()
205 .get(HEADER_SESSION_ID)
206 .and_then(|v| v.to_str().ok())
207 .map(|s| s.to_owned().into());
208 let Some(session_id) = session_id else {
209 return Ok(Response::builder()
211 .status(http::StatusCode::UNAUTHORIZED)
212 .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
213 .expect("valid response"));
214 };
215 let has_session = self
217 .session_manager
218 .has_session(&session_id)
219 .await
220 .map_err(internal_error_response("check session"))?;
221 if !has_session {
222 return Ok(Response::builder()
224 .status(http::StatusCode::UNAUTHORIZED)
225 .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
226 .expect("valid response"));
227 }
228 let last_event_id = request
230 .headers()
231 .get(HEADER_LAST_EVENT_ID)
232 .and_then(|v| v.to_str().ok())
233 .map(|s| s.to_owned());
234 if let Some(last_event_id) = last_event_id {
235 let stream = self
237 .session_manager
238 .resume(&session_id, last_event_id)
239 .await
240 .map_err(internal_error_response("resume session"))?;
241 Ok(sse_stream_response(
243 stream,
244 self.config.sse_keep_alive,
245 self.config.cancellation_token.child_token(),
246 ))
247 } else {
248 let stream = self
250 .session_manager
251 .create_standalone_stream(&session_id)
252 .await
253 .map_err(internal_error_response("create standalone stream"))?;
254 let stream = if let Some(retry) = self.config.sse_retry {
256 let priming = ServerSseMessage {
257 event_id: Some("0".into()),
258 message: None,
259 retry: Some(retry),
260 };
261 futures::stream::once(async move { priming })
262 .chain(stream)
263 .left_stream()
264 } else {
265 stream.right_stream()
266 };
267 Ok(sse_stream_response(
268 stream,
269 self.config.sse_keep_alive,
270 self.config.cancellation_token.child_token(),
271 ))
272 }
273 }
274
275 async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
276 where
277 B: Body + Send + 'static,
278 B::Error: Display,
279 {
280 if !request
283 .headers()
284 .get(http::header::ACCEPT)
285 .and_then(|header| header.to_str().ok())
286 .is_some_and(|header| {
287 header.contains(EVENT_STREAM_MIME_TYPE)
288 || header.contains(JSON_MIME_TYPE)
289 || header.contains("*/*") })
291 {
292 return Ok(Response::builder()
293 .status(http::StatusCode::NOT_ACCEPTABLE)
294 .body(
295 Full::new(Bytes::from(
296 "Not Acceptable: Client must accept text/event-stream",
297 ))
298 .boxed(),
299 )
300 .expect("valid response"));
301 }
302
303 if !request
305 .headers()
306 .get(http::header::CONTENT_TYPE)
307 .and_then(|header| header.to_str().ok())
308 .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
309 {
310 return Ok(Response::builder()
311 .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
312 .body(
313 Full::new(Bytes::from(
314 "Unsupported Media Type: Content-Type must be application/json",
315 ))
316 .boxed(),
317 )
318 .expect("valid response"));
319 }
320
321 let (part, body) = request.into_parts();
323 let mut message = match expect_json(body).await {
324 Ok(message) => message,
325 Err(response) => return Ok(response),
326 };
327
328 if self.config.stateful_mode {
329 let session_id = part
331 .headers
332 .get(HEADER_SESSION_ID)
333 .and_then(|v| v.to_str().ok());
334 if let Some(session_id) = session_id {
335 let session_id = session_id.to_owned().into();
336 let has_session = self
337 .session_manager
338 .has_session(&session_id)
339 .await
340 .map_err(internal_error_response("check session"))?;
341 if !has_session {
342 return Ok(Response::builder()
344 .status(http::StatusCode::UNAUTHORIZED)
345 .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
346 .expect("valid response"));
347 }
348
349 match &mut message {
351 ClientJsonRpcMessage::Request(req) => {
352 req.request.extensions_mut().insert(part);
353 }
354 ClientJsonRpcMessage::Notification(not) => {
355 not.notification.extensions_mut().insert(part);
356 }
357 _ => {
358 }
360 }
361
362 match message {
363 ClientJsonRpcMessage::Request(_) => {
364 let stream = self
365 .session_manager
366 .create_stream(&session_id, message)
367 .await
368 .map_err(internal_error_response("get session"))?;
369 let stream = if let Some(retry) = self.config.sse_retry {
371 let priming = ServerSseMessage {
372 event_id: Some("0".into()),
373 message: None,
374 retry: Some(retry),
375 };
376 futures::stream::once(async move { priming })
377 .chain(stream)
378 .left_stream()
379 } else {
380 stream.right_stream()
381 };
382 Ok(sse_stream_response(
383 stream,
384 self.config.sse_keep_alive,
385 self.config.cancellation_token.child_token(),
386 ))
387 }
388 ClientJsonRpcMessage::Notification(_)
389 | ClientJsonRpcMessage::Response(_)
390 | ClientJsonRpcMessage::Error(_) => {
391 self.session_manager
393 .accept_message(&session_id, message)
394 .await
395 .map_err(internal_error_response("accept message"))?;
396 Ok(accepted_response())
397 }
398 }
399 } else {
400 let (session_id, transport) = self
401 .session_manager
402 .create_session()
403 .await
404 .map_err(internal_error_response("create session"))?;
405 if let ClientJsonRpcMessage::Request(req) = &mut message {
406 if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
407 return Err(unexpected_message_response("initialize request"));
408 }
409 req.request.extensions_mut().insert(part);
411 } else {
412 return Err(unexpected_message_response("initialize request"));
413 }
414 let service = self
415 .get_service()
416 .map_err(internal_error_response("get service"))?;
417 tokio::spawn({
419 let session_manager = self.session_manager.clone();
420 let session_id = session_id.clone();
421 async move {
422 let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
423 service, transport,
424 )
425 .await;
426 match service {
427 Ok(service) => {
428 let _ = service.waiting().await;
430 }
431 Err(e) => {
432 tracing::error!("Failed to create service: {e}");
433 }
434 }
435 let _ = session_manager
436 .close_session(&session_id)
437 .await
438 .inspect_err(|e| {
439 tracing::error!("Failed to close session {session_id}: {e}");
440 });
441 }
442 });
443 let response = self
445 .session_manager
446 .initialize_session(&session_id, message)
447 .await
448 .map_err(internal_error_response("create stream"))?;
449 let stream = futures::stream::once(async move {
450 ServerSseMessage {
451 event_id: None,
452 message: Some(Arc::new(response)),
453 retry: None,
454 }
455 });
456 let stream = if let Some(retry) = self.config.sse_retry {
458 let priming = ServerSseMessage {
459 event_id: Some("0".into()),
460 message: None,
461 retry: Some(retry),
462 };
463 futures::stream::once(async move { priming })
464 .chain(stream)
465 .left_stream()
466 } else {
467 stream.right_stream()
468 };
469 let mut response = sse_stream_response(
470 stream,
471 self.config.sse_keep_alive,
472 self.config.cancellation_token.child_token(),
473 );
474
475 response.headers_mut().insert(
476 HEADER_SESSION_ID,
477 session_id
478 .parse()
479 .map_err(internal_error_response("create session id header"))?,
480 );
481 Ok(response)
482 }
483 } else {
484 let service = self
485 .get_service()
486 .map_err(internal_error_response("get service"))?;
487 match message {
488 ClientJsonRpcMessage::Request(mut request) => {
489 request.request.extensions_mut().insert(part);
490 let (transport, receiver) =
491 OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
492 let service = serve_directly(service, transport, None);
493 tokio::spawn(async move {
494 let _ = service.waiting().await;
496 });
497 let stream = ReceiverStream::new(receiver).map(|message| {
499 tracing::info!(?message);
500 ServerSseMessage {
501 event_id: None,
502 message: Some(Arc::new(message)),
503 retry: None,
504 }
505 });
506 Ok(sse_stream_response(
507 stream,
508 self.config.sse_keep_alive,
509 self.config.cancellation_token.child_token(),
510 ))
511 }
512 ClientJsonRpcMessage::Notification(_notification) => {
513 Ok(accepted_response())
515 }
516 ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
517 ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
518 }
519 }
520 }
521
522 async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
523 where
524 B: Body + Send + 'static,
525 B::Error: Display,
526 {
527 let session_id = request
529 .headers()
530 .get(HEADER_SESSION_ID)
531 .and_then(|v| v.to_str().ok())
532 .map(|s| s.to_owned().into());
533 let Some(session_id) = session_id else {
534 return Ok(Response::builder()
536 .status(http::StatusCode::UNAUTHORIZED)
537 .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
538 .expect("valid response"));
539 };
540 self.session_manager
542 .close_session(&session_id)
543 .await
544 .map_err(internal_error_response("close session"))?;
545 Ok(accepted_response())
546 }
547}