jmap_client/
client.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use std::{
13    net::IpAddr,
14    sync::{
15        atomic::{AtomicBool, Ordering},
16        Arc,
17    },
18    time::Duration,
19};
20
21use ahash::AHashSet;
22#[cfg(feature = "blocking")]
23use reqwest::blocking::{Client as HttpClient, Response};
24use reqwest::{
25    header::{self},
26    redirect,
27};
28#[cfg(feature = "async")]
29use reqwest::{Client as HttpClient, Response};
30
31use serde::de::DeserializeOwned;
32
33use crate::{
34    blob,
35    core::{
36        request::{self, Request},
37        response,
38        session::{Session, URLPart},
39    },
40    Error,
41};
42
43const DEFAULT_TIMEOUT_MS: u64 = 10 * 1000;
44static USER_AGENT: &str = concat!("jmap-client/", env!("CARGO_PKG_VERSION"));
45
46#[derive(Debug, PartialEq, Eq)]
47pub enum Credentials {
48    Basic(String),
49    Bearer(String),
50}
51
52pub struct Client {
53    session: parking_lot::Mutex<Arc<Session>>,
54    session_url: String,
55    api_url: String,
56    session_updated: AtomicBool,
57    trusted_hosts: Arc<AHashSet<String>>,
58
59    upload_url: Vec<URLPart<blob::URLParameter>>,
60    download_url: Vec<URLPart<blob::URLParameter>>,
61    #[cfg(feature = "async")]
62    event_source_url: Vec<URLPart<crate::event_source::URLParameter>>,
63
64    headers: header::HeaderMap,
65    default_account_id: String,
66    timeout: Duration,
67    pub(crate) accept_invalid_certs: bool,
68
69    #[cfg(feature = "websockets")]
70    pub(crate) authorization: String,
71    #[cfg(feature = "websockets")]
72    pub(crate) ws: tokio::sync::Mutex<Option<crate::client_ws::WsStream>>,
73}
74
75pub struct ClientBuilder {
76    credentials: Option<Credentials>,
77    trusted_hosts: AHashSet<String>,
78    forwarded_for: Option<String>,
79    accept_invalid_certs: bool,
80    timeout: Duration,
81}
82
83impl Default for ClientBuilder {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl ClientBuilder {
90    /// Creates a new `ClientBuilder`.
91    ///
92    /// Setting the credentials is required to connect to the JMAP API.
93    pub fn new() -> Self {
94        Self {
95            credentials: None,
96            trusted_hosts: AHashSet::new(),
97            timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
98            forwarded_for: None,
99            accept_invalid_certs: false,
100        }
101    }
102
103    /// Set up client credentials to connect to the JMAP API.
104    ///
105    /// The JMAP API URL is set using the [ClientBuilder.connect()](struct.ClientBuilder.html#method.connect) method.
106    ///
107    /// # Bearer authentication
108    /// Pass a `&str` with the API Token.
109    ///
110    /// ```rust
111    /// Client::new().credentials("some-api-token");
112    /// ```
113    ///
114    /// Or use the longer form by using [Credentials::bearer()](enum.Credentials.html#method.bearer).
115    /// ```rust
116    /// let credentials = Credentials::bearer("some-api-token");
117    /// Client::new().credentials(credentials);
118    /// ```
119    ///
120    /// # Basic authentication
121    /// Pass a `(&str, &str)` tuple, with the first position containing a username and the second containing a password.
122    ///
123    /// **It is not suggested to use this approach in production;** instead, if possible, use [Bearer authentication](struct.ClientBuilder.html#bearer-authentication).
124    ///
125    /// ```rust
126    /// Client::new().credentials(("user@domain.com", "password"));
127    /// ```
128    ///
129    /// Or use the longer form by using [Credentials::basic()](enum.Credentials.html#method.basic).
130    /// ```rust
131    /// let credentials = Credentials::basic("user@domain.com", "password");
132    /// Client::new().credentials(credentials);
133    /// ```
134    pub fn credentials(mut self, credentials: impl Into<Credentials>) -> Self {
135        self.credentials = Some(credentials.into());
136        self
137    }
138
139    /// Set a timeout for all the requests to the JMAP API.
140    ///
141    /// The timeout can be changed after the `Client` has been created by using [Client.set_timeout()](struct.Client.html#method.set_timeout).
142    ///
143    /// By default the timeout is 10 seconds.
144    pub fn timeout(mut self, timeout: Duration) -> Self {
145        self.timeout = timeout;
146        self
147    }
148
149    /// Accepts invalid certificates for all the requests to the JMAP API.
150    ///
151    /// By default certificates are validated.
152    ///
153    /// # Warning
154    /// **It is not suggested to use this approach in production;** this method should be used only for testing and as a last resort.
155    ///
156    /// [Read more in the reqwest docs](https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.danger_accept_invalid_certs)
157    pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
158        self.accept_invalid_certs = accept_invalid_certs;
159        self
160    }
161
162    /// Set a list of trusted hosts that will be checked when a redirect is required.
163    ///
164    /// The list can be changed after the `Client` has been created by using [Client.set_follow_redirects()](struct.Client.html#method.set_follow_redirects).
165    ///
166    /// The client will follow at most 5 redirects.
167    pub fn follow_redirects(
168        mut self,
169        trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
170    ) -> Self {
171        self.trusted_hosts = trusted_hosts.into_iter().map(|h| h.into()).collect();
172        self
173    }
174
175    /// Set the originating IP address of the client connecting to the JMAP API.
176    pub fn forwarded_for(mut self, forwarded_for: IpAddr) -> Self {
177        self.forwarded_for = Some(match forwarded_for {
178            IpAddr::V4(addr) => format!("for={}", addr),
179            IpAddr::V6(addr) => format!("for=\"{}\"", addr),
180        });
181        self
182    }
183
184    /// Connects to the JMAP API Session URL.
185    ///
186    /// Setting up [Credentials](struct.ClientBuilder.html#method.credentials) must be done before calling this function.
187    #[maybe_async::maybe_async]
188    pub async fn connect(self, url: &str) -> crate::Result<Client> {
189        let authorization = match self.credentials.expect("Missing credentials") {
190            Credentials::Basic(s) => format!("Basic {}", s),
191            Credentials::Bearer(s) => format!("Bearer {}", s),
192        };
193        let mut headers = header::HeaderMap::new();
194        headers.insert(
195            header::USER_AGENT,
196            header::HeaderValue::from_static(USER_AGENT),
197        );
198        headers.insert(
199            header::AUTHORIZATION,
200            header::HeaderValue::from_str(&authorization).unwrap(),
201        );
202        if let Some(forwarded_for) = self.forwarded_for {
203            headers.insert(
204                header::FORWARDED,
205                header::HeaderValue::from_str(&forwarded_for).unwrap(),
206            );
207        }
208
209        let trusted_hosts = Arc::new(self.trusted_hosts);
210
211        let trusted_hosts_ = trusted_hosts.clone();
212        let session_url = format!("{}/.well-known/jmap", url);
213        let session: Session = serde_json::from_slice(
214            &Client::handle_error(
215                HttpClient::builder()
216                    .timeout(self.timeout)
217                    .danger_accept_invalid_certs(self.accept_invalid_certs)
218                    .redirect(redirect::Policy::custom(move |attempt| {
219                        if attempt.previous().len() > 5 {
220                            attempt.error("Too many redirects.")
221                        } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
222                        {
223                                attempt.follow()
224                        } else {
225                            let message = format!(
226                                "Aborting redirect request to unknown host '{}'.",
227                                attempt.url().host_str().unwrap_or("")
228                            );
229                            attempt.error(message)
230                        }
231                    }))
232                    .default_headers(headers.clone())
233                    .build()?
234                    .get(&session_url)
235                    .send()
236                    .await?,
237            )
238            .await?
239            .bytes()
240            .await?,
241        )?;
242
243        let default_account_id = session
244            .primary_accounts()
245            .next()
246            .map(|a| a.1.to_string())
247            .unwrap_or_default();
248
249        headers.insert(
250            header::CONTENT_TYPE,
251            header::HeaderValue::from_static("application/json"),
252        );
253
254        Ok(Client {
255            download_url: URLPart::parse(session.download_url())?,
256            upload_url: URLPart::parse(session.upload_url())?,
257            #[cfg(feature = "async")]
258            event_source_url: URLPart::parse(session.event_source_url())?,
259            api_url: session.api_url().to_string(),
260            session: parking_lot::Mutex::new(Arc::new(session)),
261            session_url,
262            session_updated: true.into(),
263            accept_invalid_certs: self.accept_invalid_certs,
264            trusted_hosts,
265            #[cfg(feature = "websockets")]
266            authorization,
267            timeout: self.timeout,
268            headers,
269            default_account_id,
270            #[cfg(feature = "websockets")]
271            ws: None.into(),
272        })
273    }
274}
275
276impl Client {
277    #[allow(clippy::new_ret_no_self)]
278    pub fn new() -> ClientBuilder {
279        ClientBuilder::new()
280    }
281
282    pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
283        self.timeout = timeout;
284        self
285    }
286
287    pub fn set_follow_redirects(
288        &mut self,
289        trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
290    ) -> &mut Self {
291        self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
292        self
293    }
294
295    pub fn timeout(&self) -> Duration {
296        self.timeout
297    }
298
299    pub fn session(&self) -> Arc<Session> {
300        self.session.lock().clone()
301    }
302
303    pub fn session_url(&self) -> &str {
304        &self.session_url
305    }
306
307    pub fn headers(&self) -> &header::HeaderMap {
308        &self.headers
309    }
310
311    pub(crate) fn redirect_policy(&self) -> redirect::Policy {
312        let trusted_hosts = self.trusted_hosts.clone();
313        redirect::Policy::custom(move |attempt| {
314            if attempt.previous().len() > 5 {
315                attempt.error("Too many redirects.")
316            } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
317            {
318                attempt.follow()
319            } else {
320                let message = format!(
321                    "Aborting redirect request to unknown host '{}'.",
322                    attempt.url().host_str().unwrap_or("")
323                );
324                attempt.error(message)
325            }
326        })
327    }
328
329    #[maybe_async::maybe_async]
330    pub async fn send<R>(
331        &self,
332        request: &request::Request<'_>,
333    ) -> crate::Result<response::Response<R>>
334    where
335        R: DeserializeOwned,
336    {
337        let response: response::Response<R> = serde_json::from_slice(
338            &Client::handle_error(
339                HttpClient::builder()
340                    .redirect(self.redirect_policy())
341                    .danger_accept_invalid_certs(self.accept_invalid_certs)
342                    .timeout(self.timeout)
343                    .default_headers(self.headers.clone())
344                    .build()?
345                    .post(&self.api_url)
346                    .body(serde_json::to_string(&request)?)
347                    .send()
348                    .await?,
349            )
350            .await?
351            .bytes()
352            .await?,
353        )?;
354
355        if response.session_state() != self.session.lock().state() {
356            self.session_updated.store(false, Ordering::Relaxed);
357        }
358
359        Ok(response)
360    }
361
362    #[maybe_async::maybe_async]
363    pub async fn refresh_session(&self) -> crate::Result<()> {
364        let session: Session = serde_json::from_slice(
365            &Client::handle_error(
366                HttpClient::builder()
367                    .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
368                    .danger_accept_invalid_certs(self.accept_invalid_certs)
369                    .redirect(self.redirect_policy())
370                    .default_headers(self.headers.clone())
371                    .build()?
372                    .get(&self.session_url)
373                    .send()
374                    .await?,
375            )
376            .await?
377            .bytes()
378            .await?,
379        )?;
380        *self.session.lock() = Arc::new(session);
381        self.session_updated.store(true, Ordering::Relaxed);
382        Ok(())
383    }
384
385    pub fn is_session_updated(&self) -> bool {
386        self.session_updated.load(Ordering::Relaxed)
387    }
388
389    pub fn set_default_account_id(&mut self, defaul_account_id: impl Into<String>) -> &mut Self {
390        self.default_account_id = defaul_account_id.into();
391        self
392    }
393
394    pub fn default_account_id(&self) -> &str {
395        &self.default_account_id
396    }
397
398    pub fn build(&self) -> Request<'_> {
399        Request::new(self)
400    }
401
402    pub fn download_url(&self) -> &[URLPart<blob::URLParameter>] {
403        &self.download_url
404    }
405
406    pub fn upload_url(&self) -> &[URLPart<blob::URLParameter>] {
407        &self.upload_url
408    }
409
410    #[cfg(feature = "async")]
411    pub fn event_source_url(&self) -> &[URLPart<crate::event_source::URLParameter>] {
412        &self.event_source_url
413    }
414
415    #[maybe_async::maybe_async]
416    pub async fn handle_error(response: Response) -> crate::Result<Response> {
417        if response.status().is_success() {
418            Ok(response)
419        } else if let Some(b"application/problem+json") = response
420            .headers()
421            .get(header::CONTENT_TYPE)
422            .map(|h| h.as_bytes())
423        {
424            Err(Error::Problem(serde_json::from_slice(
425                &response.bytes().await?,
426            )?))
427        } else {
428            Err(Error::Server(format!("{}", response.status())))
429        }
430    }
431}
432
433impl Credentials {
434    pub fn basic(username: &str, password: &str) -> Self {
435        Credentials::Basic(base64::encode(format!("{}:{}", username, password)))
436    }
437
438    pub fn bearer(token: impl Into<String>) -> Self {
439        Credentials::Bearer(token.into())
440    }
441}
442
443impl From<&str> for Credentials {
444    fn from(s: &str) -> Self {
445        Credentials::bearer(s.to_string())
446    }
447}
448
449impl From<String> for Credentials {
450    fn from(s: String) -> Self {
451        Credentials::bearer(s)
452    }
453}
454
455impl From<(&str, &str)> for Credentials {
456    fn from((username, password): (&str, &str)) -> Self {
457        Credentials::basic(username, password)
458    }
459}
460
461impl From<(String, String)> for Credentials {
462    fn from((username, password): (String, String)) -> Self {
463        Credentials::basic(&username, &password)
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use crate::core::response::{Response, TaggedMethodResponse};
470
471    #[test]
472    fn test_deserialize() {
473        let _r: Response<TaggedMethodResponse> = serde_json::from_slice(
474            br#"{"sessionState": "123", "methodResponses": [[ "Email/query", {
475                "accountId": "A1",
476                "queryState": "abcdefg",
477                "canCalculateChanges": true,
478                "position": 0,
479                "total": 101,
480                "ids": [ "msg1023", "msg223", "msg110", "msg93", "msg91",
481                    "msg38", "msg36", "msg33", "msg11", "msg1" ]
482            }, "t0" ],
483            [ "Email/get", {
484                "accountId": "A1",
485                "state": "123456",
486                "list": [{
487                    "id": "msg1023",
488                    "threadId": "trd194"
489                }, {
490                    "id": "msg223",
491                    "threadId": "trd114"
492                }
493                ],
494                "notFound": []
495            }, "t1" ],
496            [ "Thread/get", {
497                "accountId": "A1",
498                "state": "123456",
499                "list": [{
500                    "id": "trd194",
501                    "emailIds": [ "msg1020", "msg1021", "msg1023" ]
502                }, {
503                    "id": "trd114",
504                    "emailIds": [ "msg201", "msg223" ]
505                }
506                ],
507                "notFound": []
508            }, "t2" ]]}"#,
509        )
510        .unwrap();
511    }
512}