1use super::AppState;
27use super::mcp_discovery::read_construct_mcp;
28use axum::{
29 extract::{
30 Query, State, WebSocketUpgrade,
31 ws::{Message, WebSocket},
32 },
33 http::{HeaderMap, StatusCode, header},
34 response::IntoResponse,
35};
36use futures_util::{SinkExt, StreamExt, stream::BoxStream};
37use serde::Deserialize;
38use std::time::Duration;
39
40const WS_PROTOCOL: &str = "construct.v1";
41const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
42
43fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
44 if let Some(t) = headers
45 .get(header::AUTHORIZATION)
46 .and_then(|v| v.to_str().ok())
47 .and_then(|auth| auth.strip_prefix("Bearer "))
48 {
49 if !t.is_empty() {
50 return Some(t);
51 }
52 }
53 if let Some(t) = headers
54 .get("sec-websocket-protocol")
55 .and_then(|v| v.to_str().ok())
56 .and_then(|protos| {
57 protos
58 .split(',')
59 .map(|p| p.trim())
60 .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
61 })
62 {
63 if !t.is_empty() {
64 return Some(t);
65 }
66 }
67 if let Some(t) = query_token {
68 if !t.is_empty() {
69 return Some(t);
70 }
71 }
72 None
73}
74
75#[derive(Deserialize, Default)]
76pub struct McpEventsQuery {
77 pub token: Option<String>,
79 pub session_id: Option<String>,
81 pub mcp_token: Option<String>,
83}
84
85pub fn daemon_events_url_from_discovery(discovery_url: &str, session_id: &str) -> String {
88 let trimmed = discovery_url.trim_end_matches('/');
89 let base = trimmed.strip_suffix("/mcp").unwrap_or(trimmed);
90 format!("{base}/session/{session_id}/events")
91}
92
93#[async_trait::async_trait]
97pub trait McpEventsSource: Send + Sync {
98 async fn open(
99 &self,
100 url: &str,
101 mcp_token: &str,
102 ) -> Result<BoxStream<'static, Result<String, String>>, String>;
103}
104
105pub struct ReqwestEventsSource;
108
109#[async_trait::async_trait]
110impl McpEventsSource for ReqwestEventsSource {
111 async fn open(
112 &self,
113 url: &str,
114 mcp_token: &str,
115 ) -> Result<BoxStream<'static, Result<String, String>>, String> {
116 let client = reqwest::Client::builder()
117 .connect_timeout(Duration::from_secs(5))
118 .build()
119 .map_err(|e| e.to_string())?;
120 let resp = client
121 .get(url)
122 .header(header::AUTHORIZATION, format!("Bearer {mcp_token}"))
123 .header(header::ACCEPT, "text/event-stream")
124 .send()
125 .await
126 .map_err(|e| e.to_string())?;
127 if !resp.status().is_success() {
128 return Err(format!("daemon responded {}", resp.status()));
129 }
130 let byte_stream = resp
133 .bytes_stream()
134 .map(|r| r.map(|b| b.to_vec()).map_err(|e| e.to_string()));
135 Ok(parse_sse_stream(byte_stream).boxed())
136 }
137}
138
139pub fn parse_sse_stream<S>(
145 byte_stream: S,
146) -> impl futures_util::Stream<Item = Result<String, String>> + Send + 'static
147where
148 S: futures_util::Stream<Item = Result<Vec<u8>, String>> + Send + 'static,
149{
150 use futures_util::stream::unfold;
151
152 struct St {
153 inner: BoxStream<'static, Result<Vec<u8>, String>>,
154 buffer: String,
155 data_accum: String,
156 pending: std::collections::VecDeque<String>,
157 done: bool,
158 }
159
160 let state = St {
161 inner: byte_stream.boxed(),
162 buffer: String::new(),
163 data_accum: String::new(),
164 pending: std::collections::VecDeque::new(),
165 done: false,
166 };
167
168 unfold(state, |mut st| async move {
169 if let Some(next) = st.pending.pop_front() {
171 return Some((Ok(next), st));
172 }
173 if st.done {
174 if !st.data_accum.is_empty() {
176 let out = std::mem::take(&mut st.data_accum);
177 return Some((Ok(out), st));
178 }
179 return None;
180 }
181 loop {
183 match st.inner.next().await {
184 None => {
185 st.done = true;
186 if !st.data_accum.is_empty() {
187 let out = std::mem::take(&mut st.data_accum);
188 return Some((Ok(out), st));
189 }
190 return None;
191 }
192 Some(Err(e)) => {
193 st.done = true;
194 return Some((Err(e), st));
195 }
196 Some(Ok(bytes)) => {
197 st.buffer.push_str(&String::from_utf8_lossy(&bytes));
198 while let Some(idx) = st.buffer.find('\n') {
199 let line = st.buffer[..idx].trim_end_matches('\r').to_string();
200 st.buffer.drain(..=idx);
201 if line.is_empty() {
202 if !st.data_accum.is_empty() {
203 st.pending.push_back(std::mem::take(&mut st.data_accum));
204 }
205 continue;
206 }
207 if let Some(rest) = line.strip_prefix("data:") {
208 let payload = rest.strip_prefix(' ').unwrap_or(rest);
209 if !st.data_accum.is_empty() {
210 st.data_accum.push('\n');
211 }
212 st.data_accum.push_str(payload);
213 }
214 }
216 if let Some(next) = st.pending.pop_front() {
217 return Some((Ok(next), st));
218 }
219 }
221 }
222 }
223 })
224}
225
226pub async fn handle_ws_mcp_events(
228 State(state): State<AppState>,
229 Query(params): Query<McpEventsQuery>,
230 headers: HeaderMap,
231 ws: WebSocketUpgrade,
232) -> axum::response::Response {
233 if state.pairing.require_pairing() {
234 let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
235 if !state.pairing.is_authenticated(token) {
236 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
237 }
238 }
239
240 let Some(session_id) = params.session_id.clone().filter(|s| !s.is_empty()) else {
241 return (StatusCode::BAD_REQUEST, "missing session_id").into_response();
242 };
243 let Some(mcp_token) = params.mcp_token.clone().filter(|s| !s.is_empty()) else {
244 return (StatusCode::BAD_REQUEST, "missing mcp_token").into_response();
245 };
246
247 let discovery = match read_construct_mcp() {
248 Ok(d) => d,
249 Err(_) => {
250 return (StatusCode::SERVICE_UNAVAILABLE, "mcp daemon not discovered").into_response();
251 }
252 };
253 let events_url = daemon_events_url_from_discovery(&discovery.url, &session_id);
254
255 let ws = if headers
256 .get("sec-websocket-protocol")
257 .and_then(|v| v.to_str().ok())
258 .is_some_and(|protos| protos.split(',').any(|p| p.trim() == WS_PROTOCOL))
259 {
260 ws.protocols([WS_PROTOCOL])
261 } else {
262 ws
263 };
264
265 ws.on_upgrade(move |socket| async move {
266 run_proxy(socket, events_url, mcp_token, Box::new(ReqwestEventsSource)).await;
267 })
268 .into_response()
269}
270
271pub async fn run_proxy(
274 mut ws: WebSocket,
275 events_url: String,
276 mcp_token: String,
277 source: Box<dyn McpEventsSource>,
278) {
279 let mut stream = match source.open(&events_url, &mcp_token).await {
280 Ok(s) => s,
281 Err(e) => {
282 let _ = ws
283 .send(Message::Text(
284 serde_json::json!({ "error": "daemon-unreachable", "detail": e })
285 .to_string()
286 .into(),
287 ))
288 .await;
289 let _ = ws.close().await;
290 return;
291 }
292 };
293
294 loop {
295 tokio::select! {
296 incoming = ws.recv() => {
297 match incoming {
298 Some(Ok(Message::Close(_))) | None => break,
299 Some(Err(_)) => break,
300 _ => { }
301 }
302 }
303 next = stream.next() => {
304 match next {
305 Some(Ok(payload)) => {
306 if ws.send(Message::Text(payload.into())).await.is_err() {
307 break;
308 }
309 }
310 Some(Err(_)) | None => {
311 let _ = ws.close().await;
312 break;
313 }
314 }
315 }
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use futures_util::stream;
324
325 #[test]
326 fn builds_events_url_from_mcp_discovery() {
327 assert_eq!(
328 daemon_events_url_from_discovery("http://127.0.0.1:54500/mcp", "sid-1"),
329 "http://127.0.0.1:54500/session/sid-1/events"
330 );
331 assert_eq!(
332 daemon_events_url_from_discovery("http://127.0.0.1:54500/mcp/", "sid-2"),
333 "http://127.0.0.1:54500/session/sid-2/events"
334 );
335 assert_eq!(
336 daemon_events_url_from_discovery("http://127.0.0.1:54500", "sid-3"),
337 "http://127.0.0.1:54500/session/sid-3/events"
338 );
339 }
340
341 #[tokio::test]
342 async fn sse_parser_extracts_data_frames() {
343 let chunks: Vec<Result<Vec<u8>, String>> = vec![
344 Ok(b"data: {\"a\":1}\n\n".to_vec()),
345 Ok(b"data: {\"b\":2}\n\n".to_vec()),
346 ];
347 let byte_stream = stream::iter(chunks);
348 let parsed = parse_sse_stream(byte_stream);
349 futures_util::pin_mut!(parsed);
350 let first = parsed.next().await.unwrap().unwrap();
351 let second = parsed.next().await.unwrap().unwrap();
352 assert_eq!(first, "{\"a\":1}");
353 assert_eq!(second, "{\"b\":2}");
354 }
355
356 #[tokio::test]
357 async fn sse_parser_joins_multi_data_lines() {
358 let chunks: Vec<Result<Vec<u8>, String>> =
359 vec![Ok(b"data: line1\ndata: line2\n\n".to_vec())];
360 let byte_stream = stream::iter(chunks);
361 let parsed = parse_sse_stream(byte_stream);
362 futures_util::pin_mut!(parsed);
363 let joined = parsed.next().await.unwrap().unwrap();
364 assert_eq!(joined, "line1\nline2");
365 }
366
367 #[tokio::test]
368 async fn sse_parser_ignores_non_data_fields() {
369 let chunks: Vec<Result<Vec<u8>, String>> = vec![Ok(
370 b": heartbeat\nevent: progress\ndata: {\"k\":\"v\"}\n\n".to_vec(),
371 )];
372 let byte_stream = stream::iter(chunks);
373 let parsed = parse_sse_stream(byte_stream);
374 futures_util::pin_mut!(parsed);
375 let payload = parsed.next().await.unwrap().unwrap();
376 assert_eq!(payload, "{\"k\":\"v\"}");
377 }
378
379 #[tokio::test]
380 async fn sse_parser_handles_chunk_boundaries_midline() {
381 let chunks: Vec<Result<Vec<u8>, String>> =
382 vec![Ok(b"data: {\"tok".to_vec()), Ok(b"en\":42}\n\n".to_vec())];
383 let byte_stream = stream::iter(chunks);
384 let parsed = parse_sse_stream(byte_stream);
385 futures_util::pin_mut!(parsed);
386 let payload = parsed.next().await.unwrap().unwrap();
387 assert_eq!(payload, "{\"token\":42}");
388 }
389
390 struct ScriptedSource(Vec<Result<String, String>>);
393
394 #[async_trait::async_trait]
395 impl McpEventsSource for ScriptedSource {
396 async fn open(
397 &self,
398 _url: &str,
399 _mcp_token: &str,
400 ) -> Result<BoxStream<'static, Result<String, String>>, String> {
401 let items = self.0.clone();
402 Ok(stream::iter(items).boxed())
403 }
404 }
405
406 #[tokio::test]
407 async fn source_abstraction_is_mockable_and_yields_frames() {
408 let source = ScriptedSource(vec![
409 Ok(r#"{"token":1,"progress":1,"timestamp":"t1"}"#.into()),
410 Ok(r#"{"token":1,"progress":2,"timestamp":"t2"}"#.into()),
411 ]);
412 let mut stream = source
413 .open("http://example/session/x/events", "token")
414 .await
415 .expect("open ok");
416 let first = stream.next().await.unwrap().unwrap();
417 let second = stream.next().await.unwrap().unwrap();
418 assert!(first.contains("\"progress\":1"));
419 assert!(second.contains("\"progress\":2"));
420 assert!(stream.next().await.is_none());
421 }
422
423 #[tokio::test]
424 async fn source_open_error_surfaces_to_caller() {
425 struct FailingSource;
426 #[async_trait::async_trait]
427 impl McpEventsSource for FailingSource {
428 async fn open(
429 &self,
430 _url: &str,
431 _mcp_token: &str,
432 ) -> Result<BoxStream<'static, Result<String, String>>, String> {
433 Err("connection refused".into())
434 }
435 }
436 let source = FailingSource;
437 let err = match source.open("http://x", "t").await {
438 Ok(_) => panic!("expected error"),
439 Err(e) => e,
440 };
441 assert!(err.contains("connection refused"));
442 }
443}