Skip to main content

agenterra_rmcp/transport/
streamable_http_client.rs

1use std::{borrow::Cow, sync::Arc, time::Duration};
2
3use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4pub use sse_stream::Error as SseError;
5use sse_stream::Sse;
6use thiserror::Error;
7use tokio_util::sync::CancellationToken;
8
9use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
10use crate::{
11    RoleClient,
12    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
13    transport::{
14        common::client_side_sse::SseAutoReconnectStream,
15        worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
16    },
17};
18
19type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
20
21#[derive(Error, Debug)]
22pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
23    #[error("SSE error: {0}")]
24    Sse(#[from] SseError),
25    #[error("Io error: {0}")]
26    Io(#[from] std::io::Error),
27    #[error("Client error: {0}")]
28    Client(E),
29    #[error("unexpected end of stream")]
30    UnexpectedEndOfStream,
31    #[error("unexpected server response: {0}")]
32    UnexpectedServerResponse(Cow<'static, str>),
33    #[error("Unexpected content type: {0:?}")]
34    UnexpectedContentType(Option<String>),
35    #[error("Server does not support SSE")]
36    SeverDoesNotSupportSse,
37    #[error("Server does not support delete session")]
38    SeverDoesNotSupportDeleteSession,
39    #[error("Tokio join error: {0}")]
40    TokioJoinError(#[from] tokio::task::JoinError),
41    #[error("Deserialize error: {0}")]
42    Deserialize(#[from] serde_json::Error),
43    #[error("Transport channel closed")]
44    TransportChannelClosed,
45    #[cfg(feature = "auth")]
46    #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
47    #[error("Auth error: {0}")]
48    Auth(#[from] crate::transport::auth::AuthError),
49}
50
51impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
52    fn from(e: reqwest::Error) -> Self {
53        StreamableHttpError::Client(e)
54    }
55}
56
57pub enum StreamableHttpPostResponse {
58    Accepted,
59    Json(ServerJsonRpcMessage, Option<String>),
60    Sse(BoxedSseStream, Option<String>),
61}
62
63impl std::fmt::Debug for StreamableHttpPostResponse {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::Accepted => write!(f, "Accepted"),
67            Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(),
68            Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(),
69        }
70    }
71}
72
73impl StreamableHttpPostResponse {
74    pub async fn expect_initialized<E>(
75        self,
76    ) -> Result<(ServerJsonRpcMessage, Option<String>), StreamableHttpError<E>>
77    where
78        E: std::error::Error + Send + Sync + 'static,
79    {
80        match self {
81            Self::Json(message, session_id) => Ok((message, session_id)),
82            Self::Sse(mut stream, session_id) => {
83                let event =
84                    stream
85                        .next()
86                        .await
87                        .ok_or(StreamableHttpError::UnexpectedServerResponse(
88                            "empty sse stream".into(),
89                        ))??;
90                let message: ServerJsonRpcMessage =
91                    serde_json::from_str(&event.data.unwrap_or_default())?;
92                Ok((message, session_id))
93            }
94            _ => Err(StreamableHttpError::UnexpectedServerResponse(
95                "expect initialized, accepted".into(),
96            )),
97        }
98    }
99
100    pub fn expect_json<E>(self) -> Result<ServerJsonRpcMessage, StreamableHttpError<E>>
101    where
102        E: std::error::Error + Send + Sync + 'static,
103    {
104        match self {
105            Self::Json(message, ..) => Ok(message),
106            got => Err(StreamableHttpError::UnexpectedServerResponse(
107                format!("expect json, got {got:?}").into(),
108            )),
109        }
110    }
111
112    pub fn expect_accepted<E>(self) -> Result<(), StreamableHttpError<E>>
113    where
114        E: std::error::Error + Send + Sync + 'static,
115    {
116        match self {
117            Self::Accepted => Ok(()),
118            got => Err(StreamableHttpError::UnexpectedServerResponse(
119                format!("expect accepted, got {got:?}").into(),
120            )),
121        }
122    }
123}
124
125pub trait StreamableHttpClient: Clone + Send + 'static {
126    type Error: std::error::Error + Send + Sync + 'static;
127    fn post_message(
128        &self,
129        uri: Arc<str>,
130        message: ClientJsonRpcMessage,
131        session_id: Option<Arc<str>>,
132        auth_header: Option<String>,
133    ) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
134    + Send
135    + '_;
136    fn delete_session(
137        &self,
138        uri: Arc<str>,
139        session_id: Arc<str>,
140        auth_header: Option<String>,
141    ) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
142    fn get_stream(
143        &self,
144        uri: Arc<str>,
145        session_id: Arc<str>,
146        last_event_id: Option<String>,
147        auth_header: Option<String>,
148    ) -> impl Future<
149        Output = Result<
150            BoxStream<'static, Result<Sse, SseError>>,
151            StreamableHttpError<Self::Error>,
152        >,
153    > + Send
154    + '_;
155}
156
157pub struct RetryConfig {
158    pub max_times: Option<usize>,
159    pub min_duration: Duration,
160}
161
162struct StreamableHttpClientReconnect<C> {
163    pub client: C,
164    pub session_id: Arc<str>,
165    pub uri: Arc<str>,
166}
167
168impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
169    type Error = StreamableHttpError<C::Error>;
170    type Future = BoxFuture<'static, Result<BoxedSseStream, Self::Error>>;
171    fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future {
172        let client = self.client.clone();
173        let uri = self.uri.clone();
174        let session_id = self.session_id.clone();
175        let last_event_id = last_event_id.map(|s| s.to_owned());
176        Box::pin(async move {
177            client
178                .get_stream(uri, session_id, last_event_id, None)
179                .await
180        })
181    }
182}
183
184#[derive(Debug, Clone, Default)]
185pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
186    pub client: C,
187    pub config: StreamableHttpClientTransportConfig,
188}
189
190impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
191    pub fn new_simple(url: impl Into<Arc<str>>) -> Self {
192        Self {
193            client: C::default(),
194            config: StreamableHttpClientTransportConfig {
195                uri: url.into(),
196                ..Default::default()
197            },
198        }
199    }
200}
201
202impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
203    pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
204        Self { client, config }
205    }
206}
207
208impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
209    async fn execute_sse_stream(
210        sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
211        + Send
212        + 'static,
213        sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
214        close_on_response: bool,
215        ct: CancellationToken,
216    ) -> Result<(), StreamableHttpError<C::Error>> {
217        let mut sse_stream = std::pin::pin!(sse_stream);
218        loop {
219            let message = tokio::select! {
220                event = sse_stream.next() => {
221                    event
222                }
223                _ = ct.cancelled() => {
224                    tracing::debug!("cancelled");
225                    break;
226                }
227            };
228            let Some(message) = message.transpose()? else {
229                break;
230            };
231            let is_response = matches!(message, ServerJsonRpcMessage::Response(_));
232            let yield_result = sse_worker_tx.send(message).await;
233            if yield_result.is_err() {
234                tracing::trace!("streamable http transport worker dropped, exiting");
235                break;
236            }
237            if close_on_response && is_response {
238                tracing::debug!("got response, closing sse stream");
239                break;
240            }
241        }
242        Ok(())
243    }
244}
245
246impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
247    type Role = RoleClient;
248    type Error = StreamableHttpError<C::Error>;
249    fn err_closed() -> Self::Error {
250        StreamableHttpError::TransportChannelClosed
251    }
252    fn err_join(e: tokio::task::JoinError) -> Self::Error {
253        StreamableHttpError::TokioJoinError(e)
254    }
255    fn config(&self) -> super::worker::WorkerConfig {
256        super::worker::WorkerConfig {
257            name: Some("StreamableHttpClientWorker".into()),
258            channel_buffer_capacity: self.config.channel_buffer_capacity,
259        }
260    }
261    async fn run(
262        self,
263        mut context: super::worker::WorkerContext<Self>,
264    ) -> Result<(), WorkerQuitReason> {
265        let channel_buffer_capacity = self.config.channel_buffer_capacity;
266        let (sse_worker_tx, mut sse_worker_rx) =
267            tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
268        let config = self.config.clone();
269        let transport_task_ct = context.cancellation_token.clone();
270        let _drop_guard = transport_task_ct.clone().drop_guard();
271        let WorkerSendRequest {
272            responder,
273            message: initialize_request,
274        } = context.recv_from_handler().await?;
275        let _ = responder.send(Ok(()));
276        let (message, session_id) = self
277            .client
278            .post_message(config.uri.clone(), initialize_request, None, None)
279            .await
280            .map_err(WorkerQuitReason::fatal_context("send initialize request"))?
281            .expect_initialized::<Self::Error>()
282            .await
283            .map_err(WorkerQuitReason::fatal_context(
284                "process initialize response",
285            ))?;
286        let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
287            Some(session_id.into())
288        } else {
289            if !self.config.allow_stateless {
290                return Err(WorkerQuitReason::fatal(
291                    "missing session id in initialize response",
292                    "process initialize response",
293                ));
294            }
295            None
296        };
297        // delete session when drop guard is dropped
298        if let Some(session_id) = &session_id {
299            let ct = transport_task_ct.clone();
300            let client = self.client.clone();
301            let session_id = session_id.clone();
302            let url = config.uri.clone();
303            tokio::spawn(async move {
304                ct.cancelled().await;
305                let delete_session_result =
306                    client.delete_session(url, session_id.clone(), None).await;
307                match delete_session_result {
308                    Ok(_) => {
309                        tracing::info!(session_id = session_id.as_ref(), "delete session success")
310                    }
311                    Err(StreamableHttpError::SeverDoesNotSupportDeleteSession) => {
312                        tracing::info!(
313                            session_id = session_id.as_ref(),
314                            "server doesn't support delete session"
315                        )
316                    }
317                    Err(e) => {
318                        tracing::error!(
319                            session_id = session_id.as_ref(),
320                            "fail to delete session: {e}"
321                        );
322                    }
323                };
324            });
325        }
326
327        context.send_to_handler(message).await?;
328        let initialized_notification = context.recv_from_handler().await?;
329        // expect a initialized response
330        self.client
331            .post_message(
332                config.uri.clone(),
333                initialized_notification.message,
334                session_id.clone(),
335                None,
336            )
337            .await
338            .map_err(WorkerQuitReason::fatal_context(
339                "send initialized notification",
340            ))?
341            .expect_accepted::<Self::Error>()
342            .map_err(WorkerQuitReason::fatal_context(
343                "process initialized notification response",
344            ))?;
345        let _ = initialized_notification.responder.send(Ok(()));
346        enum Event<W: Worker, E: std::error::Error + Send + Sync + 'static> {
347            ClientMessage(WorkerSendRequest<W>),
348            ServerMessage(ServerJsonRpcMessage),
349            StreamResult(Result<(), StreamableHttpError<E>>),
350        }
351        let mut streams = tokio::task::JoinSet::new();
352        if let Some(session_id) = &session_id {
353            match self
354                .client
355                .get_stream(config.uri.clone(), session_id.clone(), None, None)
356                .await
357            {
358                Ok(stream) => {
359                    let sse_stream = SseAutoReconnectStream::new(
360                        stream,
361                        StreamableHttpClientReconnect {
362                            client: self.client.clone(),
363                            session_id: session_id.clone(),
364                            uri: config.uri.clone(),
365                        },
366                        self.config.retry_config.clone(),
367                    );
368                    streams.spawn(Self::execute_sse_stream(
369                        sse_stream,
370                        sse_worker_tx.clone(),
371                        false,
372                        transport_task_ct.child_token(),
373                    ));
374                    tracing::debug!("got common stream");
375                }
376                Err(StreamableHttpError::SeverDoesNotSupportSse) => {
377                    tracing::debug!("server doesn't support sse, skip common stream");
378                }
379                Err(e) => {
380                    // fail to get common stream
381                    tracing::error!("fail to get common stream: {e}");
382                    return Err(WorkerQuitReason::fatal(
383                        "fail to get general purpose event stream",
384                        "get general purpose event stream",
385                    ));
386                }
387            }
388        }
389        loop {
390            let event = tokio::select! {
391                _ = transport_task_ct.cancelled() => {
392                    tracing::debug!("cancelled");
393                    return Err(WorkerQuitReason::Cancelled);
394                }
395                message = context.recv_from_handler() => {
396                    let message = message?;
397                    Event::ClientMessage(message)
398                },
399                message = sse_worker_rx.recv() => {
400                    let Some(message) = message else {
401                        tracing::trace!("transport dropped, exiting");
402                        return Err(WorkerQuitReason::HandlerTerminated);
403                    };
404                    Event::ServerMessage(message)
405                },
406                terminated_stream = streams.join_next(), if !streams.is_empty() => {
407                    match terminated_stream {
408                        Some(result) => {
409                            Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity))
410                        }
411                        None => {
412                            continue
413                        }
414                    }
415                }
416            };
417            match event {
418                Event::ClientMessage(send_request) => {
419                    let WorkerSendRequest { message, responder } = send_request;
420                    let response = self
421                        .client
422                        .post_message(config.uri.clone(), message, session_id.clone(), None)
423                        .await;
424                    let send_result = match response {
425                        Err(e) => Err(e),
426                        Ok(StreamableHttpPostResponse::Accepted) => {
427                            tracing::trace!("client message accepted");
428                            Ok(())
429                        }
430                        Ok(StreamableHttpPostResponse::Json(message, ..)) => {
431                            context.send_to_handler(message).await?;
432                            Ok(())
433                        }
434                        Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
435                            if let Some(session_id) = &session_id {
436                                let sse_stream = SseAutoReconnectStream::new(
437                                    stream,
438                                    StreamableHttpClientReconnect {
439                                        client: self.client.clone(),
440                                        session_id: session_id.clone(),
441                                        uri: config.uri.clone(),
442                                    },
443                                    self.config.retry_config.clone(),
444                                );
445                                streams.spawn(Self::execute_sse_stream(
446                                    sse_stream,
447                                    sse_worker_tx.clone(),
448                                    true,
449                                    transport_task_ct.child_token(),
450                                ));
451                            } else {
452                                let sse_stream = SseAutoReconnectStream::never_reconnect(
453                                    stream,
454                                    StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
455                                );
456                                streams.spawn(Self::execute_sse_stream(
457                                    sse_stream,
458                                    sse_worker_tx.clone(),
459                                    true,
460                                    transport_task_ct.child_token(),
461                                ));
462                            }
463                            tracing::trace!("got new sse stream");
464                            Ok(())
465                        }
466                    };
467                    let _ = responder.send(send_result);
468                }
469                Event::ServerMessage(json_rpc_message) => {
470                    // send the message to the handler
471                    context.send_to_handler(json_rpc_message).await?;
472                }
473                Event::StreamResult(result) => {
474                    if result.is_err() {
475                        tracing::warn!(
476                            "sse client event stream terminated with error: {:?}",
477                            result
478                        );
479                    }
480                }
481            }
482        }
483    }
484}
485
486pub type StreamableHttpClientTransport<C> = WorkerTransport<StreamableHttpClientWorker<C>>;
487
488impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
489    pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self {
490        let worker = StreamableHttpClientWorker::new(client, config);
491        WorkerTransport::spawn(worker)
492    }
493}
494#[derive(Debug, Clone)]
495pub struct StreamableHttpClientTransportConfig {
496    pub uri: Arc<str>,
497    pub retry_config: Arc<dyn SseRetryPolicy>,
498    pub channel_buffer_capacity: usize,
499    /// if true, the transport will not require a session to be established
500    pub allow_stateless: bool,
501}
502
503impl StreamableHttpClientTransportConfig {
504    pub fn with_uri(uri: impl Into<Arc<str>>) -> Self {
505        Self {
506            uri: uri.into(),
507            ..Default::default()
508        }
509    }
510}
511
512impl Default for StreamableHttpClientTransportConfig {
513    fn default() -> Self {
514        Self {
515            uri: "localhost".into(),
516            retry_config: Arc::new(ExponentialBackoff::default()),
517            channel_buffer_capacity: 16,
518            allow_stateless: true,
519        }
520    }
521}