1use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
2
3use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4use http::{HeaderName, HeaderValue};
5pub use sse_stream::Error as SseError;
6use sse_stream::Sse;
7use thiserror::Error;
8use tokio_util::sync::CancellationToken;
9use tracing::debug;
10
11use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
12use crate::{
13 RoleClient,
14 model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
15 transport::{
16 common::client_side_sse::SseAutoReconnectStream,
17 worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
18 },
19};
20
21type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
22
23#[derive(Debug)]
24pub struct AuthRequiredError {
25 pub www_authenticate_header: String,
26}
27
28#[derive(Debug)]
29pub struct InsufficientScopeError {
30 pub www_authenticate_header: String,
31 pub required_scope: Option<String>,
32}
33
34impl InsufficientScopeError {
35 pub fn can_upgrade(&self) -> bool {
37 self.required_scope.is_some()
38 }
39
40 pub fn get_required_scope(&self) -> Option<&str> {
42 self.required_scope.as_deref()
43 }
44}
45
46#[derive(Error, Debug)]
47#[non_exhaustive]
48pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
49 #[error("SSE error: {0}")]
50 Sse(#[from] SseError),
51 #[error("Io error: {0}")]
52 Io(#[from] std::io::Error),
53 #[error("Client error: {0}")]
54 Client(E),
55 #[error("unexpected end of stream")]
56 UnexpectedEndOfStream,
57 #[error("unexpected server response: {0}")]
58 UnexpectedServerResponse(Cow<'static, str>),
59 #[error("Unexpected content type: {0:?}")]
60 UnexpectedContentType(Option<String>),
61 #[error("Server does not support SSE")]
62 ServerDoesNotSupportSse,
63 #[error("Server does not support delete session")]
64 ServerDoesNotSupportDeleteSession,
65 #[error("Tokio join error: {0}")]
66 TokioJoinError(#[from] tokio::task::JoinError),
67 #[error("Deserialize error: {0}")]
68 Deserialize(#[from] serde_json::Error),
69 #[error("Transport channel closed")]
70 TransportChannelClosed,
71 #[error("Missing session id in HTTP response")]
72 MissingSessionIdInResponse,
73 #[cfg(feature = "auth")]
74 #[error("Auth error: {0}")]
75 Auth(#[from] crate::transport::auth::AuthError),
76 #[error("Auth required")]
77 AuthRequired(AuthRequiredError),
78 #[error("Insufficient scope")]
79 InsufficientScope(InsufficientScopeError),
80 #[error("Header name '{0}' is reserved and conflicts with default headers")]
81 ReservedHeaderConflict(String),
82}
83
84#[derive(Debug, Clone, Error)]
85#[non_exhaustive]
86pub enum StreamableHttpProtocolError {
87 #[error("Missing session id in response")]
88 MissingSessionIdInResponse,
89}
90
91#[allow(clippy::large_enum_variant)]
92#[non_exhaustive]
93pub enum StreamableHttpPostResponse {
94 Accepted,
95 Json(ServerJsonRpcMessage, Option<String>),
96 Sse(BoxedSseStream, Option<String>),
97}
98
99impl std::fmt::Debug for StreamableHttpPostResponse {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 Self::Accepted => write!(f, "Accepted"),
103 Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(),
104 Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(),
105 }
106 }
107}
108
109impl StreamableHttpPostResponse {
110 pub async fn expect_initialized<E>(
111 self,
112 ) -> Result<(ServerJsonRpcMessage, Option<String>), StreamableHttpError<E>>
113 where
114 E: std::error::Error + Send + Sync + 'static,
115 {
116 match self {
117 Self::Json(message, session_id) => Ok((message, session_id)),
118 Self::Sse(mut stream, session_id) => {
119 while let Some(event) = stream.next().await {
120 let event = event?;
121 let payload = event.data.unwrap_or_default();
122 if payload.trim().is_empty() {
123 continue;
124 }
125
126 let message: ServerJsonRpcMessage = serde_json::from_str(&payload)?;
127
128 if matches!(message, ServerJsonRpcMessage::Response(_)) {
129 return Ok((message, session_id));
130 }
131
132 debug!(
133 ?message,
134 "received message before initialize response; continuing to drain stream"
135 );
136 }
137
138 Err(StreamableHttpError::UnexpectedServerResponse(
139 "empty sse stream".into(),
140 ))
141 }
142 _ => Err(StreamableHttpError::UnexpectedServerResponse(
143 "expect initialized, accepted".into(),
144 )),
145 }
146 }
147
148 pub fn expect_json<E>(self) -> Result<ServerJsonRpcMessage, StreamableHttpError<E>>
149 where
150 E: std::error::Error + Send + Sync + 'static,
151 {
152 match self {
153 Self::Json(message, ..) => Ok(message),
154 got => Err(StreamableHttpError::UnexpectedServerResponse(
155 format!("expect json, got {got:?}").into(),
156 )),
157 }
158 }
159
160 pub fn expect_accepted_or_json<E>(self) -> Result<(), StreamableHttpError<E>>
161 where
162 E: std::error::Error + Send + Sync + 'static,
163 {
164 match self {
165 Self::Accepted => Ok(()),
166 Self::Json(..) => Ok(()),
168 got => Err(StreamableHttpError::UnexpectedServerResponse(
169 format!("expect accepted or json, got {got:?}").into(),
170 )),
171 }
172 }
173}
174
175pub trait StreamableHttpClient: Clone + Send + 'static {
176 type Error: std::error::Error + Send + Sync + 'static;
177 fn post_message(
178 &self,
179 uri: Arc<str>,
180 message: ClientJsonRpcMessage,
181 session_id: Option<Arc<str>>,
182 auth_header: Option<String>,
183 custom_headers: HashMap<HeaderName, HeaderValue>,
184 ) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
185 + Send
186 + '_;
187 fn delete_session(
188 &self,
189 uri: Arc<str>,
190 session_id: Arc<str>,
191 auth_header: Option<String>,
192 custom_headers: HashMap<HeaderName, HeaderValue>,
193 ) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
194 fn get_stream(
195 &self,
196 uri: Arc<str>,
197 session_id: Arc<str>,
198 last_event_id: Option<String>,
199 auth_header: Option<String>,
200 custom_headers: HashMap<HeaderName, HeaderValue>,
201 ) -> impl Future<
202 Output = Result<
203 BoxStream<'static, Result<Sse, SseError>>,
204 StreamableHttpError<Self::Error>,
205 >,
206 > + Send
207 + '_;
208}
209
210pub struct RetryConfig {
211 pub max_times: Option<usize>,
212 pub min_duration: Duration,
213}
214
215struct StreamableHttpClientReconnect<C> {
216 pub client: C,
217 pub session_id: Arc<str>,
218 pub uri: Arc<str>,
219 pub auth_header: Option<String>,
220 pub custom_headers: HashMap<HeaderName, HeaderValue>,
221}
222
223impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
224 type Error = StreamableHttpError<C::Error>;
225 type Future = BoxFuture<'static, Result<BoxedSseStream, Self::Error>>;
226 fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future {
227 let client = self.client.clone();
228 let uri = self.uri.clone();
229 let session_id = self.session_id.clone();
230 let auth_header = self.auth_header.clone();
231 let custom_headers = self.custom_headers.clone();
232 let last_event_id = last_event_id.map(|s| s.to_owned());
233 Box::pin(async move {
234 client
235 .get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
236 .await
237 })
238 }
239}
240
241struct SessionCleanupInfo<C> {
243 client: C,
244 uri: Arc<str>,
245 session_id: Arc<str>,
246 auth_header: Option<String>,
247 protocol_headers: HashMap<HeaderName, HeaderValue>,
248}
249
250#[derive(Debug, Clone, Default)]
251pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
252 pub client: C,
253 pub config: StreamableHttpClientTransportConfig,
254}
255
256impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
257 pub fn new_simple(url: impl Into<Arc<str>>) -> Self {
258 Self {
259 client: C::default(),
260 config: StreamableHttpClientTransportConfig {
261 uri: url.into(),
262 ..Default::default()
263 },
264 }
265 }
266}
267
268impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
269 pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
270 Self { client, config }
271 }
272}
273
274impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
275 async fn execute_sse_stream(
276 sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
277 + Send
278 + 'static,
279 sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
280 close_on_response: bool,
281 ct: CancellationToken,
282 ) -> Result<(), StreamableHttpError<C::Error>> {
283 let mut sse_stream = std::pin::pin!(sse_stream);
284 loop {
285 let message = tokio::select! {
286 event = sse_stream.next() => {
287 event
288 }
289 _ = ct.cancelled() => {
290 tracing::debug!("cancelled");
291 break;
292 }
293 };
294 let Some(message) = message.transpose()? else {
295 break;
296 };
297 let is_response = matches!(message, ServerJsonRpcMessage::Response(_));
298 let yield_result = sse_worker_tx.send(message).await;
299 if yield_result.is_err() {
300 tracing::trace!("streamable http transport worker dropped, exiting");
301 break;
302 }
303 if close_on_response && is_response {
304 tracing::debug!("got response, closing sse stream");
305 break;
306 }
307 }
308 Ok(())
309 }
310}
311
312impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
313 type Role = RoleClient;
314 type Error = StreamableHttpError<C::Error>;
315 fn err_closed() -> Self::Error {
316 StreamableHttpError::TransportChannelClosed
317 }
318 fn err_join(e: tokio::task::JoinError) -> Self::Error {
319 StreamableHttpError::TokioJoinError(e)
320 }
321 fn config(&self) -> super::worker::WorkerConfig {
322 super::worker::WorkerConfig {
323 name: Some("StreamableHttpClientWorker".into()),
324 channel_buffer_capacity: self.config.channel_buffer_capacity,
325 }
326 }
327 async fn run(
328 self,
329 mut context: super::worker::WorkerContext<Self>,
330 ) -> Result<(), WorkerQuitReason<Self::Error>> {
331 let channel_buffer_capacity = self.config.channel_buffer_capacity;
332 let (sse_worker_tx, mut sse_worker_rx) =
333 tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
334 let config = self.config.clone();
335 let transport_task_ct = context.cancellation_token.clone();
336 let _drop_guard = transport_task_ct.clone().drop_guard();
337 let WorkerSendRequest {
338 responder,
339 message: initialize_request,
340 } = context.recv_from_handler().await?;
341 let (message, session_id) = match self
342 .client
343 .post_message(
344 config.uri.clone(),
345 initialize_request,
346 None,
347 self.config.auth_header,
348 self.config.custom_headers,
349 )
350 .await
351 {
352 Ok(res) => {
353 let _ = responder.send(Ok(()));
354 res.expect_initialized::<C::Error>().await.map_err(
355 WorkerQuitReason::fatal_context("process initialize response"),
356 )?
357 }
358 Err(err) => {
359 let msg = format!("{:?}", err);
360 let _ = responder.send(Err(err));
361 return Err(WorkerQuitReason::fatal(
362 StreamableHttpError::TransportChannelClosed,
363 msg,
364 ));
365 }
366 };
367 let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
368 Some(session_id.into())
369 } else {
370 if !self.config.allow_stateless {
371 return Err(WorkerQuitReason::fatal(
372 StreamableHttpError::<C::Error>::MissingSessionIdInResponse,
373 "process initialize response",
374 ));
375 }
376 None
377 };
378 let protocol_headers = {
382 let mut headers = config.custom_headers.clone();
383 if let ServerJsonRpcMessage::Response(response) = &message {
384 if let ServerResult::InitializeResult(init_result) = &response.result {
385 if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
386 headers.insert(HeaderName::from_static("mcp-protocol-version"), hv);
388 }
389 }
390 }
391 headers
392 };
393
394 let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
396 client: self.client.clone(),
397 uri: config.uri.clone(),
398 session_id: sid.clone(),
399 auth_header: config.auth_header.clone(),
400 protocol_headers: protocol_headers.clone(),
401 });
402
403 context.send_to_handler(message).await?;
404 let initialized_notification = context.recv_from_handler().await?;
405 self.client
407 .post_message(
408 config.uri.clone(),
409 initialized_notification.message,
410 session_id.clone(),
411 config.auth_header.clone(),
412 protocol_headers.clone(),
413 )
414 .await
415 .map_err(WorkerQuitReason::fatal_context(
416 "send initialized notification",
417 ))?
418 .expect_accepted_or_json::<C::Error>()
419 .map_err(WorkerQuitReason::fatal_context(
420 "process initialized notification response",
421 ))?;
422 let _ = initialized_notification.responder.send(Ok(()));
423 #[allow(clippy::large_enum_variant)]
424 enum Event<W: Worker, E: std::error::Error + Send + Sync + 'static> {
425 ClientMessage(WorkerSendRequest<W>),
426 ServerMessage(ServerJsonRpcMessage),
427 StreamResult(Result<(), StreamableHttpError<E>>),
428 }
429 let mut streams = tokio::task::JoinSet::new();
430 if let Some(session_id) = &session_id {
431 let client = self.client.clone();
432 let uri = config.uri.clone();
433 let session_id = session_id.clone();
434 let auth_header = config.auth_header.clone();
435 let retry_config = self.config.retry_config.clone();
436 let sse_worker_tx = sse_worker_tx.clone();
437 let transport_task_ct = transport_task_ct.clone();
438 let config_uri = config.uri.clone();
439 let config_auth_header = config.auth_header.clone();
440 let spawn_headers = protocol_headers.clone();
441
442 streams.spawn(async move {
443 match client
444 .get_stream(
445 uri.clone(),
446 session_id.clone(),
447 None,
448 auth_header.clone(),
449 spawn_headers.clone(),
450 )
451 .await
452 {
453 Ok(stream) => {
454 let sse_stream = SseAutoReconnectStream::new(
455 stream,
456 StreamableHttpClientReconnect {
457 client: client.clone(),
458 session_id: session_id.clone(),
459 uri: config_uri,
460 auth_header: config_auth_header,
461 custom_headers: spawn_headers,
462 },
463 retry_config,
464 );
465 Self::execute_sse_stream(
466 sse_stream,
467 sse_worker_tx,
468 false,
469 transport_task_ct.child_token(),
470 )
471 .await
472 }
473 Err(StreamableHttpError::ServerDoesNotSupportSse) => {
474 tracing::debug!("server doesn't support sse, skip common stream");
475 Ok(())
476 }
477 Err(e) => {
478 tracing::error!("fail to get common stream: {e}");
480 Err(e)
481 }
482 }
483 });
484 }
485 let loop_result: Result<(), WorkerQuitReason<Self::Error>> = 'main_loop: loop {
487 let event = tokio::select! {
488 _ = transport_task_ct.cancelled() => {
489 tracing::debug!("cancelled");
490 break 'main_loop Err(WorkerQuitReason::Cancelled);
491 }
492 message = context.recv_from_handler() => {
493 match message {
494 Ok(msg) => Event::ClientMessage(msg),
495 Err(e) => break 'main_loop Err(e),
496 }
497 },
498 message = sse_worker_rx.recv() => {
499 let Some(message) = message else {
500 tracing::trace!("transport dropped, exiting");
501 break 'main_loop Err(WorkerQuitReason::HandlerTerminated);
502 };
503 Event::ServerMessage(message)
504 },
505 terminated_stream = streams.join_next(), if !streams.is_empty() => {
506 match terminated_stream {
507 Some(result) => {
508 Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity))
509 }
510 None => {
511 continue
512 }
513 }
514 }
515 };
516 match event {
517 Event::ClientMessage(send_request) => {
518 let WorkerSendRequest { message, responder } = send_request;
519 let response = self
520 .client
521 .post_message(
522 config.uri.clone(),
523 message,
524 session_id.clone(),
525 config.auth_header.clone(),
526 protocol_headers.clone(),
527 )
528 .await;
529 let send_result = match response {
530 Err(e) => Err(e),
531 Ok(StreamableHttpPostResponse::Accepted) => {
532 tracing::trace!("client message accepted");
533 Ok(())
534 }
535 Ok(StreamableHttpPostResponse::Json(message, ..)) => {
536 context.send_to_handler(message).await?;
537 Ok(())
538 }
539 Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
540 if let Some(session_id) = &session_id {
541 let sse_stream = SseAutoReconnectStream::new(
542 stream,
543 StreamableHttpClientReconnect {
544 client: self.client.clone(),
545 session_id: session_id.clone(),
546 uri: config.uri.clone(),
547 auth_header: config.auth_header.clone(),
548 custom_headers: protocol_headers.clone(),
549 },
550 self.config.retry_config.clone(),
551 );
552 streams.spawn(Self::execute_sse_stream(
553 sse_stream,
554 sse_worker_tx.clone(),
555 true,
556 transport_task_ct.child_token(),
557 ));
558 } else {
559 let sse_stream = SseAutoReconnectStream::never_reconnect(
560 stream,
561 StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
562 );
563 streams.spawn(Self::execute_sse_stream(
564 sse_stream,
565 sse_worker_tx.clone(),
566 true,
567 transport_task_ct.child_token(),
568 ));
569 }
570 tracing::trace!("got new sse stream");
571 Ok(())
572 }
573 };
574 let _ = responder.send(send_result);
575 }
576 Event::ServerMessage(json_rpc_message) => {
577 if let Err(e) = context.send_to_handler(json_rpc_message).await {
579 break 'main_loop Err(e);
580 }
581 }
582 Event::StreamResult(result) => {
583 if result.is_err() {
584 tracing::warn!(
585 "sse client event stream terminated with error: {:?}",
586 result
587 );
588 }
589 }
590 }
591 };
592
593 if let Some(cleanup) = session_cleanup_info {
596 const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
597 let cleanup_session_id = cleanup.session_id.clone();
598 match tokio::time::timeout(
599 SESSION_CLEANUP_TIMEOUT,
600 cleanup.client.delete_session(
601 cleanup.uri,
602 cleanup.session_id,
603 cleanup.auth_header,
604 cleanup.protocol_headers,
605 ),
606 )
607 .await
608 {
609 Ok(Ok(_)) => {
610 tracing::info!(
611 session_id = cleanup_session_id.as_ref(),
612 "delete session success"
613 )
614 }
615 Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
616 tracing::info!(
617 session_id = cleanup_session_id.as_ref(),
618 "server doesn't support delete session"
619 )
620 }
621 Ok(Err(e)) => {
622 tracing::error!(
623 session_id = cleanup_session_id.as_ref(),
624 "fail to delete session: {e}"
625 );
626 }
627 Err(_elapsed) => {
628 tracing::warn!(
629 session_id = cleanup_session_id.as_ref(),
630 "session cleanup timed out after {:?}",
631 SESSION_CLEANUP_TIMEOUT
632 );
633 }
634 }
635 }
636
637 loop_result
638 }
639}
640
641pub type StreamableHttpClientTransport<C> = WorkerTransport<StreamableHttpClientWorker<C>>;
734
735impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
736 pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self {
816 let worker = StreamableHttpClientWorker::new(client, config);
817 WorkerTransport::spawn(worker)
818 }
819}
820#[derive(Debug, Clone)]
821pub struct StreamableHttpClientTransportConfig {
822 pub uri: Arc<str>,
823 pub retry_config: Arc<dyn SseRetryPolicy>,
824 pub channel_buffer_capacity: usize,
825 pub allow_stateless: bool,
827 pub auth_header: Option<String>,
829 pub custom_headers: HashMap<HeaderName, HeaderValue>,
831}
832
833impl StreamableHttpClientTransportConfig {
834 pub fn with_uri(uri: impl Into<Arc<str>>) -> Self {
835 Self {
836 uri: uri.into(),
837 ..Default::default()
838 }
839 }
840
841 pub fn auth_header<T: Into<String>>(mut self, value: T) -> Self {
847 self.auth_header = Some(value.into());
849 self
850 }
851
852 pub fn custom_headers(mut self, custom_headers: HashMap<HeaderName, HeaderValue>) -> Self {
875 self.custom_headers = custom_headers;
876 self
877 }
878}
879
880impl Default for StreamableHttpClientTransportConfig {
881 fn default() -> Self {
882 Self {
883 uri: "localhost".into(),
884 retry_config: Arc::new(ExponentialBackoff::default()),
885 channel_buffer_capacity: 16,
886 allow_stateless: true,
887 auth_header: None,
888 custom_headers: HashMap::new(),
889 }
890 }
891}