Skip to main content

jmap_base_client/ws/
mod.rs

1//! WebSocket transport for JMAP (RFC 8887).
2//!
3//! Provides [`connect_ws`] which establishes a WebSocket connection and
4//! returns a [`WsSession`] for sending and receiving frames.
5//!
6//! URL source: `Session::capabilities["urn:ietf:params:jmap:websocket"].url`
7//! (the session document advertises the WebSocket endpoint).
8
9use std::str::FromStr as _;
10
11use futures::SinkExt as _;
12use futures::StreamExt as _;
13use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
14use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
15use tokio_tungstenite::tungstenite::Message;
16
17use crate::push::StateChange;
18
19/// Wire frame sent from the client to the server over WebSocket (RFC 8887 §4.3.2).
20///
21/// Wraps a [`jmap_types::JmapRequest`] and injects the mandatory `@type: "Request"`
22/// field (and optional `id`) in a single `serde_json::to_string` pass, avoiding
23/// the `to_value` + mutation + `to_string` double-serialization that the naive
24/// approach requires.
25#[derive(serde::Serialize)]
26struct WsRequestFrame<'a> {
27    /// RFC 8887 §4.3.2 — every JMAP request frame MUST carry "@type": "Request".
28    #[serde(rename = "@type")]
29    ws_type: &'static str,
30    /// Optional correlation ID echoed back in the server's Response frame.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    id: Option<&'a str>,
33    /// The JMAP request payload; flattened into the enclosing JSON object.
34    #[serde(flatten)]
35    inner: &'a jmap_types::JmapRequest,
36}
37
38/// Maximum WebSocket message size (1 MiB), consistent with the SSE frame limit.
39/// Prevents a misbehaving or hostile server from forcing the client to buffer
40/// large messages over the event connection.
41const MAX_WS_MESSAGE_BYTES: usize = 1 << 20; // 1 MiB
42
43/// A parsed frame received from the JMAP WebSocket.
44///
45/// Marked `#[non_exhaustive]` because the spec may define additional
46/// `@type` values in future revisions.
47#[non_exhaustive]
48#[derive(Debug, Clone, PartialEq)]
49pub enum WsFrame {
50    /// RFC 8620 §7.1 StateChange — one or more object types have changed
51    /// state; client must re-fetch the affected data types.
52    StateChange(StateChange),
53    /// RFC 8887 Response — reply to a JMAP request sent on this connection.
54    Response(jmap_types::JmapResponse),
55    /// Unrecognized `@type` — silently ignored per forward-compatibility rules
56    /// (RFC 8887 §4.3.1: clients SHOULD ignore unknown message types).
57    ///
58    /// Also produced when a known type (`"Response"` or `"StateChange"`) fails
59    /// to deserialize — `type_name` will be `"Response"` or `"StateChange"` in
60    /// that case, which can signal server misbehavior or a schema version
61    /// mismatch. Callers that log unknown frames should check for these names.
62    Unknown { type_name: String },
63}
64
65type Inner =
66    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
67
68/// An established JMAP WebSocket session (RFC 8887).
69///
70/// Call [`next_frame`](WsSession::next_frame) in a loop to receive events.
71/// Use [`send_request`](WsSession::send_request) to transmit JMAP requests.
72///
73/// The caller is responsible for reconnecting after the stream ends or returns
74/// a transport error. Use exponential backoff.
75pub struct WsSession {
76    sink: futures::stream::SplitSink<Inner, Message>,
77    stream: futures::stream::SplitStream<Inner>,
78}
79
80impl WsSession {
81    /// Receive the next parsed frame from the server.
82    ///
83    /// Returns `None` when the server has cleanly closed the connection.
84    /// Returns `Some(Err(...))` on parse failure or transport error. After a
85    /// transport error the connection is broken; do not call `next_frame` again.
86    pub async fn next_frame(&mut self) -> Option<Result<WsFrame, crate::error::ClientError>> {
87        loop {
88            match self.stream.next().await? {
89                Ok(Message::Text(text)) => return Some(parse_ws_frame(&text)),
90                Ok(Message::Close(_)) => return None,
91                Ok(_) => continue, // Ping / Pong / Binary: silently skip
92                Err(e) => return Some(Err(crate::error::ClientError::WebSocket(e))),
93            }
94        }
95    }
96
97    /// Send a JMAP request over the WebSocket connection.
98    ///
99    /// Serializes `req` and injects `"@type": "Request"` into the outgoing
100    /// JSON object as required by RFC 8887 §4.3.2.  The optional `id` is
101    /// echoed back in the corresponding `Response` frame, enabling out-of-order
102    /// correlation.
103    ///
104    /// # Errors
105    ///
106    /// Returns `ClientError::Serialize` if `req` cannot be serialized, or
107    /// `ClientError::WebSocket` on a transport failure.
108    pub async fn send_request(
109        &mut self,
110        req: &jmap_types::JmapRequest,
111        id: Option<&str>,
112    ) -> Result<(), crate::error::ClientError> {
113        // Wrap req in WsRequestFrame to inject @type and optional id in one
114        // serialization pass (no intermediate serde_json::Value allocation).
115        let frame = WsRequestFrame {
116            ws_type: "Request",
117            id,
118            inner: req,
119        };
120        let text = serde_json::to_string(&frame).map_err(crate::error::ClientError::Serialize)?;
121        self.sink
122            .send(Message::Text(text.into()))
123            .await
124            .map_err(crate::error::ClientError::WebSocket)
125    }
126}
127
128/// Parse a raw WebSocket text frame into a `WsFrame`.
129fn parse_ws_frame(text: &str) -> Result<WsFrame, crate::error::ClientError> {
130    let val: serde_json::Value =
131        serde_json::from_str(text).map_err(crate::error::ClientError::Parse)?;
132
133    // Pre-extract type_name as owned String before moving val into from_value.
134    // The borrow checker prevents borrowing val (for @type) and moving val
135    // (into from_value) in the same expression, so ownership must be taken first.
136    let type_name = val
137        .get("@type")
138        .and_then(|v| v.as_str())
139        .unwrap_or("<no @type>")
140        .to_owned();
141
142    match type_name.as_str() {
143        // A malformed StateChange is degraded to Unknown rather than a
144        // transport error. A single bad server frame must not kill the entire
145        // WebSocket connection; only tungstenite transport errors warrant
146        // a reconnect.
147        "StateChange" => match serde_json::from_value::<StateChange>(val) {
148            Ok(sc) => Ok(WsFrame::StateChange(sc)),
149            Err(_) => Ok(WsFrame::Unknown { type_name }),
150        },
151        // Same degradation policy for malformed Response frames.
152        "Response" => match serde_json::from_value::<jmap_types::JmapResponse>(val) {
153            Ok(r) => Ok(WsFrame::Response(r)),
154            Err(_) => Ok(WsFrame::Unknown { type_name }),
155        },
156        _ => Ok(WsFrame::Unknown { type_name }),
157    }
158}
159
160/// Open a JMAP WebSocket connection (RFC 8887).
161///
162/// `ws_url` must come from the session document's WebSocket capability URL
163/// (a `wss://` endpoint in production; `ws://` is accepted in tests).
164///
165/// `auth_header` is an optional `(header-name, header-value)` pair injected
166/// into the WebSocket upgrade request. Pass `None` when the server does not
167/// require authentication headers on the WebSocket handshake.
168///
169/// Returns `ClientError::InvalidArgument` if the URL scheme is not
170/// `ws://` or `wss://`, preventing accidental use with untrusted URLs.
171///
172/// The returned [`WsSession`] provides [`WsSession::next_frame`] for receiving
173/// events. The caller is responsible for reconnecting after disconnect with
174/// exponential backoff.
175pub async fn connect_ws(
176    ws_url: &str,
177    auth_header: Option<(&str, &str)>,
178) -> Result<WsSession, crate::error::ClientError> {
179    // Validate scheme to prevent SSRF via a compromised or MITM'd session.
180    // Case-insensitive check per RFC 3986 §3.1: lowercase the URL before
181    // comparing so that `WS://` and `wss://` are both accepted.  The
182    // original (unmodified) URL is passed to tungstenite and kept in error
183    // messages for diagnostics.
184    let ws_url_lc = ws_url.to_ascii_lowercase();
185    if !ws_url_lc.starts_with("ws://") && !ws_url_lc.starts_with("wss://") {
186        return Err(crate::error::ClientError::InvalidArgument(format!(
187            "WebSocket URL must start with ws:// or wss://, got: {ws_url:?}"
188        )));
189    }
190
191    let mut request = ws_url
192        .into_client_request()
193        .map_err(crate::error::ClientError::WebSocket)?;
194
195    if let Some((name, value)) = auth_header {
196        let hdr_name = http::HeaderName::from_str(name).map_err(|e| {
197            crate::error::ClientError::InvalidArgument(format!("invalid auth header name: {e}"))
198        })?;
199        let hdr_value = http::HeaderValue::from_str(value).map_err(|_| {
200            crate::error::ClientError::InvalidArgument("invalid auth header value".to_owned())
201        })?;
202        request.headers_mut().insert(hdr_name, hdr_value);
203    }
204
205    // WebSocketConfig is #[non_exhaustive] in tungstenite; use Default + field assignment.
206    let mut config = WebSocketConfig::default();
207    config.max_message_size = Some(MAX_WS_MESSAGE_BYTES);
208    config.max_frame_size = Some(MAX_WS_MESSAGE_BYTES);
209
210    // Apply a 10-second connect timeout, consistent with the HTTP transport's
211    // connect_timeout in DefaultTransport/CustomCaTransport.  tungstenite does
212    // not expose a connect timeout parameter, so we wrap at the Future level.
213    // A stalled TCP or TLS handshake would otherwise block indefinitely.
214    let connect_result = tokio::time::timeout(
215        std::time::Duration::from_secs(10),
216        tokio_tungstenite::connect_async_with_config(request, Some(config), false),
217    )
218    .await
219    .map_err(|_elapsed| {
220        crate::error::ClientError::WebSocket(tokio_tungstenite::tungstenite::Error::Io(
221            std::io::Error::new(
222                std::io::ErrorKind::TimedOut,
223                "WebSocket connect timed out after 10 seconds",
224            ),
225        ))
226    })?;
227    let (ws_stream, _response) = connect_result.map_err(crate::error::ClientError::WebSocket)?;
228
229    let (sink, stream) = ws_stream.split();
230    Ok(WsSession { sink, stream })
231}
232
233impl std::fmt::Debug for WsSession {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        f.debug_struct("WsSession").finish_non_exhaustive()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    /// Verify WsFrame does not contain ChatTyping or ChatPresence variants.
244    /// This exhaustive match will fail to compile if either variant is reintroduced.
245    #[test]
246    fn ws_frame_has_no_chat_variants() {
247        let frame = WsFrame::Unknown {
248            type_name: "test".to_owned(),
249        };
250        match frame {
251            WsFrame::StateChange(_) => {}
252            WsFrame::Response(_) => {}
253            WsFrame::Unknown { .. } => {}
254        }
255    }
256
257    /// Oracle: parse_ws_frame dispatches on @type field and produces a typed StateChange.
258    /// Wire format from RFC 8620 §7.1.1 example.
259    #[test]
260    fn parse_state_change() {
261        let json = r#"{"@type":"StateChange","changed":{"account1":{"Mail":"s2"}}}"#;
262        let frame = parse_ws_frame(json).expect("must parse");
263        match frame {
264            WsFrame::StateChange(sc) => {
265                let account = sc
266                    .changed
267                    .get("account1")
268                    .expect("account1 must be present");
269                assert_eq!(account.get("Mail").map(|s| s.as_ref()), Some("s2"));
270            }
271            other => panic!("expected StateChange, got {other:?}"),
272        }
273    }
274
275    /// Oracle: a StateChange with missing `changed` field degrades to Unknown.
276    #[test]
277    fn parse_malformed_state_change_degrades_to_unknown() {
278        let json = r#"{"@type":"StateChange","unexpected_field":42}"#;
279        let frame = parse_ws_frame(json).expect("must not error");
280        match frame {
281            WsFrame::Unknown { type_name } => assert_eq!(type_name, "StateChange"),
282            other => panic!("expected Unknown, got {other:?}"),
283        }
284    }
285
286    /// Oracle: parse_ws_frame returns Unknown for unrecognized @type.
287    /// Derived from parse_unknown_type test in source ws/mod.rs.
288    #[test]
289    fn parse_unknown_type() {
290        let json = r#"{"@type":"FutureEvent","foo":"bar"}"#;
291        let frame = parse_ws_frame(json).expect("must parse");
292        match frame {
293            WsFrame::Unknown { type_name } => assert_eq!(type_name, "FutureEvent"),
294            other => panic!("expected Unknown, got {other:?}"),
295        }
296    }
297
298    /// Oracle: parse_ws_frame returns Unknown for missing @type.
299    /// Derived from parse_missing_type_field test in source ws/mod.rs.
300    #[test]
301    fn parse_missing_type_field() {
302        let json = r#"{"foo":"bar"}"#;
303        let frame = parse_ws_frame(json).expect("must parse");
304        assert!(matches!(frame, WsFrame::Unknown { .. }));
305    }
306
307    /// Oracle: parse_ws_frame returns Err(Parse) for invalid JSON.
308    /// Derived from parse_invalid_json_returns_parse_error test in source ws/mod.rs.
309    #[test]
310    fn parse_invalid_json_returns_parse_error() {
311        let err = parse_ws_frame("not json").expect_err("must fail");
312        assert!(matches!(err, crate::error::ClientError::Parse(_)));
313    }
314
315    /// Oracle: RFC 8887 §4.3.2 — every JMAP request sent over WebSocket MUST
316    /// include "@type": "Request".  Tests WsRequestFrame serde directly to
317    /// verify the #[serde(rename = "@type")] attribute and flatten are correct.
318    #[test]
319    fn send_request_includes_at_type_request() {
320        let req = jmap_types::JmapRequest::new(
321            vec!["urn:ietf:params:jmap:core".to_owned()],
322            vec![],
323            None,
324        );
325        let frame = WsRequestFrame {
326            ws_type: "Request",
327            id: None,
328            inner: &req,
329        };
330        let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
331        assert!(
332            serialized.contains("\"@type\":\"Request\""),
333            "RFC 8887 §4.3.2 requires @type:Request in outgoing WS frames; got: {serialized}"
334        );
335    }
336
337    /// Oracle: RFC 8887 §4.3.2 — optional `id` field is echoed in the response.
338    /// When an id is supplied, WsRequestFrame must include it in the serialized frame.
339    #[test]
340    fn send_request_includes_id_when_provided() {
341        let req = jmap_types::JmapRequest::new(
342            vec!["urn:ietf:params:jmap:core".to_owned()],
343            vec![],
344            None,
345        );
346        let frame = WsRequestFrame {
347            ws_type: "Request",
348            id: Some("req-42"),
349            inner: &req,
350        };
351        let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
352        assert!(
353            serialized.contains("\"id\":\"req-42\""),
354            "RFC 8887 §4.3.2 optional id must be present when provided; got: {serialized}"
355        );
356    }
357
358    /// Oracle: RFC 8887 §4.3.2 — when id is None, no `id` field appears in the frame.
359    /// WsRequestFrame uses skip_serializing_if to omit the field entirely.
360    #[test]
361    fn send_request_omits_id_when_none() {
362        let req = jmap_types::JmapRequest::new(
363            vec!["urn:ietf:params:jmap:core".to_owned()],
364            vec![],
365            None,
366        );
367        let frame = WsRequestFrame {
368            ws_type: "Request",
369            id: None,
370            inner: &req,
371        };
372        let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
373        assert!(
374            !serialized.contains("\"id\":"),
375            "RFC 8887 §4.3.2: no id field must appear when id is None; got: {serialized}"
376        );
377    }
378
379    /// Oracle: connect_ws must reject http:// and https:// URLs with InvalidArgument.
380    ///
381    /// This is the documented SSRF prevention guard: a compromised or MITM'd session
382    /// could send an http:// URL; we must not follow it as a WebSocket URL.
383    /// The scheme check runs before any network I/O.
384    /// Derived from connect_ws_rejects_non_ws_schemes test in source ws/mod.rs.
385    #[tokio::test]
386    async fn connect_ws_rejects_non_ws_schemes() {
387        for bad_url in &["http://host/", "https://host/", "ftp://host/"] {
388            let result = connect_ws(bad_url, None).await.map(|_| ());
389            match result {
390                Err(crate::error::ClientError::InvalidArgument(_)) => {}
391                other => panic!("expected InvalidArgument for {bad_url:?}, got {other:?}"),
392            }
393        }
394    }
395}