1use std::{borrow::Cow, sync::Arc, time::Duration};
2
3use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4pub use sse_stream::Error as SseError;
5use sse_stream::Sse;
6use thiserror::Error;
7use tokio_util::sync::CancellationToken;
8
9use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
10use crate::{
11 RoleClient,
12 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
13 transport::{
14 common::client_side_sse::SseAutoReconnectStream,
15 worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
16 },
17};
18
19type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
20
21#[derive(Error, Debug)]
22pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
23 #[error("SSE error: {0}")]
24 Sse(#[from] SseError),
25 #[error("Io error: {0}")]
26 Io(#[from] std::io::Error),
27 #[error("Client error: {0}")]
28 Client(E),
29 #[error("unexpected end of stream")]
30 UnexpectedEndOfStream,
31 #[error("unexpected server response: {0}")]
32 UnexpectedServerResponse(Cow<'static, str>),
33 #[error("Unexpected content type: {0:?}")]
34 UnexpectedContentType(Option<String>),
35 #[error("Server does not support SSE")]
36 SeverDoesNotSupportSse,
37 #[error("Server does not support delete session")]
38 SeverDoesNotSupportDeleteSession,
39 #[error("Tokio join error: {0}")]
40 TokioJoinError(#[from] tokio::task::JoinError),
41 #[error("Deserialize error: {0}")]
42 Deserialize(#[from] serde_json::Error),
43 #[error("Transport channel closed")]
44 TransportChannelClosed,
45 #[cfg(feature = "auth")]
46 #[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
47 #[error("Auth error: {0}")]
48 Auth(#[from] crate::transport::auth::AuthError),
49}
50
51impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
52 fn from(e: reqwest::Error) -> Self {
53 StreamableHttpError::Client(e)
54 }
55}
56
57pub enum StreamableHttpPostResponse {
58 Accepted,
59 Json(ServerJsonRpcMessage, Option<String>),
60 Sse(BoxedSseStream, Option<String>),
61}
62
63impl std::fmt::Debug for StreamableHttpPostResponse {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 Self::Accepted => write!(f, "Accepted"),
67 Self::Json(arg0, arg1) => f.debug_tuple("Json").field(arg0).field(arg1).finish(),
68 Self::Sse(_, arg1) => f.debug_tuple("Sse").field(arg1).finish(),
69 }
70 }
71}
72
73impl StreamableHttpPostResponse {
74 pub async fn expect_initialized<E>(
75 self,
76 ) -> Result<(ServerJsonRpcMessage, Option<String>), StreamableHttpError<E>>
77 where
78 E: std::error::Error + Send + Sync + 'static,
79 {
80 match self {
81 Self::Json(message, session_id) => Ok((message, session_id)),
82 Self::Sse(mut stream, session_id) => {
83 let event =
84 stream
85 .next()
86 .await
87 .ok_or(StreamableHttpError::UnexpectedServerResponse(
88 "empty sse stream".into(),
89 ))??;
90 let message: ServerJsonRpcMessage =
91 serde_json::from_str(&event.data.unwrap_or_default())?;
92 Ok((message, session_id))
93 }
94 _ => Err(StreamableHttpError::UnexpectedServerResponse(
95 "expect initialized, accepted".into(),
96 )),
97 }
98 }
99
100 pub fn expect_json<E>(self) -> Result<ServerJsonRpcMessage, StreamableHttpError<E>>
101 where
102 E: std::error::Error + Send + Sync + 'static,
103 {
104 match self {
105 Self::Json(message, ..) => Ok(message),
106 got => Err(StreamableHttpError::UnexpectedServerResponse(
107 format!("expect json, got {got:?}").into(),
108 )),
109 }
110 }
111
112 pub fn expect_accepted<E>(self) -> Result<(), StreamableHttpError<E>>
113 where
114 E: std::error::Error + Send + Sync + 'static,
115 {
116 match self {
117 Self::Accepted => Ok(()),
118 got => Err(StreamableHttpError::UnexpectedServerResponse(
119 format!("expect accepted, got {got:?}").into(),
120 )),
121 }
122 }
123}
124
125pub trait StreamableHttpClient: Clone + Send + 'static {
126 type Error: std::error::Error + Send + Sync + 'static;
127 fn post_message(
128 &self,
129 uri: Arc<str>,
130 message: ClientJsonRpcMessage,
131 session_id: Option<Arc<str>>,
132 auth_header: Option<String>,
133 ) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
134 + Send
135 + '_;
136 fn delete_session(
137 &self,
138 uri: Arc<str>,
139 session_id: Arc<str>,
140 auth_header: Option<String>,
141 ) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
142 fn get_stream(
143 &self,
144 uri: Arc<str>,
145 session_id: Arc<str>,
146 last_event_id: Option<String>,
147 auth_header: Option<String>,
148 ) -> impl Future<
149 Output = Result<
150 BoxStream<'static, Result<Sse, SseError>>,
151 StreamableHttpError<Self::Error>,
152 >,
153 > + Send
154 + '_;
155}
156
157pub struct RetryConfig {
158 pub max_times: Option<usize>,
159 pub min_duration: Duration,
160}
161
162struct StreamableHttpClientReconnect<C> {
163 pub client: C,
164 pub session_id: Arc<str>,
165 pub uri: Arc<str>,
166}
167
168impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
169 type Error = StreamableHttpError<C::Error>;
170 type Future = BoxFuture<'static, Result<BoxedSseStream, Self::Error>>;
171 fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future {
172 let client = self.client.clone();
173 let uri = self.uri.clone();
174 let session_id = self.session_id.clone();
175 let last_event_id = last_event_id.map(|s| s.to_owned());
176 Box::pin(async move {
177 client
178 .get_stream(uri, session_id, last_event_id, None)
179 .await
180 })
181 }
182}
183
184#[derive(Debug, Clone, Default)]
185pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
186 pub client: C,
187 pub config: StreamableHttpClientTransportConfig,
188}
189
190impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
191 pub fn new_simple(url: impl Into<Arc<str>>) -> Self {
192 Self {
193 client: C::default(),
194 config: StreamableHttpClientTransportConfig {
195 uri: url.into(),
196 ..Default::default()
197 },
198 }
199 }
200}
201
202impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
203 pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
204 Self { client, config }
205 }
206}
207
208impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
209 async fn execute_sse_stream(
210 sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
211 + Send
212 + 'static,
213 sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
214 close_on_response: bool,
215 ct: CancellationToken,
216 ) -> Result<(), StreamableHttpError<C::Error>> {
217 let mut sse_stream = std::pin::pin!(sse_stream);
218 loop {
219 let message = tokio::select! {
220 event = sse_stream.next() => {
221 event
222 }
223 _ = ct.cancelled() => {
224 tracing::debug!("cancelled");
225 break;
226 }
227 };
228 let Some(message) = message.transpose()? else {
229 break;
230 };
231 let is_response = matches!(message, ServerJsonRpcMessage::Response(_));
232 let yield_result = sse_worker_tx.send(message).await;
233 if yield_result.is_err() {
234 tracing::trace!("streamable http transport worker dropped, exiting");
235 break;
236 }
237 if close_on_response && is_response {
238 tracing::debug!("got response, closing sse stream");
239 break;
240 }
241 }
242 Ok(())
243 }
244}
245
246impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
247 type Role = RoleClient;
248 type Error = StreamableHttpError<C::Error>;
249 fn err_closed() -> Self::Error {
250 StreamableHttpError::TransportChannelClosed
251 }
252 fn err_join(e: tokio::task::JoinError) -> Self::Error {
253 StreamableHttpError::TokioJoinError(e)
254 }
255 fn config(&self) -> super::worker::WorkerConfig {
256 super::worker::WorkerConfig {
257 name: Some("StreamableHttpClientWorker".into()),
258 channel_buffer_capacity: self.config.channel_buffer_capacity,
259 }
260 }
261 async fn run(
262 self,
263 mut context: super::worker::WorkerContext<Self>,
264 ) -> Result<(), WorkerQuitReason> {
265 let channel_buffer_capacity = self.config.channel_buffer_capacity;
266 let (sse_worker_tx, mut sse_worker_rx) =
267 tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
268 let config = self.config.clone();
269 let transport_task_ct = context.cancellation_token.clone();
270 let _drop_guard = transport_task_ct.clone().drop_guard();
271 let WorkerSendRequest {
272 responder,
273 message: initialize_request,
274 } = context.recv_from_handler().await?;
275 let _ = responder.send(Ok(()));
276 let (message, session_id) = self
277 .client
278 .post_message(config.uri.clone(), initialize_request, None, None)
279 .await
280 .map_err(WorkerQuitReason::fatal_context("send initialize request"))?
281 .expect_initialized::<Self::Error>()
282 .await
283 .map_err(WorkerQuitReason::fatal_context(
284 "process initialize response",
285 ))?;
286 let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
287 Some(session_id.into())
288 } else {
289 if !self.config.allow_stateless {
290 return Err(WorkerQuitReason::fatal(
291 "missing session id in initialize response",
292 "process initialize response",
293 ));
294 }
295 None
296 };
297 if let Some(session_id) = &session_id {
299 let ct = transport_task_ct.clone();
300 let client = self.client.clone();
301 let session_id = session_id.clone();
302 let url = config.uri.clone();
303 tokio::spawn(async move {
304 ct.cancelled().await;
305 let delete_session_result =
306 client.delete_session(url, session_id.clone(), None).await;
307 match delete_session_result {
308 Ok(_) => {
309 tracing::info!(session_id = session_id.as_ref(), "delete session success")
310 }
311 Err(StreamableHttpError::SeverDoesNotSupportDeleteSession) => {
312 tracing::info!(
313 session_id = session_id.as_ref(),
314 "server doesn't support delete session"
315 )
316 }
317 Err(e) => {
318 tracing::error!(
319 session_id = session_id.as_ref(),
320 "fail to delete session: {e}"
321 );
322 }
323 };
324 });
325 }
326
327 context.send_to_handler(message).await?;
328 let initialized_notification = context.recv_from_handler().await?;
329 self.client
331 .post_message(
332 config.uri.clone(),
333 initialized_notification.message,
334 session_id.clone(),
335 None,
336 )
337 .await
338 .map_err(WorkerQuitReason::fatal_context(
339 "send initialized notification",
340 ))?
341 .expect_accepted::<Self::Error>()
342 .map_err(WorkerQuitReason::fatal_context(
343 "process initialized notification response",
344 ))?;
345 let _ = initialized_notification.responder.send(Ok(()));
346 enum Event<W: Worker, E: std::error::Error + Send + Sync + 'static> {
347 ClientMessage(WorkerSendRequest<W>),
348 ServerMessage(ServerJsonRpcMessage),
349 StreamResult(Result<(), StreamableHttpError<E>>),
350 }
351 let mut streams = tokio::task::JoinSet::new();
352 if let Some(session_id) = &session_id {
353 match self
354 .client
355 .get_stream(config.uri.clone(), session_id.clone(), None, None)
356 .await
357 {
358 Ok(stream) => {
359 let sse_stream = SseAutoReconnectStream::new(
360 stream,
361 StreamableHttpClientReconnect {
362 client: self.client.clone(),
363 session_id: session_id.clone(),
364 uri: config.uri.clone(),
365 },
366 self.config.retry_config.clone(),
367 );
368 streams.spawn(Self::execute_sse_stream(
369 sse_stream,
370 sse_worker_tx.clone(),
371 false,
372 transport_task_ct.child_token(),
373 ));
374 tracing::debug!("got common stream");
375 }
376 Err(StreamableHttpError::SeverDoesNotSupportSse) => {
377 tracing::debug!("server doesn't support sse, skip common stream");
378 }
379 Err(e) => {
380 tracing::error!("fail to get common stream: {e}");
382 return Err(WorkerQuitReason::fatal(
383 "fail to get general purpose event stream",
384 "get general purpose event stream",
385 ));
386 }
387 }
388 }
389 loop {
390 let event = tokio::select! {
391 _ = transport_task_ct.cancelled() => {
392 tracing::debug!("cancelled");
393 return Err(WorkerQuitReason::Cancelled);
394 }
395 message = context.recv_from_handler() => {
396 let message = message?;
397 Event::ClientMessage(message)
398 },
399 message = sse_worker_rx.recv() => {
400 let Some(message) = message else {
401 tracing::trace!("transport dropped, exiting");
402 return Err(WorkerQuitReason::HandlerTerminated);
403 };
404 Event::ServerMessage(message)
405 },
406 terminated_stream = streams.join_next(), if !streams.is_empty() => {
407 match terminated_stream {
408 Some(result) => {
409 Event::StreamResult(result.map_err(StreamableHttpError::TokioJoinError).and_then(std::convert::identity))
410 }
411 None => {
412 continue
413 }
414 }
415 }
416 };
417 match event {
418 Event::ClientMessage(send_request) => {
419 let WorkerSendRequest { message, responder } = send_request;
420 let response = self
421 .client
422 .post_message(config.uri.clone(), message, session_id.clone(), None)
423 .await;
424 let send_result = match response {
425 Err(e) => Err(e),
426 Ok(StreamableHttpPostResponse::Accepted) => {
427 tracing::trace!("client message accepted");
428 Ok(())
429 }
430 Ok(StreamableHttpPostResponse::Json(message, ..)) => {
431 context.send_to_handler(message).await?;
432 Ok(())
433 }
434 Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
435 if let Some(session_id) = &session_id {
436 let sse_stream = SseAutoReconnectStream::new(
437 stream,
438 StreamableHttpClientReconnect {
439 client: self.client.clone(),
440 session_id: session_id.clone(),
441 uri: config.uri.clone(),
442 },
443 self.config.retry_config.clone(),
444 );
445 streams.spawn(Self::execute_sse_stream(
446 sse_stream,
447 sse_worker_tx.clone(),
448 true,
449 transport_task_ct.child_token(),
450 ));
451 } else {
452 let sse_stream = SseAutoReconnectStream::never_reconnect(
453 stream,
454 StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
455 );
456 streams.spawn(Self::execute_sse_stream(
457 sse_stream,
458 sse_worker_tx.clone(),
459 true,
460 transport_task_ct.child_token(),
461 ));
462 }
463 tracing::trace!("got new sse stream");
464 Ok(())
465 }
466 };
467 let _ = responder.send(send_result);
468 }
469 Event::ServerMessage(json_rpc_message) => {
470 context.send_to_handler(json_rpc_message).await?;
472 }
473 Event::StreamResult(result) => {
474 if result.is_err() {
475 tracing::warn!(
476 "sse client event stream terminated with error: {:?}",
477 result
478 );
479 }
480 }
481 }
482 }
483 }
484}
485
486pub type StreamableHttpClientTransport<C> = WorkerTransport<StreamableHttpClientWorker<C>>;
487
488impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
489 pub fn with_client(client: C, config: StreamableHttpClientTransportConfig) -> Self {
490 let worker = StreamableHttpClientWorker::new(client, config);
491 WorkerTransport::spawn(worker)
492 }
493}
494#[derive(Debug, Clone)]
495pub struct StreamableHttpClientTransportConfig {
496 pub uri: Arc<str>,
497 pub retry_config: Arc<dyn SseRetryPolicy>,
498 pub channel_buffer_capacity: usize,
499 pub allow_stateless: bool,
501}
502
503impl StreamableHttpClientTransportConfig {
504 pub fn with_uri(uri: impl Into<Arc<str>>) -> Self {
505 Self {
506 uri: uri.into(),
507 ..Default::default()
508 }
509 }
510}
511
512impl Default for StreamableHttpClientTransportConfig {
513 fn default() -> Self {
514 Self {
515 uri: "localhost".into(),
516 retry_config: Arc::new(ExponentialBackoff::default()),
517 channel_buffer_capacity: 16,
518 allow_stateless: true,
519 }
520 }
521}