Skip to main content

rmcp_soddygo/transport/
streamable_http_client.rs

1use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
2
3use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4use http::{HeaderName, HeaderValue};
5pub use sse_stream::Error as SseError;
6use sse_stream::Sse;
7use thiserror::Error;
8use tokio_util::sync::CancellationToken;
9use tracing::debug;
10
11use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
12use crate::{
13    RoleClient,
14    model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
15    transport::{
16        common::client_side_sse::SseAutoReconnectStream,
17        worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
18    },
19};
20
21type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
22
23#[derive(Debug)]
24pub struct AuthRequiredError {
25    pub www_authenticate_header: String,
26}
27
28#[derive(Debug)]
29pub struct InsufficientScopeError {
30    pub www_authenticate_header: String,
31    pub required_scope: Option<String>,
32}
33
34impl InsufficientScopeError {
35    /// check if scope upgrade is possible (i.e., we know what scope is required)
36    pub fn can_upgrade(&self) -> bool {
37        self.required_scope.is_some()
38    }
39
40    /// get the required scope for upgrade
41    pub fn get_required_scope(&self) -> Option<&str> {
42        self.required_scope.as_deref()
43    }
44}
45
46#[derive(Error, Debug)]
47#[non_exhaustive]
48pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
49    #[error("SSE error: {0}")]
50    Sse(#[from] SseError),
51    #[error("Io error: {0}")]
52    Io(#[from] std::io::Error),
53    #[error("Client error: {0}")]
54    Client(E),
55    #[error("unexpected end of stream")]
56    UnexpectedEndOfStream,
57    #[error("unexpected server response: {0}")]
58    UnexpectedServerResponse(Cow<'static, str>),
59    #[error("Unexpected content type: {0:?}")]
60    UnexpectedContentType(Option<String>),
61    #[error("Server does not support SSE")]
62    ServerDoesNotSupportSse,
63    #[error("Server does not support delete session")]
64    ServerDoesNotSupportDeleteSession,
65    #[error("Tokio join error: {0}")]
66    TokioJoinError(#[from] tokio::task::JoinError),
67    #[error("Deserialize error: {0}")]
68    Deserialize(#[from] serde_json::Error),
69    #[error("Transport channel closed")]
70    TransportChannelClosed,
71    #[error("Missing session id in HTTP response")]
72    MissingSessionIdInResponse,
73    #[cfg(feature = "auth")]
74    #[error("Auth error: {0}")]
75    Auth(#[from] crate::transport::auth::AuthError),
76    #[error("Auth required")]
77    AuthRequired(AuthRequiredError),
78    #[error("Insufficient scope")]
79    InsufficientScope(InsufficientScopeError),
80    #[error("Header name '{0}' is reserved and conflicts with default headers")]
81    ReservedHeaderConflict(String),
82}
83
84#[derive(Debug, Clone, Error)]
85#[non_exhaustive]
86pub enum StreamableHttpProtocolError {
87    #[error("Missing session id in response")]
88    MissingSessionIdInResponse,
89}
90
91#[allow(clippy::large_enum_variant)]
92#[non_exhaustive]
93pub enum StreamableHttpPostResponse {
94    Accepted,
95    Json(ServerJsonRpcMessage, Option<String>),
96    Sse(BoxedSseStream, Option<String>),
97}
98
99impl std::fmt::Debug for StreamableHttpPostResponse {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            Self::Accepted => write!(f, "Accepted"),
103            Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(),
104            Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(),
105        }
106    }
107}
108
109impl StreamableHttpPostResponse {
110    pub async fn expect_initialized<E>(
111        self,
112    ) -> Result<(ServerJsonRpcMessage, Option<String>), StreamableHttpError<E>>
113    where
114        E: std::error::Error + Send + Sync + 'static,
115    {
116        match self {
117            Self::Json(message, session_id) => Ok((message, session_id)),
118            Self::Sse(mut stream, session_id) => {
119                while let Some(event) = stream.next().await {
120                    let event = event?;
121                    let payload = event.data.unwrap_or_default();
122                    if payload.trim().is_empty() {
123                        continue;
124                    }
125
126                    let message: ServerJsonRpcMessage = serde_json::from_str(&payload)?;
127
128                    if matches!(message, ServerJsonRpcMessage::Response(_)) {
129                        return Ok((message, session_id));
130                    }
131
132                    debug!(
133                        ?message,
134                        "received message before initialize response; continuing to drain stream"
135                    );
136                }
137
138                Err(StreamableHttpError::UnexpectedServerResponse(
139                    "empty sse stream".into(),
140                ))
141            }
142            _ => Err(StreamableHttpError::UnexpectedServerResponse(
143                "expect initialized, accepted".into(),
144            )),
145        }
146    }
147
148    pub fn expect_json<E>(self) -> Result<ServerJsonRpcMessage, StreamableHttpError<E>>
149    where
150        E: std::error::Error + Send + Sync + 'static,
151    {
152        match self {
153            Self::Json(message, ..) => Ok(message),
154            got => Err(StreamableHttpError::UnexpectedServerResponse(
155                format!("expect json, got {got:?}").into(),
156            )),
157        }
158    }
159
160    pub fn expect_accepted_or_json<E>(self) -> Result<(), StreamableHttpError<E>>
161    where
162        E: std::error::Error + Send + Sync + 'static,
163    {
164        match self {
165            Self::Accepted => Ok(()),
166            // Tolerate servers that return 200 with JSON for notifications
167            Self::Json(..) => Ok(()),
168            got => Err(StreamableHttpError::UnexpectedServerResponse(
169                format!("expect accepted or json, got {got:?}").into(),
170            )),
171        }
172    }
173}
174
175pub trait StreamableHttpClient: Clone + Send + 'static {
176    type Error: std::error::Error + Send + Sync + 'static;
177    fn post_message(
178        &self,
179        uri: Arc<str>,
180        message: ClientJsonRpcMessage,
181        session_id: Option<Arc<str>>,
182        auth_header: Option<String>,
183        custom_headers: HashMap<HeaderName, HeaderValue>,
184    ) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
185    + Send
186    + '_;
187    fn delete_session(
188        &self,
189        uri: Arc<str>,
190        session_id: Arc<str>,
191        auth_header: Option<String>,
192        custom_headers: HashMap<HeaderName, HeaderValue>,
193    ) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
194    fn get_stream(
195        &self,
196        uri: Arc<str>,
197        session_id: Arc<str>,
198        last_event_id: Option<String>,
199        auth_header: Option<String>,
200        custom_headers: HashMap<HeaderName, HeaderValue>,
201    ) -> impl Future<
202        Output = Result<
203            BoxStream<'static, Result<Sse, SseError>>,
204            StreamableHttpError<Self::Error>,
205        >,
206    > + Send
207    + '_;
208}
209
210pub struct RetryConfig {
211    pub max_times: Option<usize>,
212    pub min_duration: Duration,
213}
214
215struct StreamableHttpClientReconnect<C> {
216    pub client: C,
217    pub session_id: Arc<str>,
218    pub uri: Arc<str>,
219    pub auth_header: Option<String>,
220    pub custom_headers: HashMap<HeaderName, HeaderValue>,
221}
222
223impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
224    type Error = StreamableHttpError<C::Error>;
225    type Future = BoxFuture<'static, Result<BoxedSseStream, Self::Error>>;
226    fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future {
227        let client = self.client.clone();
228        let uri = self.uri.clone();
229        let session_id = self.session_id.clone();
230        let auth_header = self.auth_header.clone();
231        let custom_headers = self.custom_headers.clone();
232        let last_event_id = last_event_id.map(|s| s.to_owned());
233        Box::pin(async move {
234            client
235                .get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
236                .await
237        })
238    }
239}
240
241/// Info retained for cleaning up the session when the worker exits.
242struct SessionCleanupInfo<C> {
243    client: C,
244    uri: Arc<str>,
245    session_id: Arc<str>,
246    auth_header: Option<String>,
247    protocol_headers: HashMap<HeaderName, HeaderValue>,
248}
249
250#[derive(Debug, Clone, Default)]
251pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
252    pub client: C,
253    pub config: StreamableHttpClientTransportConfig,
254}
255
256impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
257    pub fn new_simple(url: impl Into<Arc<str>>) -> Self {
258        Self {
259            client: C::default(),
260            config: StreamableHttpClientTransportConfig {
261                uri: url.into(),
262                ..Default::default()
263            },
264        }
265    }
266}
267
268impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
269    pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
270        Self { client, config }
271    }
272}
273
274impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
275    async fn execute_sse_stream(
276        sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
277        + Send
278        + 'static,
279        sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
280        close_on_response: bool,
281        ct: CancellationToken,
282    ) -> Result<(), StreamableHttpError<C::Error>> {
283        let mut sse_stream = std::pin::pin!(sse_stream);
284        loop {
285            let message = tokio::select! {
286                event = sse_stream.next() => {
287                    event
288                }
289                _ = ct.cancelled() => {
290                    tracing::debug!("cancelled");
291                    break;
292                }
293            };
294            let Some(message) = message.transpose()? else {
295                break;
296            };
297            let is_response = matches!(message, ServerJsonRpcMessage::Response(_));
298            let yield_result = sse_worker_tx.send(message).await;
299            if yield_result.is_err() {
300                tracing::trace!("streamable http transport worker dropped, exiting");
301                break;
302            }
303            if close_on_response && is_response {
304                tracing::debug!("got response, closing sse stream");
305                break;
306            }
307        }
308        Ok(())
309    }
310}
311
312impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
313    type Role = RoleClient;
314    type Error = StreamableHttpError<C::Error>;
315    fn err_closed() -> Self::Error {
316        StreamableHttpError::TransportChannelClosed
317    }
318    fn err_join(e: tokio::task::JoinError) -> Self::Error {
319        StreamableHttpError::TokioJoinError(e)
320    }
321    fn config(&self) -> super::worker::WorkerConfig {
322        super::worker::WorkerConfig {
323            name: Some("StreamableHttpClientWorker".into()),
324            channel_buffer_capacity: self.config.channel_buffer_capacity,
325        }
326    }
327    async fn run(
328        self,
329        mut context: super::worker::WorkerContext<Self>,
330    ) -> Result<(), WorkerQuitReason<Self::Error>> {
331        let channel_buffer_capacity = self.config.channel_buffer_capacity;
332        let (sse_worker_tx, mut sse_worker_rx) =
333            tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
334        let config = self.config.clone();
335        let transport_task_ct = context.cancellation_token.clone();
336        let _drop_guard = transport_task_ct.clone().drop_guard();
337        let WorkerSendRequest {
338            responder,
339            message: initialize_request,
340        } = context.recv_from_handler().await?;
341        let (message, session_id) = match self
342            .client
343            .post_message(
344                config.uri.clone(),
345                initialize_request,
346                None,
347                self.config.auth_header,
348                self.config.custom_headers,
349            )
350            .await
351        {
352            Ok(res) => {
353                let _ = responder.send(Ok(()));
354                res.expect_initialized::<C::Error>().await.map_err(
355                    WorkerQuitReason::fatal_context("process initialize response"),
356                )?
357            }
358            Err(err) => {
359                let msg = format!("{:?}", err);
360                let _ = responder.send(Err(err));
361                return Err(WorkerQuitReason::fatal(
362                    StreamableHttpError::TransportChannelClosed,
363                    msg,
364                ));
365            }
366        };
367        let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
368            Some(session_id.into())
369        } else {
370            if !self.config.allow_stateless {
371                return Err(WorkerQuitReason::fatal(
372                    StreamableHttpError::<C::Error>::MissingSessionIdInResponse,
373                    "process initialize response",
374                ));
375            }
376            None
377        };
378        // Extract the negotiated protocol version from the init response
379        // and build a custom headers map that includes MCP-Protocol-Version
380        // for all subsequent HTTP requests (per MCP 2025-06-18 spec).
381        let protocol_headers = {
382            let mut headers = config.custom_headers.clone();
383            if let ServerJsonRpcMessage::Response(response) = &message {
384                if let ServerResult::InitializeResult(init_result) = &response.result {
385                    if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
386                        // HeaderName::from_static requires lowercase
387                        headers.insert(HeaderName::from_static("mcp-protocol-version"), hv);
388                    }
389                }
390            }
391            headers
392        };
393
394        // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
395        let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
396            client: self.client.clone(),
397            uri: config.uri.clone(),
398            session_id: sid.clone(),
399            auth_header: config.auth_header.clone(),
400            protocol_headers: protocol_headers.clone(),
401        });
402
403        context.send_to_handler(message).await?;
404        let initialized_notification = context.recv_from_handler().await?;
405        // expect a initialized response
406        self.client
407            .post_message(
408                config.uri.clone(),
409                initialized_notification.message,
410                session_id.clone(),
411                config.auth_header.clone(),
412                protocol_headers.clone(),
413            )
414            .await
415            .map_err(WorkerQuitReason::fatal_context(
416                "send initialized notification",
417            ))?
418            .expect_accepted_or_json::<C::Error>()
419            .map_err(WorkerQuitReason::fatal_context(
420                "process initialized notification response",
421            ))?;
422        let _ = initialized_notification.responder.send(Ok(()));
423        #[allow(clippy::large_enum_variant)]
424        enum Event<W: Worker, E: std::error::Error + Send + Sync + 'static> {
425            ClientMessage(WorkerSendRequest<W>),
426            ServerMessage(ServerJsonRpcMessage),
427            StreamResult(Result<(), StreamableHttpError<E>>),
428        }
429        let mut streams = tokio::task::JoinSet::new();
430        if let Some(session_id) = &session_id {
431            let client = self.client.clone();
432            let uri = config.uri.clone();
433            let session_id = session_id.clone();
434            let auth_header = config.auth_header.clone();
435            let retry_config = self.config.retry_config.clone();
436            let sse_worker_tx = sse_worker_tx.clone();
437            let transport_task_ct = transport_task_ct.clone();
438            let config_uri = config.uri.clone();
439            let config_auth_header = config.auth_header.clone();
440            let spawn_headers = protocol_headers.clone();
441
442            streams.spawn(async move {
443                match client
444                    .get_stream(
445                        uri.clone(),
446                        session_id.clone(),
447                        None,
448                        auth_header.clone(),
449                        spawn_headers.clone(),
450                    )
451                    .await
452                {
453                    Ok(stream) => {
454                        let sse_stream = SseAutoReconnectStream::new(
455                            stream,
456                            StreamableHttpClientReconnect {
457                                client: client.clone(),
458                                session_id: session_id.clone(),
459                                uri: config_uri,
460                                auth_header: config_auth_header,
461                                custom_headers: spawn_headers,
462                            },
463                            retry_config,
464                        );
465                        Self::execute_sse_stream(
466                            sse_stream,
467                            sse_worker_tx,
468                            false,
469                            transport_task_ct.child_token(),
470                        )
471                        .await
472                    }
473                    Err(StreamableHttpError::ServerDoesNotSupportSse) => {
474                        tracing::debug!("server doesn't support sse, skip common stream");
475                        Ok(())
476                    }
477                    Err(e) => {
478                        // fail to get common stream
479                        tracing::error!("fail to get common stream: {e}");
480                        Err(e)
481                    }
482                }
483            });
484        }
485        // Main event loop - capture exit reason so we can do cleanup before returning
486        let loop_result: Result<(), WorkerQuitReason<Self::Error>> = 'main_loop: loop {
487            let event = tokio::select! {
488                _ = transport_task_ct.cancelled() => {
489                    tracing::debug!("cancelled");
490                    break 'main_loop Err(WorkerQuitReason::Cancelled);
491                }
492                message = context.recv_from_handler() => {
493                    match message {
494                        Ok(msg) => Event::ClientMessage(msg),
495                        Err(e) => break 'main_loop Err(e),
496                    }
497                },
498                message = sse_worker_rx.recv() => {
499                    let Some(message) = message else {
500                        tracing::trace!("transport dropped, exiting");
501                        break 'main_loop Err(WorkerQuitReason::HandlerTerminated);
502                    };
503                    Event::ServerMessage(message)
504                },
505                terminated_stream = streams.join_next(), if !streams.is_empty() => {
506                    match terminated_stream {
507                        Some(result) => {
508                            Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity))
509                        }
510                        None => {
511                            continue
512                        }
513                    }
514                }
515            };
516            match event {
517                Event::ClientMessage(send_request) => {
518                    let WorkerSendRequest { message, responder } = send_request;
519                    let response = self
520                        .client
521                        .post_message(
522                            config.uri.clone(),
523                            message,
524                            session_id.clone(),
525                            config.auth_header.clone(),
526                            protocol_headers.clone(),
527                        )
528                        .await;
529                    let send_result = match response {
530                        Err(e) => Err(e),
531                        Ok(StreamableHttpPostResponse::Accepted) => {
532                            tracing::trace!("client message accepted");
533                            Ok(())
534                        }
535                        Ok(StreamableHttpPostResponse::Json(message, ..)) => {
536                            context.send_to_handler(message).await?;
537                            Ok(())
538                        }
539                        Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
540                            if let Some(session_id) = &session_id {
541                                let sse_stream = SseAutoReconnectStream::new(
542                                    stream,
543                                    StreamableHttpClientReconnect {
544                                        client: self.client.clone(),
545                                        session_id: session_id.clone(),
546                                        uri: config.uri.clone(),
547                                        auth_header: config.auth_header.clone(),
548                                        custom_headers: protocol_headers.clone(),
549                                    },
550                                    self.config.retry_config.clone(),
551                                );
552                                streams.spawn(Self::execute_sse_stream(
553                                    sse_stream,
554                                    sse_worker_tx.clone(),
555                                    true,
556                                    transport_task_ct.child_token(),
557                                ));
558                            } else {
559                                let sse_stream = SseAutoReconnectStream::never_reconnect(
560                                    stream,
561                                    StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
562                                );
563                                streams.spawn(Self::execute_sse_stream(
564                                    sse_stream,
565                                    sse_worker_tx.clone(),
566                                    true,
567                                    transport_task_ct.child_token(),
568                                ));
569                            }
570                            tracing::trace!("got new sse stream");
571                            Ok(())
572                        }
573                    };
574                    let _ = responder.send(send_result);
575                }
576                Event::ServerMessage(json_rpc_message) => {
577                    // send the message to the handler
578                    if let Err(e) = context.send_to_handler(json_rpc_message).await {
579                        break 'main_loop Err(e);
580                    }
581                }
582                Event::StreamResult(result) => {
583                    if result.is_err() {
584                        tracing::warn!(
585                            "sse client event stream terminated with error: {:?}",
586                            result
587                        );
588                    }
589                }
590            }
591        };
592
593        // Cleanup session before returning (ensures close() waits for session deletion)
594        // Use a timeout to prevent indefinite hangs if the server is unresponsive
595        if let Some(cleanup) = session_cleanup_info {
596            const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
597            let cleanup_session_id = cleanup.session_id.clone();
598            match tokio::time::timeout(
599                SESSION_CLEANUP_TIMEOUT,
600                cleanup.client.delete_session(
601                    cleanup.uri,
602                    cleanup.session_id,
603                    cleanup.auth_header,
604                    cleanup.protocol_headers,
605                ),
606            )
607            .await
608            {
609                Ok(Ok(_)) => {
610                    tracing::info!(
611                        session_id = cleanup_session_id.as_ref(),
612                        "delete session success"
613                    )
614                }
615                Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
616                    tracing::info!(
617                        session_id = cleanup_session_id.as_ref(),
618                        "server doesn't support delete session"
619                    )
620                }
621                Ok(Err(e)) => {
622                    tracing::error!(
623                        session_id = cleanup_session_id.as_ref(),
624                        "fail to delete session: {e}"
625                    );
626                }
627                Err(_elapsed) => {
628                    tracing::warn!(
629                        session_id = cleanup_session_id.as_ref(),
630                        "session cleanup timed out after {:?}",
631                        SESSION_CLEANUP_TIMEOUT
632                    );
633                }
634            }
635        }
636
637        loop_result
638    }
639}
640
641/// A client-agnostic HTTP transport for RMCP that supports streaming responses.
642///
643/// This transport allows you to choose your preferred HTTP client implementation
644/// by implementing the [`StreamableHttpClient`] trait. The transport handles
645/// session management, SSE streaming, and automatic reconnection.
646///
647/// # Usage
648///
649/// ## Using reqwest
650///
651/// ```rust,no_run
652/// use rmcp::transport::StreamableHttpClientTransport;
653///
654/// // Enable the reqwest feature in Cargo.toml:
655/// // rmcp = { version = "0.5", features = ["transport-streamable-http-client-reqwest"] }
656///
657/// let transport = StreamableHttpClientTransport::from_uri("http://localhost:8000/mcp");
658/// ```
659///
660/// ## Using a custom HTTP client
661///
662/// ```rust,no_run
663/// use rmcp::transport::streamable_http_client::{
664///     StreamableHttpClient,
665///     StreamableHttpClientTransport,
666///     StreamableHttpClientTransportConfig
667/// };
668/// use std::sync::Arc;
669/// use std::collections::HashMap;
670/// use futures::stream::BoxStream;
671/// use rmcp::model::ClientJsonRpcMessage;
672/// use http::{HeaderName, HeaderValue};
673/// use sse_stream::{Sse, Error as SseError};
674///
675/// #[derive(Clone)]
676/// struct MyHttpClient;
677///
678/// #[derive(Debug, thiserror::Error)]
679/// struct MyError;
680///
681/// impl std::fmt::Display for MyError {
682///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
683///         write!(f, "MyError")
684///     }
685/// }
686///
687/// impl StreamableHttpClient for MyHttpClient {
688///     type Error = MyError;
689///
690///     async fn post_message(
691///         &self,
692///         _uri: Arc<str>,
693///         _message: ClientJsonRpcMessage,
694///         _session_id: Option<Arc<str>>,
695///         _auth_header: Option<String>,
696///         _custom_headers: HashMap<HeaderName, HeaderValue>,
697///     ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
698///         todo!()
699///     }
700///
701///     async fn delete_session(
702///         &self,
703///         _uri: Arc<str>,
704///         _session_id: Arc<str>,
705///         _auth_header: Option<String>,
706///         _custom_headers: HashMap<HeaderName, HeaderValue>,
707///     ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
708///         todo!()
709///     }
710///
711///     async fn get_stream(
712///         &self,
713///         _uri: Arc<str>,
714///         _session_id: Arc<str>,
715///         _last_event_id: Option<String>,
716///         _auth_header: Option<String>,
717///         _custom_headers: HashMap<HeaderName, HeaderValue>,
718///     ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
719///         todo!()
720///     }
721/// }
722///
723/// let transport = StreamableHttpClientTransport::with_client(
724///     MyHttpClient,
725///     StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp")
726/// );
727/// ```
728///
729/// # Feature Flags
730///
731/// - `transport-streamable-http-client`: Base feature providing the generic transport infrastructure
732/// - `transport-streamable-http-client-reqwest`: Includes reqwest HTTP client support with convenience methods
733pub type StreamableHttpClientTransport<C> = WorkerTransport<StreamableHttpClientWorker<C>>;
734
735impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
736    /// Creates a new transport with a custom HTTP client implementation.
737    ///
738    /// This method allows you to use any HTTP client that implements the [`StreamableHttpClient`] trait.
739    /// Use this when you want to use a custom HTTP client or when the reqwest feature is not enabled.
740    ///
741    /// # Arguments
742    ///
743    /// * `client` - Your HTTP client implementation
744    /// * `config` - Transport configuration including the server URI
745    ///
746    /// # Example
747    ///
748    /// ```rust,no_run
749    /// use rmcp::transport::streamable_http_client::{
750    ///     StreamableHttpClient,
751    ///     StreamableHttpClientTransport,
752    ///     StreamableHttpClientTransportConfig
753    /// };
754    /// use std::sync::Arc;
755    /// use std::collections::HashMap;
756    /// use futures::stream::BoxStream;
757    /// use rmcp::model::ClientJsonRpcMessage;
758    /// use http::{HeaderName, HeaderValue};
759    /// use sse_stream::{Sse, Error as SseError};
760    ///
761    /// // Define your custom client
762    /// #[derive(Clone)]
763    /// struct MyHttpClient;
764    ///
765    /// #[derive(Debug, thiserror::Error)]
766    /// struct MyError;
767    ///
768    /// impl std::fmt::Display for MyError {
769    ///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
770    ///         write!(f, "MyError")
771    ///     }
772    /// }
773    ///
774    /// impl StreamableHttpClient for MyHttpClient {
775    ///     type Error = MyError;
776    ///
777    ///     async fn post_message(
778    ///         &self,
779    ///         _uri: Arc<str>,
780    ///         _message: ClientJsonRpcMessage,
781    ///         _session_id: Option<Arc<str>>,
782    ///         _auth_header: Option<String>,
783    ///         _custom_headers: HashMap<HeaderName, HeaderValue>,
784    ///     ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
785    ///         todo!()
786    ///     }
787    ///
788    ///     async fn delete_session(
789    ///         &self,
790    ///         _uri: Arc<str>,
791    ///         _session_id: Arc<str>,
792    ///         _auth_header: Option<String>,
793    ///         _custom_headers: HashMap<HeaderName, HeaderValue>,
794    ///     ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
795    ///         todo!()
796    ///     }
797    ///
798    ///     async fn get_stream(
799    ///         &self,
800    ///         _uri: Arc<str>,
801    ///         _session_id: Arc<str>,
802    ///         _last_event_id: Option<String>,
803    ///         _auth_header: Option<String>,
804    ///         _custom_headers: HashMap<HeaderName, HeaderValue>,
805    ///     ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
806    ///         todo!()
807    ///     }
808    /// }
809    ///
810    /// let transport = StreamableHttpClientTransport::with_client(
811    ///     MyHttpClient,
812    ///     StreamableHttpClientTransportConfig::with_uri("http://localhost:8000/mcp")
813    /// );
814    /// ```
815    pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self {
816        let worker = StreamableHttpClientWorker::new(client, config);
817        WorkerTransport::spawn(worker)
818    }
819}
820#[derive(Debug, Clone)]
821pub struct StreamableHttpClientTransportConfig {
822    pub uri: Arc<str>,
823    pub retry_config: Arc<dyn SseRetryPolicy>,
824    pub channel_buffer_capacity: usize,
825    /// if true, the transport will not require a session to be established
826    pub allow_stateless: bool,
827    /// The value to send in the authorization header
828    pub auth_header: Option<String>,
829    /// Custom HTTP headers to include with every request
830    pub custom_headers: HashMap<HeaderName, HeaderValue>,
831}
832
833impl StreamableHttpClientTransportConfig {
834    pub fn with_uri(uri: impl Into<Arc<str>>) -> Self {
835        Self {
836            uri: uri.into(),
837            ..Default::default()
838        }
839    }
840
841    /// Set the authorization header to send with requests
842    ///
843    /// # Arguments
844    ///
845    /// * `value` - A bearer token without the `Bearer ` prefix
846    pub fn auth_header<T: Into<String>>(mut self, value: T) -> Self {
847        // set our authorization header
848        self.auth_header = Some(value.into());
849        self
850    }
851
852    /// Set custom HTTP headers to include with every request
853    ///
854    /// # Arguments
855    ///
856    /// * `custom_headers` - A HashMap of header names to header values
857    ///
858    /// # Example
859    ///
860    /// ```rust,no_run
861    /// use std::collections::HashMap;
862    /// use http::{HeaderName, HeaderValue};
863    /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
864    ///
865    /// let mut headers = HashMap::new();
866    /// headers.insert(
867    ///     HeaderName::from_static("x-custom-header"),
868    ///     HeaderValue::from_static("custom-value")
869    /// );
870    ///
871    /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
872    ///     .custom_headers(headers);
873    /// ```
874    pub fn custom_headers(mut self, custom_headers: HashMap<HeaderName, HeaderValue>) -> Self {
875        self.custom_headers = custom_headers;
876        self
877    }
878}
879
880impl Default for StreamableHttpClientTransportConfig {
881    fn default() -> Self {
882        Self {
883            uri: "localhost".into(),
884            retry_config: Arc::new(ExponentialBackoff::default()),
885            channel_buffer_capacity: 16,
886            allow_stateless: true,
887            auth_header: None,
888            custom_headers: HashMap::new(),
889        }
890    }
891}