1use crate::{
13 blob,
14 core::{
15 request::{self, Request},
16 response,
17 session::{Session, URLPart},
18 },
19 Error,
20};
21use ahash::AHashSet;
22use base64::{engine::general_purpose, Engine};
23#[cfg(feature = "blocking")]
24use reqwest::blocking::{Client as HttpClient, Response};
25use reqwest::{
26 header::{self},
27 redirect,
28};
29#[cfg(feature = "async")]
30use reqwest::{Client as HttpClient, Response};
31use serde::de::DeserializeOwned;
32use std::{
33 net::IpAddr,
34 sync::{
35 atomic::{AtomicBool, Ordering},
36 Arc,
37 },
38 time::Duration,
39};
40
41const DEFAULT_TIMEOUT_MS: u64 = 10 * 1000;
42static USER_AGENT: &str = concat!("jmap-client/", env!("CARGO_PKG_VERSION"));
43
44#[derive(Debug, PartialEq, Eq)]
45pub enum Credentials {
46 Basic(String),
47 Bearer(String),
48}
49
50pub struct Client {
51 session: parking_lot::Mutex<Arc<Session>>,
52 session_url: String,
53 api_url: String,
54 session_updated: AtomicBool,
55 trusted_hosts: Arc<AHashSet<String>>,
56
57 upload_url: Vec<URLPart<blob::URLParameter>>,
58 download_url: Vec<URLPart<blob::URLParameter>>,
59 #[cfg(feature = "async")]
60 event_source_url: Vec<URLPart<crate::event_source::URLParameter>>,
61
62 headers: header::HeaderMap,
63 default_account_id: String,
64 timeout: Duration,
65 pub(crate) accept_invalid_certs: bool,
66
67 #[cfg(feature = "websockets")]
68 pub(crate) authorization: String,
69 #[cfg(feature = "websockets")]
70 pub(crate) ws: tokio::sync::Mutex<Option<crate::client_ws::WsStream>>,
71}
72
73pub struct ClientBuilder {
74 credentials: Option<Credentials>,
75 trusted_hosts: AHashSet<String>,
76 forwarded_for: Option<String>,
77 accept_invalid_certs: bool,
78 timeout: Duration,
79}
80
81impl Default for ClientBuilder {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl ClientBuilder {
88 pub fn new() -> Self {
92 Self {
93 credentials: None,
94 trusted_hosts: AHashSet::new(),
95 timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
96 forwarded_for: None,
97 accept_invalid_certs: false,
98 }
99 }
100
101 pub fn credentials(mut self, credentials: impl Into<Credentials>) -> Self {
133 self.credentials = Some(credentials.into());
134 self
135 }
136
137 pub fn timeout(mut self, timeout: Duration) -> Self {
143 self.timeout = timeout;
144 self
145 }
146
147 pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
156 self.accept_invalid_certs = accept_invalid_certs;
157 self
158 }
159
160 pub fn follow_redirects(
166 mut self,
167 trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
168 ) -> Self {
169 self.trusted_hosts = trusted_hosts.into_iter().map(|h| h.into()).collect();
170 self
171 }
172
173 pub fn forwarded_for(mut self, forwarded_for: IpAddr) -> Self {
175 self.forwarded_for = Some(match forwarded_for {
176 IpAddr::V4(addr) => format!("for={}", addr),
177 IpAddr::V6(addr) => format!("for=\"{}\"", addr),
178 });
179 self
180 }
181
182 #[maybe_async::maybe_async]
186 pub async fn connect(self, url: &str) -> crate::Result<Client> {
187 let authorization = match self.credentials.expect("Missing credentials") {
188 Credentials::Basic(s) => format!("Basic {}", s),
189 Credentials::Bearer(s) => format!("Bearer {}", s),
190 };
191 let mut headers = header::HeaderMap::new();
192 headers.insert(
193 header::USER_AGENT,
194 header::HeaderValue::from_static(USER_AGENT),
195 );
196 headers.insert(
197 header::AUTHORIZATION,
198 header::HeaderValue::from_str(&authorization).unwrap(),
199 );
200 if let Some(forwarded_for) = self.forwarded_for {
201 headers.insert(
202 header::FORWARDED,
203 header::HeaderValue::from_str(&forwarded_for).unwrap(),
204 );
205 }
206
207 let trusted_hosts = Arc::new(self.trusted_hosts);
208
209 let trusted_hosts_ = trusted_hosts.clone();
210 let session_url = format!("{}/.well-known/jmap", url);
211 let session: Session = serde_json::from_slice(
212 &Client::handle_error(
213 HttpClient::builder()
214 .timeout(self.timeout)
215 .danger_accept_invalid_certs(self.accept_invalid_certs)
216 .redirect(redirect::Policy::custom(move |attempt| {
217 if attempt.previous().len() > 5 {
218 attempt.error("Too many redirects.")
219 } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
220 {
221 attempt.follow()
222 } else {
223 let message = format!(
224 "Aborting redirect request to unknown host '{}'.",
225 attempt.url().host_str().unwrap_or("")
226 );
227 attempt.error(message)
228 }
229 }))
230 .default_headers(headers.clone())
231 .build()?
232 .get(&session_url)
233 .send()
234 .await?,
235 )
236 .await?
237 .bytes()
238 .await?,
239 )?;
240
241 let default_account_id = session
242 .primary_accounts()
243 .next()
244 .map(|a| a.1.to_string())
245 .unwrap_or_default();
246
247 headers.insert(
248 header::CONTENT_TYPE,
249 header::HeaderValue::from_static("application/json"),
250 );
251
252 Ok(Client {
253 download_url: URLPart::parse(session.download_url())?,
254 upload_url: URLPart::parse(session.upload_url())?,
255 #[cfg(feature = "async")]
256 event_source_url: URLPart::parse(session.event_source_url())?,
257 api_url: session.api_url().to_string(),
258 session: parking_lot::Mutex::new(Arc::new(session)),
259 session_url,
260 session_updated: true.into(),
261 accept_invalid_certs: self.accept_invalid_certs,
262 trusted_hosts,
263 #[cfg(feature = "websockets")]
264 authorization,
265 timeout: self.timeout,
266 headers,
267 default_account_id,
268 #[cfg(feature = "websockets")]
269 ws: None.into(),
270 })
271 }
272}
273
274impl Client {
275 #[allow(clippy::new_ret_no_self)]
276 pub fn new() -> ClientBuilder {
277 ClientBuilder::new()
278 }
279
280 pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
281 self.timeout = timeout;
282 self
283 }
284
285 pub fn set_follow_redirects(
286 &mut self,
287 trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
288 ) -> &mut Self {
289 self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
290 self
291 }
292
293 pub fn timeout(&self) -> Duration {
294 self.timeout
295 }
296
297 pub fn session(&self) -> Arc<Session> {
298 self.session.lock().clone()
299 }
300
301 pub fn session_url(&self) -> &str {
302 &self.session_url
303 }
304
305 pub fn headers(&self) -> &header::HeaderMap {
306 &self.headers
307 }
308
309 pub(crate) fn redirect_policy(&self) -> redirect::Policy {
310 let trusted_hosts = self.trusted_hosts.clone();
311 redirect::Policy::custom(move |attempt| {
312 if attempt.previous().len() > 5 {
313 attempt.error("Too many redirects.")
314 } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
315 {
316 attempt.follow()
317 } else {
318 let message = format!(
319 "Aborting redirect request to unknown host '{}'.",
320 attempt.url().host_str().unwrap_or("")
321 );
322 attempt.error(message)
323 }
324 })
325 }
326
327 #[maybe_async::maybe_async]
328 pub async fn send<R>(
329 &self,
330 request: &request::Request<'_>,
331 ) -> crate::Result<response::Response<R>>
332 where
333 R: DeserializeOwned,
334 {
335 let response: response::Response<R> = serde_json::from_slice(
336 &Client::handle_error(
337 HttpClient::builder()
338 .redirect(self.redirect_policy())
339 .danger_accept_invalid_certs(self.accept_invalid_certs)
340 .timeout(self.timeout)
341 .default_headers(self.headers.clone())
342 .build()?
343 .post(&self.api_url)
344 .body(serde_json::to_string(&request)?)
345 .send()
346 .await?,
347 )
348 .await?
349 .bytes()
350 .await?,
351 )?;
352
353 if response.session_state() != self.session.lock().state() {
354 self.session_updated.store(false, Ordering::Relaxed);
355 }
356
357 Ok(response)
358 }
359
360 #[maybe_async::maybe_async]
361 pub async fn refresh_session(&self) -> crate::Result<()> {
362 let session: Session = serde_json::from_slice(
363 &Client::handle_error(
364 HttpClient::builder()
365 .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
366 .danger_accept_invalid_certs(self.accept_invalid_certs)
367 .redirect(self.redirect_policy())
368 .default_headers(self.headers.clone())
369 .build()?
370 .get(&self.session_url)
371 .send()
372 .await?,
373 )
374 .await?
375 .bytes()
376 .await?,
377 )?;
378 *self.session.lock() = Arc::new(session);
379 self.session_updated.store(true, Ordering::Relaxed);
380 Ok(())
381 }
382
383 pub fn is_session_updated(&self) -> bool {
384 self.session_updated.load(Ordering::Relaxed)
385 }
386
387 pub fn set_default_account_id(&mut self, defaul_account_id: impl Into<String>) -> &mut Self {
388 self.default_account_id = defaul_account_id.into();
389 self
390 }
391
392 pub fn default_account_id(&self) -> &str {
393 &self.default_account_id
394 }
395
396 pub fn build(&self) -> Request<'_> {
397 Request::new(self)
398 }
399
400 pub fn download_url(&self) -> &[URLPart<blob::URLParameter>] {
401 &self.download_url
402 }
403
404 pub fn upload_url(&self) -> &[URLPart<blob::URLParameter>] {
405 &self.upload_url
406 }
407
408 #[cfg(feature = "async")]
409 pub fn event_source_url(&self) -> &[URLPart<crate::event_source::URLParameter>] {
410 &self.event_source_url
411 }
412
413 #[maybe_async::maybe_async]
414 pub async fn handle_error(response: Response) -> crate::Result<Response> {
415 if response.status().is_success() {
416 Ok(response)
417 } else if let Some(b"application/problem+json") = response
418 .headers()
419 .get(header::CONTENT_TYPE)
420 .map(|h| h.as_bytes())
421 {
422 Err(Error::Problem(serde_json::from_slice(
423 &response.bytes().await?,
424 )?))
425 } else {
426 Err(Error::Server(format!("{}", response.status())))
427 }
428 }
429}
430
431impl Credentials {
432 pub fn basic(username: &str, password: &str) -> Self {
433 Credentials::Basic(general_purpose::STANDARD.encode(format!("{}:{}", username, password)))
434 }
435
436 pub fn bearer(token: impl Into<String>) -> Self {
437 Credentials::Bearer(token.into())
438 }
439}
440
441impl From<&str> for Credentials {
442 fn from(s: &str) -> Self {
443 Credentials::bearer(s.to_string())
444 }
445}
446
447impl From<String> for Credentials {
448 fn from(s: String) -> Self {
449 Credentials::bearer(s)
450 }
451}
452
453impl From<(&str, &str)> for Credentials {
454 fn from((username, password): (&str, &str)) -> Self {
455 Credentials::basic(username, password)
456 }
457}
458
459impl From<(String, String)> for Credentials {
460 fn from((username, password): (String, String)) -> Self {
461 Credentials::basic(&username, &password)
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use crate::core::response::{Response, TaggedMethodResponse};
468
469 #[test]
470 fn test_deserialize() {
471 let _r: Response<TaggedMethodResponse> = serde_json::from_slice(
472 br#"{"sessionState": "123", "methodResponses": [[ "Email/query", {
473 "accountId": "A1",
474 "queryState": "abcdefg",
475 "canCalculateChanges": true,
476 "position": 0,
477 "total": 101,
478 "ids": [ "msg1023", "msg223", "msg110", "msg93", "msg91",
479 "msg38", "msg36", "msg33", "msg11", "msg1" ]
480 }, "t0" ],
481 [ "Email/get", {
482 "accountId": "A1",
483 "state": "123456",
484 "list": [{
485 "id": "msg1023",
486 "threadId": "trd194"
487 }, {
488 "id": "msg223",
489 "threadId": "trd114"
490 }
491 ],
492 "notFound": []
493 }, "t1" ],
494 [ "Thread/get", {
495 "accountId": "A1",
496 "state": "123456",
497 "list": [{
498 "id": "trd194",
499 "emailIds": [ "msg1020", "msg1021", "msg1023" ]
500 }, {
501 "id": "trd114",
502 "emailIds": [ "msg201", "msg223" ]
503 }
504 ],
505 "notFound": []
506 }, "t2" ]]}"#,
507 )
508 .unwrap();
509 }
510}