1use base64::prelude::BASE64_STANDARD;
6use base64::Engine;
7use chrono::{offset, DateTime, Duration};
8use serde::{Deserialize, Serialize};
9use serde_json;
10use std::collections::BTreeMap;
11use std::fmt;
12use std::fs::File;
13use std::io::BufReader;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16
17use super::jwt::{create_jwt_encoded, download_google_jwks, verify_access_token, JWKSet, JWT_AUDIENCE_IDENTITY};
18use crate::{errors::FirebaseError, jwt::TokenValidationResult};
19
20type Error = super::errors::FirebaseError;
21
22#[derive(Default, Clone)]
24pub(crate) struct Keys {
25 pub pub_key: BTreeMap<String, Arc<biscuit::jws::Secret>>,
26 pub pub_key_expires_at: Option<DateTime<offset::Utc>>,
27 pub secret: Option<Arc<biscuit::jws::Secret>>,
28}
29
30impl fmt::Debug for Keys {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 f.debug_struct("Keys")
33 .field("pub_key_expires_at", &self.pub_key_expires_at)
34 .field("pub_key", &self.pub_key.keys().collect::<Vec<&String>>())
35 .field("secret", &self.secret.is_some())
36 .finish()
37 }
38}
39
40#[derive(Serialize, Deserialize, Default, Clone, Debug)]
52pub struct Credentials {
53 pub project_id: String,
54 pub private_key_id: String,
55 pub private_key: String,
56 pub client_email: String,
57 pub client_id: String,
58 pub api_key: String,
59 #[serde(default, skip)]
63 pub(crate) keys: Arc<RwLock<Keys>>,
64}
65
66pub fn pem_to_der(pem_file_contents: &str) -> Result<Vec<u8>, Error> {
68 let pem_file_contents = pem_file_contents
69 .find("-----BEGIN")
70 .and_then(|i| Some(&pem_file_contents[i + 10..]))
72 .and_then(|str| str.find("-----").and_then(|i| Some(&str[i + 5..])))
74 .and_then(|str| str.rfind("-----END").and_then(|i| Some(&str[..i])));
76 if pem_file_contents.is_none() {
77 return Err(FirebaseError::Generic(
78 "Invalid private key in credentials file. Must be valid PEM.",
79 ));
80 }
81
82 let base64_body = pem_file_contents.unwrap().replace("\n", "");
83 Ok(BASE64_STANDARD
84 .decode(&base64_body)
85 .map_err(|_| FirebaseError::Generic("Invalid private key in credentials file. Expected Base64 data."))?)
86}
87
88#[test]
89fn pem_to_der_test() {
90 const INPUT: &str = r#"-----BEGIN PRIVATE KEY-----
91MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCTbt9Rs2niyIRE
92FIdrhIN757eq/1Ry/VhZALBXAveg+lt+ui/9EHtYPJH1A9NyyAwChs0UCRWqkkEo
93Amtz4dJQ1YlGi0/BGhK2lg==
94-----END PRIVATE KEY-----
95"#;
96 const EXPECTED: [u8; 112] = [
97 48, 130, 4, 188, 2, 1, 0, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 4, 130, 4, 166, 48, 130, 4,
98 162, 2, 1, 0, 2, 130, 1, 1, 0, 147, 110, 223, 81, 179, 105, 226, 200, 132, 68, 20, 135, 107, 132, 131, 123,
99 231, 183, 170, 255, 84, 114, 253, 88, 89, 0, 176, 87, 2, 247, 160, 250, 91, 126, 186, 47, 253, 16, 123, 88, 60,
100 145, 245, 3, 211, 114, 200, 12, 2, 134, 205, 20, 9, 21, 170, 146, 65, 40, 2, 107, 115, 225, 210, 80, 213, 137,
101 70, 139, 79, 193, 26, 18, 182, 150,
102 ];
103
104 assert_eq!(&EXPECTED[..], &pem_to_der(INPUT).unwrap()[..]);
105}
106
107impl Credentials {
108 pub async fn new(credentials_file_content: &str) -> Result<Credentials, Error> {
130 let mut credentials: Credentials = serde_json::from_str(credentials_file_content)?;
131 credentials.compute_secret().await?;
132 Ok(credentials)
133 }
134
135 pub async fn from_file(credential_file: &str) -> Result<Self, Error> {
140 let f = BufReader::new(File::open(credential_file)?);
141 let mut credentials: Credentials = serde_json::from_reader(f)?;
142 credentials.compute_secret().await?;
143 Ok(credentials)
144 }
145
146 pub async fn with_jwkset(self, jwks: &JWKSet) -> Result<Credentials, Error> {
151 self.add_jwks_public_keys(jwks).await;
152 self.verify().await?;
153 Ok(self)
154 }
155
156 pub async fn download_jwkset(self) -> Result<Credentials, Error> {
176 self.download_google_jwks().await?;
177 self.verify().await?;
178 Ok(self)
179 }
180
181 pub async fn verify(&self) -> Result<(), Error> {
184 let access_token = create_jwt_encoded(
185 &self,
186 Some(["admin"].iter()),
187 Duration::hours(1),
188 Some(self.client_id.clone()),
189 None,
190 JWT_AUDIENCE_IDENTITY,
191 )
192 .await?;
193 verify_access_token(&self, &access_token).await?;
194 Ok(())
195 }
196
197 pub async fn verify_token(&self, token: &str) -> Result<TokenValidationResult, Error> {
198 verify_access_token(&self, token).await
199 }
200
201 pub async fn decode_secret(&self, kid: &str) -> Result<Option<Arc<biscuit::jws::Secret>>, Error> {
204 let should_refresh = {
205 let keys = self.keys.read().await;
206 keys.pub_key_expires_at
207 .map(|expires_at| expires_at - offset::Utc::now() < Duration::minutes(10))
208 .unwrap_or(false)
209 };
210
211 if should_refresh {
212 self.download_google_jwks().await?;
213 }
214
215 Ok(self.keys.read().await.pub_key.get(kid).and_then(|f| Some(f.clone())))
216 }
217
218 pub async fn add_jwks_public_keys(&self, jwkset: &JWKSet) {
234 let key_lock = self.keys.write();
235 let keys = &mut key_lock.await.pub_key;
236
237 for entry in jwkset.keys.iter() {
238 if !entry.headers.key_id.is_some() {
239 continue;
240 }
241
242 let key_id = entry.headers.key_id.as_ref().unwrap().to_owned();
243 keys.insert(key_id, Arc::new(entry.ne.jws_public_key_secret()));
244 }
245 }
246
247 pub async fn download_google_jwks(&self) -> Result<(), Error> {
251 {
252 let mut keys = self.keys.write().await;
253 keys.pub_key = BTreeMap::new();
254 }
255
256 let (jwks, max_age_client) = download_google_jwks(&self.client_email).await?;
257 self.add_jwks_public_keys(&JWKSet::new(&jwks)?).await;
258 let (jwks, max_age_public) = download_google_jwks("securetoken@system.gserviceaccount.com").await?;
259 self.add_jwks_public_keys(&JWKSet::new(&jwks)?).await;
260
261 let default_expiration = Duration::hours(2);
262 let max_age_client = max_age_client.unwrap_or(default_expiration);
263 let max_age_public = max_age_public.unwrap_or(default_expiration);
264
265 let expires_at = if max_age_client < max_age_public {
266 max_age_client
267 } else {
268 max_age_public
269 };
270
271 {
272 let mut keys = self.keys.write().await;
273 keys.pub_key_expires_at = Some(offset::Utc::now() + expires_at);
274 }
275
276 Ok(())
277 }
278
279 pub async fn compute_secret(&mut self) -> Result<(), Error> {
284 use biscuit::jws::Secret;
285 use ring::signature;
286
287 let vec = pem_to_der(&self.private_key)?;
288 let key_pair = signature::RsaKeyPair::from_pkcs8(&vec)?;
289 self.keys.write().await.secret = Some(Arc::new(Secret::RsaKeyPair(Arc::new(key_pair))));
290 Ok(())
291 }
292}
293
294#[doc(hidden)]
295#[allow(dead_code)]
296pub async fn doctest_credentials() -> Credentials {
297 let jwk_list = JWKSet::new(include_str!("../tests/service-account-test.jwks")).unwrap();
298 Credentials::new(include_str!("../tests/service-account-test.json"))
299 .await
300 .expect("Failed to deserialize credentials")
301 .with_jwkset(&jwk_list)
302 .await
303 .expect("JWK public keys verification failed")
304}
305
306#[tokio::test]
307async fn deserialize_credentials() {
308 let jwk_list = JWKSet::new(include_str!("../tests/service-account-test.jwks")).unwrap();
309 let c: Credentials = Credentials::new(include_str!("../tests/service-account-test.json"))
310 .await
311 .expect("Failed to deserialize credentials")
312 .with_jwkset(&jwk_list)
313 .await
314 .expect("JWK public keys verification failed");
315 assert_eq!(c.api_key, "api_key");
316
317 use std::path::PathBuf;
318 let mut credential_file = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
319 credential_file.push("tests/service-account-test.json");
320
321 let c = Credentials::from_file(credential_file.to_str().unwrap())
322 .await
323 .expect("Failed to open credentials file")
324 .with_jwkset(&jwk_list)
325 .await
326 .expect("JWK public keys verification failed");
327 assert_eq!(c.api_key, "api_key");
328}