use std::str::FromStr as _;
use futures::SinkExt as _;
use futures::StreamExt as _;
use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message;
use crate::push::StateChange;
#[derive(serde::Serialize)]
struct WsRequestFrame<'a> {
#[serde(rename = "@type")]
ws_type: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<&'a str>,
#[serde(flatten)]
inner: &'a jmap_types::JmapRequest,
}
pub const DEFAULT_WS_MAX_MESSAGE_BYTES: usize = 1 << 20;
#[non_exhaustive]
#[derive(Clone, PartialEq)]
pub enum WsFrame {
StateChange(StateChange),
Response(jmap_types::JmapResponse),
Unknown {
type_name: String,
raw: serde_json::Value,
},
}
impl std::fmt::Debug for WsFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WsFrame::StateChange(sc) => f.debug_tuple("StateChange").field(sc).finish(),
WsFrame::Response(r) => f.debug_tuple("Response").field(r).finish(),
WsFrame::Unknown { type_name, raw: _ } => f
.debug_struct("Unknown")
.field("type_name", type_name)
.field("raw", &"[REDACTED]")
.finish(),
}
}
}
type Inner =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
pub struct WsSession {
sender: WsSender,
receiver: WsReceiver,
}
pub struct WsSender {
sink: futures::stream::SplitSink<Inner, Message>,
}
pub struct WsReceiver {
stream: futures::stream::SplitStream<Inner>,
}
const MAX_CONSECUTIVE_NON_TEXT_FRAMES: usize = 64;
fn classify_message(msg: &Message) -> MessageDisposition {
match msg {
Message::Text(_) => MessageDisposition::Text,
Message::Close(_) => MessageDisposition::Close,
Message::Binary(_) => MessageDisposition::Binary,
_ => MessageDisposition::Skip,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MessageDisposition {
Text,
Close,
Binary,
Skip,
}
impl WsReceiver {
pub async fn next_frame(&mut self) -> Option<Result<WsFrame, crate::error::ClientError>> {
let mut consecutive_skips = 0usize;
loop {
let msg = match self.stream.next().await? {
Ok(m) => m,
Err(e) => return Some(Err(crate::error::ClientError::from_ws(e))),
};
match classify_message(&msg) {
MessageDisposition::Text => {
let Message::Text(text) = msg else {
return Some(Err(crate::error::ClientError::UnexpectedResponse(
"WebSocket: classify_message returned Text for non-Text variant".into(),
)));
};
return Some(parse_ws_frame(&text));
}
MessageDisposition::Close => return None,
MessageDisposition::Binary => {
return Some(Err(crate::error::ClientError::UnexpectedResponse(
"WebSocket: server sent Binary frame; RFC 8887 §4.1 mandates text frames"
.into(),
)));
}
MessageDisposition::Skip => {
consecutive_skips = consecutive_skips.saturating_add(1);
if consecutive_skips > MAX_CONSECUTIVE_NON_TEXT_FRAMES {
return Some(Err(crate::error::ClientError::UnexpectedResponse(
format!(
"WebSocket: exceeded {MAX_CONSECUTIVE_NON_TEXT_FRAMES} consecutive non-text frames; possible server misbehaviour"
),
)));
}
}
}
}
}
}
impl WsSender {
pub async fn send_text(&mut self, text: String) -> Result<(), crate::error::ClientError> {
self.sink
.send(Message::Text(text.into()))
.await
.map_err(crate::error::ClientError::from_ws)
}
pub async fn send_request(
&mut self,
req: &jmap_types::JmapRequest,
id: Option<&str>,
) -> Result<(), crate::error::ClientError> {
let frame = WsRequestFrame {
ws_type: "Request",
id,
inner: req,
};
let text =
serde_json::to_string(&frame).map_err(crate::error::ClientError::from_serialize)?;
self.sink
.send(Message::Text(text.into()))
.await
.map_err(crate::error::ClientError::from_ws)
}
}
impl WsSession {
pub async fn next_frame(&mut self) -> Option<Result<WsFrame, crate::error::ClientError>> {
self.receiver.next_frame().await
}
pub async fn send_text(&mut self, text: String) -> Result<(), crate::error::ClientError> {
self.sender.send_text(text).await
}
pub async fn send_request(
&mut self,
req: &jmap_types::JmapRequest,
id: Option<&str>,
) -> Result<(), crate::error::ClientError> {
self.sender.send_request(req, id).await
}
pub fn split(self) -> (WsSender, WsReceiver) {
let WsSession { sender, receiver } = self;
(sender, receiver)
}
}
fn parse_ws_frame(text: &str) -> Result<WsFrame, crate::error::ClientError> {
let val: serde_json::Value =
serde_json::from_str(text).map_err(crate::error::ClientError::from_parse)?;
let type_name = val
.get("@type")
.and_then(|v| v.as_str())
.unwrap_or("<no @type>")
.to_owned();
match type_name.as_str() {
"StateChange" => match serde_json::from_str::<StateChange>(text) {
Ok(sc) => Ok(WsFrame::StateChange(sc)),
Err(_) => Ok(WsFrame::Unknown {
type_name,
raw: val,
}),
},
"Response" => match serde_json::from_str::<jmap_types::JmapResponse>(text) {
Ok(r) => Ok(WsFrame::Response(r)),
Err(_) => Ok(WsFrame::Unknown {
type_name,
raw: val,
}),
},
_ => Ok(WsFrame::Unknown {
type_name,
raw: val,
}),
}
}
pub async fn connect_ws(
ws_url: &str,
auth_header: Option<crate::auth::AuthHeader<'_>>,
) -> Result<WsSession, crate::error::ClientError> {
connect_ws_with_limit(ws_url, auth_header, DEFAULT_WS_MAX_MESSAGE_BYTES).await
}
pub async fn connect_ws_with_limit(
ws_url: &str,
auth_header: Option<crate::auth::AuthHeader<'_>>,
max_message_bytes: usize,
) -> Result<WsSession, crate::error::ClientError> {
if max_message_bytes == 0 {
return Err(crate::error::ClientError::InvalidArgument(
"connect_ws_with_limit: max_message_bytes must be > 0".to_owned(),
));
}
let scheme_ok = ws_url
.split_once("://")
.is_some_and(|(s, _)| s.eq_ignore_ascii_case("ws") || s.eq_ignore_ascii_case("wss"));
if !scheme_ok {
return Err(crate::error::ClientError::InvalidArgument(format!(
"WebSocket URL must start with ws:// or wss://, got: {ws_url:?}"
)));
}
let mut request = ws_url
.into_client_request()
.map_err(crate::error::ClientError::from_ws)?;
if let Some(header) = auth_header {
let hdr_name = http::HeaderName::from_str(header.name()).map_err(|_| {
crate::error::ClientError::InvalidArgument("invalid auth header name".to_owned())
})?;
let hdr_value = http::HeaderValue::from_str(header.expose_value()).map_err(|_| {
crate::error::ClientError::InvalidArgument("invalid auth header value".to_owned())
})?;
request.headers_mut().insert(hdr_name, hdr_value);
}
let mut config = WebSocketConfig::default();
config.max_message_size = Some(max_message_bytes);
config.max_frame_size = Some(max_message_bytes);
let connect_result = tokio::time::timeout(
std::time::Duration::from_secs(10),
tokio_tungstenite::connect_async_with_config(request, Some(config), false),
)
.await
.map_err(|_elapsed| {
crate::error::ClientError::from_ws(tokio_tungstenite::tungstenite::Error::Io(
std::io::Error::new(
std::io::ErrorKind::TimedOut,
"WebSocket connect timed out after 10 seconds",
),
))
})?;
let (ws_stream, _response) = connect_result.map_err(crate::error::ClientError::from_ws)?;
let (sink, stream) = ws_stream.split();
Ok(WsSession {
sender: WsSender { sink },
receiver: WsReceiver { stream },
})
}
impl std::fmt::Debug for WsSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsSession").finish_non_exhaustive()
}
}
impl std::fmt::Debug for WsSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsSender").finish_non_exhaustive()
}
}
impl std::fmt::Debug for WsReceiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsReceiver").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_frame_has_no_chat_variants() {
let frame = WsFrame::Unknown {
type_name: "test".to_owned(),
raw: serde_json::Value::Null,
};
match frame {
WsFrame::StateChange(_) => {}
WsFrame::Response(_) => {}
WsFrame::Unknown { .. } => {}
}
}
#[test]
fn parse_state_change() {
let json = r#"{"@type":"StateChange","changed":{"account1":{"Mail":"s2"}}}"#;
let frame = parse_ws_frame(json).expect("must parse");
match frame {
WsFrame::StateChange(sc) => {
let account = sc
.changed
.get("account1")
.expect("account1 must be present");
assert_eq!(account.get("Mail").map(|s| s.as_ref()), Some("s2"));
}
other => panic!("expected StateChange, got {other:?}"),
}
}
#[test]
fn parse_malformed_state_change_degrades_to_unknown() {
let json = r#"{"@type":"StateChange","unexpected_field":42}"#;
let frame = parse_ws_frame(json).expect("must not error");
match frame {
WsFrame::Unknown { type_name, .. } => assert_eq!(type_name, "StateChange"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[test]
fn parse_unknown_type() {
let json = r#"{"@type":"FutureEvent","foo":"bar"}"#;
let frame = parse_ws_frame(json).expect("must parse");
match frame {
WsFrame::Unknown { type_name, .. } => assert_eq!(type_name, "FutureEvent"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[test]
fn parse_missing_type_field() {
let json = r#"{"foo":"bar"}"#;
let frame = parse_ws_frame(json).expect("must parse");
assert!(matches!(frame, WsFrame::Unknown { .. }));
}
#[test]
fn parse_invalid_json_returns_parse_error() {
let err = parse_ws_frame("not json").expect_err("must fail");
assert!(matches!(err, crate::error::ClientError::Parse(_)));
}
#[test]
fn send_request_includes_at_type_request() {
let req = jmap_types::JmapRequest::new(
vec!["urn:ietf:params:jmap:core".to_owned()],
vec![],
None,
);
let frame = WsRequestFrame {
ws_type: "Request",
id: None,
inner: &req,
};
let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
assert!(
serialized.contains("\"@type\":\"Request\""),
"RFC 8887 §4.3.2 requires @type:Request in outgoing WS frames; got: {serialized}"
);
}
#[test]
fn send_request_includes_id_when_provided() {
let req = jmap_types::JmapRequest::new(
vec!["urn:ietf:params:jmap:core".to_owned()],
vec![],
None,
);
let frame = WsRequestFrame {
ws_type: "Request",
id: Some("req-42"),
inner: &req,
};
let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
assert!(
serialized.contains("\"id\":\"req-42\""),
"RFC 8887 §4.3.2 optional id must be present when provided; got: {serialized}"
);
}
#[test]
fn send_request_omits_id_when_none() {
let req = jmap_types::JmapRequest::new(
vec!["urn:ietf:params:jmap:core".to_owned()],
vec![],
None,
);
let frame = WsRequestFrame {
ws_type: "Request",
id: None,
inner: &req,
};
let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
assert!(
!serialized.contains("\"id\":"),
"RFC 8887 §4.3.2: no id field must appear when id is None; got: {serialized}"
);
}
#[tokio::test]
async fn connect_ws_rejects_non_ws_schemes() {
for bad_url in &["http://host/", "https://host/", "ftp://host/"] {
let result = connect_ws(bad_url, None).await.map(|_| ());
match result {
Err(crate::error::ClientError::InvalidArgument(_)) => {}
other => panic!("expected InvalidArgument for {bad_url:?}, got {other:?}"),
}
}
}
#[test]
fn classify_text_message() {
let m = Message::Text("hi".into());
assert_eq!(classify_message(&m), MessageDisposition::Text);
}
#[test]
fn classify_close_message() {
let m = Message::Close(None);
assert_eq!(classify_message(&m), MessageDisposition::Close);
}
#[test]
fn classify_binary_message_is_not_skipped() {
let m = Message::Binary(vec![1, 2, 3].into());
assert_eq!(classify_message(&m), MessageDisposition::Binary);
assert_ne!(
classify_message(&m),
MessageDisposition::Skip,
"Binary must NOT be silently skipped (RFC 8887 §4.1)"
);
}
#[test]
fn classify_ping_pong_messages_are_skipped() {
let ping = Message::Ping(vec![].into());
let pong = Message::Pong(vec![].into());
assert_eq!(classify_message(&ping), MessageDisposition::Skip);
assert_eq!(classify_message(&pong), MessageDisposition::Skip);
}
#[test]
fn consecutive_skip_cap_matches_documented_value() {
assert_eq!(MAX_CONSECUTIVE_NON_TEXT_FRAMES, 64);
}
#[test]
fn ws_frame_unknown_raw_is_redacted_in_debug_output() {
let canary = "redaction-canary-cred-WFTMr8FoYpfP-do-not-leak";
let frame = WsFrame::Unknown {
type_name: "PushVerification".to_owned(),
raw: serde_json::json!({
"verificationCode": canary,
}),
};
let rendered = format!("{frame:?}");
assert!(
!rendered.contains(canary),
"WsFrame::Unknown Debug must redact `raw`; the canary literal \
appeared in the rendered output, indicating either \
#[derive(Debug)] was restored or the manual Debug impl \
forgot to redact the raw field. Rendered output: {rendered}"
);
assert!(
rendered.contains("[REDACTED]"),
"WsFrame::Unknown Debug must render the redaction placeholder; \
rendered output: {rendered}"
);
assert!(
rendered.contains("PushVerification"),
"WsFrame::Unknown Debug must still surface type_name for \
diagnostics; rendered output: {rendered}"
);
}
#[test]
fn ws_frame_other_variants_remain_useful_in_debug_output() {
let response_frame = WsFrame::Response(jmap_types::JmapResponse::new(
vec![],
"test-session".into(),
None,
));
let rendered = format!("{response_frame:?}");
assert!(
rendered.starts_with("Response"),
"Response variant Debug must surface variant tag; got: {rendered}"
);
assert!(
rendered.contains("test-session"),
"Response variant Debug must surface session_state for \
diagnostics; got: {rendered}"
);
}
#[test]
fn ws_sender_and_receiver_are_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<WsSender>();
assert_send_sync::<WsReceiver>();
assert_send_sync::<WsSession>();
}
}