Skip to main content

rmcp_soddygo/transport/
streamable_http_client.rs

1use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
2
3use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4use http::{HeaderName, HeaderValue};
5pub use sse_stream::Error as SseError;
6use sse_stream::Sse;
7use thiserror::Error;
8use tokio_util::sync::CancellationToken;
9use tracing::debug;
10
11use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
12use crate::{
13    RoleClient,
14    model::{
15        ClientJsonRpcMessage, ClientNotification, InitializedNotification, ServerJsonRpcMessage,
16        ServerResult,
17    },
18    transport::{
19        common::client_side_sse::SseAutoReconnectStream,
20        worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
21    },
22};
23
24type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
25
26#[derive(Debug)]
27#[non_exhaustive]
28pub struct AuthRequiredError {
29    pub www_authenticate_header: String,
30}
31
32impl AuthRequiredError {
33    /// Create a new `AuthRequiredError` instance.
34    pub fn new(www_authenticate_header: String) -> Self {
35        Self {
36            www_authenticate_header,
37        }
38    }
39}
40
41#[derive(Debug)]
42#[non_exhaustive]
43pub struct InsufficientScopeError {
44    pub www_authenticate_header: String,
45    pub required_scope: Option<String>,
46}
47
48impl InsufficientScopeError {
49    /// Create a new `InsufficientScopeError` instance.
50    pub fn new(www_authenticate_header: String, required_scope: Option<String>) -> Self {
51        Self {
52            www_authenticate_header,
53            required_scope,
54        }
55    }
56
57    /// check if scope upgrade is possible (i.e., we know what scope is required)
58    pub fn can_upgrade(&self) -> bool {
59        self.required_scope.is_some()
60    }
61
62    /// get the required scope for upgrade
63    pub fn get_required_scope(&self) -> Option<&str> {
64        self.required_scope.as_deref()
65    }
66}
67
68#[derive(Error, Debug)]
69#[non_exhaustive]
70pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
71    #[error("SSE error: {0}")]
72    Sse(#[from] SseError),
73    #[error("Io error: {0}")]
74    Io(#[from] std::io::Error),
75    #[error("Client error: {0}")]
76    Client(E),
77    #[error("unexpected end of stream")]
78    UnexpectedEndOfStream,
79    #[error("unexpected server response: {0}")]
80    UnexpectedServerResponse(Cow<'static, str>),
81    #[error("Unexpected content type: {0:?}")]
82    UnexpectedContentType(Option<String>),
83    #[error("Server does not support SSE")]
84    ServerDoesNotSupportSse,
85    #[error("Server does not support delete session")]
86    ServerDoesNotSupportDeleteSession,
87    #[error("Tokio join error: {0}")]
88    TokioJoinError(#[from] tokio::task::JoinError),
89    #[error("Deserialize error: {0}")]
90    Deserialize(#[from] serde_json::Error),
91    #[error("Transport channel closed")]
92    TransportChannelClosed,
93    #[error("Missing session id in HTTP response")]
94    MissingSessionIdInResponse,
95    #[cfg(feature = "auth")]
96    #[error("Auth error: {0}")]
97    Auth(#[from] crate::transport::auth::AuthError),
98    #[error("Auth required")]
99    AuthRequired(AuthRequiredError),
100    #[error("Insufficient scope")]
101    InsufficientScope(InsufficientScopeError),
102    #[error("Header name '{0}' is reserved and conflicts with default headers")]
103    ReservedHeaderConflict(String),
104    #[error("Session expired (HTTP 404)")]
105    SessionExpired,
106}
107
108#[derive(Debug, Clone, Error)]
109#[non_exhaustive]
110pub enum StreamableHttpProtocolError {
111    #[error("Missing session id in response")]
112    MissingSessionIdInResponse,
113}
114
115#[allow(clippy::large_enum_variant)]
116#[non_exhaustive]
117pub enum StreamableHttpPostResponse {
118    Accepted,
119    Json(ServerJsonRpcMessage, Option<String>),
120    Sse(BoxedSseStream, Option<String>),
121}
122
123impl std::fmt::Debug for StreamableHttpPostResponse {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        match self {
126            Self::Accepted => write!(f, "Accepted"),
127            Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(),
128            Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(),
129        }
130    }
131}
132
133impl StreamableHttpPostResponse {
134    pub async fn expect_initialized<E>(
135        self,
136    ) -> Result<(ServerJsonRpcMessage, Option<String>), StreamableHttpError<E>>
137    where
138        E: std::error::Error + Send + Sync + 'static,
139    {
140        match self {
141            Self::Json(message, session_id) => Ok((message, session_id)),
142            Self::Sse(mut stream, session_id) => {
143                while let Some(event) = stream.next().await {
144                    let event = event?;
145                    let payload = event.data.unwrap_or_default();
146                    if payload.trim().is_empty() {
147                        continue;
148                    }
149
150                    let message: ServerJsonRpcMessage = serde_json::from_str(&payload)?;
151
152                    if matches!(message, ServerJsonRpcMessage::Response(_)) {
153                        return Ok((message, session_id));
154                    }
155
156                    debug!(
157                        ?message,
158                        "received message before initialize response; continuing to drain stream"
159                    );
160                }
161
162                Err(StreamableHttpError::UnexpectedServerResponse(
163                    "empty sse stream".into(),
164                ))
165            }
166            _ => Err(StreamableHttpError::UnexpectedServerResponse(
167                "expect initialized, accepted".into(),
168            )),
169        }
170    }
171
172    pub fn expect_json<E>(self) -> Result<ServerJsonRpcMessage, StreamableHttpError<E>>
173    where
174        E: std::error::Error + Send + Sync + 'static,
175    {
176        match self {
177            Self::Json(message, ..) => Ok(message),
178            got => Err(StreamableHttpError::UnexpectedServerResponse(
179                format!("expect json, got {got:?}").into(),
180            )),
181        }
182    }
183
184    pub fn expect_accepted_or_json<E>(self) -> Result<(), StreamableHttpError<E>>
185    where
186        E: std::error::Error + Send + Sync + 'static,
187    {
188        match self {
189            Self::Accepted => Ok(()),
190            // Tolerate servers that return 200 with JSON for notifications
191            Self::Json(..) => Ok(()),
192            got => Err(StreamableHttpError::UnexpectedServerResponse(
193                format!("expect accepted or json, got {got:?}").into(),
194            )),
195        }
196    }
197}
198
199pub trait StreamableHttpClient: Clone + Send + 'static {
200    type Error: std::error::Error + Send + Sync + 'static;
201    fn post_message(
202        &self,
203        uri: Arc<str>,
204        message: ClientJsonRpcMessage,
205        session_id: Option<Arc<str>>,
206        auth_header: Option<String>,
207        custom_headers: HashMap<HeaderName, HeaderValue>,
208    ) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
209    + Send
210    + '_;
211    fn delete_session(
212        &self,
213        uri: Arc<str>,
214        session_id: Arc<str>,
215        auth_header: Option<String>,
216        custom_headers: HashMap<HeaderName, HeaderValue>,
217    ) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
218    fn get_stream(
219        &self,
220        uri: Arc<str>,
221        session_id: Arc<str>,
222        last_event_id: Option<String>,
223        auth_header: Option<String>,
224        custom_headers: HashMap<HeaderName, HeaderValue>,
225    ) -> impl Future<
226        Output = Result<
227            BoxStream<'static, Result<Sse, SseError>>,
228            StreamableHttpError<Self::Error>,
229        >,
230    > + Send
231    + '_;
232}
233
234#[non_exhaustive]
235pub struct RetryConfig {
236    pub max_times: Option<usize>,
237    pub min_duration: Duration,
238}
239
240struct StreamableHttpClientReconnect<C> {
241    pub client: C,
242    pub session_id: Arc<str>,
243    pub uri: Arc<str>,
244    pub auth_header: Option<String>,
245    pub custom_headers: HashMap<HeaderName, HeaderValue>,
246}
247
248impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
249    type Error = StreamableHttpError<C::Error>;
250    type Future = BoxFuture<'static, Result<BoxedSseStream, Self::Error>>;
251    fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future {
252        let client = self.client.clone();
253        let uri = self.uri.clone();
254        let session_id = self.session_id.clone();
255        let auth_header = self.auth_header.clone();
256        let custom_headers = self.custom_headers.clone();
257        let last_event_id = last_event_id.map(|s| s.to_owned());
258        Box::pin(async move {
259            client
260                .get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
261                .await
262        })
263    }
264}
265
266/// Info retained for cleaning up the session when the worker exits.
267struct SessionCleanupInfo<C> {
268    client: C,
269    uri: Arc<str>,
270    session_id: Arc<str>,
271    auth_header: Option<String>,
272    protocol_headers: HashMap<HeaderName, HeaderValue>,
273}
274
275#[derive(Debug, Clone, Default)]
276#[non_exhaustive]
277pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
278    pub client: C,
279    pub config: StreamableHttpClientTransportConfig,
280}
281
282impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
283    pub fn new_simple(url: impl Into<Arc<str>>) -> Self {
284        Self {
285            client: C::default(),
286            config: StreamableHttpClientTransportConfig {
287                uri: url.into(),
288                ..Default::default()
289            },
290        }
291    }
292}
293
294impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
295    pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
296        Self { client, config }
297    }
298}
299
300impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
301    /// Convert a raw SSE stream into a JSON-RPC message stream without
302    /// reconnection logic.
303    fn raw_sse_to_jsonrpc(
304        stream: BoxedSseStream,
305    ) -> impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>> + Send + 'static
306    {
307        stream.filter_map(|event| async {
308            match event {
309                Err(e) => Some(Err(StreamableHttpError::Sse(e))),
310                Ok(sse) => {
311                    let is_message =
312                        matches!(sse.event.as_deref(), None | Some("") | Some("message"));
313                    if !is_message {
314                        return None;
315                    }
316                    let data = sse.data?;
317                    if data.trim().is_empty() {
318                        return None;
319                    }
320                    match serde_json::from_str::<ServerJsonRpcMessage>(&data) {
321                        Ok(msg) => Some(Ok(msg)),
322                        Err(e) => {
323                            tracing::debug!("failed to deserialize server message: {e}");
324                            None
325                        }
326                    }
327                }
328            }
329        })
330    }
331
332    async fn execute_sse_stream(
333        sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
334        + Send
335        + 'static,
336        sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
337        close_on_response: bool,
338        ct: CancellationToken,
339    ) -> Result<(), StreamableHttpError<C::Error>> {
340        let mut sse_stream = std::pin::pin!(sse_stream);
341        loop {
342            let message = tokio::select! {
343                event = sse_stream.next() => {
344                    event
345                }
346                _ = ct.cancelled() => {
347                    tracing::debug!("cancelled");
348                    break;
349                }
350            };
351            let Some(message) = message.transpose()? else {
352                break;
353            };
354            let is_response = matches!(
355                message,
356                ServerJsonRpcMessage::Response(_) | ServerJsonRpcMessage::Error(_)
357            );
358            let yield_result = sse_worker_tx.send(message).await;
359            if yield_result.is_err() {
360                tracing::trace!("streamable http transport worker dropped, exiting");
361                break;
362            }
363            if close_on_response && is_response {
364                tracing::debug!("got response, draining sse stream for connection reuse");
365                // Consume the remaining stream so the HTTP/1.1 connection
366                // returns to the pool cleanly.
367                let _ = tokio::time::timeout(std::time::Duration::from_millis(50), async {
368                    while sse_stream.next().await.is_some() {}
369                })
370                .await;
371                break;
372            }
373        }
374        Ok(())
375    }
376
377    /// Performs a transparent re-initialization handshake after a session-expired 404.
378    ///
379    /// Takes an owned clone of the client (avoiding `&self` across `.await` so the
380    /// future remains `Send` without requiring `C: Sync`).  POSTs the saved
381    /// initialize request without a session ID, extracts the new session ID and
382    /// protocol version, sends `notifications/initialized`, and returns the new
383    /// `(session_id, protocol_headers)` pair.  The init result message is **not**
384    /// forwarded to the handler because the handler already processed the original
385    /// initialization.
386    async fn perform_reinitialization(
387        client: C,
388        saved_init_request: ClientJsonRpcMessage,
389        uri: Arc<str>,
390        auth_header: Option<String>,
391        custom_headers: HashMap<HeaderName, HeaderValue>,
392    ) -> Result<(Option<Arc<str>>, HashMap<HeaderName, HeaderValue>), StreamableHttpError<C::Error>>
393    {
394        let (init_msg, new_session_id_str) = client
395            .post_message(
396                uri.clone(),
397                saved_init_request,
398                None,
399                auth_header.clone(),
400                custom_headers.clone(),
401            )
402            .await?
403            .expect_initialized::<C::Error>()
404            .await?;
405
406        let new_session_id: Option<Arc<str>> = new_session_id_str.map(|s| Arc::from(s.as_str()));
407
408        // Start from custom_headers, then inject the negotiated MCP-Protocol-Version
409        // so all subsequent requests carry the right version (MCP 2025-06-18 spec).
410        let mut new_protocol_headers = custom_headers;
411        if let ServerJsonRpcMessage::Response(response) = &init_msg {
412            if let ServerResult::InitializeResult(init_result) = &response.result {
413                if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
414                    new_protocol_headers
415                        .insert(HeaderName::from_static("mcp-protocol-version"), hv);
416                }
417            }
418        }
419
420        let initialized_notification = ClientJsonRpcMessage::notification(
421            ClientNotification::InitializedNotification(InitializedNotification {
422                method: Default::default(),
423                extensions: Default::default(),
424            }),
425        );
426        client
427            .post_message(
428                uri,
429                initialized_notification,
430                new_session_id.clone(),
431                auth_header,
432                new_protocol_headers.clone(),
433            )
434            .await?
435            .expect_accepted_or_json::<C::Error>()?;
436
437        Ok((new_session_id, new_protocol_headers))
438    }
439}
440
441impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
442    type Role = RoleClient;
443    type Error = StreamableHttpError<C::Error>;
444    fn err_closed() -> Self::Error {
445        StreamableHttpError::TransportChannelClosed
446    }
447    fn err_join(e: tokio::task::JoinError) -> Self::Error {
448        StreamableHttpError::TokioJoinError(e)
449    }
450    fn config(&self) -> super::worker::WorkerConfig {
451        super::worker::WorkerConfig {
452            name: Some("StreamableHttpClientWorker".into()),
453            channel_buffer_capacity: self.config.channel_buffer_capacity,
454        }
455    }
456    async fn run(
457        self,
458        mut context: super::worker::WorkerContext<Self>,
459    ) -> Result<(), WorkerQuitReason<Self::Error>> {
460        let channel_buffer_capacity = self.config.channel_buffer_capacity;
461        let (sse_worker_tx, mut sse_worker_rx) =
462            tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
463        let config = self.config.clone();
464        let transport_task_ct = context.cancellation_token.clone();
465        let _drop_guard = transport_task_ct.clone().drop_guard();
466        let WorkerSendRequest {
467            responder,
468            message: initialize_request,
469        } = context.recv_from_handler().await?;
470        let saved_init_request = initialize_request.clone();
471        let (message, session_id) = match self
472            .client
473            .post_message(
474                config.uri.clone(),
475                initialize_request,
476                None,
477                config.auth_header.clone(),
478                config.custom_headers.clone(),
479            )
480            .await
481        {
482            Ok(res) => {
483                let _ = responder.send(Ok(()));
484                res.expect_initialized::<C::Error>().await.map_err(
485                    WorkerQuitReason::fatal_context("process initialize response"),
486                )?
487            }
488            Err(err) => {
489                let msg = format!("{:?}", err);
490                let _ = responder.send(Err(err));
491                return Err(WorkerQuitReason::fatal(
492                    StreamableHttpError::TransportChannelClosed,
493                    msg,
494                ));
495            }
496        };
497        let mut session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
498            Some(session_id.into())
499        } else {
500            if !self.config.allow_stateless {
501                return Err(WorkerQuitReason::fatal(
502                    StreamableHttpError::<C::Error>::MissingSessionIdInResponse,
503                    "process initialize response",
504                ));
505            }
506            None
507        };
508        // Extract the negotiated protocol version from the init response
509        // and build a custom headers map that includes MCP-Protocol-Version
510        // for all subsequent HTTP requests (per MCP 2025-06-18 spec).
511        let mut protocol_headers = {
512            let mut headers = config.custom_headers.clone();
513            if let ServerJsonRpcMessage::Response(response) = &message {
514                if let ServerResult::InitializeResult(init_result) = &response.result {
515                    if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
516                        // HeaderName::from_static requires lowercase
517                        headers.insert(HeaderName::from_static("mcp-protocol-version"), hv);
518                    }
519                }
520            }
521            headers
522        };
523
524        // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
525        let mut session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
526            client: self.client.clone(),
527            uri: config.uri.clone(),
528            session_id: sid.clone(),
529            auth_header: config.auth_header.clone(),
530            protocol_headers: protocol_headers.clone(),
531        });
532
533        context.send_to_handler(message).await?;
534        let initialized_notification = context.recv_from_handler().await?;
535        // expect a initialized response
536        self.client
537            .post_message(
538                config.uri.clone(),
539                initialized_notification.message,
540                session_id.clone(),
541                config.auth_header.clone(),
542                protocol_headers.clone(),
543            )
544            .await
545            .map_err(WorkerQuitReason::fatal_context(
546                "send initialized notification",
547            ))?
548            .expect_accepted_or_json::<C::Error>()
549            .map_err(WorkerQuitReason::fatal_context(
550                "process initialized notification response",
551            ))?;
552        let _ = initialized_notification.responder.send(Ok(()));
553        #[allow(clippy::large_enum_variant)]
554        enum Event<W: Worker, E: std::error::Error + Send + Sync + 'static> {
555            ClientMessage(WorkerSendRequest<W>),
556            ServerMessage(ServerJsonRpcMessage),
557            StreamResult(Result<(), StreamableHttpError<E>>),
558        }
559        let mut streams = tokio::task::JoinSet::new();
560        if let Some(session_id) = &session_id {
561            let client = self.client.clone();
562            let uri = config.uri.clone();
563            let session_id = session_id.clone();
564            let auth_header = config.auth_header.clone();
565            let retry_config = self.config.retry_config.clone();
566            let sse_worker_tx = sse_worker_tx.clone();
567            let transport_task_ct = transport_task_ct.clone();
568            let config_uri = config.uri.clone();
569            let config_auth_header = config.auth_header.clone();
570            let spawn_headers = protocol_headers.clone();
571
572            streams.spawn(async move {
573                match client
574                    .get_stream(
575                        uri.clone(),
576                        session_id.clone(),
577                        None,
578                        auth_header.clone(),
579                        spawn_headers.clone(),
580                    )
581                    .await
582                {
583                    Ok(stream) => {
584                        let sse_stream = SseAutoReconnectStream::new(
585                            stream,
586                            StreamableHttpClientReconnect {
587                                client: client.clone(),
588                                session_id: session_id.clone(),
589                                uri: config_uri,
590                                auth_header: config_auth_header,
591                                custom_headers: spawn_headers,
592                            },
593                            retry_config,
594                        );
595                        Self::execute_sse_stream(
596                            sse_stream,
597                            sse_worker_tx,
598                            false,
599                            transport_task_ct.child_token(),
600                        )
601                        .await
602                    }
603                    Err(StreamableHttpError::ServerDoesNotSupportSse) => {
604                        tracing::debug!("server doesn't support sse, skip common stream");
605                        Ok(())
606                    }
607                    Err(e) => {
608                        // fail to get common stream
609                        tracing::error!("fail to get common stream: {e}");
610                        Err(e)
611                    }
612                }
613            });
614        }
615        // Main event loop - capture exit reason so we can do cleanup before returning
616        let loop_result: Result<(), WorkerQuitReason<Self::Error>> = 'main_loop: loop {
617            let event = tokio::select! {
618                _ = transport_task_ct.cancelled() => {
619                    tracing::debug!("cancelled");
620                    break 'main_loop Err(WorkerQuitReason::Cancelled);
621                }
622                message = context.recv_from_handler() => {
623                    match message {
624                        Ok(msg) => Event::ClientMessage(msg),
625                        Err(e) => break 'main_loop Err(e),
626                    }
627                },
628                message = sse_worker_rx.recv() => {
629                    let Some(message) = message else {
630                        tracing::trace!("transport dropped, exiting");
631                        break 'main_loop Err(WorkerQuitReason::HandlerTerminated);
632                    };
633                    Event::ServerMessage(message)
634                },
635                terminated_stream = streams.join_next(), if !streams.is_empty() => {
636                    match terminated_stream {
637                        Some(result) => {
638                            Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity))
639                        }
640                        None => {
641                            continue
642                        }
643                    }
644                }
645            };
646            match event {
647                Event::ClientMessage(send_request) => {
648                    let WorkerSendRequest { message, responder } = send_request;
649                    // Pass a clone to the first attempt so `message` is retained for a
650                    // potential re-init retry. `post_message` takes ownership and the
651                    // trait cannot be changed, so the clone is unavoidable.
652                    let response = self
653                        .client
654                        .post_message(
655                            config.uri.clone(),
656                            message.clone(),
657                            session_id.clone(),
658                            config.auth_header.clone(),
659                            protocol_headers.clone(),
660                        )
661                        .await;
662                    let send_result = match response {
663                        Err(StreamableHttpError::SessionExpired) => {
664                            if !config.reinit_on_expired_session {
665                                Err(StreamableHttpError::SessionExpired)
666                            } else {
667                                // The server discarded the session (HTTP 404). Perform a
668                                // fresh handshake once and replay the original message.
669                                tracing::info!(
670                                    "session expired (HTTP 404), attempting transparent re-initialization"
671                                );
672                                match Self::perform_reinitialization(
673                                    self.client.clone(),
674                                    saved_init_request.clone(),
675                                    config.uri.clone(),
676                                    config.auth_header.clone(),
677                                    config.custom_headers.clone(),
678                                )
679                                .await
680                                {
681                                    Ok((new_session_id, new_protocol_headers)) => {
682                                        // Old streams hold the stale session ID; abort them
683                                        // so the new standalone SSE stream takes over.
684                                        streams.abort_all();
685
686                                        session_id = new_session_id;
687                                        protocol_headers = new_protocol_headers;
688                                        session_cleanup_info =
689                                            session_id.as_ref().map(|sid| SessionCleanupInfo {
690                                                client: self.client.clone(),
691                                                uri: config.uri.clone(),
692                                                session_id: sid.clone(),
693                                                auth_header: config.auth_header.clone(),
694                                                protocol_headers: protocol_headers.clone(),
695                                            });
696
697                                        if let Some(new_sid) = &session_id {
698                                            let client = self.client.clone();
699                                            let uri = config.uri.clone();
700                                            let new_sid = new_sid.clone();
701                                            let auth_header = config.auth_header.clone();
702                                            let retry_config = self.config.retry_config.clone();
703                                            let sse_tx = sse_worker_tx.clone();
704                                            let task_ct = transport_task_ct.clone();
705                                            let config_uri = config.uri.clone();
706                                            let config_auth = config.auth_header.clone();
707                                            let spawn_headers = protocol_headers.clone();
708                                            streams.spawn(async move {
709                                            match client
710                                                .get_stream(
711                                                    uri,
712                                                    new_sid.clone(),
713                                                    None,
714                                                    auth_header.clone(),
715                                                    spawn_headers.clone(),
716                                                )
717                                                .await
718                                            {
719                                                Ok(stream) => {
720                                                    let sse_stream = SseAutoReconnectStream::new(
721                                                        stream,
722                                                        StreamableHttpClientReconnect {
723                                                            client: client.clone(),
724                                                            session_id: new_sid,
725                                                            uri: config_uri,
726                                                            auth_header: config_auth,
727                                                            custom_headers: spawn_headers,
728                                                        },
729                                                        retry_config,
730                                                    );
731                                                    Self::execute_sse_stream(
732                                                        sse_stream,
733                                                        sse_tx,
734                                                        false,
735                                                        task_ct.child_token(),
736                                                    )
737                                                    .await
738                                                }
739                                                Err(StreamableHttpError::ServerDoesNotSupportSse) => {
740                                                    tracing::debug!(
741                                                        "server doesn't support sse after re-init"
742                                                    );
743                                                    Ok(())
744                                                }
745                                                Err(e) => {
746                                                    tracing::error!(
747                                                        "fail to get common stream after re-init: {e}"
748                                                    );
749                                                    Err(e)
750                                                }
751                                            }
752                                        });
753                                        }
754
755                                        let retry_response = self
756                                            .client
757                                            .post_message(
758                                                config.uri.clone(),
759                                                message,
760                                                session_id.clone(),
761                                                config.auth_header.clone(),
762                                                protocol_headers.clone(),
763                                            )
764                                            .await;
765                                        match retry_response {
766                                            Err(e) => Err(e),
767                                            Ok(StreamableHttpPostResponse::Accepted) => {
768                                                tracing::trace!(
769                                                    "client message accepted after re-init"
770                                                );
771                                                Ok(())
772                                            }
773                                            Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
774                                                context.send_to_handler(msg).await?;
775                                                Ok(())
776                                            }
777                                            Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
778                                                streams.spawn(Self::execute_sse_stream(
779                                                    Self::raw_sse_to_jsonrpc(stream),
780                                                    sse_worker_tx.clone(),
781                                                    true,
782                                                    transport_task_ct.child_token(),
783                                                ));
784                                                tracing::trace!("got new sse stream after re-init");
785                                                Ok(())
786                                            }
787                                        }
788                                    }
789                                    Err(reinit_err) => Err(reinit_err),
790                                }
791                            } // else enable_reinit_on_expired_session
792                        }
793                        Err(e) => Err(e),
794                        Ok(StreamableHttpPostResponse::Accepted) => {
795                            tracing::trace!("client message accepted");
796                            Ok(())
797                        }
798                        Ok(StreamableHttpPostResponse::Json(message, ..)) => {
799                            context.send_to_handler(message).await?;
800                            Ok(())
801                        }
802                        Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
803                            streams.spawn(Self::execute_sse_stream(
804                                Self::raw_sse_to_jsonrpc(stream),
805                                sse_worker_tx.clone(),
806                                true,
807                                transport_task_ct.child_token(),
808                            ));
809                            tracing::trace!("got new sse stream");
810                            Ok(())
811                        }
812                    };
813                    let _ = responder.send(send_result);
814                }
815                Event::ServerMessage(json_rpc_message) => {
816                    // send the message to the handler
817                    if let Err(e) = context.send_to_handler(json_rpc_message).await {
818                        break 'main_loop Err(e);
819                    }
820                }
821                Event::StreamResult(result) => {
822                    if result.is_err() {
823                        tracing::warn!(
824                            "sse client event stream terminated with error: {:?}",
825                            result
826                        );
827                    }
828                }
829            }
830        };
831
832        // Cleanup session before returning (ensures close() waits for session deletion)
833        // Use a timeout to prevent indefinite hangs if the server is unresponsive
834        if let Some(cleanup) = session_cleanup_info {
835            const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
836            let cleanup_session_id = cleanup.session_id.clone();
837            match tokio::time::timeout(
838                SESSION_CLEANUP_TIMEOUT,
839                cleanup.client.delete_session(
840                    cleanup.uri,
841                    cleanup.session_id,
842                    cleanup.auth_header,
843                    cleanup.protocol_headers,
844                ),
845            )
846            .await
847            {
848                Ok(Ok(_)) => {
849                    tracing::info!(
850                        session_id = cleanup_session_id.as_ref(),
851                        "delete session success"
852                    )
853                }
854                Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
855                    tracing::info!(
856                        session_id = cleanup_session_id.as_ref(),
857                        "server doesn't support delete session"
858                    )
859                }
860                Ok(Err(e)) => {
861                    tracing::error!(
862                        session_id = cleanup_session_id.as_ref(),
863                        "fail to delete session: {e}"
864                    );
865                }
866                Err(_elapsed) => {
867                    tracing::warn!(
868                        session_id = cleanup_session_id.as_ref(),
869                        "session cleanup timed out after {:?}",
870                        SESSION_CLEANUP_TIMEOUT
871                    );
872                }
873            }
874        }
875
876        loop_result
877    }
878}
879
880/// A client-agnostic HTTP transport for RMCP that supports streaming responses.
881///
882/// This transport allows you to choose your preferred HTTP client implementation
883/// by implementing the [`StreamableHttpClient`] trait. The transport handles
884/// session management, SSE streaming, and automatic reconnection.
885///
886/// # Usage
887///
888/// ## Using reqwest
889///
890/// ```rust,no_run
891/// use rmcp::transport::StreamableHttpClientTransport;
892///
893/// // Enable the reqwest feature in Cargo.toml:
894/// // rmcp = { version = "0.5", features = ["transport-streamable-http-client-reqwest"] }
895///
896/// let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp");
897/// ```
898///
899/// ## Using a custom HTTP client
900///
901/// ```rust,no_run
902/// use rmcp::transport::streamable_http_client::{
903///     StreamableHttpClient,
904///     StreamableHttpClientTransport,
905///     StreamableHttpClientTransportConfig
906/// };
907/// use std::sync::Arc;
908/// use std::collections::HashMap;
909/// use futures::stream::BoxStream;
910/// use rmcp::model::ClientJsonRpcMessage;
911/// use http::{HeaderName, HeaderValue};
912/// use sse_stream::{Sse, Error as SseError};
913///
914/// #[derive(Clone)]
915/// struct MyHttpClient;
916///
917/// #[derive(Debug, thiserror::Error)]
918/// struct MyError;
919///
920/// impl std::fmt::Display for MyError {
921///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
922///         write!(f, "MyError")
923///     }
924/// }
925///
926/// impl StreamableHttpClient for MyHttpClient {
927///     type Error = MyError;
928///
929///     async fn post_message(
930///         &self,
931///         _uri: Arc<str>,
932///         _message: ClientJsonRpcMessage,
933///         _session_id: Option<Arc<str>>,
934///         _auth_header: Option<String>,
935///         _custom_headers: HashMap<HeaderName, HeaderValue>,
936///     ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
937///         todo!()
938///     }
939///
940///     async fn delete_session(
941///         &self,
942///         _uri: Arc<str>,
943///         _session_id: Arc<str>,
944///         _auth_header: Option<String>,
945///         _custom_headers: HashMap<HeaderName, HeaderValue>,
946///     ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
947///         todo!()
948///     }
949///
950///     async fn get_stream(
951///         &self,
952///         _uri: Arc<str>,
953///         _session_id: Arc<str>,
954///         _last_event_id: Option<String>,
955///         _auth_header: Option<String>,
956///         _custom_headers: HashMap<HeaderName, HeaderValue>,
957///     ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
958///         todo!()
959///     }
960/// }
961///
962/// let transport = StreamableHttpClientTransport::with_client(
963///     MyHttpClient,
964///     StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp")
965/// );
966/// ```
967///
968/// # Feature Flags
969///
970/// - `transport-streamable-http-client`: Base feature providing the generic transport infrastructure
971/// - `transport-streamable-http-client-reqwest`: Includes reqwest HTTP client support with convenience methods
972pub type StreamableHttpClientTransport<C> = WorkerTransport<StreamableHttpClientWorker<C>>;
973
974impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
975    /// Creates a new transport with a custom HTTP client implementation.
976    ///
977    /// This method allows you to use any HTTP client that implements the [`StreamableHttpClient`] trait.
978    /// Use this when you want to use a custom HTTP client or when the reqwest feature is not enabled.
979    ///
980    /// # Arguments
981    ///
982    /// * `client` - Your HTTP client implementation
983    /// * `config` - Transport configuration including the server URI
984    ///
985    /// # Example
986    ///
987    /// ```rust,no_run
988    /// use rmcp::transport::streamable_http_client::{
989    ///     StreamableHttpClient,
990    ///     StreamableHttpClientTransport,
991    ///     StreamableHttpClientTransportConfig
992    /// };
993    /// use std::sync::Arc;
994    /// use std::collections::HashMap;
995    /// use futures::stream::BoxStream;
996    /// use rmcp::model::ClientJsonRpcMessage;
997    /// use http::{HeaderName, HeaderValue};
998    /// use sse_stream::{Sse, Error as SseError};
999    ///
1000    /// // Define your custom client
1001    /// #[derive(Clone)]
1002    /// struct MyHttpClient;
1003    ///
1004    /// #[derive(Debug, thiserror::Error)]
1005    /// struct MyError;
1006    ///
1007    /// impl std::fmt::Display for MyError {
1008    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1009    ///         write!(f, "MyError")
1010    ///     }
1011    /// }
1012    ///
1013    /// impl StreamableHttpClient for MyHttpClient {
1014    ///     type Error = MyError;
1015    ///
1016    ///     async fn post_message(
1017    ///         &self,
1018    ///         _uri: Arc<str>,
1019    ///         _message: ClientJsonRpcMessage,
1020    ///         _session_id: Option<Arc<str>>,
1021    ///         _auth_header: Option<String>,
1022    ///         _custom_headers: HashMap<HeaderName, HeaderValue>,
1023    ///     ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
1024    ///         todo!()
1025    ///     }
1026    ///
1027    ///     async fn delete_session(
1028    ///         &self,
1029    ///         _uri: Arc<str>,
1030    ///         _session_id: Arc<str>,
1031    ///         _auth_header: Option<String>,
1032    ///         _custom_headers: HashMap<HeaderName, HeaderValue>,
1033    ///     ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
1034    ///         todo!()
1035    ///     }
1036    ///
1037    ///     async fn get_stream(
1038    ///         &self,
1039    ///         _uri: Arc<str>,
1040    ///         _session_id: Arc<str>,
1041    ///         _last_event_id: Option<String>,
1042    ///         _auth_header: Option<String>,
1043    ///         _custom_headers: HashMap<HeaderName, HeaderValue>,
1044    ///     ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
1045    ///         todo!()
1046    ///     }
1047    /// }
1048    ///
1049    /// let transport = StreamableHttpClientTransport::with_client(
1050    ///     MyHttpClient,
1051    ///     StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp")
1052    /// );
1053    /// ```
1054    pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self {
1055        let worker = StreamableHttpClientWorker::new(client, config);
1056        WorkerTransport::spawn(worker)
1057    }
1058}
1059#[derive(Debug, Clone)]
1060#[non_exhaustive]
1061pub struct StreamableHttpClientTransportConfig {
1062    pub uri: Arc<str>,
1063    pub retry_config: Arc<dyn SseRetryPolicy>,
1064    pub channel_buffer_capacity: usize,
1065    /// if true, the transport will not require a session to be established
1066    pub allow_stateless: bool,
1067    /// The value to send in the authorization header
1068    pub auth_header: Option<String>,
1069    /// Custom HTTP headers to include with every request
1070    pub custom_headers: HashMap<HeaderName, HeaderValue>,
1071    /// Enables transparent recovery when the server reports an expired session (`HTTP 404`).
1072    ///
1073    /// When enabled, the transport performs one automatic recovery attempt:
1074    /// 1. Replays the original `initialize` handshake to create a new session.
1075    /// 2. Re-establishes streaming state for that session.
1076    /// 3. Retries the in-flight request that failed with `SessionExpired`.
1077    ///
1078    /// This recovery is best-effort and bounded to a single attempt. If recovery fails,
1079    /// the original failure path is preserved and the error is returned to the caller.
1080    pub reinit_on_expired_session: bool,
1081}
1082
1083impl StreamableHttpClientTransportConfig {
1084    pub fn with_uri(uri: impl Into<Arc<str>>) -> Self {
1085        Self {
1086            uri: uri.into(),
1087            ..Default::default()
1088        }
1089    }
1090
1091    /// Set the authorization header to send with requests
1092    ///
1093    /// # Arguments
1094    ///
1095    /// * `value` - A bearer token without the `Bearer ` prefix
1096    pub fn auth_header<T: Into<String>>(mut self, value: T) -> Self {
1097        // set our authorization header
1098        self.auth_header = Some(value.into());
1099        self
1100    }
1101
1102    /// Set custom HTTP headers to include with every request
1103    ///
1104    /// # Arguments
1105    ///
1106    /// * `custom_headers` - A HashMap of header names to header values
1107    ///
1108    /// # Example
1109    ///
1110    /// ```rust,no_run
1111    /// use std::collections::HashMap;
1112    /// use http::{HeaderName, HeaderValue};
1113    /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
1114    ///
1115    /// let mut headers = HashMap::new();
1116    /// headers.insert(
1117    ///     HeaderName::from_static("x-custom-header"),
1118    ///     HeaderValue::from_static("custom-value")
1119    /// );
1120    ///
1121    /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
1122    ///     .custom_headers(headers);
1123    /// ```
1124    pub fn custom_headers(mut self, custom_headers: HashMap<HeaderName, HeaderValue>) -> Self {
1125        self.custom_headers = custom_headers;
1126        self
1127    }
1128
1129    /// Set whether the transport should attempt transparent re-initialization on session expiration
1130    /// See [`Self::reinit_on_expired_session`] for details.
1131    /// # Example
1132    /// ```rust,no_run
1133    /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
1134    /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
1135    ///     .reinit_on_expired_session(true);
1136    /// ```
1137    pub fn reinit_on_expired_session(mut self, enable: bool) -> Self {
1138        self.reinit_on_expired_session = enable;
1139        self
1140    }
1141}
1142
1143impl Default for StreamableHttpClientTransportConfig {
1144    fn default() -> Self {
1145        Self {
1146            uri: "localhost".into(),
1147            retry_config: Arc::new(ExponentialBackoff::default()),
1148            channel_buffer_capacity: 16,
1149            allow_stateless: true,
1150            auth_header: None,
1151            custom_headers: HashMap::new(),
1152            reinit_on_expired_session: true,
1153        }
1154    }
1155}