Skip to main content

construct/gateway/
ws_mcp_events.rs

1//! `GET /ws/mcp/events` — WebSocket proxy onto the in-process MCP server's
2//! session-wide progress SSE stream (`GET /session/<id>/events`).
3//!
4//! The MCP server runs as a tokio task inside the same daemon process as the
5//! gateway — this proxy still loopbacks through `~/.construct/mcp.json` so it
6//! works without knowing the ephemeral port ahead of time.
7//!
8//! The V2 Code tab opens this WS while a CLI coding agent is running in the
9//! PTY. Every `ProgressEvent` published by any in-flight `tools/call` on
10//! that MCP session is forwarded here as a single JSON text frame, matching
11//! the server's `ProgressEvent` serialization:
12//!
13//! ```json
14//! { "token": 7, "progress": 4, "total": 10, "message": "...",
15//!   "tool": "notion", "timestamp": "2026-04-17T10:20:33+00:00" }
16//! ```
17//!
18//! ## Why a gateway proxy (not direct SSE from the browser)?
19//!
20//! - Keeps all external traffic funneled through the gateway — single auth
21//!   surface, no CORS friction, consistent with `/ws/terminal`.
22//! - The gateway independently verifies its own bearer (`?token=<zc_…>` /
23//!   `Authorization:`) using `PairingGuard`; the MCP session token
24//!   (`?mcp_token=<…>`) is used only to talk to the in-process MCP server.
25
26use super::AppState;
27use super::mcp_discovery::read_construct_mcp;
28use axum::{
29    extract::{
30        Query, State, WebSocketUpgrade,
31        ws::{Message, WebSocket},
32    },
33    http::{HeaderMap, StatusCode, header},
34    response::IntoResponse,
35};
36use futures_util::{SinkExt, StreamExt, stream::BoxStream};
37use serde::Deserialize;
38use std::time::Duration;
39
40const WS_PROTOCOL: &str = "construct.v1";
41const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
42
43fn extract_ws_token<'a>(headers: &'a HeaderMap, query_token: Option<&'a str>) -> Option<&'a str> {
44    if let Some(t) = headers
45        .get(header::AUTHORIZATION)
46        .and_then(|v| v.to_str().ok())
47        .and_then(|auth| auth.strip_prefix("Bearer "))
48    {
49        if !t.is_empty() {
50            return Some(t);
51        }
52    }
53    if let Some(t) = headers
54        .get("sec-websocket-protocol")
55        .and_then(|v| v.to_str().ok())
56        .and_then(|protos| {
57            protos
58                .split(',')
59                .map(|p| p.trim())
60                .find_map(|p| p.strip_prefix(BEARER_SUBPROTO_PREFIX))
61        })
62    {
63        if !t.is_empty() {
64            return Some(t);
65        }
66    }
67    if let Some(t) = query_token {
68        if !t.is_empty() {
69            return Some(t);
70        }
71    }
72    None
73}
74
75#[derive(Deserialize, Default)]
76pub struct McpEventsQuery {
77    /// Gateway bearer token (same as `/ws/terminal`).
78    pub token: Option<String>,
79    /// MCP daemon session id (returned by `POST /session`).
80    pub session_id: Option<String>,
81    /// MCP daemon bearer token (returned by `POST /session`).
82    pub mcp_token: Option<String>,
83}
84
85/// Map the discovery URL (which usually ends in `/mcp`) to the session-events
86/// URL for the given session id.
87pub fn daemon_events_url_from_discovery(discovery_url: &str, session_id: &str) -> String {
88    let trimmed = discovery_url.trim_end_matches('/');
89    let base = trimmed.strip_suffix("/mcp").unwrap_or(trimmed);
90    format!("{base}/session/{session_id}/events")
91}
92
93/// Abstraction over "open the session-events SSE and yield each event's data
94/// payload as a String". Real impl uses `reqwest`; tests inject a scripted
95/// mock.
96#[async_trait::async_trait]
97pub trait McpEventsSource: Send + Sync {
98    async fn open(
99        &self,
100        url: &str,
101        mcp_token: &str,
102    ) -> Result<BoxStream<'static, Result<String, String>>, String>;
103}
104
105/// Default source — opens the SSE stream via `reqwest`, parses the server
106/// frames, and yields each `data: …` line's payload as a String.
107pub struct ReqwestEventsSource;
108
109#[async_trait::async_trait]
110impl McpEventsSource for ReqwestEventsSource {
111    async fn open(
112        &self,
113        url: &str,
114        mcp_token: &str,
115    ) -> Result<BoxStream<'static, Result<String, String>>, String> {
116        let client = reqwest::Client::builder()
117            .connect_timeout(Duration::from_secs(5))
118            .build()
119            .map_err(|e| e.to_string())?;
120        let resp = client
121            .get(url)
122            .header(header::AUTHORIZATION, format!("Bearer {mcp_token}"))
123            .header(header::ACCEPT, "text/event-stream")
124            .send()
125            .await
126            .map_err(|e| e.to_string())?;
127        if !resp.status().is_success() {
128            return Err(format!("daemon responded {}", resp.status()));
129        }
130        // Convert reqwest's byte stream (Result<Bytes, reqwest::Error>) into
131        // Result<Vec<u8>, String>, then run the SSE framer.
132        let byte_stream = resp
133            .bytes_stream()
134            .map(|r| r.map(|b| b.to_vec()).map_err(|e| e.to_string()));
135        Ok(parse_sse_stream(byte_stream).boxed())
136    }
137}
138
139/// Parse a stream of byte chunks into a stream of `data:` payload Strings.
140///
141/// Simplified SSE parser: events are blank-line terminated, only the `data:`
142/// field is surfaced (`event:`, `id:`, `:comment`, etc. are ignored).
143/// Multiple `data:` lines within one event are joined with `\n` per spec.
144pub fn parse_sse_stream<S>(
145    byte_stream: S,
146) -> impl futures_util::Stream<Item = Result<String, String>> + Send + 'static
147where
148    S: futures_util::Stream<Item = Result<Vec<u8>, String>> + Send + 'static,
149{
150    use futures_util::stream::unfold;
151
152    struct St {
153        inner: BoxStream<'static, Result<Vec<u8>, String>>,
154        buffer: String,
155        data_accum: String,
156        pending: std::collections::VecDeque<String>,
157        done: bool,
158    }
159
160    let state = St {
161        inner: byte_stream.boxed(),
162        buffer: String::new(),
163        data_accum: String::new(),
164        pending: std::collections::VecDeque::new(),
165        done: false,
166    };
167
168    unfold(state, |mut st| async move {
169        // Flush already-queued events first (one per yield).
170        if let Some(next) = st.pending.pop_front() {
171            return Some((Ok(next), st));
172        }
173        if st.done {
174            // Drain any trailing data accumulated without a blank line.
175            if !st.data_accum.is_empty() {
176                let out = std::mem::take(&mut st.data_accum);
177                return Some((Ok(out), st));
178            }
179            return None;
180        }
181        // Pull more bytes until at least one event is flushed or EOF.
182        loop {
183            match st.inner.next().await {
184                None => {
185                    st.done = true;
186                    if !st.data_accum.is_empty() {
187                        let out = std::mem::take(&mut st.data_accum);
188                        return Some((Ok(out), st));
189                    }
190                    return None;
191                }
192                Some(Err(e)) => {
193                    st.done = true;
194                    return Some((Err(e), st));
195                }
196                Some(Ok(bytes)) => {
197                    st.buffer.push_str(&String::from_utf8_lossy(&bytes));
198                    while let Some(idx) = st.buffer.find('\n') {
199                        let line = st.buffer[..idx].trim_end_matches('\r').to_string();
200                        st.buffer.drain(..=idx);
201                        if line.is_empty() {
202                            if !st.data_accum.is_empty() {
203                                st.pending.push_back(std::mem::take(&mut st.data_accum));
204                            }
205                            continue;
206                        }
207                        if let Some(rest) = line.strip_prefix("data:") {
208                            let payload = rest.strip_prefix(' ').unwrap_or(rest);
209                            if !st.data_accum.is_empty() {
210                                st.data_accum.push('\n');
211                            }
212                            st.data_accum.push_str(payload);
213                        }
214                        // Other fields ignored (event:, id:, retry:, :comment).
215                    }
216                    if let Some(next) = st.pending.pop_front() {
217                        return Some((Ok(next), st));
218                    }
219                    // Keep pulling more bytes.
220                }
221            }
222        }
223    })
224}
225
226/// GET /ws/mcp/events — WebSocket upgrade for session-wide MCP progress.
227pub async fn handle_ws_mcp_events(
228    State(state): State<AppState>,
229    Query(params): Query<McpEventsQuery>,
230    headers: HeaderMap,
231    ws: WebSocketUpgrade,
232) -> axum::response::Response {
233    if state.pairing.require_pairing() {
234        let token = extract_ws_token(&headers, params.token.as_deref()).unwrap_or("");
235        if !state.pairing.is_authenticated(token) {
236            return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
237        }
238    }
239
240    let Some(session_id) = params.session_id.clone().filter(|s| !s.is_empty()) else {
241        return (StatusCode::BAD_REQUEST, "missing session_id").into_response();
242    };
243    let Some(mcp_token) = params.mcp_token.clone().filter(|s| !s.is_empty()) else {
244        return (StatusCode::BAD_REQUEST, "missing mcp_token").into_response();
245    };
246
247    let discovery = match read_construct_mcp() {
248        Ok(d) => d,
249        Err(_) => {
250            return (StatusCode::SERVICE_UNAVAILABLE, "mcp daemon not discovered").into_response();
251        }
252    };
253    let events_url = daemon_events_url_from_discovery(&discovery.url, &session_id);
254
255    let ws = if headers
256        .get("sec-websocket-protocol")
257        .and_then(|v| v.to_str().ok())
258        .is_some_and(|protos| protos.split(',').any(|p| p.trim() == WS_PROTOCOL))
259    {
260        ws.protocols([WS_PROTOCOL])
261    } else {
262        ws
263    };
264
265    ws.on_upgrade(move |socket| async move {
266        run_proxy(socket, events_url, mcp_token, Box::new(ReqwestEventsSource)).await;
267    })
268    .into_response()
269}
270
271/// Pipe every SSE `data:` payload from `source` into the WebSocket as a text
272/// frame. Returns when either side closes.
273pub async fn run_proxy(
274    mut ws: WebSocket,
275    events_url: String,
276    mcp_token: String,
277    source: Box<dyn McpEventsSource>,
278) {
279    let mut stream = match source.open(&events_url, &mcp_token).await {
280        Ok(s) => s,
281        Err(e) => {
282            let _ = ws
283                .send(Message::Text(
284                    serde_json::json!({ "error": "daemon-unreachable", "detail": e })
285                        .to_string()
286                        .into(),
287                ))
288                .await;
289            let _ = ws.close().await;
290            return;
291        }
292    };
293
294    loop {
295        tokio::select! {
296            incoming = ws.recv() => {
297                match incoming {
298                    Some(Ok(Message::Close(_))) | None => break,
299                    Some(Err(_)) => break,
300                    _ => { /* ignore client->server frames; this channel is server-push */ }
301                }
302            }
303            next = stream.next() => {
304                match next {
305                    Some(Ok(payload)) => {
306                        if ws.send(Message::Text(payload.into())).await.is_err() {
307                            break;
308                        }
309                    }
310                    Some(Err(_)) | None => {
311                        let _ = ws.close().await;
312                        break;
313                    }
314                }
315            }
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use futures_util::stream;
324
325    #[test]
326    fn builds_events_url_from_mcp_discovery() {
327        assert_eq!(
328            daemon_events_url_from_discovery("http://127.0.0.1:54500/mcp", "sid-1"),
329            "http://127.0.0.1:54500/session/sid-1/events"
330        );
331        assert_eq!(
332            daemon_events_url_from_discovery("http://127.0.0.1:54500/mcp/", "sid-2"),
333            "http://127.0.0.1:54500/session/sid-2/events"
334        );
335        assert_eq!(
336            daemon_events_url_from_discovery("http://127.0.0.1:54500", "sid-3"),
337            "http://127.0.0.1:54500/session/sid-3/events"
338        );
339    }
340
341    #[tokio::test]
342    async fn sse_parser_extracts_data_frames() {
343        let chunks: Vec<Result<Vec<u8>, String>> = vec![
344            Ok(b"data: {\"a\":1}\n\n".to_vec()),
345            Ok(b"data: {\"b\":2}\n\n".to_vec()),
346        ];
347        let byte_stream = stream::iter(chunks);
348        let parsed = parse_sse_stream(byte_stream);
349        futures_util::pin_mut!(parsed);
350        let first = parsed.next().await.unwrap().unwrap();
351        let second = parsed.next().await.unwrap().unwrap();
352        assert_eq!(first, "{\"a\":1}");
353        assert_eq!(second, "{\"b\":2}");
354    }
355
356    #[tokio::test]
357    async fn sse_parser_joins_multi_data_lines() {
358        let chunks: Vec<Result<Vec<u8>, String>> =
359            vec![Ok(b"data: line1\ndata: line2\n\n".to_vec())];
360        let byte_stream = stream::iter(chunks);
361        let parsed = parse_sse_stream(byte_stream);
362        futures_util::pin_mut!(parsed);
363        let joined = parsed.next().await.unwrap().unwrap();
364        assert_eq!(joined, "line1\nline2");
365    }
366
367    #[tokio::test]
368    async fn sse_parser_ignores_non_data_fields() {
369        let chunks: Vec<Result<Vec<u8>, String>> = vec![Ok(
370            b": heartbeat\nevent: progress\ndata: {\"k\":\"v\"}\n\n".to_vec(),
371        )];
372        let byte_stream = stream::iter(chunks);
373        let parsed = parse_sse_stream(byte_stream);
374        futures_util::pin_mut!(parsed);
375        let payload = parsed.next().await.unwrap().unwrap();
376        assert_eq!(payload, "{\"k\":\"v\"}");
377    }
378
379    #[tokio::test]
380    async fn sse_parser_handles_chunk_boundaries_midline() {
381        let chunks: Vec<Result<Vec<u8>, String>> =
382            vec![Ok(b"data: {\"tok".to_vec()), Ok(b"en\":42}\n\n".to_vec())];
383        let byte_stream = stream::iter(chunks);
384        let parsed = parse_sse_stream(byte_stream);
385        futures_util::pin_mut!(parsed);
386        let payload = parsed.next().await.unwrap().unwrap();
387        assert_eq!(payload, "{\"token\":42}");
388    }
389
390    // ── Source mock used by the proxy handler ─────────────────────────────
391
392    struct ScriptedSource(Vec<Result<String, String>>);
393
394    #[async_trait::async_trait]
395    impl McpEventsSource for ScriptedSource {
396        async fn open(
397            &self,
398            _url: &str,
399            _mcp_token: &str,
400        ) -> Result<BoxStream<'static, Result<String, String>>, String> {
401            let items = self.0.clone();
402            Ok(stream::iter(items).boxed())
403        }
404    }
405
406    #[tokio::test]
407    async fn source_abstraction_is_mockable_and_yields_frames() {
408        let source = ScriptedSource(vec![
409            Ok(r#"{"token":1,"progress":1,"timestamp":"t1"}"#.into()),
410            Ok(r#"{"token":1,"progress":2,"timestamp":"t2"}"#.into()),
411        ]);
412        let mut stream = source
413            .open("http://example/session/x/events", "token")
414            .await
415            .expect("open ok");
416        let first = stream.next().await.unwrap().unwrap();
417        let second = stream.next().await.unwrap().unwrap();
418        assert!(first.contains("\"progress\":1"));
419        assert!(second.contains("\"progress\":2"));
420        assert!(stream.next().await.is_none());
421    }
422
423    #[tokio::test]
424    async fn source_open_error_surfaces_to_caller() {
425        struct FailingSource;
426        #[async_trait::async_trait]
427        impl McpEventsSource for FailingSource {
428            async fn open(
429                &self,
430                _url: &str,
431                _mcp_token: &str,
432            ) -> Result<BoxStream<'static, Result<String, String>>, String> {
433                Err("connection refused".into())
434            }
435        }
436        let source = FailingSource;
437        let err = match source.open("http://x", "t").await {
438            Ok(_) => panic!("expected error"),
439            Err(e) => e,
440        };
441        assert!(err.contains("connection refused"));
442    }
443}