google_oauth/
async_client.rs1#![allow(non_upper_case_globals)]
2
3use std::ops::Add;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use lazy_static::lazy_static;
7use log::debug;
8use async_lock::RwLock;
9use crate::{DEFAULT_TIMEOUT, GOOGLE_OAUTH_V3_USER_INFO_API, GOOGLE_SA_CERTS_URL, GoogleAccessTokenPayload, GooglePayload, MyResult, utils};
10use crate::certs::{Cert, Certs};
11use crate::jwt_parser::JwtParser;
12use crate::validate::id_token;
13
14lazy_static! {
15 static ref ca: reqwest::Client = reqwest::Client::new();
16}
17
18#[derive(Debug, Clone)]
20pub struct AsyncClient {
21 client_ids: Arc<RwLock<Vec<String>>>,
22 timeout: Duration,
23 cached_certs: Arc<RwLock<Certs>>,
24}
25
26impl AsyncClient {
27 pub fn new<S: ToString>(client_id: S) -> Self {
29 let client_id = client_id.to_string();
30 Self::new_with_vec([client_id])
31 }
32
33 pub fn new_with_vec<T, V>(client_ids: T) -> Self
35 where
36 T: AsRef<[V]>,
37 V: AsRef<str>,
38 {
39 Self {
40 client_ids: Arc::new(RwLock::new(
41 client_ids
42 .as_ref()
43 .iter()
44 .map(|c| c.as_ref())
45 .filter(|c| !c.is_empty())
46 .map(|c| c.to_string())
47 .collect()
48 )),
49 timeout: Duration::from_secs(DEFAULT_TIMEOUT),
50 cached_certs: Arc::default(),
51 }
52 }
53
54 pub async fn add_client_id<T: ToString>(&mut self, client_id: T) {
58 let client_id = client_id.to_string();
59
60 if !client_id.is_empty() {
61 if self.client_ids.read().await.contains(&client_id) {
63 return
64 }
65
66 self.client_ids.write().await.push(client_id)
67 }
68 }
69
70 pub async fn remove_client_id<T: AsRef<str>>(&mut self, client_id: T) {
74 let to_delete = client_id.as_ref();
75
76 if !to_delete.is_empty() {
77 let mut client_ids = self.client_ids.write().await;
78 client_ids.retain(|id| id != to_delete)
79 }
80 }
81
82 pub fn timeout(mut self, d: Duration) -> Self {
85 if !d.is_zero() {
86 self.timeout = d;
87 }
88
89 self
90 }
91
92 pub async fn validate_id_token<S>(&self, token: S) -> MyResult<GooglePayload>
94 where S: AsRef<str>
95 {
96 let token = token.as_ref();
97 let client_ids = self.client_ids.read().await;
98
99 let parser = JwtParser::parse(token)?;
100 id_token::validate_info(&*client_ids, &parser)?;
101
102 let cert = self.get_cert(&parser.header.alg, &parser.header.kid).await?;
103 id_token::do_validate(&cert, &parser)?;
104
105 Ok(parser.payload)
106 }
107
108 async fn get_cert(&self, alg: &str, kid: &str) -> MyResult<Cert> {
109 {
110 let cached_certs = self.cached_certs.read().await;
111 if !cached_certs.need_refresh() {
112 debug!("certs: use cache");
113 return cached_certs.find_cert(alg, kid);
114 }
115 }
116
117 debug!("certs: try to fetch new certs");
118
119 let mut cached_certs = self.cached_certs.write().await;
120
121 let resp = ca.get(GOOGLE_SA_CERTS_URL)
123 .timeout(self.timeout)
124 .send()
125 .await?;
126
127 let max_age = utils::parse_max_age_from_async_resp(&resp);
129
130 let info = resp.bytes().await?;
131 *cached_certs = serde_json::from_slice(&info)?;
132
133 cached_certs.set_cache_until(Instant::now().add(Duration::from_secs(max_age)));
134 cached_certs.find_cert(alg, kid)
135 }
136
137 pub async fn validate_access_token<S>(&self, token: S) -> MyResult<GoogleAccessTokenPayload>
139 where S: AsRef<str>
140 {
141 let token = token.as_ref();
142
143 let info = ca.get(format!("{}?access_token={}", GOOGLE_OAUTH_V3_USER_INFO_API, token))
144 .timeout(self.timeout)
145 .send()
146 .await?
147 .bytes()
148 .await?;
149
150 Ok(serde_json::from_slice(&info)?)
151 }
152}
153
154impl Default for AsyncClient {
155 fn default() -> Self {
156 Self::new_with_vec::<&[_; 0], &'static str>(&[])
157 }
158}