firebase_auth/
firebase_auth.rs

1use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
2use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
3use serde::de::DeserializeOwned;
4use std::{
5    env,
6    sync::{Arc, Mutex, RwLock},
7    time::Duration,
8};
9use tokio::{task::JoinHandle, time::sleep};
10use tracing::*;
11
12use crate::structs::{JwkConfiguration, JwkKeys, KeyResponse, PublicKeysError};
13
14const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
15const JWK_URL: &str =
16    "https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com";
17
18pub fn get_configuration(project_id: &str) -> JwkConfiguration {
19    JwkConfiguration {
20        jwk_url: JWK_URL.to_owned(),
21        audience: project_id.to_owned(),
22        issuer: format!("https://securetoken.google.com/{}", project_id),
23    }
24}
25
26fn parse_max_age_value(cache_control_value: &str) -> Result<Duration, PublicKeysError> {
27    let tokens: Vec<(&str, &str)> = cache_control_value
28        .split(',')
29        .map(|s| s.split('=').map(|ss| ss.trim()).collect::<Vec<&str>>())
30        .map(|ss| {
31            let key = ss.first().unwrap_or(&"");
32            let val = ss.get(1).unwrap_or(&"");
33            (*key, *val)
34        })
35        .collect();
36    match tokens
37        .iter()
38        .find(|(key, _)| key.to_lowercase() == *"max-age")
39    {
40        None => Err(PublicKeysError::NoMaxAgeSpecified),
41        Some((_, str_val)) => Ok(Duration::from_secs(
42            str_val
43                .parse()
44                .map_err(|_| PublicKeysError::NonNumericMaxAge)?,
45        )),
46    }
47}
48
49async fn get_public_keys() -> Result<JwkKeys, PublicKeysError> {
50    let response = reqwest::get(JWK_URL)
51        .await
52        .map_err(PublicKeysError::CouldntFetchPublicKeys)?;
53
54    let cache_control = match response.headers().get("Cache-Control") {
55        Some(header_value) => header_value.to_str(),
56        None => return Err(PublicKeysError::NoCacheControlHeader),
57    };
58
59    let max_age = match cache_control {
60        Ok(v) => parse_max_age_value(v),
61        Err(_) => return Err(PublicKeysError::MaxAgeValueEmpty),
62    };
63
64    let public_keys = response
65        .json::<KeyResponse>()
66        .await
67        .map_err(|e| {
68            PublicKeysError::CannotParsePublicKey(e)
69        })?;
70
71    Ok(JwkKeys {
72        keys: public_keys.keys,
73        max_age: max_age.unwrap_or(FALLBACK_TIMEOUT),
74    })
75}
76
77#[derive(Debug)]
78pub enum VerificationError {
79    InvalidSignature,
80    InvalidKeyAlgorithm,
81    InvalidToken,
82    NoKidHeader,
83    NotfoundMatchKid,
84    CannotDecodePublicKeys,
85}
86
87impl std::fmt::Display for VerificationError {
88    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
89        write!(f, "{:?}", self)
90    }
91}
92
93fn extract_claims_from_unsigned_token<T: DeserializeOwned>(token: &str) ->  Result<T, VerificationError> {
94    let parts: Vec<&str> = token.split('.').collect();
95    if parts.len() != 3 {
96        return Err(VerificationError::InvalidToken);
97    }
98    let decoded_payload = BASE64_STANDARD_NO_PAD.decode(parts[1].trim()).unwrap();
99    let claims: T = serde_json::from_slice(&decoded_payload).map_err(|_| VerificationError::InvalidToken)?;
100    Ok(claims)
101}
102
103fn verify_id_token_with_project_id<T: DeserializeOwned>(
104    config: &JwkConfiguration,
105    public_keys: &JwkKeys,
106    token: &str,
107) -> Result<T, VerificationError> {
108    if env::var("FIREBASE_AUTH_EMULATOR_HOST").is_ok() {
109        return extract_claims_from_unsigned_token(token);
110    }
111    
112    let header = decode_header(token).map_err(|_| VerificationError::InvalidSignature)?;
113
114    if header.alg != Algorithm::RS256 {
115        return Err(VerificationError::InvalidKeyAlgorithm);
116    }
117
118    let kid = match header.kid {
119        Some(v) => v,
120        None => return Err(VerificationError::NoKidHeader),
121    };
122
123    let public_key = match public_keys.keys.iter().find(|v| v.kid == kid) {
124        Some(v) => v,
125        None => return Err(VerificationError::NotfoundMatchKid),
126    };
127
128    let decoding_key = DecodingKey::from_rsa_components(&public_key.n, &public_key.e)
129        .map_err(|_| VerificationError::CannotDecodePublicKeys)?;
130
131    let mut validation = Validation::new(Algorithm::RS256);
132    validation.set_audience(&[config.audience.to_owned()]);
133    validation.set_issuer(&[config.issuer.to_owned()]);
134
135    let user = decode::<T>(token, &decoding_key, &validation)
136        .map_err(|_| VerificationError::InvalidToken)?
137        .claims;
138    Ok(user)
139}
140
141#[derive(Debug)]
142struct JwkVerifier {
143    keys: JwkKeys,
144    config: JwkConfiguration,
145}
146
147impl JwkVerifier {
148    fn new(project_id: &str, keys: JwkKeys) -> JwkVerifier {
149        JwkVerifier {
150            keys,
151            config: get_configuration(project_id),
152        }
153    }
154
155    fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, VerificationError> {
156        verify_id_token_with_project_id(&self.config, &self.keys, token)
157    }
158
159    fn set_keys(&mut self, keys: JwkKeys) {
160        self.keys = keys;
161    }
162}
163
164/// Provide a service to automatically pull the new google public key based on the Cache-Control
165/// header.
166/// If there is an error during refreshing, automatically retry indefinitely every 10 seconds.
167#[derive(Clone)]
168pub struct FirebaseAuth {
169    verifier: Arc<RwLock<JwkVerifier>>,
170    handler: Arc<Mutex<Box<JoinHandle<()>>>>,
171}
172
173impl Drop for FirebaseAuth {
174    fn drop(&mut self) {
175        // Stop the update thread when the updater is destructed
176        let handler = self.handler.lock().unwrap();
177        handler.abort();
178    }
179}
180
181impl FirebaseAuth {
182    pub async fn new(project_id: &str) -> FirebaseAuth {
183        let jwk_keys: JwkKeys = match get_public_keys().await {
184            Ok(keys) => keys,
185            Err(e) => {
186                eprintln!("Error getting public jwk keys {:?}", e);
187                panic!("Unable to get public jwk keys! Cannot verify user tokens! Shutting down...")
188            }
189        };
190        let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
191
192        let mut instance = FirebaseAuth {
193            verifier,
194            handler: Arc::new(Mutex::new(Box::new(tokio::spawn(async {})))),
195        };
196
197        instance.start_key_update();
198        instance
199    }
200
201    pub fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, VerificationError> {
202        let verifier = self.verifier.read().unwrap();
203        verifier.verify(token)
204    }
205
206    fn start_key_update(&mut self) {
207        let verifier_ref = Arc::clone(&self.verifier);
208
209        let task = tokio::spawn(async move {
210            loop {
211                let delay = match get_public_keys().await {
212                    Ok(jwk_keys) => {
213                        let mut verifier = verifier_ref.write().unwrap();
214                        verifier.set_keys(jwk_keys.clone());
215                        debug!(
216                            "Updated JWK keys. Next refresh will be in {:?}",
217                            jwk_keys.max_age
218                        );
219                        jwk_keys.max_age
220                    }
221                    Err(err) => {
222                        warn!("Error getting public jwk keys {:?}", err);
223                        warn!("Re-try getting public keys in 10 seconds");
224                        Duration::from_secs(10)
225                    }
226                };
227                sleep(delay).await;
228            }
229        });
230
231        let mut handler = self.handler.lock().unwrap();
232        *handler = Box::new(task);
233    }
234}