google_oauth/
async_client.rs1#![allow(non_upper_case_globals)]
2
3use std::ops::Add;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use lazy_static::lazy_static;
7use log::debug;
8use async_lock::RwLock;
9use crate::{DEFAULT_TIMEOUT, GOOGLE_OAUTH_V3_USER_INFO_API, GOOGLE_SA_CERTS_URL, GoogleAccessTokenPayload, GooglePayload, MyResult, utils};
10use crate::certs::{Cert, Certs};
11use crate::jwt_parser::JwtParser;
12use crate::validate::id_token;
13
14lazy_static! {
15 static ref ca: reqwest::Client = reqwest::Client::new();
16}
17
18#[derive(Debug, Clone)]
20pub struct AsyncClient {
21 client_ids: Arc<RwLock<Vec<String>>>,
22 timeout: Duration,
23 cached_certs: Arc<RwLock<Certs>>,
24}
25
26impl AsyncClient {
27 pub fn new<S: ToString>(client_id: S) -> Self {
29 let client_id = client_id.to_string();
30 Self::new_with_vec([client_id])
31 }
32
33 pub fn new_with_vec<T, V>(client_ids: T) -> Self
35 where
36 T: AsRef<[V]>,
37 V: AsRef<str>,
38 {
39 Self {
40 client_ids: Arc::new(RwLock::new(
41 client_ids
42 .as_ref()
43 .iter()
44 .map(|c| c.as_ref())
45 .filter(|c| !c.is_empty())
46 .map(|c| c.to_string())
47 .collect()
48 )),
49 timeout: Duration::from_secs(DEFAULT_TIMEOUT),
50 cached_certs: Arc::default(),
51 }
52 }
53
54 pub async fn add_client_id<T: ToString>(&mut self, client_id: T) {
58 let client_id = client_id.to_string();
59
60 if !client_id.is_empty() {
61 if self.client_ids.read().await.contains(&client_id) {
63 return
64 }
65
66 self.client_ids.write().await.push(client_id)
67 }
68 }
69
70 pub async fn remove_client_id<T: AsRef<str>>(&mut self, client_id: T) {
74 let to_delete = client_id.as_ref();
75
76 if !to_delete.is_empty() {
77 let mut client_ids = self.client_ids.write().await;
78 client_ids.retain(|id| id != to_delete)
79 }
80 }
81
82 pub fn timeout(mut self, d: Duration) -> Self {
85 if d.as_nanos() != 0 {
86 self.timeout = d;
87 }
88
89 self
90 }
91
92 pub async fn validate_id_token<S>(&self, token: S) -> MyResult<GooglePayload>
94 where S: AsRef<str>
95 {
96 let token = token.as_ref();
97 let client_ids = self.client_ids.read().await;
98
99 let parser: JwtParser<GooglePayload> = JwtParser::parse(token)?;
100
101 id_token::validate_info(&*client_ids, &parser)?;
102
103 let cert = self.get_cert(parser.header.alg.as_str(), parser.header.kid.as_str()).await?;
104
105 id_token::do_validate(&cert, &parser)?;
106
107 Ok(parser.payload)
108 }
109
110 async fn get_cert(&self, alg: &str, kid: &str) -> MyResult<Cert> {
111 {
112 let cached_certs = self.cached_certs.read().await;
113 if !cached_certs.need_refresh() {
114 debug!("certs: use cache");
115 return cached_certs.find_cert(alg, kid);
116 }
117 }
118
119 debug!("certs: try to fetch new certs");
120
121 let mut cached_certs = self.cached_certs.write().await;
122
123 let resp = ca.get(GOOGLE_SA_CERTS_URL)
125 .timeout(self.timeout)
126 .send()
127 .await?;
128
129 let max_age = utils::parse_max_age_from_async_resp(&resp);
131
132 let text = resp.text().await?;
133 *cached_certs = serde_json::from_str(&text)?;
134
135 cached_certs.set_cache_until(
136 Instant::now().add(Duration::from_secs(max_age))
137 );
138
139 cached_certs.find_cert(alg, kid)
140 }
141
142 pub async fn validate_access_token<S>(&self, token: S) -> MyResult<GoogleAccessTokenPayload>
144 where S: AsRef<str>
145 {
146 let token = token.as_ref();
147
148 let info = ca.get(format!("{}?access_token={}", GOOGLE_OAUTH_V3_USER_INFO_API, token))
149 .timeout(self.timeout)
150 .send()
151 .await?
152 .text()
153 .await?;
154
155 let payload = serde_json::from_str(&info)?;
156
157 Ok(payload)
158 }
159}
160
161impl Default for AsyncClient {
162 fn default() -> Self {
163 Self::new_with_vec::<&[_; 0], &'static str>(&[])
164 }
165}