use std::str::FromStr as _;
use futures::SinkExt as _;
use futures::StreamExt as _;
use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message;
use crate::push::StateChange;
#[derive(serde::Serialize)]
struct WsRequestFrame<'a> {
#[serde(rename = "@type")]
ws_type: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<&'a str>,
#[serde(flatten)]
inner: &'a jmap_types::JmapRequest,
}
const MAX_WS_MESSAGE_BYTES: usize = 1 << 20;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum WsFrame {
StateChange(StateChange),
Response(jmap_types::JmapResponse),
Unknown { type_name: String },
}
type Inner =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
pub struct WsSession {
sink: futures::stream::SplitSink<Inner, Message>,
stream: futures::stream::SplitStream<Inner>,
}
impl WsSession {
pub async fn next_frame(&mut self) -> Option<Result<WsFrame, crate::error::ClientError>> {
loop {
match self.stream.next().await? {
Ok(Message::Text(text)) => return Some(parse_ws_frame(&text)),
Ok(Message::Close(_)) => return None,
Ok(_) => continue, Err(e) => return Some(Err(crate::error::ClientError::WebSocket(e))),
}
}
}
pub async fn send_request(
&mut self,
req: &jmap_types::JmapRequest,
id: Option<&str>,
) -> Result<(), crate::error::ClientError> {
let frame = WsRequestFrame {
ws_type: "Request",
id,
inner: req,
};
let text = serde_json::to_string(&frame).map_err(crate::error::ClientError::Serialize)?;
self.sink
.send(Message::Text(text.into()))
.await
.map_err(crate::error::ClientError::WebSocket)
}
}
fn parse_ws_frame(text: &str) -> Result<WsFrame, crate::error::ClientError> {
let val: serde_json::Value =
serde_json::from_str(text).map_err(crate::error::ClientError::Parse)?;
let type_name = val
.get("@type")
.and_then(|v| v.as_str())
.unwrap_or("<no @type>")
.to_owned();
match type_name.as_str() {
"StateChange" => match serde_json::from_value::<StateChange>(val) {
Ok(sc) => Ok(WsFrame::StateChange(sc)),
Err(_) => Ok(WsFrame::Unknown { type_name }),
},
"Response" => match serde_json::from_value::<jmap_types::JmapResponse>(val) {
Ok(r) => Ok(WsFrame::Response(r)),
Err(_) => Ok(WsFrame::Unknown { type_name }),
},
_ => Ok(WsFrame::Unknown { type_name }),
}
}
pub async fn connect_ws(
ws_url: &str,
auth_header: Option<(&str, &str)>,
) -> Result<WsSession, crate::error::ClientError> {
let ws_url_lc = ws_url.to_ascii_lowercase();
if !ws_url_lc.starts_with("ws://") && !ws_url_lc.starts_with("wss://") {
return Err(crate::error::ClientError::InvalidArgument(format!(
"WebSocket URL must start with ws:// or wss://, got: {ws_url:?}"
)));
}
let mut request = ws_url
.into_client_request()
.map_err(crate::error::ClientError::WebSocket)?;
if let Some((name, value)) = auth_header {
let hdr_name = http::HeaderName::from_str(name).map_err(|e| {
crate::error::ClientError::InvalidArgument(format!("invalid auth header name: {e}"))
})?;
let hdr_value = http::HeaderValue::from_str(value).map_err(|_| {
crate::error::ClientError::InvalidArgument("invalid auth header value".to_owned())
})?;
request.headers_mut().insert(hdr_name, hdr_value);
}
let mut config = WebSocketConfig::default();
config.max_message_size = Some(MAX_WS_MESSAGE_BYTES);
config.max_frame_size = Some(MAX_WS_MESSAGE_BYTES);
let connect_result = tokio::time::timeout(
std::time::Duration::from_secs(10),
tokio_tungstenite::connect_async_with_config(request, Some(config), false),
)
.await
.map_err(|_elapsed| {
crate::error::ClientError::WebSocket(tokio_tungstenite::tungstenite::Error::Io(
std::io::Error::new(
std::io::ErrorKind::TimedOut,
"WebSocket connect timed out after 10 seconds",
),
))
})?;
let (ws_stream, _response) = connect_result.map_err(crate::error::ClientError::WebSocket)?;
let (sink, stream) = ws_stream.split();
Ok(WsSession { sink, stream })
}
impl std::fmt::Debug for WsSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsSession").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_frame_has_no_chat_variants() {
let frame = WsFrame::Unknown {
type_name: "test".to_owned(),
};
match frame {
WsFrame::StateChange(_) => {}
WsFrame::Response(_) => {}
WsFrame::Unknown { .. } => {}
}
}
#[test]
fn parse_state_change() {
let json = r#"{"@type":"StateChange","changed":{"account1":{"Mail":"s2"}}}"#;
let frame = parse_ws_frame(json).expect("must parse");
match frame {
WsFrame::StateChange(sc) => {
let account = sc
.changed
.get("account1")
.expect("account1 must be present");
assert_eq!(account.get("Mail").map(|s| s.as_ref()), Some("s2"));
}
other => panic!("expected StateChange, got {other:?}"),
}
}
#[test]
fn parse_malformed_state_change_degrades_to_unknown() {
let json = r#"{"@type":"StateChange","unexpected_field":42}"#;
let frame = parse_ws_frame(json).expect("must not error");
match frame {
WsFrame::Unknown { type_name } => assert_eq!(type_name, "StateChange"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[test]
fn parse_unknown_type() {
let json = r#"{"@type":"FutureEvent","foo":"bar"}"#;
let frame = parse_ws_frame(json).expect("must parse");
match frame {
WsFrame::Unknown { type_name } => assert_eq!(type_name, "FutureEvent"),
other => panic!("expected Unknown, got {other:?}"),
}
}
#[test]
fn parse_missing_type_field() {
let json = r#"{"foo":"bar"}"#;
let frame = parse_ws_frame(json).expect("must parse");
assert!(matches!(frame, WsFrame::Unknown { .. }));
}
#[test]
fn parse_invalid_json_returns_parse_error() {
let err = parse_ws_frame("not json").expect_err("must fail");
assert!(matches!(err, crate::error::ClientError::Parse(_)));
}
#[test]
fn send_request_includes_at_type_request() {
let req = jmap_types::JmapRequest::new(
vec!["urn:ietf:params:jmap:core".to_owned()],
vec![],
None,
);
let frame = WsRequestFrame {
ws_type: "Request",
id: None,
inner: &req,
};
let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
assert!(
serialized.contains("\"@type\":\"Request\""),
"RFC 8887 §4.3.2 requires @type:Request in outgoing WS frames; got: {serialized}"
);
}
#[test]
fn send_request_includes_id_when_provided() {
let req = jmap_types::JmapRequest::new(
vec!["urn:ietf:params:jmap:core".to_owned()],
vec![],
None,
);
let frame = WsRequestFrame {
ws_type: "Request",
id: Some("req-42"),
inner: &req,
};
let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
assert!(
serialized.contains("\"id\":\"req-42\""),
"RFC 8887 §4.3.2 optional id must be present when provided; got: {serialized}"
);
}
#[test]
fn send_request_omits_id_when_none() {
let req = jmap_types::JmapRequest::new(
vec!["urn:ietf:params:jmap:core".to_owned()],
vec![],
None,
);
let frame = WsRequestFrame {
ws_type: "Request",
id: None,
inner: &req,
};
let serialized = serde_json::to_string(&frame).expect("WsRequestFrame must serialize");
assert!(
!serialized.contains("\"id\":"),
"RFC 8887 §4.3.2: no id field must appear when id is None; got: {serialized}"
);
}
#[tokio::test]
async fn connect_ws_rejects_non_ws_schemes() {
for bad_url in &["http://host/", "https://host/", "ftp://host/"] {
let result = connect_ws(bad_url, None).await.map(|_| ());
match result {
Err(crate::error::ClientError::InvalidArgument(_)) => {}
other => panic!("expected InvalidArgument for {bad_url:?}, got {other:?}"),
}
}
}
}