1use std::{pin::Pin, sync::Arc};
13
14use ahash::AHashMap;
15use futures_util::{stream::SplitSink, SinkExt, Stream, StreamExt};
16use reqwest::header::SEC_WEBSOCKET_PROTOCOL;
17use rustls::{
18 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
19 ClientConfig, SignatureScheme,
20};
21use serde::{Deserialize, Serialize};
22use tokio::net::TcpStream;
23use tokio_tungstenite::{
24 tungstenite::{client::IntoClientRequest, Message},
25 Connector, MaybeTlsStream, WebSocketStream,
26};
27
28use crate::{
29 client::Client,
30 core::{
31 error::{ProblemDetails, ProblemType},
32 request::{Arguments, Request},
33 response::{Response, TaggedMethodResponse},
34 },
35 DataType, Method, PushObject, URI,
36};
37
38#[derive(Debug, Serialize)]
39struct WebSocketRequest {
40 #[serde(rename = "@type")]
41 pub _type: WebSocketRequestType,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub id: Option<String>,
45
46 using: Vec<URI>,
47
48 #[serde(rename = "methodCalls")]
49 method_calls: Vec<(Method, Arguments, String)>,
50
51 #[serde(rename = "createdIds")]
52 #[serde(skip_serializing_if = "Option::is_none")]
53 created_ids: Option<AHashMap<String, String>>,
54}
55
56#[derive(Debug, Deserialize)]
57pub struct WebSocketResponse {
58 #[serde(rename = "@type")]
59 _type: WebSocketResponseType,
60
61 #[serde(rename = "requestId")]
62 request_id: Option<String>,
63
64 #[serde(rename = "methodResponses")]
65 method_responses: Vec<TaggedMethodResponse>,
66
67 #[serde(rename = "createdIds")]
68 created_ids: Option<AHashMap<String, String>>,
69
70 #[serde(rename = "sessionState")]
71 session_state: String,
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75enum WebSocketResponseType {
76 Response,
77}
78
79#[derive(Debug, Serialize)]
80struct WebSocketPushEnable {
81 #[serde(rename = "@type")]
82 _type: WebSocketPushEnableType,
83
84 #[serde(rename = "dataTypes")]
85 data_types: Option<Vec<DataType>>,
86
87 #[serde(rename = "pushState")]
88 #[serde(skip_serializing_if = "Option::is_none")]
89 push_state: Option<String>,
90}
91
92#[derive(Debug, Serialize)]
93struct WebSocketPushDisable {
94 #[serde(rename = "@type")]
95 _type: WebSocketPushDisableType,
96}
97
98#[derive(Debug, Serialize)]
99enum WebSocketRequestType {
100 Request,
101}
102
103#[derive(Debug, Serialize)]
104enum WebSocketPushEnableType {
105 WebSocketPushEnable,
106}
107
108#[derive(Debug, Serialize)]
109enum WebSocketPushDisableType {
110 WebSocketPushDisable,
111}
112
113#[derive(Deserialize, Debug)]
114pub struct WebSocketPushObject {
115 #[serde(flatten)]
116 pub push: PushObject,
117
118 #[serde(rename = "pushState")]
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub push_state: Option<String>,
121}
122
123#[derive(Debug, Deserialize)]
124pub struct WebSocketError {
125 #[serde(rename = "@type")]
126 pub type_: WebSocketErrorType,
127
128 #[serde(rename = "requestId")]
129 pub request_id: Option<String>,
130
131 #[serde(rename = "type")]
132 p_type: ProblemType,
133 status: Option<u32>,
134 title: Option<String>,
135 detail: Option<String>,
136 limit: Option<String>,
137}
138
139#[derive(Serialize, Deserialize, Debug)]
140pub enum WebSocketErrorType {
141 RequestError,
142}
143
144#[derive(Debug, Deserialize)]
145#[serde(untagged)]
146enum WebSocketMessage_ {
147 Response(WebSocketResponse),
148 PushNotification(WebSocketPushObject),
149 Error(WebSocketError),
150}
151
152#[derive(Debug)]
153pub enum WebSocketMessage {
154 Response(Response<TaggedMethodResponse>),
155 PushNotification(PushObject),
156}
157
158pub struct WsStream {
159 tx: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
160 req_id: usize,
161}
162
163#[doc(hidden)]
164#[derive(Debug)]
165struct DummyVerifier;
166
167impl ServerCertVerifier for DummyVerifier {
168 fn verify_server_cert(
169 &self,
170 _end_entity: &rustls_pki_types::CertificateDer<'_>,
171 _intermediates: &[rustls_pki_types::CertificateDer<'_>],
172 _server_name: &rustls_pki_types::ServerName<'_>,
173 _ocsp_response: &[u8],
174 _now: rustls_pki_types::UnixTime,
175 ) -> Result<ServerCertVerified, rustls::Error> {
176 Ok(ServerCertVerified::assertion())
177 }
178
179 fn verify_tls12_signature(
180 &self,
181 _message: &[u8],
182 _cert: &rustls_pki_types::CertificateDer<'_>,
183 _dss: &rustls::DigitallySignedStruct,
184 ) -> Result<HandshakeSignatureValid, rustls::Error> {
185 Ok(HandshakeSignatureValid::assertion())
186 }
187
188 fn verify_tls13_signature(
189 &self,
190 _message: &[u8],
191 _cert: &rustls_pki_types::CertificateDer<'_>,
192 _dss: &rustls::DigitallySignedStruct,
193 ) -> Result<HandshakeSignatureValid, rustls::Error> {
194 Ok(HandshakeSignatureValid::assertion())
195 }
196
197 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
198 vec![
199 SignatureScheme::RSA_PKCS1_SHA1,
200 SignatureScheme::ECDSA_SHA1_Legacy,
201 SignatureScheme::RSA_PKCS1_SHA256,
202 SignatureScheme::ECDSA_NISTP256_SHA256,
203 SignatureScheme::RSA_PKCS1_SHA384,
204 SignatureScheme::ECDSA_NISTP384_SHA384,
205 SignatureScheme::RSA_PKCS1_SHA512,
206 SignatureScheme::ECDSA_NISTP521_SHA512,
207 SignatureScheme::RSA_PSS_SHA256,
208 SignatureScheme::RSA_PSS_SHA384,
209 SignatureScheme::RSA_PSS_SHA512,
210 SignatureScheme::ED25519,
211 SignatureScheme::ED448,
212 ]
213 }
214}
215
216impl Client {
217 pub async fn connect_ws(
218 &self,
219 ) -> crate::Result<Pin<Box<impl Stream<Item = crate::Result<WebSocketMessage>>>>> {
220 let session = self.session();
221 let capabilities = session.websocket_capabilities().ok_or_else(|| {
222 crate::Error::Internal(
223 "JMAP server does not advertise any websocket capabilities.".to_string(),
224 )
225 })?;
226
227 let mut request = capabilities.url().into_client_request()?;
228 request
229 .headers_mut()
230 .insert("Authorization", self.authorization.parse().unwrap());
231 request
232 .headers_mut()
233 .insert(SEC_WEBSOCKET_PROTOCOL, "jmap".parse().unwrap());
234
235 let (stream, _) = if self.accept_invalid_certs & capabilities.url().starts_with("wss") {
236 tokio_tungstenite::connect_async_tls_with_config(
237 request,
238 None,
239 false,
240 Connector::Rustls(Arc::new(
241 ClientConfig::builder()
242 .dangerous()
243 .with_custom_certificate_verifier(Arc::new(DummyVerifier {}))
244 .with_no_client_auth(),
245 ))
246 .into(),
247 )
248 .await?
249 } else {
250 tokio_tungstenite::connect_async(request).await?
251 };
252 let (tx, mut rx) = stream.split();
253
254 *self.ws.lock().await = WsStream { tx, req_id: 0 }.into();
255
256 Ok(Box::pin(async_stream::stream! {
257 while let Some(message) = rx.next().await {
258 match message {
259 Ok(message) if message.is_text() => {
260 match serde_json::from_slice::<WebSocketMessage_>(&message.into_data()) {
261 Ok(message) => match message {
262 WebSocketMessage_::Response(response) => {
263 yield Ok(WebSocketMessage::Response(Response::new(
264 response.method_responses,
265 response.created_ids,
266 response.session_state,
267 response.request_id,
268 )))
269 }
270 WebSocketMessage_::PushNotification(push) => {
271 yield Ok(WebSocketMessage::PushNotification(push.push))
272 }
273 WebSocketMessage_::Error(err) => yield Err(ProblemDetails::from(err).into()),
274 },
275 Err(err) => yield Err(err.into()),
276 }
277 }
278 Ok(_) => (),
279 Err(err) => yield Err(err.into()),
280 }
281 }
282 }))
283 }
284
285 pub async fn send_ws(&self, request: Request<'_>) -> crate::Result<String> {
286 let mut _ws = self.ws.lock().await;
287 let ws = _ws
288 .as_mut()
289 .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?;
290
291 let request_id = ws.req_id.to_string();
293 ws.req_id += 1;
294
295 ws.tx
296 .send(Message::text(
297 serde_json::to_string(&WebSocketRequest {
298 _type: WebSocketRequestType::Request,
299 id: request_id.clone().into(),
300 using: request.using,
301 method_calls: request.method_calls,
302 created_ids: request.created_ids,
303 })
304 .unwrap_or_default(),
305 ))
306 .await?;
307
308 Ok(request_id)
309 }
310
311 pub async fn enable_push_ws(
312 &self,
313 data_types: Option<impl IntoIterator<Item = DataType>>,
314 push_state: Option<impl Into<String>>,
315 ) -> crate::Result<()> {
316 self.ws
317 .lock()
318 .await
319 .as_mut()
320 .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?
321 .tx
322 .send(Message::text(
323 serde_json::to_string(&WebSocketPushEnable {
324 _type: WebSocketPushEnableType::WebSocketPushEnable,
325 data_types: data_types.map(|it| it.into_iter().collect()),
326 push_state: push_state.map(|it| it.into()),
327 })
328 .unwrap_or_default(),
329 ))
330 .await
331 .map_err(|err| err.into())
332 }
333
334 pub async fn disable_push_ws(&self) -> crate::Result<()> {
335 self.ws
336 .lock()
337 .await
338 .as_mut()
339 .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?
340 .tx
341 .send(Message::text(
342 serde_json::to_string(&WebSocketPushDisable {
343 _type: WebSocketPushDisableType::WebSocketPushDisable,
344 })
345 .unwrap_or_default(),
346 ))
347 .await
348 .map_err(|err| err.into())
349 }
350
351 pub async fn ws_ping(&self) -> crate::Result<()> {
352 self.ws
353 .lock()
354 .await
355 .as_mut()
356 .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?
357 .tx
358 .send(Message::Ping(vec![].into()))
359 .await
360 .map_err(|err| err.into())
361 }
362}
363
364impl From<WebSocketError> for ProblemDetails {
365 fn from(problem: WebSocketError) -> Self {
366 ProblemDetails::new(
367 problem.p_type,
368 problem.status,
369 problem.title,
370 problem.detail,
371 problem.limit,
372 problem.request_id,
373 )
374 }
375}