Skip to main content

jmap_client/
client.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    blob,
14    core::{
15        request::{self, Request},
16        response,
17        session::{Session, URLPart},
18    },
19    Error,
20};
21use ahash::AHashSet;
22use base64::{engine::general_purpose, Engine};
23#[cfg(feature = "blocking")]
24use reqwest::blocking::{Client as HttpClient, Response};
25use reqwest::{
26    header::{self},
27    redirect,
28};
29#[cfg(feature = "async")]
30use reqwest::{Client as HttpClient, Response};
31use serde::de::DeserializeOwned;
32use std::{
33    net::IpAddr,
34    sync::{
35        atomic::{AtomicBool, Ordering},
36        Arc,
37    },
38    time::Duration,
39};
40
41const DEFAULT_TIMEOUT_MS: u64 = 10 * 1000;
42static USER_AGENT: &str = concat!("jmap-client/", env!("CARGO_PKG_VERSION"));
43
44#[derive(Debug, PartialEq, Eq)]
45pub enum Credentials {
46    Basic(String),
47    Bearer(String),
48}
49
50pub struct Client {
51    session: parking_lot::Mutex<Arc<Session>>,
52    session_url: String,
53    api_url: String,
54    session_updated: AtomicBool,
55    trusted_hosts: Arc<AHashSet<String>>,
56
57    upload_url: Vec<URLPart<blob::URLParameter>>,
58    download_url: Vec<URLPart<blob::URLParameter>>,
59    #[cfg(feature = "async")]
60    event_source_url: Vec<URLPart<crate::event_source::URLParameter>>,
61
62    headers: header::HeaderMap,
63    default_account_id: String,
64    timeout: Duration,
65    pub(crate) accept_invalid_certs: bool,
66
67    #[cfg(feature = "websockets")]
68    pub(crate) authorization: String,
69    #[cfg(feature = "websockets")]
70    pub(crate) ws: tokio::sync::Mutex<Option<crate::client_ws::WsStream>>,
71}
72
73pub struct ClientBuilder {
74    credentials: Option<Credentials>,
75    trusted_hosts: AHashSet<String>,
76    forwarded_for: Option<String>,
77    accept_invalid_certs: bool,
78    timeout: Duration,
79}
80
81impl Default for ClientBuilder {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl ClientBuilder {
88    /// Creates a new `ClientBuilder`.
89    ///
90    /// Setting the credentials is required to connect to the JMAP API.
91    pub fn new() -> Self {
92        Self {
93            credentials: None,
94            trusted_hosts: AHashSet::new(),
95            timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
96            forwarded_for: None,
97            accept_invalid_certs: false,
98        }
99    }
100
101    /// Set up client credentials to connect to the JMAP API.
102    ///
103    /// The JMAP API URL is set using the [ClientBuilder.connect()](struct.ClientBuilder.html#method.connect) method.
104    ///
105    /// # Bearer authentication
106    /// Pass a `&str` with the API Token.
107    ///
108    /// ```rust
109    /// Client::new().credentials("some-api-token");
110    /// ```
111    ///
112    /// Or use the longer form by using [Credentials::bearer()](enum.Credentials.html#method.bearer).
113    /// ```rust
114    /// let credentials = Credentials::bearer("some-api-token");
115    /// Client::new().credentials(credentials);
116    /// ```
117    ///
118    /// # Basic authentication
119    /// Pass a `(&str, &str)` tuple, with the first position containing a username and the second containing a password.
120    ///
121    /// **It is not suggested to use this approach in production;** instead, if possible, use [Bearer authentication](struct.ClientBuilder.html#bearer-authentication).
122    ///
123    /// ```rust
124    /// Client::new().credentials(("user@domain.com", "password"));
125    /// ```
126    ///
127    /// Or use the longer form by using [Credentials::basic()](enum.Credentials.html#method.basic).
128    /// ```rust
129    /// let credentials = Credentials::basic("user@domain.com", "password");
130    /// Client::new().credentials(credentials);
131    /// ```
132    pub fn credentials(mut self, credentials: impl Into<Credentials>) -> Self {
133        self.credentials = Some(credentials.into());
134        self
135    }
136
137    /// Set a timeout for all the requests to the JMAP API.
138    ///
139    /// The timeout can be changed after the `Client` has been created by using [Client.set_timeout()](struct.Client.html#method.set_timeout).
140    ///
141    /// By default the timeout is 10 seconds.
142    pub fn timeout(mut self, timeout: Duration) -> Self {
143        self.timeout = timeout;
144        self
145    }
146
147    /// Accepts invalid certificates for all the requests to the JMAP API.
148    ///
149    /// By default certificates are validated.
150    ///
151    /// # Warning
152    /// **It is not suggested to use this approach in production;** this method should be used only for testing and as a last resort.
153    ///
154    /// [Read more in the reqwest docs](https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.danger_accept_invalid_certs)
155    pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
156        self.accept_invalid_certs = accept_invalid_certs;
157        self
158    }
159
160    /// Set a list of trusted hosts that will be checked when a redirect is required.
161    ///
162    /// The list can be changed after the `Client` has been created by using [Client.set_follow_redirects()](struct.Client.html#method.set_follow_redirects).
163    ///
164    /// The client will follow at most 5 redirects.
165    pub fn follow_redirects(
166        mut self,
167        trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
168    ) -> Self {
169        self.trusted_hosts = trusted_hosts.into_iter().map(|h| h.into()).collect();
170        self
171    }
172
173    /// Set the originating IP address of the client connecting to the JMAP API.
174    pub fn forwarded_for(mut self, forwarded_for: IpAddr) -> Self {
175        self.forwarded_for = Some(match forwarded_for {
176            IpAddr::V4(addr) => format!("for={}", addr),
177            IpAddr::V6(addr) => format!("for=\"{}\"", addr),
178        });
179        self
180    }
181
182    /// Connects to the JMAP API Session URL.
183    ///
184    /// Setting up [Credentials](struct.ClientBuilder.html#method.credentials) must be done before calling this function.
185    #[maybe_async::maybe_async]
186    pub async fn connect(self, url: &str) -> crate::Result<Client> {
187        let authorization = match self.credentials.expect("Missing credentials") {
188            Credentials::Basic(s) => format!("Basic {}", s),
189            Credentials::Bearer(s) => format!("Bearer {}", s),
190        };
191        let mut headers = header::HeaderMap::new();
192        headers.insert(
193            header::USER_AGENT,
194            header::HeaderValue::from_static(USER_AGENT),
195        );
196        headers.insert(
197            header::AUTHORIZATION,
198            header::HeaderValue::from_str(&authorization).unwrap(),
199        );
200        if let Some(forwarded_for) = self.forwarded_for {
201            headers.insert(
202                header::FORWARDED,
203                header::HeaderValue::from_str(&forwarded_for).unwrap(),
204            );
205        }
206
207        let trusted_hosts = Arc::new(self.trusted_hosts);
208
209        let trusted_hosts_ = trusted_hosts.clone();
210        let session_url = format!("{}/.well-known/jmap", url);
211        let session: Session = serde_json::from_slice(
212            &Client::handle_error(
213                HttpClient::builder()
214                    .timeout(self.timeout)
215                    .danger_accept_invalid_certs(self.accept_invalid_certs)
216                    .redirect(redirect::Policy::custom(move |attempt| {
217                        if attempt.previous().len() > 5 {
218                            attempt.error("Too many redirects.")
219                        } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
220                        {
221                                attempt.follow()
222                        } else {
223                            let message = format!(
224                                "Aborting redirect request to unknown host '{}'.",
225                                attempt.url().host_str().unwrap_or("")
226                            );
227                            attempt.error(message)
228                        }
229                    }))
230                    .default_headers(headers.clone())
231                    .build()?
232                    .get(&session_url)
233                    .send()
234                    .await?,
235            )
236            .await?
237            .bytes()
238            .await?,
239        )?;
240
241        let default_account_id = session
242            .primary_accounts()
243            .next()
244            .map(|a| a.1.to_string())
245            .unwrap_or_default();
246
247        headers.insert(
248            header::CONTENT_TYPE,
249            header::HeaderValue::from_static("application/json"),
250        );
251
252        Ok(Client {
253            download_url: URLPart::parse(session.download_url())?,
254            upload_url: URLPart::parse(session.upload_url())?,
255            #[cfg(feature = "async")]
256            event_source_url: URLPart::parse(session.event_source_url())?,
257            api_url: session.api_url().to_string(),
258            session: parking_lot::Mutex::new(Arc::new(session)),
259            session_url,
260            session_updated: true.into(),
261            accept_invalid_certs: self.accept_invalid_certs,
262            trusted_hosts,
263            #[cfg(feature = "websockets")]
264            authorization,
265            timeout: self.timeout,
266            headers,
267            default_account_id,
268            #[cfg(feature = "websockets")]
269            ws: None.into(),
270        })
271    }
272}
273
274impl Client {
275    #[allow(clippy::new_ret_no_self)]
276    pub fn new() -> ClientBuilder {
277        ClientBuilder::new()
278    }
279
280    pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
281        self.timeout = timeout;
282        self
283    }
284
285    pub fn set_follow_redirects(
286        &mut self,
287        trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
288    ) -> &mut Self {
289        self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
290        self
291    }
292
293    pub fn timeout(&self) -> Duration {
294        self.timeout
295    }
296
297    pub fn session(&self) -> Arc<Session> {
298        self.session.lock().clone()
299    }
300
301    pub fn session_url(&self) -> &str {
302        &self.session_url
303    }
304
305    pub fn headers(&self) -> &header::HeaderMap {
306        &self.headers
307    }
308
309    pub(crate) fn redirect_policy(&self) -> redirect::Policy {
310        let trusted_hosts = self.trusted_hosts.clone();
311        redirect::Policy::custom(move |attempt| {
312            if attempt.previous().len() > 5 {
313                attempt.error("Too many redirects.")
314            } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
315            {
316                attempt.follow()
317            } else {
318                let message = format!(
319                    "Aborting redirect request to unknown host '{}'.",
320                    attempt.url().host_str().unwrap_or("")
321                );
322                attempt.error(message)
323            }
324        })
325    }
326
327    #[maybe_async::maybe_async]
328    pub async fn send<R>(
329        &self,
330        request: &request::Request<'_>,
331    ) -> crate::Result<response::Response<R>>
332    where
333        R: DeserializeOwned,
334    {
335        let response: response::Response<R> = serde_json::from_slice(
336            &Client::handle_error(
337                HttpClient::builder()
338                    .redirect(self.redirect_policy())
339                    .danger_accept_invalid_certs(self.accept_invalid_certs)
340                    .timeout(self.timeout)
341                    .default_headers(self.headers.clone())
342                    .build()?
343                    .post(&self.api_url)
344                    .body(serde_json::to_string(&request)?)
345                    .send()
346                    .await?,
347            )
348            .await?
349            .bytes()
350            .await?,
351        )?;
352
353        if response.session_state() != self.session.lock().state() {
354            self.session_updated.store(false, Ordering::Relaxed);
355        }
356
357        Ok(response)
358    }
359
360    #[maybe_async::maybe_async]
361    pub async fn refresh_session(&self) -> crate::Result<()> {
362        let session: Session = serde_json::from_slice(
363            &Client::handle_error(
364                HttpClient::builder()
365                    .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
366                    .danger_accept_invalid_certs(self.accept_invalid_certs)
367                    .redirect(self.redirect_policy())
368                    .default_headers(self.headers.clone())
369                    .build()?
370                    .get(&self.session_url)
371                    .send()
372                    .await?,
373            )
374            .await?
375            .bytes()
376            .await?,
377        )?;
378        *self.session.lock() = Arc::new(session);
379        self.session_updated.store(true, Ordering::Relaxed);
380        Ok(())
381    }
382
383    pub fn is_session_updated(&self) -> bool {
384        self.session_updated.load(Ordering::Relaxed)
385    }
386
387    pub fn set_default_account_id(&mut self, defaul_account_id: impl Into<String>) -> &mut Self {
388        self.default_account_id = defaul_account_id.into();
389        self
390    }
391
392    pub fn default_account_id(&self) -> &str {
393        &self.default_account_id
394    }
395
396    pub fn build(&self) -> Request<'_> {
397        Request::new(self)
398    }
399
400    pub fn download_url(&self) -> &[URLPart<blob::URLParameter>] {
401        &self.download_url
402    }
403
404    pub fn upload_url(&self) -> &[URLPart<blob::URLParameter>] {
405        &self.upload_url
406    }
407
408    #[cfg(feature = "async")]
409    pub fn event_source_url(&self) -> &[URLPart<crate::event_source::URLParameter>] {
410        &self.event_source_url
411    }
412
413    #[maybe_async::maybe_async]
414    pub async fn handle_error(response: Response) -> crate::Result<Response> {
415        if response.status().is_success() {
416            Ok(response)
417        } else if let Some(b"application/problem+json") = response
418            .headers()
419            .get(header::CONTENT_TYPE)
420            .map(|h| h.as_bytes())
421        {
422            Err(Error::Problem(serde_json::from_slice(
423                &response.bytes().await?,
424            )?))
425        } else {
426            Err(Error::Server(format!("{}", response.status())))
427        }
428    }
429}
430
431impl Credentials {
432    pub fn basic(username: &str, password: &str) -> Self {
433        Credentials::Basic(general_purpose::STANDARD.encode(format!("{}:{}", username, password)))
434    }
435
436    pub fn bearer(token: impl Into<String>) -> Self {
437        Credentials::Bearer(token.into())
438    }
439}
440
441impl From<&str> for Credentials {
442    fn from(s: &str) -> Self {
443        Credentials::bearer(s.to_string())
444    }
445}
446
447impl From<String> for Credentials {
448    fn from(s: String) -> Self {
449        Credentials::bearer(s)
450    }
451}
452
453impl From<(&str, &str)> for Credentials {
454    fn from((username, password): (&str, &str)) -> Self {
455        Credentials::basic(username, password)
456    }
457}
458
459impl From<(String, String)> for Credentials {
460    fn from((username, password): (String, String)) -> Self {
461        Credentials::basic(&username, &password)
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use crate::core::response::{Response, TaggedMethodResponse};
468
469    #[test]
470    fn test_deserialize() {
471        let _r: Response<TaggedMethodResponse> = serde_json::from_slice(
472            br#"{"sessionState": "123", "methodResponses": [[ "Email/query", {
473                "accountId": "A1",
474                "queryState": "abcdefg",
475                "canCalculateChanges": true,
476                "position": 0,
477                "total": 101,
478                "ids": [ "msg1023", "msg223", "msg110", "msg93", "msg91",
479                    "msg38", "msg36", "msg33", "msg11", "msg1" ]
480            }, "t0" ],
481            [ "Email/get", {
482                "accountId": "A1",
483                "state": "123456",
484                "list": [{
485                    "id": "msg1023",
486                    "threadId": "trd194"
487                }, {
488                    "id": "msg223",
489                    "threadId": "trd114"
490                }
491                ],
492                "notFound": []
493            }, "t1" ],
494            [ "Thread/get", {
495                "accountId": "A1",
496                "state": "123456",
497                "list": [{
498                    "id": "trd194",
499                    "emailIds": [ "msg1020", "msg1021", "msg1023" ]
500                }, {
501                    "id": "trd114",
502                    "emailIds": [ "msg201", "msg223" ]
503                }
504                ],
505                "notFound": []
506            }, "t2" ]]}"#,
507        )
508        .unwrap();
509    }
510}