mcp_client/transport/
sse.rs1use crate::transport::{Error, PendingRequests, TransportMessage};
2use async_trait::async_trait;
3use eventsource_client::{Client, SSE};
4use futures::TryStreamExt;
5use mcp_spec::protocol::{JsonRpcMessage, JsonRpcRequest};
6use reqwest::Client as HttpClient;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{mpsc, RwLock};
10use tokio::time::{timeout, Duration};
11use tracing::warn;
12use url::Url;
13
14use super::{send_message, Transport, TransportHandle};
15
16const ENDPOINT_TIMEOUT_SECS: u64 = 5;
18
19pub struct SseActor {
23 receiver: mpsc::Receiver<TransportMessage>,
25 pending_requests: Arc<PendingRequests>,
27 sse_url: String,
29 http_client: HttpClient,
31 post_endpoint: Arc<RwLock<Option<String>>>,
33}
34
35impl SseActor {
36 pub fn new(
37 receiver: mpsc::Receiver<TransportMessage>,
38 pending_requests: Arc<PendingRequests>,
39 sse_url: String,
40 post_endpoint: Arc<RwLock<Option<String>>>,
41 ) -> Self {
42 Self {
43 receiver,
44 pending_requests,
45 sse_url,
46 post_endpoint,
47 http_client: HttpClient::new(),
48 }
49 }
50
51 pub async fn run(self) {
55 tokio::join!(
56 Self::handle_incoming_messages(
57 self.sse_url.clone(),
58 Arc::clone(&self.pending_requests),
59 Arc::clone(&self.post_endpoint)
60 ),
61 Self::handle_outgoing_messages(
62 self.receiver,
63 self.http_client.clone(),
64 Arc::clone(&self.post_endpoint),
65 Arc::clone(&self.pending_requests),
66 )
67 );
68 }
69
70 async fn handle_incoming_messages(
75 sse_url: String,
76 pending_requests: Arc<PendingRequests>,
77 post_endpoint: Arc<RwLock<Option<String>>>,
78 ) {
79 let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
80 Ok(builder) => builder.build(),
81 Err(e) => {
82 pending_requests.clear().await;
83 warn!("Failed to connect SSE client: {}", e);
84 return;
85 }
86 };
87 let mut stream = client.stream();
88
89 while let Ok(Some(event)) = stream.try_next().await {
91 match event {
92 SSE::Event(e) if e.event_type == "endpoint" => {
93 let base_url = Url::parse(&sse_url).expect("Invalid base URL");
95 let post_url = base_url
96 .join(&e.data)
97 .expect("Failed to resolve endpoint URL");
98
99 tracing::debug!("Discovered SSE POST endpoint: {}", post_url);
100 *post_endpoint.write().await = Some(post_url.to_string());
101 break;
102 }
103 _ => continue,
104 }
105 }
106
107 while let Ok(Some(event)) = stream.try_next().await {
109 match event {
110 SSE::Event(e) if e.event_type == "message" => {
111 match serde_json::from_str::<JsonRpcMessage>(&e.data) {
113 Ok(message) => {
114 match &message {
115 JsonRpcMessage::Response(response) => {
116 if let Some(id) = &response.id {
117 pending_requests
118 .respond(&id.to_string(), Ok(message))
119 .await;
120 }
121 }
122 JsonRpcMessage::Error(error) => {
123 if let Some(id) = &error.id {
124 pending_requests
125 .respond(&id.to_string(), Ok(message))
126 .await;
127 }
128 }
129 _ => {} }
131 }
132 Err(err) => {
133 warn!("Failed to parse SSE message: {err}");
134 }
135 }
136 }
137 _ => { }
138 }
139 }
140
141 tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
143 pending_requests.clear().await;
144 }
145
146 async fn handle_outgoing_messages(
150 mut receiver: mpsc::Receiver<TransportMessage>,
151 http_client: HttpClient,
152 post_endpoint: Arc<RwLock<Option<String>>>,
153 pending_requests: Arc<PendingRequests>,
154 ) {
155 while let Some(transport_msg) = receiver.recv().await {
156 let post_url = match post_endpoint.read().await.as_ref() {
157 Some(url) => url.clone(),
158 None => {
159 if let Some(response_tx) = transport_msg.response_tx {
160 let _ = response_tx.send(Err(Error::NotConnected));
161 }
162 continue;
163 }
164 };
165
166 let message_str = match serde_json::to_string(&transport_msg.message) {
168 Ok(s) => s,
169 Err(e) => {
170 if let Some(tx) = transport_msg.response_tx {
171 let _ = tx.send(Err(Error::Serialization(e)));
172 }
173 continue;
174 }
175 };
176
177 if let Some(response_tx) = transport_msg.response_tx {
179 if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
180 &transport_msg.message
181 {
182 pending_requests.insert(id.to_string(), response_tx).await;
183 }
184 }
185
186 match http_client
188 .post(&post_url)
189 .header("Content-Type", "application/json")
190 .body(message_str)
191 .send()
192 .await
193 {
194 Ok(resp) => {
195 if !resp.status().is_success() {
196 let err = Error::HttpError {
197 status: resp.status().as_u16(),
198 message: resp.status().to_string(),
199 };
200 warn!("HTTP request returned error: {err}");
201 }
204 }
205 Err(e) => {
206 warn!("HTTP POST failed: {e}");
207 }
209 }
210 }
211
212 tracing::error!("SseActor: outgoing message loop ended. Clearing pending requests.");
214 pending_requests.clear().await;
215 }
216}
217
218#[derive(Clone)]
219pub struct SseTransportHandle {
220 sender: mpsc::Sender<TransportMessage>,
221}
222
223#[async_trait::async_trait]
224impl TransportHandle for SseTransportHandle {
225 async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
226 send_message(&self.sender, message).await
227 }
228}
229
230#[derive(Clone)]
231pub struct SseTransport {
232 sse_url: String,
233 env: HashMap<String, String>,
234}
235
236impl SseTransport {
238 pub fn new<S: Into<String>>(sse_url: S, env: HashMap<String, String>) -> Self {
239 Self {
240 sse_url: sse_url.into(),
241 env,
242 }
243 }
244
245 async fn wait_for_endpoint(
247 post_endpoint: Arc<RwLock<Option<String>>>,
248 ) -> Result<String, Error> {
249 let check_interval = Duration::from_millis(100);
251 let mut attempts = 0;
252 let max_attempts = 10;
253
254 while attempts < max_attempts {
255 if let Some(url) = post_endpoint.read().await.clone() {
256 return Ok(url);
257 }
258 tokio::time::sleep(check_interval).await;
259 attempts += 1;
260 }
261 Err(Error::SseConnection("No endpoint discovered".to_string()))
262 }
263}
264
265#[async_trait]
266impl Transport for SseTransport {
267 type Handle = SseTransportHandle;
268
269 async fn start(&self) -> Result<Self::Handle, Error> {
270 for (key, value) in &self.env {
272 std::env::set_var(key, value);
273 }
274
275 let (tx, rx) = mpsc::channel(32);
277
278 let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
279 let post_endpoint_clone = Arc::clone(&post_endpoint);
280
281 let actor = SseActor::new(
283 rx,
284 Arc::new(PendingRequests::new()),
285 self.sse_url.clone(),
286 post_endpoint,
287 );
288
289 tokio::spawn(actor.run());
291
292 match timeout(
294 Duration::from_secs(ENDPOINT_TIMEOUT_SECS),
295 Self::wait_for_endpoint(post_endpoint_clone),
296 )
297 .await
298 {
299 Ok(_) => Ok(SseTransportHandle { sender: tx }),
300 Err(e) => Err(Error::SseConnection(e.to_string())),
301 }
302 }
303
304 async fn close(&self) -> Result<(), Error> {
305 Ok(())
308 }
309}