mcp_client/transport/
sse.rs

1use crate::transport::{Error, PendingRequests, TransportMessage};
2use async_trait::async_trait;
3use eventsource_client::{Client, SSE};
4use futures::TryStreamExt;
5use mcp_spec::protocol::{JsonRpcMessage, JsonRpcRequest};
6use reqwest::Client as HttpClient;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::{mpsc, RwLock};
10use tokio::time::{timeout, Duration};
11use tracing::warn;
12use url::Url;
13
14use super::{send_message, Transport, TransportHandle};
15
16// Timeout for the endpoint discovery
17const ENDPOINT_TIMEOUT_SECS: u64 = 5;
18
19/// The SSE-based actor that continuously:
20/// - Reads incoming events from the SSE stream.
21/// - Sends outgoing messages via HTTP POST (once the post endpoint is known).
22pub struct SseActor {
23    /// Receives messages (requests/notifications) from the handle
24    receiver: mpsc::Receiver<TransportMessage>,
25    /// Map of request-id -> oneshot sender
26    pending_requests: Arc<PendingRequests>,
27    /// Base SSE URL
28    sse_url: String,
29    /// For sending HTTP POST requests
30    http_client: HttpClient,
31    /// The discovered endpoint for POST requests (once "endpoint" SSE event arrives)
32    post_endpoint: Arc<RwLock<Option<String>>>,
33}
34
35impl SseActor {
36    pub fn new(
37        receiver: mpsc::Receiver<TransportMessage>,
38        pending_requests: Arc<PendingRequests>,
39        sse_url: String,
40        post_endpoint: Arc<RwLock<Option<String>>>,
41    ) -> Self {
42        Self {
43            receiver,
44            pending_requests,
45            sse_url,
46            post_endpoint,
47            http_client: HttpClient::new(),
48        }
49    }
50
51    /// The main entry point for the actor. Spawns two concurrent loops:
52    /// 1) handle_incoming_messages (SSE events)
53    /// 2) handle_outgoing_messages (sending messages via POST)
54    pub async fn run(self) {
55        tokio::join!(
56            Self::handle_incoming_messages(
57                self.sse_url.clone(),
58                Arc::clone(&self.pending_requests),
59                Arc::clone(&self.post_endpoint)
60            ),
61            Self::handle_outgoing_messages(
62                self.receiver,
63                self.http_client.clone(),
64                Arc::clone(&self.post_endpoint),
65                Arc::clone(&self.pending_requests),
66            )
67        );
68    }
69
70    /// Continuously reads SSE events from `sse_url`.
71    /// - If an `endpoint` event is received, store it in `post_endpoint`.
72    /// - If a `message` event is received, parse it as `JsonRpcMessage`
73    ///   and respond to pending requests if it's a `Response`.
74    async fn handle_incoming_messages(
75        sse_url: String,
76        pending_requests: Arc<PendingRequests>,
77        post_endpoint: Arc<RwLock<Option<String>>>,
78    ) {
79        let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
80            Ok(builder) => builder.build(),
81            Err(e) => {
82                pending_requests.clear().await;
83                warn!("Failed to connect SSE client: {}", e);
84                return;
85            }
86        };
87        let mut stream = client.stream();
88
89        // First, wait for the "endpoint" event
90        while let Ok(Some(event)) = stream.try_next().await {
91            match event {
92                SSE::Event(e) if e.event_type == "endpoint" => {
93                    // SSE server uses the "endpoint" event to tell us the POST URL
94                    let base_url = Url::parse(&sse_url).expect("Invalid base URL");
95                    let post_url = base_url
96                        .join(&e.data)
97                        .expect("Failed to resolve endpoint URL");
98
99                    tracing::debug!("Discovered SSE POST endpoint: {}", post_url);
100                    *post_endpoint.write().await = Some(post_url.to_string());
101                    break;
102                }
103                _ => continue,
104            }
105        }
106
107        // Now handle subsequent events
108        while let Ok(Some(event)) = stream.try_next().await {
109            match event {
110                SSE::Event(e) if e.event_type == "message" => {
111                    // Attempt to parse the SSE data as a JsonRpcMessage
112                    match serde_json::from_str::<JsonRpcMessage>(&e.data) {
113                        Ok(message) => {
114                            match &message {
115                                JsonRpcMessage::Response(response) => {
116                                    if let Some(id) = &response.id {
117                                        pending_requests
118                                            .respond(&id.to_string(), Ok(message))
119                                            .await;
120                                    }
121                                }
122                                JsonRpcMessage::Error(error) => {
123                                    if let Some(id) = &error.id {
124                                        pending_requests
125                                            .respond(&id.to_string(), Ok(message))
126                                            .await;
127                                    }
128                                }
129                                _ => {} // TODO: Handle other variants (Request, etc.)
130                            }
131                        }
132                        Err(err) => {
133                            warn!("Failed to parse SSE message: {err}");
134                        }
135                    }
136                }
137                _ => { /* ignore other events */ }
138            }
139        }
140
141        // SSE stream ended or errored; signal any pending requests
142        tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
143        pending_requests.clear().await;
144    }
145
146    /// Continuously receives messages from the `mpsc::Receiver`.
147    /// - If it's a request, store the oneshot in `pending_requests`.
148    /// - POST the message to the discovered endpoint (once known).
149    async fn handle_outgoing_messages(
150        mut receiver: mpsc::Receiver<TransportMessage>,
151        http_client: HttpClient,
152        post_endpoint: Arc<RwLock<Option<String>>>,
153        pending_requests: Arc<PendingRequests>,
154    ) {
155        while let Some(transport_msg) = receiver.recv().await {
156            let post_url = match post_endpoint.read().await.as_ref() {
157                Some(url) => url.clone(),
158                None => {
159                    if let Some(response_tx) = transport_msg.response_tx {
160                        let _ = response_tx.send(Err(Error::NotConnected));
161                    }
162                    continue;
163                }
164            };
165
166            // Serialize the JSON-RPC message
167            let message_str = match serde_json::to_string(&transport_msg.message) {
168                Ok(s) => s,
169                Err(e) => {
170                    if let Some(tx) = transport_msg.response_tx {
171                        let _ = tx.send(Err(Error::Serialization(e)));
172                    }
173                    continue;
174                }
175            };
176
177            // If it's a request, store the channel so we can respond later
178            if let Some(response_tx) = transport_msg.response_tx {
179                if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
180                    &transport_msg.message
181                {
182                    pending_requests.insert(id.to_string(), response_tx).await;
183                }
184            }
185
186            // Perform the HTTP POST
187            match http_client
188                .post(&post_url)
189                .header("Content-Type", "application/json")
190                .body(message_str)
191                .send()
192                .await
193            {
194                Ok(resp) => {
195                    if !resp.status().is_success() {
196                        let err = Error::HttpError {
197                            status: resp.status().as_u16(),
198                            message: resp.status().to_string(),
199                        };
200                        warn!("HTTP request returned error: {err}");
201                        // This doesn't directly fail the request,
202                        // because we rely on SSE to deliver the error response
203                    }
204                }
205                Err(e) => {
206                    warn!("HTTP POST failed: {e}");
207                    // Similarly, SSE might eventually reveal the error
208                }
209            }
210        }
211
212        // mpsc channel closed => no more outgoing messages
213        tracing::error!("SseActor: outgoing message loop ended. Clearing pending requests.");
214        pending_requests.clear().await;
215    }
216}
217
218#[derive(Clone)]
219pub struct SseTransportHandle {
220    sender: mpsc::Sender<TransportMessage>,
221}
222
223#[async_trait::async_trait]
224impl TransportHandle for SseTransportHandle {
225    async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
226        send_message(&self.sender, message).await
227    }
228}
229
230#[derive(Clone)]
231pub struct SseTransport {
232    sse_url: String,
233    env: HashMap<String, String>,
234}
235
236/// The SSE transport spawns an `SseActor` on `start()`.
237impl SseTransport {
238    pub fn new<S: Into<String>>(sse_url: S, env: HashMap<String, String>) -> Self {
239        Self {
240            sse_url: sse_url.into(),
241            env,
242        }
243    }
244
245    /// Waits for the endpoint to be set, up to 10 attempts.
246    async fn wait_for_endpoint(
247        post_endpoint: Arc<RwLock<Option<String>>>,
248    ) -> Result<String, Error> {
249        // Check every 100ms for the endpoint, for up to 10 attempts
250        let check_interval = Duration::from_millis(100);
251        let mut attempts = 0;
252        let max_attempts = 10;
253
254        while attempts < max_attempts {
255            if let Some(url) = post_endpoint.read().await.clone() {
256                return Ok(url);
257            }
258            tokio::time::sleep(check_interval).await;
259            attempts += 1;
260        }
261        Err(Error::SseConnection("No endpoint discovered".to_string()))
262    }
263}
264
265#[async_trait]
266impl Transport for SseTransport {
267    type Handle = SseTransportHandle;
268
269    async fn start(&self) -> Result<Self::Handle, Error> {
270        // Set environment variables
271        for (key, value) in &self.env {
272            std::env::set_var(key, value);
273        }
274
275        // Create a channel for outgoing TransportMessages
276        let (tx, rx) = mpsc::channel(32);
277
278        let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
279        let post_endpoint_clone = Arc::clone(&post_endpoint);
280
281        // Build the actor
282        let actor = SseActor::new(
283            rx,
284            Arc::new(PendingRequests::new()),
285            self.sse_url.clone(),
286            post_endpoint,
287        );
288
289        // Spawn the actor task
290        tokio::spawn(actor.run());
291
292        // Wait for the endpoint to be discovered before returning the handle
293        match timeout(
294            Duration::from_secs(ENDPOINT_TIMEOUT_SECS),
295            Self::wait_for_endpoint(post_endpoint_clone),
296        )
297        .await
298        {
299            Ok(_) => Ok(SseTransportHandle { sender: tx }),
300            Err(e) => Err(Error::SseConnection(e.to_string())),
301        }
302    }
303
304    async fn close(&self) -> Result<(), Error> {
305        // For SSE, you might close the stream or send a shutdown signal to the actor.
306        // Here, we do nothing special.
307        Ok(())
308    }
309}