jsonwebtoken-google-wasm 0.0.8

parse and validate google jwt token compliant with webassembly runtimes with jsonwebtoken
Documentation
/// Based upon [avkviring/jsonwebtoken-google](https://github.com/avkviring/jsonwebtoken-google) which can't target WASM due to a `rand` crate dependency.
use async_trait::async_trait;
use base64::DecodeError;
use headers::Header;
use jwt_simple::prelude::*;

use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use reqwest::header::{HeaderMap, CACHE_CONTROL};
use serde::Deserialize;
use std::collections::HashMap;
use std::time::SystemTime;
use thiserror::Error;

#[async_trait(?Send)]
pub trait KeyProvider: Send {
    async fn get_key(
        &mut self,
        kid: &str,
        now: SystemTime,
    ) -> Result<RS256PublicKey, ProviderError>;
}

#[derive(Error, Debug)]
pub enum ProviderError {
    #[error("Key not found.")]
    KeyNotFound,
    #[error("Fetch error - {0}.")]
    FetchError(String),
    #[error("Parse error - {0}.")]
    ParseError(String),
    #[error("Unknown error.")]
    UnknownError,
    #[error("Decode error - {0}.")]
    DecodingError(DecodeError),
    #[error("Key creation error - {0}.")]
    CreateKeyError(jwt_simple::Error),
}

#[derive(Deserialize, Clone)]
pub struct GoogleKeys {
    keys: Vec<GoogleKey>,
}

#[derive(Deserialize, Clone, Debug)]
pub struct GoogleKey {
    kid: String,
    n: String,
    e: String,
}

impl GoogleKey {
    fn n(&self) -> Result<Vec<u8>, ProviderError> {
        let n: Result<Vec<u8>, DecodeError> = URL_SAFE_NO_PAD.decode(self.n.clone());
        n.map_err(ProviderError::DecodingError)
    }
    fn e(&self) -> Result<Vec<u8>, ProviderError> {
        let e = URL_SAFE_NO_PAD.decode(self.e.clone());
        e.map_err(ProviderError::DecodingError)
    }
}

pub struct GooglePublicKeyProvider {
    url: String,
    keys: HashMap<String, GoogleKey>,
    expiration_time: Option<SystemTime>,
}

impl GooglePublicKeyProvider {
    pub fn new(public_key_url: &str) -> Self {
        Self {
            url: public_key_url.to_owned(),
            keys: Default::default(),
            expiration_time: None,
        }
    }

    pub async fn reload(&mut self, now: SystemTime) -> Result<(), ProviderError> {
        let r = reqwest::get(&self.url).await;
        let r = r.map_err(|e| ProviderError::FetchError(format!("{:?}", e)))?;
        let expiration_time = GooglePublicKeyProvider::parse_expiration_time(r.headers(), now);
        let google_keys = r.json::<GoogleKeys>().await;
        let google_keys = google_keys.map_err(|e| ProviderError::ParseError(format!("{:?}", e)))?;

        self.keys.clear();
        for key in google_keys.keys.into_iter() {
            self.keys.insert(key.kid.clone(), key);
        }
        self.expiration_time = expiration_time;
        Result::Ok(())
    }

    fn parse_expiration_time(header_map: &HeaderMap, now: SystemTime) -> Option<SystemTime> {
        headers::CacheControl::decode(&mut header_map.get_all(CACHE_CONTROL).iter())
            .ok()
            .and_then(|h| h.max_age())
            .map(|d| now + d)
    }

    pub fn is_expire(&self, now: SystemTime) -> bool {
        self.expiration_time.map_or_else(|| true, |t| now > t)
    }
}

#[async_trait(?Send)]
impl KeyProvider for GooglePublicKeyProvider {
    async fn get_key(
        &mut self,
        kid: &str,
        now: SystemTime,
    ) -> Result<RS256PublicKey, ProviderError> {
        if self.is_expire(now) {
            self.reload(now).await?;
        }

        let key_data = self.keys.get(&kid.to_owned());
        let key_data = key_data.ok_or(ProviderError::KeyNotFound)?;

        let key = RS256PublicKey::from_components(&key_data.n()?, &key_data.e()?);
        key.map_err(ProviderError::CreateKeyError)
    }
}
#[cfg(test)]
mod tests {
    use crate::keys::{GooglePublicKeyProvider, KeyProvider, ProviderError};
    use httpmock::MockServer;
    use std::time::{Duration, SystemTime};

    #[tokio::test]
    async fn should_parse_keys() {
        let n = "3g46w4uRYBx8CXFauWh6c5yO4ax_VDu5y8ml_Jd4Gx711155PTdtLeRuwZOhJ6nRy8YvLFPXc_aXtHifnQsi9YuI_vo7LGG2v3CCxh6ndZBjIeFkxErMDg4ELt2DQ0PgJUQUAKCkl2_gkVV9vh3oxahv_BpIgv1kuYlyQQi5JWeF7zAIm0FaZ-LJT27NbsCugcZIDQg9sztTN18L3-P_kYwvAkKY2bGYNU19qLFM1gZkzccFEDZv3LzAz7qbdWkwCoK00TUUH8TNjqmK67bytYzgEgkfF9q9szEQ5TrRL0uFg9LxT3kSTLYqYOVaUIX3uaChwaa-bQvHuNmryu7i9w";
        let e = "AQAB";
        let kid = "some-kid";
        let resp = format!("{{\"keys\": [{{\"kty\": \"RSA\",\"use\": \"sig\",\"e\": \"{}\",\"n\": \"{}\",\"alg\": \"RS256\",\"kid\": \"{}\"}}]}}", e, n, kid);

        let server = MockServer::start();
        let _server_mock = server.mock(|when, then| {
            when.method(httpmock::Method::GET).path("/");

            then.status(200)
                .header(
                    "cache-control",
                    "public, max-age=24920, must-revalidate, no-transform",
                )
                .header("Content-Type", "application/json; charset=UTF-8")
                .body(resp);
        });
        let mut provider = GooglePublicKeyProvider::new(server.url("/").as_str());

        assert!(matches!(
            provider.get_key(kid, SystemTime::now()).await,
            Result::Ok(_)
        ));
        assert!(matches!(
            provider.get_key("missing-key", SystemTime::now()).await,
            Result::Err(_)
        ));
    }

    #[tokio::test]
    async fn should_expire_and_reload() {
        let server = MockServer::start();
        let n = "3g46w4uRYBx8CXFauWh6c5yO4ax_VDu5y8ml_Jd4Gx711155PTdtLeRuwZOhJ6nRy8YvLFPXc_aXtHifnQsi9YuI_vo7LGG2v3CCxh6ndZBjIeFkxErMDg4ELt2DQ0PgJUQUAKCkl2_gkVV9vh3oxahv_BpIgv1kuYlyQQi5JWeF7zAIm0FaZ-LJT27NbsCugcZIDQg9sztTN18L3-P_kYwvAkKY2bGYNU19qLFM1gZkzccFEDZv3LzAz7qbdWkwCoK00TUUH8TNjqmK67bytYzgEgkfF9q9szEQ5TrRL0uFg9LxT3kSTLYqYOVaUIX3uaChwaa-bQvHuNmryu7i9w";
        let e = "AQAB";
        let kid = "some-kid";
        let resp = format!("{{\"keys\": [{{\"kty\": \"RSA\",\"use\": \"sig\",\"e\": \"{}\",\"n\": \"{}\",\"alg\": \"RS256\",\"kid\": \"{}\"}}]}}", e, n, kid);

        let mut server_mock = server.mock(|when, then| {
            when.method(httpmock::Method::GET).path("/");
            then.status(200)
                .header(
                    "cache-control",
                    "public, max-age=3, must-revalidate, no-transform",
                )
                .header("Content-Type", "application/json; charset=UTF-8")
                .body("{\"keys\":[]}");
        });

        let mut provider = GooglePublicKeyProvider::new(server.url("/").as_str());
        let key_result = provider.get_key(kid, SystemTime::now()).await;
        assert!(matches!(
            key_result,
            Result::Err(ProviderError::KeyNotFound)
        ));

        server_mock.delete();
        let _server_mock = server.mock(|when, then| {
            when.method(httpmock::Method::GET).path("/");
            then.status(200)
                .header(
                    "cache-control",
                    "public, max-age=3, must-revalidate, no-transform",
                )
                .header("Content-Type", "application/json; charset=UTF-8")
                .body(resp);
        });

        std::thread::sleep(Duration::from_secs(4));
        let key_result = provider.get_key(kid, SystemTime::now()).await;
        assert!(matches!(key_result, Result::Ok(_)));
    }
}