google_oauth/
async_client.rs

1#![allow(non_upper_case_globals)]
2
3use std::ops::Add;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use lazy_static::lazy_static;
7use log::debug;
8use async_lock::RwLock;
9use crate::{DEFAULT_TIMEOUT, GOOGLE_OAUTH_V3_USER_INFO_API, GOOGLE_SA_CERTS_URL, GoogleAccessTokenPayload, GooglePayload, MyResult, utils};
10use crate::certs::{Cert, Certs};
11use crate::jwt_parser::JwtParser;
12use crate::validate::id_token;
13
14lazy_static! {
15    static ref ca: reqwest::Client = reqwest::Client::new();
16}
17
18/// AsyncClient is an async client to do verification.
19#[derive(Debug, Clone)]
20pub struct AsyncClient {
21    client_ids: Arc<RwLock<Vec<String>>>,
22    timeout: Duration,
23    cached_certs: Arc<RwLock<Certs>>,
24}
25
26impl AsyncClient {
27    /// Create a new async client.
28    pub fn new<S: ToString>(client_id: S) -> Self {
29        let client_id = client_id.to_string();
30        Self::new_with_vec([client_id])
31    }
32
33    /// Create a new async client, with multiple client ids.
34    pub fn new_with_vec<T, V>(client_ids: T) -> Self
35        where
36            T: AsRef<[V]>,
37            V: AsRef<str>,
38    {
39        Self {
40            client_ids: Arc::new(RwLock::new(
41                client_ids
42                    .as_ref()
43                    .iter()
44                    .map(|c| c.as_ref())
45                    .filter(|c| !c.is_empty())
46                    .map(|c| c.to_string())
47                    .collect()
48            )),
49            timeout: Duration::from_secs(DEFAULT_TIMEOUT),
50            cached_certs: Arc::default(),
51        }
52    }
53
54    /// Add a new client_id for future validating.
55    ///
56    /// Note: this function is thread safe.
57    pub async fn add_client_id<T: ToString>(&mut self, client_id: T) {
58        let client_id = client_id.to_string();
59
60        if !client_id.is_empty() {
61            // check if client_id exists?
62            if self.client_ids.read().await.contains(&client_id) {
63                return
64            }
65
66            self.client_ids.write().await.push(client_id)
67        }
68    }
69
70    /// Remove a client_id, if it exists.
71    ///
72    /// Note: this function is thread safe.
73    pub async fn remove_client_id<T: AsRef<str>>(&mut self, client_id: T) {
74        let to_delete = client_id.as_ref();
75
76        if !to_delete.is_empty() {
77            let mut client_ids = self.client_ids.write().await;
78            client_ids.retain(|id| id != to_delete)
79        }
80    }
81
82    /// Set the timeout (used in fetching google certs).
83    /// Default timeout is 5 seconds. Zero timeout will be ignored.
84    pub fn timeout(mut self, d: Duration) -> Self {
85        if d.as_nanos() != 0 {
86            self.timeout = d;
87        }
88
89        self
90    }
91
92    /// Do verification with `id_token`. If succeed, return the user data.
93    pub async fn validate_id_token<S>(&self, token: S) -> MyResult<GooglePayload>
94        where S: AsRef<str>
95    {
96        let token = token.as_ref();
97        let client_ids = self.client_ids.read().await;
98
99        let parser: JwtParser<GooglePayload> = JwtParser::parse(token)?;
100
101        id_token::validate_info(&*client_ids, &parser)?;
102
103        let cert = self.get_cert(parser.header.alg.as_str(), parser.header.kid.as_str()).await?;
104
105        id_token::do_validate(&cert, &parser)?;
106
107        Ok(parser.payload)
108    }
109
110    async fn get_cert(&self, alg: &str, kid: &str) -> MyResult<Cert> {
111        {
112            let cached_certs = self.cached_certs.read().await;
113            if !cached_certs.need_refresh() {
114                debug!("certs: use cache");
115                return cached_certs.find_cert(alg, kid);
116            }
117        }
118
119        debug!("certs: try to fetch new certs");
120
121        let mut cached_certs = self.cached_certs.write().await;
122
123        // refresh certs here...
124        let resp = ca.get(GOOGLE_SA_CERTS_URL)
125            .timeout(self.timeout)
126            .send()
127            .await?;
128
129        // parse the response header `age` and `max-age`.
130        let max_age = utils::parse_max_age_from_async_resp(&resp);
131
132        let text = resp.text().await?;
133        *cached_certs = serde_json::from_str(&text)?;
134
135        cached_certs.set_cache_until(
136            Instant::now().add(Duration::from_secs(max_age))
137        );
138
139        cached_certs.find_cert(alg, kid)
140    }
141
142    /// Try to validate access token. If succeed, return the user info.
143    pub async fn validate_access_token<S>(&self, token: S) -> MyResult<GoogleAccessTokenPayload>
144        where S: AsRef<str>
145    {
146        let token = token.as_ref();
147
148        let info = ca.get(format!("{}?access_token={}", GOOGLE_OAUTH_V3_USER_INFO_API, token))
149            .timeout(self.timeout)
150            .send()
151            .await?
152            .text()
153            .await?;
154
155        let payload = serde_json::from_str(&info)?;
156
157        Ok(payload)
158    }
159}
160
161impl Default for AsyncClient {
162    fn default() -> Self {
163        Self::new_with_vec::<&[_; 0], &'static str>(&[])
164    }
165}