1use std::{
13 net::IpAddr,
14 sync::{
15 atomic::{AtomicBool, Ordering},
16 Arc,
17 },
18 time::Duration,
19};
20
21use ahash::AHashSet;
22#[cfg(feature = "blocking")]
23use reqwest::blocking::{Client as HttpClient, Response};
24use reqwest::{
25 header::{self},
26 redirect,
27};
28#[cfg(feature = "async")]
29use reqwest::{Client as HttpClient, Response};
30
31use serde::de::DeserializeOwned;
32
33use crate::{
34 blob,
35 core::{
36 request::{self, Request},
37 response,
38 session::{Session, URLPart},
39 },
40 Error,
41};
42
43const DEFAULT_TIMEOUT_MS: u64 = 10 * 1000;
44static USER_AGENT: &str = concat!("jmap-client/", env!("CARGO_PKG_VERSION"));
45
46#[derive(Debug, PartialEq, Eq)]
47pub enum Credentials {
48 Basic(String),
49 Bearer(String),
50}
51
52pub struct Client {
53 session: parking_lot::Mutex<Arc<Session>>,
54 session_url: String,
55 api_url: String,
56 session_updated: AtomicBool,
57 trusted_hosts: Arc<AHashSet<String>>,
58
59 upload_url: Vec<URLPart<blob::URLParameter>>,
60 download_url: Vec<URLPart<blob::URLParameter>>,
61 #[cfg(feature = "async")]
62 event_source_url: Vec<URLPart<crate::event_source::URLParameter>>,
63
64 headers: header::HeaderMap,
65 default_account_id: String,
66 timeout: Duration,
67 pub(crate) accept_invalid_certs: bool,
68
69 #[cfg(feature = "websockets")]
70 pub(crate) authorization: String,
71 #[cfg(feature = "websockets")]
72 pub(crate) ws: tokio::sync::Mutex<Option<crate::client_ws::WsStream>>,
73}
74
75pub struct ClientBuilder {
76 credentials: Option<Credentials>,
77 trusted_hosts: AHashSet<String>,
78 forwarded_for: Option<String>,
79 accept_invalid_certs: bool,
80 timeout: Duration,
81}
82
83impl Default for ClientBuilder {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl ClientBuilder {
90 pub fn new() -> Self {
94 Self {
95 credentials: None,
96 trusted_hosts: AHashSet::new(),
97 timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS),
98 forwarded_for: None,
99 accept_invalid_certs: false,
100 }
101 }
102
103 pub fn credentials(mut self, credentials: impl Into<Credentials>) -> Self {
135 self.credentials = Some(credentials.into());
136 self
137 }
138
139 pub fn timeout(mut self, timeout: Duration) -> Self {
145 self.timeout = timeout;
146 self
147 }
148
149 pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
158 self.accept_invalid_certs = accept_invalid_certs;
159 self
160 }
161
162 pub fn follow_redirects(
168 mut self,
169 trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
170 ) -> Self {
171 self.trusted_hosts = trusted_hosts.into_iter().map(|h| h.into()).collect();
172 self
173 }
174
175 pub fn forwarded_for(mut self, forwarded_for: IpAddr) -> Self {
177 self.forwarded_for = Some(match forwarded_for {
178 IpAddr::V4(addr) => format!("for={}", addr),
179 IpAddr::V6(addr) => format!("for=\"{}\"", addr),
180 });
181 self
182 }
183
184 #[maybe_async::maybe_async]
188 pub async fn connect(self, url: &str) -> crate::Result<Client> {
189 let authorization = match self.credentials.expect("Missing credentials") {
190 Credentials::Basic(s) => format!("Basic {}", s),
191 Credentials::Bearer(s) => format!("Bearer {}", s),
192 };
193 let mut headers = header::HeaderMap::new();
194 headers.insert(
195 header::USER_AGENT,
196 header::HeaderValue::from_static(USER_AGENT),
197 );
198 headers.insert(
199 header::AUTHORIZATION,
200 header::HeaderValue::from_str(&authorization).unwrap(),
201 );
202 if let Some(forwarded_for) = self.forwarded_for {
203 headers.insert(
204 header::FORWARDED,
205 header::HeaderValue::from_str(&forwarded_for).unwrap(),
206 );
207 }
208
209 let trusted_hosts = Arc::new(self.trusted_hosts);
210
211 let trusted_hosts_ = trusted_hosts.clone();
212 let session_url = format!("{}/.well-known/jmap", url);
213 let session: Session = serde_json::from_slice(
214 &Client::handle_error(
215 HttpClient::builder()
216 .timeout(self.timeout)
217 .danger_accept_invalid_certs(self.accept_invalid_certs)
218 .redirect(redirect::Policy::custom(move |attempt| {
219 if attempt.previous().len() > 5 {
220 attempt.error("Too many redirects.")
221 } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts_.contains(host) )
222 {
223 attempt.follow()
224 } else {
225 let message = format!(
226 "Aborting redirect request to unknown host '{}'.",
227 attempt.url().host_str().unwrap_or("")
228 );
229 attempt.error(message)
230 }
231 }))
232 .default_headers(headers.clone())
233 .build()?
234 .get(&session_url)
235 .send()
236 .await?,
237 )
238 .await?
239 .bytes()
240 .await?,
241 )?;
242
243 let default_account_id = session
244 .primary_accounts()
245 .next()
246 .map(|a| a.1.to_string())
247 .unwrap_or_default();
248
249 headers.insert(
250 header::CONTENT_TYPE,
251 header::HeaderValue::from_static("application/json"),
252 );
253
254 Ok(Client {
255 download_url: URLPart::parse(session.download_url())?,
256 upload_url: URLPart::parse(session.upload_url())?,
257 #[cfg(feature = "async")]
258 event_source_url: URLPart::parse(session.event_source_url())?,
259 api_url: session.api_url().to_string(),
260 session: parking_lot::Mutex::new(Arc::new(session)),
261 session_url,
262 session_updated: true.into(),
263 accept_invalid_certs: self.accept_invalid_certs,
264 trusted_hosts,
265 #[cfg(feature = "websockets")]
266 authorization,
267 timeout: self.timeout,
268 headers,
269 default_account_id,
270 #[cfg(feature = "websockets")]
271 ws: None.into(),
272 })
273 }
274}
275
276impl Client {
277 #[allow(clippy::new_ret_no_self)]
278 pub fn new() -> ClientBuilder {
279 ClientBuilder::new()
280 }
281
282 pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
283 self.timeout = timeout;
284 self
285 }
286
287 pub fn set_follow_redirects(
288 &mut self,
289 trusted_hosts: impl IntoIterator<Item = impl Into<String>>,
290 ) -> &mut Self {
291 self.trusted_hosts = Arc::new(trusted_hosts.into_iter().map(|h| h.into()).collect());
292 self
293 }
294
295 pub fn timeout(&self) -> Duration {
296 self.timeout
297 }
298
299 pub fn session(&self) -> Arc<Session> {
300 self.session.lock().clone()
301 }
302
303 pub fn session_url(&self) -> &str {
304 &self.session_url
305 }
306
307 pub fn headers(&self) -> &header::HeaderMap {
308 &self.headers
309 }
310
311 pub(crate) fn redirect_policy(&self) -> redirect::Policy {
312 let trusted_hosts = self.trusted_hosts.clone();
313 redirect::Policy::custom(move |attempt| {
314 if attempt.previous().len() > 5 {
315 attempt.error("Too many redirects.")
316 } else if matches!( attempt.url().host_str(), Some(host) if trusted_hosts.contains(host) )
317 {
318 attempt.follow()
319 } else {
320 let message = format!(
321 "Aborting redirect request to unknown host '{}'.",
322 attempt.url().host_str().unwrap_or("")
323 );
324 attempt.error(message)
325 }
326 })
327 }
328
329 #[maybe_async::maybe_async]
330 pub async fn send<R>(
331 &self,
332 request: &request::Request<'_>,
333 ) -> crate::Result<response::Response<R>>
334 where
335 R: DeserializeOwned,
336 {
337 let response: response::Response<R> = serde_json::from_slice(
338 &Client::handle_error(
339 HttpClient::builder()
340 .redirect(self.redirect_policy())
341 .danger_accept_invalid_certs(self.accept_invalid_certs)
342 .timeout(self.timeout)
343 .default_headers(self.headers.clone())
344 .build()?
345 .post(&self.api_url)
346 .body(serde_json::to_string(&request)?)
347 .send()
348 .await?,
349 )
350 .await?
351 .bytes()
352 .await?,
353 )?;
354
355 if response.session_state() != self.session.lock().state() {
356 self.session_updated.store(false, Ordering::Relaxed);
357 }
358
359 Ok(response)
360 }
361
362 #[maybe_async::maybe_async]
363 pub async fn refresh_session(&self) -> crate::Result<()> {
364 let session: Session = serde_json::from_slice(
365 &Client::handle_error(
366 HttpClient::builder()
367 .timeout(Duration::from_millis(DEFAULT_TIMEOUT_MS))
368 .danger_accept_invalid_certs(self.accept_invalid_certs)
369 .redirect(self.redirect_policy())
370 .default_headers(self.headers.clone())
371 .build()?
372 .get(&self.session_url)
373 .send()
374 .await?,
375 )
376 .await?
377 .bytes()
378 .await?,
379 )?;
380 *self.session.lock() = Arc::new(session);
381 self.session_updated.store(true, Ordering::Relaxed);
382 Ok(())
383 }
384
385 pub fn is_session_updated(&self) -> bool {
386 self.session_updated.load(Ordering::Relaxed)
387 }
388
389 pub fn set_default_account_id(&mut self, defaul_account_id: impl Into<String>) -> &mut Self {
390 self.default_account_id = defaul_account_id.into();
391 self
392 }
393
394 pub fn default_account_id(&self) -> &str {
395 &self.default_account_id
396 }
397
398 pub fn build(&self) -> Request<'_> {
399 Request::new(self)
400 }
401
402 pub fn download_url(&self) -> &[URLPart<blob::URLParameter>] {
403 &self.download_url
404 }
405
406 pub fn upload_url(&self) -> &[URLPart<blob::URLParameter>] {
407 &self.upload_url
408 }
409
410 #[cfg(feature = "async")]
411 pub fn event_source_url(&self) -> &[URLPart<crate::event_source::URLParameter>] {
412 &self.event_source_url
413 }
414
415 #[maybe_async::maybe_async]
416 pub async fn handle_error(response: Response) -> crate::Result<Response> {
417 if response.status().is_success() {
418 Ok(response)
419 } else if let Some(b"application/problem+json") = response
420 .headers()
421 .get(header::CONTENT_TYPE)
422 .map(|h| h.as_bytes())
423 {
424 Err(Error::Problem(serde_json::from_slice(
425 &response.bytes().await?,
426 )?))
427 } else {
428 Err(Error::Server(format!("{}", response.status())))
429 }
430 }
431}
432
433impl Credentials {
434 pub fn basic(username: &str, password: &str) -> Self {
435 Credentials::Basic(base64::encode(format!("{}:{}", username, password)))
436 }
437
438 pub fn bearer(token: impl Into<String>) -> Self {
439 Credentials::Bearer(token.into())
440 }
441}
442
443impl From<&str> for Credentials {
444 fn from(s: &str) -> Self {
445 Credentials::bearer(s.to_string())
446 }
447}
448
449impl From<String> for Credentials {
450 fn from(s: String) -> Self {
451 Credentials::bearer(s)
452 }
453}
454
455impl From<(&str, &str)> for Credentials {
456 fn from((username, password): (&str, &str)) -> Self {
457 Credentials::basic(username, password)
458 }
459}
460
461impl From<(String, String)> for Credentials {
462 fn from((username, password): (String, String)) -> Self {
463 Credentials::basic(&username, &password)
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use crate::core::response::{Response, TaggedMethodResponse};
470
471 #[test]
472 fn test_deserialize() {
473 let _r: Response<TaggedMethodResponse> = serde_json::from_slice(
474 br#"{"sessionState": "123", "methodResponses": [[ "Email/query", {
475 "accountId": "A1",
476 "queryState": "abcdefg",
477 "canCalculateChanges": true,
478 "position": 0,
479 "total": 101,
480 "ids": [ "msg1023", "msg223", "msg110", "msg93", "msg91",
481 "msg38", "msg36", "msg33", "msg11", "msg1" ]
482 }, "t0" ],
483 [ "Email/get", {
484 "accountId": "A1",
485 "state": "123456",
486 "list": [{
487 "id": "msg1023",
488 "threadId": "trd194"
489 }, {
490 "id": "msg223",
491 "threadId": "trd114"
492 }
493 ],
494 "notFound": []
495 }, "t1" ],
496 [ "Thread/get", {
497 "accountId": "A1",
498 "state": "123456",
499 "list": [{
500 "id": "trd194",
501 "emailIds": [ "msg1020", "msg1021", "msg1023" ]
502 }, {
503 "id": "trd114",
504 "emailIds": [ "msg201", "msg223" ]
505 }
506 ],
507 "notFound": []
508 }, "t2" ]]}"#,
509 )
510 .unwrap();
511 }
512}