jmap_base_client/ws/mod.rs
1//! WebSocket transport for JMAP (RFC 8887).
2//!
3//! Provides [`connect_ws`] which establishes a WebSocket connection and
4//! returns a [`WsSession`] for sending and receiving frames.
5//!
6//! URL source: `Session::capabilities["urn:ietf:params:jmap:websocket"].url`
7//! (the session document advertises the WebSocket endpoint).
8
9use std::str::FromStr as _;
10
11use futures::SinkExt as _;
12use futures::StreamExt as _;
13use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
14use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
15use tokio_tungstenite::tungstenite::Message;
16
17use crate::push::StateChange;
18
19/// Wire frame sent from the client to the server over WebSocket (RFC 8887 §4.3.2).
20///
21/// Wraps a [`jmap_types::JmapRequest`] and injects the mandatory `@type: "Request"`
22/// field (and optional `id`) in a single `serde_json::to_string` pass, avoiding
23/// the `to_value` + mutation + `to_string` double-serialization that the naive
24/// approach requires.
25#[derive(serde::Serialize)]
26struct WsRequestFrame<'a> {
27 /// RFC 8887 §4.3.2 — every JMAP request frame MUST carry "@type": "Request".
28 #[serde(rename = "@type")]
29 ws_type: &'static str,
30 /// Optional correlation ID echoed back in the server's Response frame.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 id: Option<&'a str>,
33 /// The JMAP request payload; flattened into the enclosing JSON object.
34 #[serde(flatten)]
35 inner: &'a jmap_types::JmapRequest,
36}
37
38/// Maximum WebSocket message size (1 MiB), consistent with the SSE frame limit.
39/// Prevents a misbehaving or hostile server from forcing the client to buffer
40/// large messages over the event connection.
41const MAX_WS_MESSAGE_BYTES: usize = 1 << 20; // 1 MiB
42
43/// A parsed frame received from the JMAP WebSocket.
44///
45/// Marked `#[non_exhaustive]` because the spec may define additional
46/// `@type` values in future revisions.
47#[non_exhaustive]
48#[derive(Debug, Clone, PartialEq)]
49pub enum WsFrame {
50 /// RFC 8620 §7.1 StateChange — one or more object types have changed
51 /// state; client must re-fetch the affected data types.
52 StateChange(StateChange),
53 /// RFC 8887 Response — reply to a JMAP request sent on this connection.
54 Response(jmap_types::JmapResponse),
55 /// Unrecognized `@type` — silently ignored per forward-compatibility rules
56 /// (RFC 8887 §4.3.1: clients SHOULD ignore unknown message types).
57 ///
58 /// Also produced when a known type (`"Response"` or `"StateChange"`) fails
59 /// to deserialize — `type_name` will be `"Response"` or `"StateChange"` in
60 /// that case, which can signal server misbehavior or a schema version
61 /// mismatch. Callers that log unknown frames should check for these names.
62 Unknown { type_name: String },
63}
64
65type Inner =
66 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
67
68/// An established JMAP WebSocket session (RFC 8887).
69///
70/// Call [`next_frame`](WsSession::next_frame) in a loop to receive events.
71/// Use [`send_request`](WsSession::send_request) to transmit JMAP requests.
72///
73/// The caller is responsible for reconnecting after the stream ends or returns
74/// a transport error. Use exponential backoff.
75pub struct WsSession {
76 sink: futures::stream::SplitSink<Inner, Message>,
77 stream: futures::stream::SplitStream<Inner>,
78}
79
80impl WsSession {
81 /// Receive the next parsed frame from the server.
82 ///
83 /// Returns `None` when the server has cleanly closed the connection.
84 /// Returns `Some(Err(...))` on parse failure or transport error. After a
85 /// transport error the connection is broken; do not call `next_frame` again.
86 pub async fn next_frame(&mut self) -> Option<Result<WsFrame, crate::error::ClientError>> {
87 loop {
88 match self.stream.next().await? {
89 Ok(Message::Text(text)) => return Some(parse_ws_frame(&text)),
90 Ok(Message::Close(_)) => return None,
91 Ok(_) => continue, // Ping / Pong / Binary: silently skip
92 Err(e) => return Some(Err(crate::error::ClientError::WebSocket(e))),
93 }
94 }
95 }
96
97 /// Send a JMAP request over the WebSocket connection.
98 ///
99 /// Serializes `req` and injects `"@type": "Request"` into the outgoing
100 /// JSON object as required by RFC 8887 §4.3.2. The optional `id` is
101 /// echoed back in the corresponding `Response` frame, enabling out-of-order
102 /// correlation.
103 ///
104 /// # Errors
105 ///
106 /// Returns `ClientError::Serialize` if `req` cannot be serialized, or
107 /// `ClientError::WebSocket` on a transport failure.
108 pub async fn send_request(
109 &mut self,
110 req: &jmap_types::JmapRequest,
111 id: Option<&str>,
112 ) -> Result<(), crate::error::ClientError> {
113 // Wrap req in WsRequestFrame to inject @type and optional id in one
114 // serialization pass (no intermediate serde_json::Value allocation).
115 let frame = WsRequestFrame {
116 ws_type: "Request",
117 id,
118 inner: req,
119 };
120 let text = serde_json::to_string(&frame).map_err(crate::error::ClientError::Serialize)?;
121 self.sink
122 .send(Message::Text(text.into()))
123 .await
124 .map_err(crate::error::ClientError::WebSocket)
125 }
126}
127
128/// Parse a raw WebSocket text frame into a `WsFrame`.
129fn parse_ws_frame(text: &str) -> Result<WsFrame, crate::error::ClientError> {
130 let val: serde_json::Value =
131 serde_json::from_str(text).map_err(crate::error::ClientError::Parse)?;
132
133 // Pre-extract type_name as owned String before moving val into from_value.
134 // The borrow checker prevents borrowing val (for @type) and moving val
135 // (into from_value) in the same expression, so ownership must be taken first.
136 let type_name = val
137 .get("@type")
138 .and_then(|v| v.as_str())
139 .unwrap_or("<no @type>")
140 .to_owned();
141
142 match type_name.as_str() {
143 // A malformed StateChange is degraded to Unknown rather than a
144 // transport error. A single bad server frame must not kill the entire
145 // WebSocket connection; only tungstenite transport errors warrant
146 // a reconnect.
147 "StateChange" => match serde_json::from_value::<StateChange>(val) {
148 Ok(sc) => Ok(WsFrame::StateChange(sc)),
149 Err(_) => Ok(WsFrame::Unknown { type_name }),
150 },
151 // Same degradation policy for malformed Response frames.
152 "Response" => match serde_json::from_value::<jmap_types::JmapResponse>(val) {
153 Ok(r) => Ok(WsFrame::Response(r)),
154 Err(_) => Ok(WsFrame::Unknown { type_name }),
155 },
156 _ => Ok(WsFrame::Unknown { type_name }),
157 }
158}
159
160/// Open a JMAP WebSocket connection (RFC 8887).
161///
162/// `ws_url` must come from the session document's WebSocket capability URL
163/// (a `wss://` endpoint in production; `ws://` is accepted in tests).
164///
165/// `auth_header` is an optional `(header-name, header-value)` pair injected
166/// into the WebSocket upgrade request. Pass `None` when the server does not
167/// require authentication headers on the WebSocket handshake.
168///
169/// Returns `ClientError::InvalidArgument` if the URL scheme is not
170/// `ws://` or `wss://`, preventing accidental use with untrusted URLs.
171///
172/// The returned [`WsSession`] provides [`WsSession::next_frame`] for receiving
173/// events. The caller is responsible for reconnecting after disconnect with
174/// exponential backoff.
175pub async fn connect_ws(
176 ws_url: &str,
177 auth_header: Option<(&str, &str)>,
178) -> Result<WsSession, crate::error::ClientError> {
179 // Validate scheme to prevent SSRF via a compromised or MITM'd session.
180 // Case-insensitive check per RFC 3986 §3.1: lowercase the URL before
181 // comparing so that `WS://` and `wss://` are both accepted. The
182 // original (unmodified) URL is passed to tungstenite and kept in error
183 // messages for diagnostics.
184 let ws_url_lc = ws_url.to_ascii_lowercase();
185 if !ws_url_lc.starts_with("ws://") && !ws_url_lc.starts_with("wss://") {
186 return Err(crate::error::ClientError::InvalidArgument(format!(
187 "WebSocket URL must start with ws:// or wss://, got: {ws_url:?}"
188 )));
189 }
190
191 let mut request = ws_url
192 .into_client_request()
193 .map_err(crate::error::ClientError::WebSocket)?;
194
195 if let Some((name, value)) = auth_header {
196 let hdr_name = http::HeaderName::from_str(name).map_err(|e| {
197 crate::error::ClientError::InvalidArgument(format!("invalid auth header name: {e}"))
198 })?;
199 let hdr_value = http::HeaderValue::from_str(value).map_err(|_| {
200 crate::error::ClientError::InvalidArgument("invalid auth header value".to_owned())
201 })?;
202 request.headers_mut().insert(hdr_name, hdr_value);
203 }
204
205 // WebSocketConfig is #[non_exhaustive] in tungstenite; use Default + field assignment.
206 let mut config = WebSocketConfig::default();
207 config.max_message_size = Some(MAX_WS_MESSAGE_BYTES);
208 config.max_frame_size = Some(MAX_WS_MESSAGE_BYTES);
209
210 // Apply a 10-second connect timeout, consistent with the HTTP transport's
211 // connect_timeout in DefaultTransport/CustomCaTransport. tungstenite does
212 // not expose a connect timeout parameter, so we wrap at the Future level.
213 // A stalled TCP or TLS handshake would otherwise block indefinitely.
214 let connect_result = tokio::time::timeout(
215 std::time::Duration::from_secs(10),
216 tokio_tungstenite::connect_async_with_config(request, Some(config), false),
217 )
218 .await
219 .map_err(|_elapsed| {
220 crate::error::ClientError::WebSocket(tokio_tungstenite::tungstenite::Error::Io(
221 std::io::Error::new(
222 std::io::ErrorKind::TimedOut,
223 "WebSocket connect timed out after 10 seconds",
224 ),
225 ))
226 })?;
227 let (ws_stream, _response) = connect_result.map_err(crate::error::ClientError::WebSocket)?;
228
229 let (sink, stream) = ws_stream.split();
230 Ok(WsSession { sink, stream })
231}
232
233impl std::fmt::Debug for WsSession {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.debug_struct("WsSession").finish_non_exhaustive()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 /// Verify WsFrame does not contain ChatTyping or ChatPresence variants.
244 /// This exhaustive match will fail to compile if either variant is reintroduced.
245 #[test]
246 fn ws_frame_has_no_chat_variants() {
247 let frame = WsFrame::Unknown {
248 type_name: "test".to_owned(),
249 };
250 match frame {
251 WsFrame::StateChange(_) => {}
252 WsFrame::Response(_) => {}
253 WsFrame::Unknown { .. } => {}
254 }
255 }
256
257 /// Oracle: parse_ws_frame dispatches on @type field and produces a typed StateChange.
258 /// Wire format from RFC 8620 §7.1.1 example.
259 #[test]
260 fn parse_state_change() {
261 let json = r#"{"@type":"StateChange","changed":{"account1":{"Mail":"s2"}}}"#;
262 let frame = parse_ws_frame(json).expect("must parse");
263 match frame {
264 WsFrame::StateChange(sc) => {
265 let account = sc
266 .changed
267 .get("account1")
268 .expect("account1 must be present");
269 assert_eq!(account.get("Mail").map(|s| s.as_ref()), Some("s2"));
270 }
271 other => panic!("expected StateChange, got {other:?}"),
272 }
273 }
274
275 /// Oracle: a StateChange with missing `changed` field degrades to Unknown.
276 #[test]
277 fn parse_malformed_state_change_degrades_to_unknown() {
278 let json = r#"{"@type":"StateChange","unexpected_field":42}"#;
279 let frame = parse_ws_frame(json).expect("must not error");
280 match frame {
281 WsFrame::Unknown { type_name } => assert_eq!(type_name, "StateChange"),
282 other => panic!("expected Unknown, got {other:?}"),
283 }
284 }
285
286 /// Oracle: parse_ws_frame returns Unknown for unrecognized @type.
287 /// Derived from parse_unknown_type test in source ws/mod.rs.
288 #[test]
289 fn parse_unknown_type() {
290 let json = r#"{"@type":"FutureEvent","foo":"bar"}"#;
291 let frame = parse_ws_frame(json).expect("must parse");
292 match frame {
293 WsFrame::Unknown { type_name } => assert_eq!(type_name, "FutureEvent"),
294 other => panic!("expected Unknown, got {other:?}"),
295 }
296 }
297
298 /// Oracle: parse_ws_frame returns Unknown for missing @type.
299 /// Derived from parse_missing_type_field test in source ws/mod.rs.
300 #[test]
301 fn parse_missing_type_field() {
302 let json = r#"{"foo":"bar"}"#;
303 let frame = parse_ws_frame(json).expect("must parse");
304 assert!(matches!(frame, WsFrame::Unknown { .. }));
305 }
306
307 /// Oracle: parse_ws_frame returns Err(Parse) for invalid JSON.
308 /// Derived from parse_invalid_json_returns_parse_error test in source ws/mod.rs.
309 #[test]
310 fn parse_invalid_json_returns_parse_error() {
311 let err = parse_ws_frame("not json").expect_err("must fail");
312 assert!(matches!(err, crate::error::ClientError::Parse(_)));
313 }
314
315 /// Oracle: RFC 8887 §4.3.2 — every JMAP request sent over WebSocket MUST
316 /// include "@type": "Request". Tests WsRequestFrame serde directly to
317 /// verify the #[serde(rename = "@type")] attribute and flatten are correct.
318 #[test]
319 fn send_request_includes_at_type_request() {
320 let req = jmap_types::JmapRequest::new(
321 vec!["urn:ietf:params:jmap:core".to_owned()],
322 vec![],
323 None,
324 );
325 let frame = WsRequestFrame {
326 ws_type: "Request",
327 id: None,
328 inner: &req,
329 };
330 let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
331 assert!(
332 serialized.contains("\"@type\":\"Request\""),
333 "RFC 8887 §4.3.2 requires @type:Request in outgoing WS frames; got: {serialized}"
334 );
335 }
336
337 /// Oracle: RFC 8887 §4.3.2 — optional `id` field is echoed in the response.
338 /// When an id is supplied, WsRequestFrame must include it in the serialized frame.
339 #[test]
340 fn send_request_includes_id_when_provided() {
341 let req = jmap_types::JmapRequest::new(
342 vec!["urn:ietf:params:jmap:core".to_owned()],
343 vec![],
344 None,
345 );
346 let frame = WsRequestFrame {
347 ws_type: "Request",
348 id: Some("req-42"),
349 inner: &req,
350 };
351 let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
352 assert!(
353 serialized.contains("\"id\":\"req-42\""),
354 "RFC 8887 §4.3.2 optional id must be present when provided; got: {serialized}"
355 );
356 }
357
358 /// Oracle: RFC 8887 §4.3.2 — when id is None, no `id` field appears in the frame.
359 /// WsRequestFrame uses skip_serializing_if to omit the field entirely.
360 #[test]
361 fn send_request_omits_id_when_none() {
362 let req = jmap_types::JmapRequest::new(
363 vec!["urn:ietf:params:jmap:core".to_owned()],
364 vec![],
365 None,
366 );
367 let frame = WsRequestFrame {
368 ws_type: "Request",
369 id: None,
370 inner: &req,
371 };
372 let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
373 assert!(
374 !serialized.contains("\"id\":"),
375 "RFC 8887 §4.3.2: no id field must appear when id is None; got: {serialized}"
376 );
377 }
378
379 /// Oracle: connect_ws must reject http:// and https:// URLs with InvalidArgument.
380 ///
381 /// This is the documented SSRF prevention guard: a compromised or MITM'd session
382 /// could send an http:// URL; we must not follow it as a WebSocket URL.
383 /// The scheme check runs before any network I/O.
384 /// Derived from connect_ws_rejects_non_ws_schemes test in source ws/mod.rs.
385 #[tokio::test]
386 async fn connect_ws_rejects_non_ws_schemes() {
387 for bad_url in &["http://host/", "https://host/", "ftp://host/"] {
388 let result = connect_ws(bad_url, None).await.map(|_| ());
389 match result {
390 Err(crate::error::ClientError::InvalidArgument(_)) => {}
391 other => panic!("expected InvalidArgument for {bad_url:?}, got {other:?}"),
392 }
393 }
394 }
395}