1use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
2
3use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4use http::{HeaderName, HeaderValue};
5pub use sse_stream::Error as SseError;
6use sse_stream::Sse;
7use thiserror::Error;
8use tokio_util::sync::CancellationToken;
9use tracing::debug;
10
11use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
12use crate::{
13 RoleClient,
14 model::{
15 ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage,
16 ServerResult,
17 },
18 transport::{
19 common::client_side_sse::SseAutoReconnectStream,
20 worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
21 },
22};
23
24type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
25
26#[derive(Debug)]
27#[non_exhaustive]
28pub struct AuthRequiredError {
29 pub www_authenticate_header: String,
30}
31
32impl AuthRequiredError {
33 pub fn new(www_authenticate_header: String) -> Self {
35 Self {
36 www_authenticate_header,
37 }
38 }
39}
40
41#[derive(Debug)]
42#[non_exhaustive]
43pub struct InsufficientScopeError {
44 pub www_authenticate_header: String,
45 pub required_scope: Option<String>,
46}
47
48impl InsufficientScopeError {
49 pub fn new(www_authenticate_header: String, required_scope: Option<String>) -> Self {
51 Self {
52 www_authenticate_header,
53 required_scope,
54 }
55 }
56
57 pub fn can_upgrade(&self) -> bool {
59 self.required_scope.is_some()
60 }
61
62 pub fn get_required_scope(&self) -> Option<&str> {
64 self.required_scope.as_deref()
65 }
66}
67
68#[derive(Error, Debug)]
69#[non_exhaustive]
70pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
71 #[error("SSE error: {0}")]
72 Sse(#[from] SseError),
73 #[error("Io error: {0}")]
74 Io(#[from] std::io::Error),
75 #[error("Client error: {0}")]
76 Client(E),
77 #[error("unexpected end of stream")]
78 UnexpectedEndOfStream,
79 #[error("unexpected server response: {0}")]
80 UnexpectedServerResponse(Cow<'static, str>),
81 #[error("Unexpected content type: {0:?}")]
82 UnexpectedContentType(Option<String>),
83 #[error("Server does not support SSE")]
84 ServerDoesNotSupportSse,
85 #[error("Server does not support delete session")]
86 ServerDoesNotSupportDeleteSession,
87 #[error("Tokio join error: {0}")]
88 TokioJoinError(#[from] tokio::task::JoinError),
89 #[error("Deserialize error: {0}")]
90 Deserialize(#[from] serde_json::Error),
91 #[error("Transport channel closed")]
92 TransportChannelClosed,
93 #[error("Missing session id in HTTP response")]
94 MissingSessionIdInResponse,
95 #[cfg(feature = "auth")]
96 #[error("Auth error: {0}")]
97 Auth(#[from] crate::transport::auth::AuthError),
98 #[error("Auth required")]
99 AuthRequired(AuthRequiredError),
100 #[error("Insufficient scope")]
101 InsufficientScope(InsufficientScopeError),
102 #[error("Header name '{0}' is reserved and conflicts with default headers")]
103 ReservedHeaderConflict(String),
104 #[error("Session expired (HTTP 404)")]
105 SessionExpired,
106}
107
108#[derive(Debug, Clone, Error)]
109#[non_exhaustive]
110pub enum StreamableHttpProtocolError {
111 #[error("Missing session id in response")]
112 MissingSessionIdInResponse,
113}
114
115#[allow(clippy::large_enum_variant)]
116#[non_exhaustive]
117pub enum StreamableHttpPostResponse {
118 Accepted,
119 Json(ServerJsonRpcMessage, Option<String>),
120 Sse(BoxedSseStream, Option<String>),
121}
122
123impl std::fmt::Debug for StreamableHttpPostResponse {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 Self::Accepted => write!(f, "Accepted"),
127 Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(),
128 Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(),
129 }
130 }
131}
132
133impl StreamableHttpPostResponse {
134 pub async fn expect_initialized<E>(
135 self,
136 ) -> Result<(ServerJsonRpcMessage, Option<String>), StreamableHttpError<E>>
137 where
138 E: std::error::Error + Send + Sync + 'static,
139 {
140 match self {
141 Self::Json(message, session_id) => Ok((message, session_id)),
142 Self::Sse(mut stream, session_id) => {
143 while let Some(event) = stream.next().await {
144 let event = event?;
145 let payload = event.data.unwrap_or_default();
146 if payload.trim().is_empty() {
147 continue;
148 }
149
150 let message: ServerJsonRpcMessage = serde_json::from_str(&payload)?;
151
152 if matches!(message, ServerJsonRpcMessage::Response(_)) {
153 return Ok((message, session_id));
154 }
155
156 debug!(
157 ?message,
158 "received message before initialize response; continuing to drain stream"
159 );
160 }
161
162 Err(StreamableHttpError::UnexpectedServerResponse(
163 "empty sse stream".into(),
164 ))
165 }
166 _ => Err(StreamableHttpError::UnexpectedServerResponse(
167 "expect initialized, accepted".into(),
168 )),
169 }
170 }
171
172 pub fn expect_json<E>(self) -> Result<ServerJsonRpcMessage, StreamableHttpError<E>>
173 where
174 E: std::error::Error + Send + Sync + 'static,
175 {
176 match self {
177 Self::Json(message, ..) => Ok(message),
178 got => Err(StreamableHttpError::UnexpectedServerResponse(
179 format!("expect json, got {got:?}").into(),
180 )),
181 }
182 }
183
184 pub fn expect_accepted_or_json<E>(self) -> Result<(), StreamableHttpError<E>>
185 where
186 E: std::error::Error + Send + Sync + 'static,
187 {
188 match self {
189 Self::Accepted => Ok(()),
190 Self::Json(..) => Ok(()),
192 got => Err(StreamableHttpError::UnexpectedServerResponse(
193 format!("expect accepted or json, got {got:?}").into(),
194 )),
195 }
196 }
197}
198
199pub trait StreamableHttpClient: Clone + Send + 'static {
200 type Error: std::error::Error + Send + Sync + 'static;
201 fn post_message(
202 &self,
203 uri: Arc<str>,
204 message: ClientJsonRpcMessage,
205 session_id: Option<Arc<str>>,
206 auth_header: Option<String>,
207 custom_headers: HashMap<HeaderName, HeaderValue>,
208 ) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
209 + Send
210 + '_;
211 fn delete_session(
212 &self,
213 uri: Arc<str>,
214 session_id: Arc<str>,
215 auth_header: Option<String>,
216 custom_headers: HashMap<HeaderName, HeaderValue>,
217 ) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
218 fn get_stream(
219 &self,
220 uri: Arc<str>,
221 session_id: Arc<str>,
222 last_event_id: Option<String>,
223 auth_header: Option<String>,
224 custom_headers: HashMap<HeaderName, HeaderValue>,
225 ) -> impl Future<
226 Output = Result<
227 BoxStream<'static, Result<Sse, SseError>>,
228 StreamableHttpError<Self::Error>,
229 >,
230 > + Send
231 + '_;
232}
233
234#[non_exhaustive]
235pub struct RetryConfig {
236 pub max_times: Option<usize>,
237 pub min_duration: Duration,
238}
239
240struct StreamableHttpClientReconnect<C> {
241 pub client: C,
242 pub session_id: Arc<str>,
243 pub uri: Arc<str>,
244 pub auth_header: Option<String>,
245 pub custom_headers: HashMap<HeaderName, HeaderValue>,
246}
247
248impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
249 type Error = StreamableHttpError<C::Error>;
250 type Future = BoxFuture<'static, Result<BoxedSseStream, Self::Error>>;
251 fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future {
252 let client = self.client.clone();
253 let uri = self.uri.clone();
254 let session_id = self.session_id.clone();
255 let auth_header = self.auth_header.clone();
256 let custom_headers = self.custom_headers.clone();
257 let last_event_id = last_event_id.map(|s| s.to_owned());
258 Box::pin(async move {
259 client
260 .get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
261 .await
262 })
263 }
264}
265
266struct SessionCleanupInfo<C> {
268 client: C,
269 uri: Arc<str>,
270 session_id: Arc<str>,
271 auth_header: Option<String>,
272 protocol_headers: HashMap<HeaderName, HeaderValue>,
273}
274
275#[derive(Debug, Clone, Default)]
276#[non_exhaustive]
277pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
278 pub client: C,
279 pub config: StreamableHttpClientTransportConfig,
280}
281
282impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
283 pub fn new_simple(url: impl Into<Arc<str>>) -> Self {
284 Self {
285 client: C::default(),
286 config: StreamableHttpClientTransportConfig {
287 uri: url.into(),
288 ..Default::default()
289 },
290 }
291 }
292}
293
294impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
295 pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
296 Self { client, config }
297 }
298}
299
300impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
301 fn raw_sse_to_jsonrpc(
304 stream: BoxedSseStream,
305 ) -> impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>> + Send + 'static
306 {
307 stream.filter_map(|event| async {
308 match event {
309 Err(e) => Some(Err(StreamableHttpError::Sse(e))),
310 Ok(sse) => {
311 let is_message =
312 matches!(sse.event.as_deref(), None | Some("") | Some("message"));
313 if !is_message {
314 return None;
315 }
316 let data = sse.data?;
317 if data.trim().is_empty() {
318 return None;
319 }
320 match serde_json::from_str::<ServerJsonRpcMessage>(&data) {
321 Ok(msg) => Some(Ok(msg)),
322 Err(e) => {
323 tracing::debug!("failed to deserialize server message: {e}");
324 None
325 }
326 }
327 }
328 }
329 })
330 }
331
332 async fn execute_sse_stream(
333 sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
334 + Send
335 + 'static,
336 sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
337 close_on_response: bool,
338 ct: CancellationToken,
339 ) -> Result<(), StreamableHttpError<C::Error>> {
340 let mut sse_stream = std::pin::pin!(sse_stream);
341 loop {
342 let message = tokio::select! {
343 event = sse_stream.next() => {
344 event
345 }
346 _ = ct.cancelled() => {
347 tracing::debug!("cancelled");
348 break;
349 }
350 };
351 let Some(message) = message.transpose()? else {
352 break;
353 };
354 let is_response = matches!(
355 message,
356 ServerJsonRpcMessage::Response(_) | ServerJsonRpcMessage::Error(_)
357 );
358 let yield_result = sse_worker_tx.send(message).await;
359 if yield_result.is_err() {
360 tracing::trace!("streamable http transport worker dropped, exiting");
361 break;
362 }
363 if close_on_response && is_response {
364 tracing::debug!("got response, draining sse stream for connection reuse");
365 let _ = tokio::time::timeout(std::time::Duration::from_millis(50), async {
368 while sse_stream.next().await.is_some() {}
369 })
370 .await;
371 break;
372 }
373 }
374 Ok(())
375 }
376
377 async fn perform_reinitialization(
387 client: C,
388 saved_init_request: ClientJsonRpcMessage,
389 uri: Arc<str>,
390 auth_header: Option<String>,
391 custom_headers: HashMap<HeaderName, HeaderValue>,
392 ) -> Result<(Option<Arc<str>>, HashMap<HeaderName, HeaderValue>), StreamableHttpError<C::Error>>
393 {
394 let (init_msg, new_session_id_str) = client
395 .post_message(
396 uri.clone(),
397 saved_init_request,
398 None,
399 auth_header.clone(),
400 custom_headers.clone(),
401 )
402 .await?
403 .expect_initialized::<C::Error>()
404 .await?;
405
406 let new_session_id: Option<Arc<str>> = new_session_id_str.map(|s| Arc::from(s.as_str()));
407
408 let mut new_protocol_headers = custom_headers;
411 if let ServerJsonRpcMessage::Response(response) = &init_msg {
412 if let ServerResult::InitializeResult(init_result) = &response.result {
413 if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
414 new_protocol_headers
415 .insert(HeaderName::from_static("mcp-protocol-version"), hv);
416 }
417 }
418 }
419
420 let initialized_notification = ClientJsonRpcMessage::notification(
421 ClientNotification::InitializedNotification(InitializedNotification {
422 method: Default::default(),
423 extensions: Default::default(),
424 }),
425 );
426 client
427 .post_message(
428 uri,
429 initialized_notification,
430 new_session_id.clone(),
431 auth_header,
432 new_protocol_headers.clone(),
433 )
434 .await?
435 .expect_accepted_or_json::<C::Error>()?;
436
437 Ok((new_session_id, new_protocol_headers))
438 }
439}
440
441impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
442 type Role = RoleClient;
443 type Error = StreamableHttpError<C::Error>;
444 fn err_closed() -> Self::Error {
445 StreamableHttpError::TransportChannelClosed
446 }
447 fn err_join(e: tokio::task::JoinError) -> Self::Error {
448 StreamableHttpError::TokioJoinError(e)
449 }
450 fn config(&self) -> super::worker::WorkerConfig {
451 super::worker::WorkerConfig {
452 name: Some("StreamableHttpClientWorker".into()),
453 channel_buffer_capacity: self.config.channel_buffer_capacity,
454 }
455 }
456 async fn run(
457 self,
458 mut context: super::worker::WorkerContext<Self>,
459 ) -> Result<(), WorkerQuitReason<Self::Error>> {
460 let channel_buffer_capacity = self.config.channel_buffer_capacity;
461 let (sse_worker_tx, mut sse_worker_rx) =
462 tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
463 let config = self.config.clone();
464 let transport_task_ct = context.cancellation_token.clone();
465 let _drop_guard = transport_task_ct.clone().drop_guard();
466 let WorkerSendRequest {
467 responder,
468 message: initialize_request,
469 } = context.recv_from_handler().await?;
470 let saved_init_request = initialize_request.clone();
471 let (message, session_id) = match self
472 .client
473 .post_message(
474 config.uri.clone(),
475 initialize_request,
476 None,
477 config.auth_header.clone(),
478 config.custom_headers.clone(),
479 )
480 .await
481 {
482 Ok(res) => {
483 let _ = responder.send(Ok(()));
484 res.expect_initialized::<C::Error>().await.map_err(
485 WorkerQuitReason::fatal_context("process initialize response"),
486 )?
487 }
488 Err(err) => {
489 let msg = format!("{:?}", err);
490 let _ = responder.send(Err(err));
491 return Err(WorkerQuitReason::fatal(
492 StreamableHttpError::TransportChannelClosed,
493 msg,
494 ));
495 }
496 };
497 let mut session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
498 Some(session_id.into())
499 } else {
500 if !self.config.allow_stateless {
501 return Err(WorkerQuitReason::fatal(
502 StreamableHttpError::<C::Error>::MissingSessionIdInResponse,
503 "process initialize response",
504 ));
505 }
506 None
507 };
508 let mut protocol_headers = {
512 let mut headers = config.custom_headers.clone();
513 if let ServerJsonRpcMessage::Response(response) = &message {
514 if let ServerResult::InitializeResult(init_result) = &response.result {
515 if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
516 headers.insert(HeaderName::from_static("mcp-protocol-version"), hv);
518 }
519 }
520 }
521 headers
522 };
523
524 let mut session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
526 client: self.client.clone(),
527 uri: config.uri.clone(),
528 session_id: sid.clone(),
529 auth_header: config.auth_header.clone(),
530 protocol_headers: protocol_headers.clone(),
531 });
532
533 context.send_to_handler(message).await?;
534 let initialized_notification = context.recv_from_handler().await?;
535 self.client
537 .post_message(
538 config.uri.clone(),
539 initialized_notification.message,
540 session_id.clone(),
541 config.auth_header.clone(),
542 protocol_headers.clone(),
543 )
544 .await
545 .map_err(WorkerQuitReason::fatal_context(
546 "send initialized notification",
547 ))?
548 .expect_accepted_or_json::<C::Error>()
549 .map_err(WorkerQuitReason::fatal_context(
550 "process initialized notification response",
551 ))?;
552 let _ = initialized_notification.responder.send(Ok(()));
553 #[allow(clippy::large_enum_variant)]
554 enum Event<W: Worker, E: std::error::Error + Send + Sync + 'static> {
555 ClientMessage(WorkerSendRequest<W>),
556 ServerMessage(ServerJsonRpcMessage),
557 StreamResult(Result<(), StreamableHttpError<E>>),
558 }
559 let mut streams = tokio::task::JoinSet::new();
560 if let Some(session_id) = &session_id {
561 let client = self.client.clone();
562 let uri = config.uri.clone();
563 let session_id = session_id.clone();
564 let auth_header = config.auth_header.clone();
565 let retry_config = self.config.retry_config.clone();
566 let sse_worker_tx = sse_worker_tx.clone();
567 let transport_task_ct = transport_task_ct.clone();
568 let config_uri = config.uri.clone();
569 let config_auth_header = config.auth_header.clone();
570 let spawn_headers = protocol_headers.clone();
571
572 streams.spawn(async move {
573 match client
574 .get_stream(
575 uri.clone(),
576 session_id.clone(),
577 None,
578 auth_header.clone(),
579 spawn_headers.clone(),
580 )
581 .await
582 {
583 Ok(stream) => {
584 let sse_stream = SseAutoReconnectStream::new(
585 stream,
586 StreamableHttpClientReconnect {
587 client: client.clone(),
588 session_id: session_id.clone(),
589 uri: config_uri,
590 auth_header: config_auth_header,
591 custom_headers: spawn_headers,
592 },
593 retry_config,
594 );
595 Self::execute_sse_stream(
596 sse_stream,
597 sse_worker_tx,
598 false,
599 transport_task_ct.child_token(),
600 )
601 .await
602 }
603 Err(StreamableHttpError::ServerDoesNotSupportSse) => {
604 tracing::debug!("server doesn't support sse, skip common stream");
605 Ok(())
606 }
607 Err(e) => {
608 tracing::error!("fail to get common stream: {e}");
610 Err(e)
611 }
612 }
613 });
614 }
615 let loop_result: Result<(), WorkerQuitReason<Self::Error>> = 'main_loop: loop {
617 let event = tokio::select! {
618 _ = transport_task_ct.cancelled() => {
619 tracing::debug!("cancelled");
620 break 'main_loop Err(WorkerQuitReason::Cancelled);
621 }
622 message = context.recv_from_handler() => {
623 match message {
624 Ok(msg) => Event::ClientMessage(msg),
625 Err(e) => break 'main_loop Err(e),
626 }
627 },
628 message = sse_worker_rx.recv() => {
629 let Some(message) = message else {
630 tracing::trace!("transport dropped, exiting");
631 break 'main_loop Err(WorkerQuitReason::HandlerTerminated);
632 };
633 Event::ServerMessage(message)
634 },
635 terminated_stream = streams.join_next(), if !streams.is_empty() => {
636 match terminated_stream {
637 Some(result) => {
638 Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity))
639 }
640 None => {
641 continue
642 }
643 }
644 }
645 };
646 match event {
647 Event::ClientMessage(send_request) => {
648 let WorkerSendRequest { message, responder } = send_request;
649 let response = self
653 .client
654 .post_message(
655 config.uri.clone(),
656 message.clone(),
657 session_id.clone(),
658 config.auth_header.clone(),
659 protocol_headers.clone(),
660 )
661 .await;
662 let send_result = match response {
663 Err(StreamableHttpError::SessionExpired) => {
664 if !config.reinit_on_expired_session {
665 Err(StreamableHttpError::SessionExpired)
666 } else {
667 tracing::info!(
670 "session expired (HTTP 404), attempting transparent re-initialization"
671 );
672 match Self::perform_reinitialization(
673 self.client.clone(),
674 saved_init_request.clone(),
675 config.uri.clone(),
676 config.auth_header.clone(),
677 config.custom_headers.clone(),
678 )
679 .await
680 {
681 Ok((new_session_id, new_protocol_headers)) => {
682 streams.abort_all();
685
686 session_id = new_session_id;
687 protocol_headers = new_protocol_headers;
688 session_cleanup_info =
689 session_id.as_ref().map(|sid| SessionCleanupInfo {
690 client: self.client.clone(),
691 uri: config.uri.clone(),
692 session_id: sid.clone(),
693 auth_header: config.auth_header.clone(),
694 protocol_headers: protocol_headers.clone(),
695 });
696
697 if let Some(new_sid) = &session_id {
698 let client = self.client.clone();
699 let uri = config.uri.clone();
700 let new_sid = new_sid.clone();
701 let auth_header = config.auth_header.clone();
702 let retry_config = self.config.retry_config.clone();
703 let sse_tx = sse_worker_tx.clone();
704 let task_ct = transport_task_ct.clone();
705 let config_uri = config.uri.clone();
706 let config_auth = config.auth_header.clone();
707 let spawn_headers = protocol_headers.clone();
708 streams.spawn(async move {
709 match client
710 .get_stream(
711 uri,
712 new_sid.clone(),
713 None,
714 auth_header.clone(),
715 spawn_headers.clone(),
716 )
717 .await
718 {
719 Ok(stream) => {
720 let sse_stream = SseAutoReconnectStream::new(
721 stream,
722 StreamableHttpClientReconnect {
723 client: client.clone(),
724 session_id: new_sid,
725 uri: config_uri,
726 auth_header: config_auth,
727 custom_headers: spawn_headers,
728 },
729 retry_config,
730 );
731 Self::execute_sse_stream(
732 sse_stream,
733 sse_tx,
734 false,
735 task_ct.child_token(),
736 )
737 .await
738 }
739 Err(StreamableHttpError::ServerDoesNotSupportSse) => {
740 tracing::debug!(
741 "server doesn't support sse after re-init"
742 );
743 Ok(())
744 }
745 Err(e) => {
746 tracing::error!(
747 "fail to get common stream after re-init: {e}"
748 );
749 Err(e)
750 }
751 }
752 });
753 }
754
755 let retry_response = self
756 .client
757 .post_message(
758 config.uri.clone(),
759 message,
760 session_id.clone(),
761 config.auth_header.clone(),
762 protocol_headers.clone(),
763 )
764 .await;
765 match retry_response {
766 Err(e) => Err(e),
767 Ok(StreamableHttpPostResponse::Accepted) => {
768 tracing::trace!(
769 "client message accepted after re-init"
770 );
771 Ok(())
772 }
773 Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
774 context.send_to_handler(msg).await?;
775 Ok(())
776 }
777 Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
778 streams.spawn(Self::execute_sse_stream(
779 Self::raw_sse_to_jsonrpc(stream),
780 sse_worker_tx.clone(),
781 true,
782 transport_task_ct.child_token(),
783 ));
784 tracing::trace!("got new sse stream after re-init");
785 Ok(())
786 }
787 }
788 }
789 Err(reinit_err) => Err(reinit_err),
790 }
791 } }
793 Err(e) => Err(e),
794 Ok(StreamableHttpPostResponse::Accepted) => {
795 tracing::trace!("client message accepted");
796 Ok(())
797 }
798 Ok(StreamableHttpPostResponse::Json(message, ..)) => {
799 context.send_to_handler(message).await?;
800 Ok(())
801 }
802 Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
803 streams.spawn(Self::execute_sse_stream(
804 Self::raw_sse_to_jsonrpc(stream),
805 sse_worker_tx.clone(),
806 true,
807 transport_task_ct.child_token(),
808 ));
809 tracing::trace!("got new sse stream");
810 Ok(())
811 }
812 };
813 let _ = responder.send(send_result);
814 }
815 Event::ServerMessage(json_rpc_message) => {
816 if let Err(e) = context.send_to_handler(json_rpc_message).await {
818 break 'main_loop Err(e);
819 }
820 }
821 Event::StreamResult(result) => {
822 if result.is_err() {
823 tracing::warn!(
824 "sse client event stream terminated with error: {:?}",
825 result
826 );
827 }
828 }
829 }
830 };
831
832 if let Some(cleanup) = session_cleanup_info {
835 const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
836 let cleanup_session_id = cleanup.session_id.clone();
837 match tokio::time::timeout(
838 SESSION_CLEANUP_TIMEOUT,
839 cleanup.client.delete_session(
840 cleanup.uri,
841 cleanup.session_id,
842 cleanup.auth_header,
843 cleanup.protocol_headers,
844 ),
845 )
846 .await
847 {
848 Ok(Ok(_)) => {
849 tracing::info!(
850 session_id = cleanup_session_id.as_ref(),
851 "delete session success"
852 )
853 }
854 Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
855 tracing::info!(
856 session_id = cleanup_session_id.as_ref(),
857 "server doesn't support delete session"
858 )
859 }
860 Ok(Err(e)) => {
861 tracing::error!(
862 session_id = cleanup_session_id.as_ref(),
863 "fail to delete session: {e}"
864 );
865 }
866 Err(_elapsed) => {
867 tracing::warn!(
868 session_id = cleanup_session_id.as_ref(),
869 "session cleanup timed out after {:?}",
870 SESSION_CLEANUP_TIMEOUT
871 );
872 }
873 }
874 }
875
876 loop_result
877 }
878}
879
880pub type StreamableHttpClientTransport<C> = WorkerTransport<StreamableHttpClientWorker<C>>;
973
974impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
975 pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self {
1055 let worker = StreamableHttpClientWorker::new(client, config);
1056 WorkerTransport::spawn(worker)
1057 }
1058}
1059#[derive(Debug, Clone)]
1060#[non_exhaustive]
1061pub struct StreamableHttpClientTransportConfig {
1062 pub uri: Arc<str>,
1063 pub retry_config: Arc<dyn SseRetryPolicy>,
1064 pub channel_buffer_capacity: usize,
1065 pub allow_stateless: bool,
1067 pub auth_header: Option<String>,
1069 pub custom_headers: HashMap<HeaderName, HeaderValue>,
1071 pub reinit_on_expired_session: bool,
1081}
1082
1083impl StreamableHttpClientTransportConfig {
1084 pub fn with_uri(uri: impl Into<Arc<str>>) -> Self {
1085 Self {
1086 uri: uri.into(),
1087 ..Default::default()
1088 }
1089 }
1090
1091 pub fn auth_header<T: Into<String>>(mut self, value: T) -> Self {
1097 self.auth_header = Some(value.into());
1099 self
1100 }
1101
1102 pub fn custom_headers(mut self, custom_headers: HashMap<HeaderName, HeaderValue>) -> Self {
1125 self.custom_headers = custom_headers;
1126 self
1127 }
1128
1129 pub fn reinit_on_expired_session(mut self, enable: bool) -> Self {
1138 self.reinit_on_expired_session = enable;
1139 self
1140 }
1141}
1142
1143impl Default for StreamableHttpClientTransportConfig {
1144 fn default() -> Self {
1145 Self {
1146 uri: "localhost".into(),
1147 retry_config: Arc::new(ExponentialBackoff::default()),
1148 channel_buffer_capacity: 16,
1149 allow_stateless: true,
1150 auth_header: None,
1151 custom_headers: HashMap::new(),
1152 reinit_on_expired_session: true,
1153 }
1154 }
1155}