rs_firebase_admin_sdk/jwt/
mod.rs

1use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
2use core::future::Future;
3use error_stack::{Report, ResultExt};
4use jsonwebtoken::{DecodingKey, Validation, decode, decode_header};
5use jsonwebtoken_jwks_cache::{CachedJWKS, TimeoutSpec};
6use serde_json::{Value, from_slice};
7use std::collections::HashMap;
8use std::time::Duration;
9use thiserror::Error;
10
11const GOOGLE_JWKS_URI: &str =
12    "https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com";
13const GOOGLE_PKEYS_URI: &str =
14    "https://www.googleapis.com/identitytoolkit/v3/relyingparty/publicKeys";
15const GOOGLE_ID_TOKEN_ISSUER_PREFIX: &str = "https://securetoken.google.com/";
16const GOOGLE_COOKIE_ISSUER_PREFIX: &str = "https://session.firebase.google.com/";
17
18#[derive(Error, Debug, Clone)]
19pub enum TokenVerificationError {
20    #[error("Token's key is missing")]
21    MissingKey,
22    #[error("Invalid token")]
23    Invalid,
24    #[error("Unexpected error")]
25    Internal,
26}
27
28pub trait TokenValidator {
29    /// Validate JWT returning all claims on success
30    fn validate(
31        &self,
32        token: &str,
33    ) -> impl Future<Output = Result<HashMap<String, Value>, Report<TokenVerificationError>>> + Send + Sync;
34}
35
36pub struct LiveValidator {
37    project_id: String,
38    issuer: String,
39    jwks: CachedJWKS,
40}
41
42impl LiveValidator {
43    pub fn new_jwt_validator(project_id: String) -> Result<Self, reqwest::Error> {
44        Ok(Self {
45            issuer: format!("{GOOGLE_ID_TOKEN_ISSUER_PREFIX}{project_id}"),
46            project_id,
47            jwks: CachedJWKS::new(
48                // should always succeed
49                GOOGLE_JWKS_URI.parse().unwrap(),
50                Duration::from_secs(60),
51                TimeoutSpec::default(),
52            )?,
53        })
54    }
55
56    pub fn new_cookie_validator(project_id: String) -> Result<Self, reqwest::Error> {
57        Ok(Self {
58            issuer: format!("{GOOGLE_COOKIE_ISSUER_PREFIX}{project_id}"),
59            project_id,
60            jwks: CachedJWKS::new_rsa_pkeys(
61                // should always succeed
62                GOOGLE_PKEYS_URI.parse().unwrap(),
63                Duration::from_secs(60),
64                TimeoutSpec::default(),
65            )?,
66        })
67    }
68}
69
70impl TokenValidator for LiveValidator {
71    async fn validate(
72        &self,
73        token: &str,
74    ) -> Result<HashMap<String, Value>, Report<TokenVerificationError>> {
75        let jwks = self
76            .jwks
77            .get()
78            .await
79            .change_context(TokenVerificationError::Internal)?;
80        let jwt_header = decode_header(token).change_context(TokenVerificationError::Invalid)?;
81
82        let jwk: DecodingKey = jwks
83            .find(&jwt_header.kid.ok_or(TokenVerificationError::MissingKey)?)
84            .ok_or(TokenVerificationError::MissingKey)?
85            .try_into()
86            .change_context(TokenVerificationError::Internal)?;
87
88        let mut validator = Validation::new(jwt_header.alg);
89        validator.set_audience(&[&self.project_id]);
90        validator.set_issuer(&[&self.issuer]);
91
92        decode::<HashMap<String, Value>>(token, &jwk, &validator)
93            .change_context(TokenVerificationError::Invalid)
94            .map(|t| t.claims)
95    }
96}
97
98#[derive(Default)]
99pub struct EmulatorValidator;
100
101impl TokenValidator for EmulatorValidator {
102    async fn validate(
103        &self,
104        token: &str,
105    ) -> Result<HashMap<String, Value>, Report<TokenVerificationError>> {
106        let header = token
107            .split(".")
108            .nth(1)
109            .ok_or(TokenVerificationError::Invalid)?;
110
111        let header = URL_SAFE_NO_PAD
112            .decode(header)
113            .change_context(TokenVerificationError::Invalid)?;
114
115        from_slice(&header).change_context(TokenVerificationError::Invalid)
116    }
117}