firebase_auth/
firebase_auth.rs1use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
2use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
3use serde::de::DeserializeOwned;
4use std::{
5 env,
6 sync::{Arc, Mutex, RwLock},
7 time::Duration,
8};
9use tokio::{task::JoinHandle, time::sleep};
10use tracing::*;
11
12use crate::structs::{JwkConfiguration, JwkKeys, KeyResponse, PublicKeysError};
13
14const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
15const JWK_URL: &str =
16 "https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com";
17
18pub fn get_configuration(project_id: &str) -> JwkConfiguration {
19 JwkConfiguration {
20 jwk_url: JWK_URL.to_owned(),
21 audience: project_id.to_owned(),
22 issuer: format!("https://securetoken.google.com/{}", project_id),
23 }
24}
25
26fn parse_max_age_value(cache_control_value: &str) -> Result<Duration, PublicKeysError> {
27 let tokens: Vec<(&str, &str)> = cache_control_value
28 .split(',')
29 .map(|s| s.split('=').map(|ss| ss.trim()).collect::<Vec<&str>>())
30 .map(|ss| {
31 let key = ss.first().unwrap_or(&"");
32 let val = ss.get(1).unwrap_or(&"");
33 (*key, *val)
34 })
35 .collect();
36 match tokens
37 .iter()
38 .find(|(key, _)| key.to_lowercase() == *"max-age")
39 {
40 None => Err(PublicKeysError::NoMaxAgeSpecified),
41 Some((_, str_val)) => Ok(Duration::from_secs(
42 str_val
43 .parse()
44 .map_err(|_| PublicKeysError::NonNumericMaxAge)?,
45 )),
46 }
47}
48
49async fn get_public_keys() -> Result<JwkKeys, PublicKeysError> {
50 let response = reqwest::get(JWK_URL)
51 .await
52 .map_err(PublicKeysError::CouldntFetchPublicKeys)?;
53
54 let cache_control = match response.headers().get("Cache-Control") {
55 Some(header_value) => header_value.to_str(),
56 None => return Err(PublicKeysError::NoCacheControlHeader),
57 };
58
59 let max_age = match cache_control {
60 Ok(v) => parse_max_age_value(v),
61 Err(_) => return Err(PublicKeysError::MaxAgeValueEmpty),
62 };
63
64 let public_keys = response
65 .json::<KeyResponse>()
66 .await
67 .map_err(|e| {
68 PublicKeysError::CannotParsePublicKey(e)
69 })?;
70
71 Ok(JwkKeys {
72 keys: public_keys.keys,
73 max_age: max_age.unwrap_or(FALLBACK_TIMEOUT),
74 })
75}
76
77#[derive(Debug)]
78pub enum VerificationError {
79 InvalidSignature,
80 InvalidKeyAlgorithm,
81 InvalidToken,
82 NoKidHeader,
83 NotfoundMatchKid,
84 CannotDecodePublicKeys,
85}
86
87impl std::fmt::Display for VerificationError {
88 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
89 write!(f, "{:?}", self)
90 }
91}
92
93fn extract_claims_from_unsigned_token<T: DeserializeOwned>(token: &str) -> Result<T, VerificationError> {
94 let parts: Vec<&str> = token.split('.').collect();
95 if parts.len() != 3 {
96 return Err(VerificationError::InvalidToken);
97 }
98 let decoded_payload = BASE64_STANDARD_NO_PAD.decode(parts[1].trim()).unwrap();
99 let claims: T = serde_json::from_slice(&decoded_payload).map_err(|_| VerificationError::InvalidToken)?;
100 Ok(claims)
101}
102
103fn verify_id_token_with_project_id<T: DeserializeOwned>(
104 config: &JwkConfiguration,
105 public_keys: &JwkKeys,
106 token: &str,
107) -> Result<T, VerificationError> {
108 if env::var("FIREBASE_AUTH_EMULATOR_HOST").is_ok() {
109 return extract_claims_from_unsigned_token(token);
110 }
111
112 let header = decode_header(token).map_err(|_| VerificationError::InvalidSignature)?;
113
114 if header.alg != Algorithm::RS256 {
115 return Err(VerificationError::InvalidKeyAlgorithm);
116 }
117
118 let kid = match header.kid {
119 Some(v) => v,
120 None => return Err(VerificationError::NoKidHeader),
121 };
122
123 let public_key = match public_keys.keys.iter().find(|v| v.kid == kid) {
124 Some(v) => v,
125 None => return Err(VerificationError::NotfoundMatchKid),
126 };
127
128 let decoding_key = DecodingKey::from_rsa_components(&public_key.n, &public_key.e)
129 .map_err(|_| VerificationError::CannotDecodePublicKeys)?;
130
131 let mut validation = Validation::new(Algorithm::RS256);
132 validation.set_audience(&[config.audience.to_owned()]);
133 validation.set_issuer(&[config.issuer.to_owned()]);
134
135 let user = decode::<T>(token, &decoding_key, &validation)
136 .map_err(|_| VerificationError::InvalidToken)?
137 .claims;
138 Ok(user)
139}
140
141#[derive(Debug)]
142struct JwkVerifier {
143 keys: JwkKeys,
144 config: JwkConfiguration,
145}
146
147impl JwkVerifier {
148 fn new(project_id: &str, keys: JwkKeys) -> JwkVerifier {
149 JwkVerifier {
150 keys,
151 config: get_configuration(project_id),
152 }
153 }
154
155 fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, VerificationError> {
156 verify_id_token_with_project_id(&self.config, &self.keys, token)
157 }
158
159 fn set_keys(&mut self, keys: JwkKeys) {
160 self.keys = keys;
161 }
162}
163
164#[derive(Clone)]
168pub struct FirebaseAuth {
169 verifier: Arc<RwLock<JwkVerifier>>,
170 handler: Arc<Mutex<Box<JoinHandle<()>>>>,
171}
172
173impl Drop for FirebaseAuth {
174 fn drop(&mut self) {
175 let handler = self.handler.lock().unwrap();
177 handler.abort();
178 }
179}
180
181impl FirebaseAuth {
182 pub async fn new(project_id: &str) -> FirebaseAuth {
183 let jwk_keys: JwkKeys = match get_public_keys().await {
184 Ok(keys) => keys,
185 Err(e) => {
186 eprintln!("Error getting public jwk keys {:?}", e);
187 panic!("Unable to get public jwk keys! Cannot verify user tokens! Shutting down...")
188 }
189 };
190 let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
191
192 let mut instance = FirebaseAuth {
193 verifier,
194 handler: Arc::new(Mutex::new(Box::new(tokio::spawn(async {})))),
195 };
196
197 instance.start_key_update();
198 instance
199 }
200
201 pub fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, VerificationError> {
202 let verifier = self.verifier.read().unwrap();
203 verifier.verify(token)
204 }
205
206 fn start_key_update(&mut self) {
207 let verifier_ref = Arc::clone(&self.verifier);
208
209 let task = tokio::spawn(async move {
210 loop {
211 let delay = match get_public_keys().await {
212 Ok(jwk_keys) => {
213 let mut verifier = verifier_ref.write().unwrap();
214 verifier.set_keys(jwk_keys.clone());
215 debug!(
216 "Updated JWK keys. Next refresh will be in {:?}",
217 jwk_keys.max_age
218 );
219 jwk_keys.max_age
220 }
221 Err(err) => {
222 warn!("Error getting public jwk keys {:?}", err);
223 warn!("Re-try getting public keys in 10 seconds");
224 Duration::from_secs(10)
225 }
226 };
227 sleep(delay).await;
228 }
229 });
230
231 let mut handler = self.handler.lock().unwrap();
232 *handler = Box::new(task);
233 }
234}