jmap-base-client 0.1.0

RFC 8620 JMAP base client — auth-agnostic, session fetch, blob, SSE, WebSocket
Documentation
//! WebSocket transport for JMAP (RFC 8887).
//!
//! Provides [`connect_ws`] which establishes a WebSocket connection and
//! returns a [`WsSession`] for sending and receiving frames.
//!
//! URL source: `Session::capabilities["urn:ietf:params:jmap:websocket"].url`
//! (the session document advertises the WebSocket endpoint).

use std::str::FromStr as _;

use futures::SinkExt as _;
use futures::StreamExt as _;
use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message;

use crate::push::StateChange;

/// Wire frame sent from the client to the server over WebSocket (RFC 8887 §4.3.2).
///
/// Wraps a [`jmap_types::JmapRequest`] and injects the mandatory `@type: "Request"`
/// field (and optional `id`) in a single `serde_json::to_string` pass, avoiding
/// the `to_value` + mutation + `to_string` double-serialization that the naive
/// approach requires.
#[derive(serde::Serialize)]
struct WsRequestFrame<'a> {
    /// RFC 8887 §4.3.2 — every JMAP request frame MUST carry "@type": "Request".
    #[serde(rename = "@type")]
    ws_type: &'static str,
    /// Optional correlation ID echoed back in the server's Response frame.
    #[serde(skip_serializing_if = "Option::is_none")]
    id: Option<&'a str>,
    /// The JMAP request payload; flattened into the enclosing JSON object.
    #[serde(flatten)]
    inner: &'a jmap_types::JmapRequest,
}

/// Maximum WebSocket message size (1 MiB), consistent with the SSE frame limit.
/// Prevents a misbehaving or hostile server from forcing the client to buffer
/// large messages over the event connection.
const MAX_WS_MESSAGE_BYTES: usize = 1 << 20; // 1 MiB

/// A parsed frame received from the JMAP WebSocket.
///
/// Marked `#[non_exhaustive]` because the spec may define additional
/// `@type` values in future revisions.
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum WsFrame {
    /// RFC 8620 §7.1 StateChange — one or more object types have changed
    /// state; client must re-fetch the affected data types.
    StateChange(StateChange),
    /// RFC 8887 Response — reply to a JMAP request sent on this connection.
    Response(jmap_types::JmapResponse),
    /// Unrecognized `@type` — silently ignored per forward-compatibility rules
    /// (RFC 8887 §4.3.1: clients SHOULD ignore unknown message types).
    ///
    /// Also produced when a known type (`"Response"` or `"StateChange"`) fails
    /// to deserialize — `type_name` will be `"Response"` or `"StateChange"` in
    /// that case, which can signal server misbehavior or a schema version
    /// mismatch. Callers that log unknown frames should check for these names.
    Unknown { type_name: String },
}

type Inner =
    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;

/// An established JMAP WebSocket session (RFC 8887).
///
/// Call [`next_frame`](WsSession::next_frame) in a loop to receive events.
/// Use [`send_request`](WsSession::send_request) to transmit JMAP requests.
///
/// The caller is responsible for reconnecting after the stream ends or returns
/// a transport error. Use exponential backoff.
pub struct WsSession {
    sink: futures::stream::SplitSink<Inner, Message>,
    stream: futures::stream::SplitStream<Inner>,
}

impl WsSession {
    /// Receive the next parsed frame from the server.
    ///
    /// Returns `None` when the server has cleanly closed the connection.
    /// Returns `Some(Err(...))` on parse failure or transport error. After a
    /// transport error the connection is broken; do not call `next_frame` again.
    pub async fn next_frame(&mut self) -> Option<Result<WsFrame, crate::error::ClientError>> {
        loop {
            match self.stream.next().await? {
                Ok(Message::Text(text)) => return Some(parse_ws_frame(&text)),
                Ok(Message::Close(_)) => return None,
                Ok(_) => continue, // Ping / Pong / Binary: silently skip
                Err(e) => return Some(Err(crate::error::ClientError::WebSocket(e))),
            }
        }
    }

    /// Send a JMAP request over the WebSocket connection.
    ///
    /// Serializes `req` and injects `"@type": "Request"` into the outgoing
    /// JSON object as required by RFC 8887 §4.3.2.  The optional `id` is
    /// echoed back in the corresponding `Response` frame, enabling out-of-order
    /// correlation.
    ///
    /// # Errors
    ///
    /// Returns `ClientError::Serialize` if `req` cannot be serialized, or
    /// `ClientError::WebSocket` on a transport failure.
    pub async fn send_request(
        &mut self,
        req: &jmap_types::JmapRequest,
        id: Option<&str>,
    ) -> Result<(), crate::error::ClientError> {
        // Wrap req in WsRequestFrame to inject @type and optional id in one
        // serialization pass (no intermediate serde_json::Value allocation).
        let frame = WsRequestFrame {
            ws_type: "Request",
            id,
            inner: req,
        };
        let text = serde_json::to_string(&frame).map_err(crate::error::ClientError::Serialize)?;
        self.sink
            .send(Message::Text(text.into()))
            .await
            .map_err(crate::error::ClientError::WebSocket)
    }
}

/// Parse a raw WebSocket text frame into a `WsFrame`.
fn parse_ws_frame(text: &str) -> Result<WsFrame, crate::error::ClientError> {
    let val: serde_json::Value =
        serde_json::from_str(text).map_err(crate::error::ClientError::Parse)?;

    // Pre-extract type_name as owned String before moving val into from_value.
    // The borrow checker prevents borrowing val (for @type) and moving val
    // (into from_value) in the same expression, so ownership must be taken first.
    let type_name = val
        .get("@type")
        .and_then(|v| v.as_str())
        .unwrap_or("<no @type>")
        .to_owned();

    match type_name.as_str() {
        // A malformed StateChange is degraded to Unknown rather than a
        // transport error. A single bad server frame must not kill the entire
        // WebSocket connection; only tungstenite transport errors warrant
        // a reconnect.
        "StateChange" => match serde_json::from_value::<StateChange>(val) {
            Ok(sc) => Ok(WsFrame::StateChange(sc)),
            Err(_) => Ok(WsFrame::Unknown { type_name }),
        },
        // Same degradation policy for malformed Response frames.
        "Response" => match serde_json::from_value::<jmap_types::JmapResponse>(val) {
            Ok(r) => Ok(WsFrame::Response(r)),
            Err(_) => Ok(WsFrame::Unknown { type_name }),
        },
        _ => Ok(WsFrame::Unknown { type_name }),
    }
}

/// Open a JMAP WebSocket connection (RFC 8887).
///
/// `ws_url` must come from the session document's WebSocket capability URL
/// (a `wss://` endpoint in production; `ws://` is accepted in tests).
///
/// `auth_header` is an optional `(header-name, header-value)` pair injected
/// into the WebSocket upgrade request. Pass `None` when the server does not
/// require authentication headers on the WebSocket handshake.
///
/// Returns `ClientError::InvalidArgument` if the URL scheme is not
/// `ws://` or `wss://`, preventing accidental use with untrusted URLs.
///
/// The returned [`WsSession`] provides [`WsSession::next_frame`] for receiving
/// events. The caller is responsible for reconnecting after disconnect with
/// exponential backoff.
pub async fn connect_ws(
    ws_url: &str,
    auth_header: Option<(&str, &str)>,
) -> Result<WsSession, crate::error::ClientError> {
    // Validate scheme to prevent SSRF via a compromised or MITM'd session.
    // Case-insensitive check per RFC 3986 §3.1: lowercase the URL before
    // comparing so that `WS://` and `wss://` are both accepted.  The
    // original (unmodified) URL is passed to tungstenite and kept in error
    // messages for diagnostics.
    let ws_url_lc = ws_url.to_ascii_lowercase();
    if !ws_url_lc.starts_with("ws://") && !ws_url_lc.starts_with("wss://") {
        return Err(crate::error::ClientError::InvalidArgument(format!(
            "WebSocket URL must start with ws:// or wss://, got: {ws_url:?}"
        )));
    }

    let mut request = ws_url
        .into_client_request()
        .map_err(crate::error::ClientError::WebSocket)?;

    if let Some((name, value)) = auth_header {
        let hdr_name = http::HeaderName::from_str(name).map_err(|e| {
            crate::error::ClientError::InvalidArgument(format!("invalid auth header name: {e}"))
        })?;
        let hdr_value = http::HeaderValue::from_str(value).map_err(|_| {
            crate::error::ClientError::InvalidArgument("invalid auth header value".to_owned())
        })?;
        request.headers_mut().insert(hdr_name, hdr_value);
    }

    // WebSocketConfig is #[non_exhaustive] in tungstenite; use Default + field assignment.
    let mut config = WebSocketConfig::default();
    config.max_message_size = Some(MAX_WS_MESSAGE_BYTES);
    config.max_frame_size = Some(MAX_WS_MESSAGE_BYTES);

    // Apply a 10-second connect timeout, consistent with the HTTP transport's
    // connect_timeout in DefaultTransport/CustomCaTransport.  tungstenite does
    // not expose a connect timeout parameter, so we wrap at the Future level.
    // A stalled TCP or TLS handshake would otherwise block indefinitely.
    let connect_result = tokio::time::timeout(
        std::time::Duration::from_secs(10),
        tokio_tungstenite::connect_async_with_config(request, Some(config), false),
    )
    .await
    .map_err(|_elapsed| {
        crate::error::ClientError::WebSocket(tokio_tungstenite::tungstenite::Error::Io(
            std::io::Error::new(
                std::io::ErrorKind::TimedOut,
                "WebSocket connect timed out after 10 seconds",
            ),
        ))
    })?;
    let (ws_stream, _response) = connect_result.map_err(crate::error::ClientError::WebSocket)?;

    let (sink, stream) = ws_stream.split();
    Ok(WsSession { sink, stream })
}

impl std::fmt::Debug for WsSession {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("WsSession").finish_non_exhaustive()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Verify WsFrame does not contain ChatTyping or ChatPresence variants.
    /// This exhaustive match will fail to compile if either variant is reintroduced.
    #[test]
    fn ws_frame_has_no_chat_variants() {
        let frame = WsFrame::Unknown {
            type_name: "test".to_owned(),
        };
        match frame {
            WsFrame::StateChange(_) => {}
            WsFrame::Response(_) => {}
            WsFrame::Unknown { .. } => {}
        }
    }

    /// Oracle: parse_ws_frame dispatches on @type field and produces a typed StateChange.
    /// Wire format from RFC 8620 §7.1.1 example.
    #[test]
    fn parse_state_change() {
        let json = r#"{"@type":"StateChange","changed":{"account1":{"Mail":"s2"}}}"#;
        let frame = parse_ws_frame(json).expect("must parse");
        match frame {
            WsFrame::StateChange(sc) => {
                let account = sc
                    .changed
                    .get("account1")
                    .expect("account1 must be present");
                assert_eq!(account.get("Mail").map(|s| s.as_ref()), Some("s2"));
            }
            other => panic!("expected StateChange, got {other:?}"),
        }
    }

    /// Oracle: a StateChange with missing `changed` field degrades to Unknown.
    #[test]
    fn parse_malformed_state_change_degrades_to_unknown() {
        let json = r#"{"@type":"StateChange","unexpected_field":42}"#;
        let frame = parse_ws_frame(json).expect("must not error");
        match frame {
            WsFrame::Unknown { type_name } => assert_eq!(type_name, "StateChange"),
            other => panic!("expected Unknown, got {other:?}"),
        }
    }

    /// Oracle: parse_ws_frame returns Unknown for unrecognized @type.
    /// Derived from parse_unknown_type test in source ws/mod.rs.
    #[test]
    fn parse_unknown_type() {
        let json = r#"{"@type":"FutureEvent","foo":"bar"}"#;
        let frame = parse_ws_frame(json).expect("must parse");
        match frame {
            WsFrame::Unknown { type_name } => assert_eq!(type_name, "FutureEvent"),
            other => panic!("expected Unknown, got {other:?}"),
        }
    }

    /// Oracle: parse_ws_frame returns Unknown for missing @type.
    /// Derived from parse_missing_type_field test in source ws/mod.rs.
    #[test]
    fn parse_missing_type_field() {
        let json = r#"{"foo":"bar"}"#;
        let frame = parse_ws_frame(json).expect("must parse");
        assert!(matches!(frame, WsFrame::Unknown { .. }));
    }

    /// Oracle: parse_ws_frame returns Err(Parse) for invalid JSON.
    /// Derived from parse_invalid_json_returns_parse_error test in source ws/mod.rs.
    #[test]
    fn parse_invalid_json_returns_parse_error() {
        let err = parse_ws_frame("not json").expect_err("must fail");
        assert!(matches!(err, crate::error::ClientError::Parse(_)));
    }

    /// Oracle: RFC 8887 §4.3.2 — every JMAP request sent over WebSocket MUST
    /// include "@type": "Request".  Tests WsRequestFrame serde directly to
    /// verify the #[serde(rename = "@type")] attribute and flatten are correct.
    #[test]
    fn send_request_includes_at_type_request() {
        let req = jmap_types::JmapRequest::new(
            vec!["urn:ietf:params:jmap:core".to_owned()],
            vec![],
            None,
        );
        let frame = WsRequestFrame {
            ws_type: "Request",
            id: None,
            inner: &req,
        };
        let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
        assert!(
            serialized.contains("\"@type\":\"Request\""),
            "RFC 8887 §4.3.2 requires @type:Request in outgoing WS frames; got: {serialized}"
        );
    }

    /// Oracle: RFC 8887 §4.3.2 — optional `id` field is echoed in the response.
    /// When an id is supplied, WsRequestFrame must include it in the serialized frame.
    #[test]
    fn send_request_includes_id_when_provided() {
        let req = jmap_types::JmapRequest::new(
            vec!["urn:ietf:params:jmap:core".to_owned()],
            vec![],
            None,
        );
        let frame = WsRequestFrame {
            ws_type: "Request",
            id: Some("req-42"),
            inner: &req,
        };
        let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
        assert!(
            serialized.contains("\"id\":\"req-42\""),
            "RFC 8887 §4.3.2 optional id must be present when provided; got: {serialized}"
        );
    }

    /// Oracle: RFC 8887 §4.3.2 — when id is None, no `id` field appears in the frame.
    /// WsRequestFrame uses skip_serializing_if to omit the field entirely.
    #[test]
    fn send_request_omits_id_when_none() {
        let req = jmap_types::JmapRequest::new(
            vec!["urn:ietf:params:jmap:core".to_owned()],
            vec![],
            None,
        );
        let frame = WsRequestFrame {
            ws_type: "Request",
            id: None,
            inner: &req,
        };
        let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
        assert!(
            !serialized.contains("\"id\":"),
            "RFC 8887 §4.3.2: no id field must appear when id is None; got: {serialized}"
        );
    }

    /// Oracle: connect_ws must reject http:// and https:// URLs with InvalidArgument.
    ///
    /// This is the documented SSRF prevention guard: a compromised or MITM'd session
    /// could send an http:// URL; we must not follow it as a WebSocket URL.
    /// The scheme check runs before any network I/O.
    /// Derived from connect_ws_rejects_non_ws_schemes test in source ws/mod.rs.
    #[tokio::test]
    async fn connect_ws_rejects_non_ws_schemes() {
        for bad_url in &["http://host/", "https://host/", "ftp://host/"] {
            let result = connect_ws(bad_url, None).await.map(|_| ());
            match result {
                Err(crate::error::ClientError::InvalidArgument(_)) => {}
                other => panic!("expected InvalidArgument for {bad_url:?}, got {other:?}"),
            }
        }
    }
}