jmap_client/
client_ws.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use std::{pin::Pin, sync::Arc};
13
14use ahash::AHashMap;
15use futures_util::{stream::SplitSink, SinkExt, Stream, StreamExt};
16use reqwest::header::SEC_WEBSOCKET_PROTOCOL;
17use rustls::{
18    client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
19    ClientConfig, SignatureScheme,
20};
21use serde::{Deserialize, Serialize};
22use tokio::net::TcpStream;
23use tokio_tungstenite::{
24    tungstenite::{client::IntoClientRequest, Message},
25    Connector, MaybeTlsStream, WebSocketStream,
26};
27
28use crate::{
29    client::Client,
30    core::{
31        error::{ProblemDetails, ProblemType},
32        request::{Arguments, Request},
33        response::{Response, TaggedMethodResponse},
34    },
35    DataType, Method, PushObject, URI,
36};
37
38#[derive(Debug, Serialize)]
39struct WebSocketRequest {
40    #[serde(rename = "@type")]
41    pub _type: WebSocketRequestType,
42
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub id: Option<String>,
45
46    using: Vec<URI>,
47
48    #[serde(rename = "methodCalls")]
49    method_calls: Vec<(Method, Arguments, String)>,
50
51    #[serde(rename = "createdIds")]
52    #[serde(skip_serializing_if = "Option::is_none")]
53    created_ids: Option<AHashMap<String, String>>,
54}
55
56#[derive(Debug, Deserialize)]
57pub struct WebSocketResponse {
58    #[serde(rename = "@type")]
59    _type: WebSocketResponseType,
60
61    #[serde(rename = "requestId")]
62    request_id: Option<String>,
63
64    #[serde(rename = "methodResponses")]
65    method_responses: Vec<TaggedMethodResponse>,
66
67    #[serde(rename = "createdIds")]
68    created_ids: Option<AHashMap<String, String>>,
69
70    #[serde(rename = "sessionState")]
71    session_state: String,
72}
73
74#[derive(Debug, Serialize, Deserialize)]
75enum WebSocketResponseType {
76    Response,
77}
78
79#[derive(Debug, Serialize)]
80struct WebSocketPushEnable {
81    #[serde(rename = "@type")]
82    _type: WebSocketPushEnableType,
83
84    #[serde(rename = "dataTypes")]
85    data_types: Option<Vec<DataType>>,
86
87    #[serde(rename = "pushState")]
88    #[serde(skip_serializing_if = "Option::is_none")]
89    push_state: Option<String>,
90}
91
92#[derive(Debug, Serialize)]
93struct WebSocketPushDisable {
94    #[serde(rename = "@type")]
95    _type: WebSocketPushDisableType,
96}
97
98#[derive(Debug, Serialize)]
99enum WebSocketRequestType {
100    Request,
101}
102
103#[derive(Debug, Serialize)]
104enum WebSocketPushEnableType {
105    WebSocketPushEnable,
106}
107
108#[derive(Debug, Serialize)]
109enum WebSocketPushDisableType {
110    WebSocketPushDisable,
111}
112
113#[derive(Deserialize, Debug)]
114pub struct WebSocketPushObject {
115    #[serde(flatten)]
116    pub push: PushObject,
117
118    #[serde(rename = "pushState")]
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub push_state: Option<String>,
121}
122
123#[derive(Debug, Deserialize)]
124pub struct WebSocketError {
125    #[serde(rename = "@type")]
126    pub type_: WebSocketErrorType,
127
128    #[serde(rename = "requestId")]
129    pub request_id: Option<String>,
130
131    #[serde(rename = "type")]
132    p_type: ProblemType,
133    status: Option<u32>,
134    title: Option<String>,
135    detail: Option<String>,
136    limit: Option<String>,
137}
138
139#[derive(Serialize, Deserialize, Debug)]
140pub enum WebSocketErrorType {
141    RequestError,
142}
143
144#[derive(Debug, Deserialize)]
145#[serde(untagged)]
146enum WebSocketMessage_ {
147    Response(WebSocketResponse),
148    PushNotification(WebSocketPushObject),
149    Error(WebSocketError),
150}
151
152#[derive(Debug)]
153pub enum WebSocketMessage {
154    Response(Response<TaggedMethodResponse>),
155    PushNotification(PushObject),
156}
157
158pub struct WsStream {
159    tx: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
160    req_id: usize,
161}
162
163#[doc(hidden)]
164#[derive(Debug)]
165struct DummyVerifier;
166
167impl ServerCertVerifier for DummyVerifier {
168    fn verify_server_cert(
169        &self,
170        _end_entity: &rustls_pki_types::CertificateDer<'_>,
171        _intermediates: &[rustls_pki_types::CertificateDer<'_>],
172        _server_name: &rustls_pki_types::ServerName<'_>,
173        _ocsp_response: &[u8],
174        _now: rustls_pki_types::UnixTime,
175    ) -> Result<ServerCertVerified, rustls::Error> {
176        Ok(ServerCertVerified::assertion())
177    }
178
179    fn verify_tls12_signature(
180        &self,
181        _message: &[u8],
182        _cert: &rustls_pki_types::CertificateDer<'_>,
183        _dss: &rustls::DigitallySignedStruct,
184    ) -> Result<HandshakeSignatureValid, rustls::Error> {
185        Ok(HandshakeSignatureValid::assertion())
186    }
187
188    fn verify_tls13_signature(
189        &self,
190        _message: &[u8],
191        _cert: &rustls_pki_types::CertificateDer<'_>,
192        _dss: &rustls::DigitallySignedStruct,
193    ) -> Result<HandshakeSignatureValid, rustls::Error> {
194        Ok(HandshakeSignatureValid::assertion())
195    }
196
197    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
198        vec![
199            SignatureScheme::RSA_PKCS1_SHA1,
200            SignatureScheme::ECDSA_SHA1_Legacy,
201            SignatureScheme::RSA_PKCS1_SHA256,
202            SignatureScheme::ECDSA_NISTP256_SHA256,
203            SignatureScheme::RSA_PKCS1_SHA384,
204            SignatureScheme::ECDSA_NISTP384_SHA384,
205            SignatureScheme::RSA_PKCS1_SHA512,
206            SignatureScheme::ECDSA_NISTP521_SHA512,
207            SignatureScheme::RSA_PSS_SHA256,
208            SignatureScheme::RSA_PSS_SHA384,
209            SignatureScheme::RSA_PSS_SHA512,
210            SignatureScheme::ED25519,
211            SignatureScheme::ED448,
212        ]
213    }
214}
215
216impl Client {
217    pub async fn connect_ws(
218        &self,
219    ) -> crate::Result<Pin<Box<impl Stream<Item = crate::Result<WebSocketMessage>>>>> {
220        let session = self.session();
221        let capabilities = session.websocket_capabilities().ok_or_else(|| {
222            crate::Error::Internal(
223                "JMAP server does not advertise any websocket capabilities.".to_string(),
224            )
225        })?;
226
227        let mut request = capabilities.url().into_client_request()?;
228        request
229            .headers_mut()
230            .insert("Authorization", self.authorization.parse().unwrap());
231        request
232            .headers_mut()
233            .insert(SEC_WEBSOCKET_PROTOCOL, "jmap".parse().unwrap());
234
235        let (stream, _) = if self.accept_invalid_certs & capabilities.url().starts_with("wss") {
236            tokio_tungstenite::connect_async_tls_with_config(
237                request,
238                None,
239                false,
240                Connector::Rustls(Arc::new(
241                    ClientConfig::builder()
242                        .dangerous()
243                        .with_custom_certificate_verifier(Arc::new(DummyVerifier {}))
244                        .with_no_client_auth(),
245                ))
246                .into(),
247            )
248            .await?
249        } else {
250            tokio_tungstenite::connect_async(request).await?
251        };
252        let (tx, mut rx) = stream.split();
253
254        *self.ws.lock().await = WsStream { tx, req_id: 0 }.into();
255
256        Ok(Box::pin(async_stream::stream! {
257            while let Some(message) = rx.next().await {
258                match message {
259                    Ok(message) if message.is_text() => {
260                        match serde_json::from_slice::<WebSocketMessage_>(&message.into_data()) {
261                            Ok(message) => match message {
262                                WebSocketMessage_::Response(response) => {
263                                    yield Ok(WebSocketMessage::Response(Response::new(
264                                        response.method_responses,
265                                        response.created_ids,
266                                        response.session_state,
267                                        response.request_id,
268                                    )))
269                                }
270                                WebSocketMessage_::PushNotification(push) => {
271                                    yield Ok(WebSocketMessage::PushNotification(push.push))
272                                }
273                                WebSocketMessage_::Error(err) => yield Err(ProblemDetails::from(err).into()),
274                            },
275                            Err(err) => yield Err(err.into()),
276                        }
277                    }
278                    Ok(_) => (),
279                    Err(err) => yield Err(err.into()),
280                }
281            }
282        }))
283    }
284
285    pub async fn send_ws(&self, request: Request<'_>) -> crate::Result<String> {
286        let mut _ws = self.ws.lock().await;
287        let ws = _ws
288            .as_mut()
289            .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?;
290
291        // Assign request id
292        let request_id = ws.req_id.to_string();
293        ws.req_id += 1;
294
295        ws.tx
296            .send(Message::text(
297                serde_json::to_string(&WebSocketRequest {
298                    _type: WebSocketRequestType::Request,
299                    id: request_id.clone().into(),
300                    using: request.using,
301                    method_calls: request.method_calls,
302                    created_ids: request.created_ids,
303                })
304                .unwrap_or_default(),
305            ))
306            .await?;
307
308        Ok(request_id)
309    }
310
311    pub async fn enable_push_ws(
312        &self,
313        data_types: Option<impl IntoIterator<Item = DataType>>,
314        push_state: Option<impl Into<String>>,
315    ) -> crate::Result<()> {
316        self.ws
317            .lock()
318            .await
319            .as_mut()
320            .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?
321            .tx
322            .send(Message::text(
323                serde_json::to_string(&WebSocketPushEnable {
324                    _type: WebSocketPushEnableType::WebSocketPushEnable,
325                    data_types: data_types.map(|it| it.into_iter().collect()),
326                    push_state: push_state.map(|it| it.into()),
327                })
328                .unwrap_or_default(),
329            ))
330            .await
331            .map_err(|err| err.into())
332    }
333
334    pub async fn disable_push_ws(&self) -> crate::Result<()> {
335        self.ws
336            .lock()
337            .await
338            .as_mut()
339            .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?
340            .tx
341            .send(Message::text(
342                serde_json::to_string(&WebSocketPushDisable {
343                    _type: WebSocketPushDisableType::WebSocketPushDisable,
344                })
345                .unwrap_or_default(),
346            ))
347            .await
348            .map_err(|err| err.into())
349    }
350
351    pub async fn ws_ping(&self) -> crate::Result<()> {
352        self.ws
353            .lock()
354            .await
355            .as_mut()
356            .ok_or_else(|| crate::Error::Internal("Websocket stream not set.".to_string()))?
357            .tx
358            .send(Message::Ping(vec![].into()))
359            .await
360            .map_err(|err| err.into())
361    }
362}
363
364impl From<WebSocketError> for ProblemDetails {
365    fn from(problem: WebSocketError) -> Self {
366        ProblemDetails::new(
367            problem.p_type,
368            problem.status,
369            problem.title,
370            problem.detail,
371            problem.limit,
372            problem.request_id,
373        )
374    }
375}