Skip to main content

rs_firebase_admin_sdk/jwt/
mod.rs

1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2use core::future::Future;
3use core::pin::Pin;
4use error_stack::{Report, ResultExt};
5use jsonwebtoken::{DecodingKey, Validation, decode, decode_header};
6use jsonwebtoken_jwks_cache::{CachedJWKS, TimeoutSpec};
7use serde_json::{Value, from_slice};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use thiserror::Error;
12
13const GOOGLE_JWKS_URI: &str =
14    "https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com";
15const GOOGLE_PKEYS_URI: &str =
16    "https://www.googleapis.com/identitytoolkit/v3/relyingparty/publicKeys";
17const GOOGLE_ID_TOKEN_ISSUER_PREFIX: &str = "https://securetoken.google.com/";
18const GOOGLE_COOKIE_ISSUER_PREFIX: &str = "https://session.firebase.google.com/";
19
20#[derive(Error, Debug, Clone)]
21pub enum TokenVerificationError {
22    #[error("Token's key is missing")]
23    MissingKey,
24    #[error("Invalid token")]
25    Invalid,
26    #[error("Unexpected error")]
27    Internal,
28}
29
30pub type ClaimsResult = Pin<
31    Box<dyn Future<Output = Result<HashMap<String, Value>, Report<TokenVerificationError>>> + Send>,
32>;
33
34pub trait TokenValidator {
35    /// Validate JWT returning all claims on success
36    fn validate(self: Arc<Self>, token: String) -> ClaimsResult;
37}
38
39#[derive(Clone)]
40pub struct LiveValidator {
41    project_id: String,
42    issuer: String,
43    jwks: CachedJWKS,
44}
45
46impl LiveValidator {
47    pub fn new_jwt_validator(project_id: String) -> Result<Self, reqwest::Error> {
48        Ok(Self {
49            issuer: format!("{GOOGLE_ID_TOKEN_ISSUER_PREFIX}{project_id}"),
50            project_id,
51            jwks: CachedJWKS::new(
52                // should always succeed
53                GOOGLE_JWKS_URI.parse().unwrap(),
54                Duration::from_secs(60),
55                TimeoutSpec::default(),
56            )?,
57        })
58    }
59
60    pub fn new_cookie_validator(project_id: String) -> Result<Self, reqwest::Error> {
61        Ok(Self {
62            issuer: format!("{GOOGLE_COOKIE_ISSUER_PREFIX}{project_id}"),
63            project_id,
64            jwks: CachedJWKS::new_rsa_pkeys(
65                // should always succeed
66                GOOGLE_PKEYS_URI.parse().unwrap(),
67                Duration::from_secs(60),
68                TimeoutSpec::default(),
69            )?,
70        })
71    }
72}
73
74impl TokenValidator for LiveValidator {
75    fn validate(self: Arc<Self>, token: String) -> ClaimsResult {
76        Box::pin(async move {
77            let jwks = self
78                .jwks
79                .get()
80                .await
81                .change_context(TokenVerificationError::Internal)?;
82            let jwt_header =
83                decode_header(&token).change_context(TokenVerificationError::Invalid)?;
84
85            let jwk: DecodingKey = jwks
86                .find(&jwt_header.kid.ok_or(TokenVerificationError::MissingKey)?)
87                .ok_or(TokenVerificationError::MissingKey)?
88                .try_into()
89                .change_context(TokenVerificationError::Internal)?;
90
91            let mut validator = Validation::new(jwt_header.alg);
92            validator.set_audience(&[&self.project_id]);
93            validator.set_issuer(&[&self.issuer]);
94
95            decode::<HashMap<String, Value>>(&token, &jwk, &validator)
96                .change_context(TokenVerificationError::Invalid)
97                .map(|t| t.claims)
98        })
99    }
100}
101
102#[derive(Default)]
103pub struct EmulatorValidator;
104
105impl TokenValidator for EmulatorValidator {
106    fn validate(self: Arc<Self>, token: String) -> ClaimsResult {
107        Box::pin(async move {
108            let header = token
109                .split(".")
110                .nth(1)
111                .ok_or(TokenVerificationError::Invalid)?;
112
113            let header = URL_SAFE_NO_PAD
114                .decode(header)
115                .change_context(TokenVerificationError::Invalid)?;
116
117            from_slice(&header).change_context(TokenVerificationError::Invalid)
118        })
119    }
120}