1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use crate::keys::{GooglePublicKeyProvider, KeyProvider, ProviderError};
/// Based upon [avkviring/jsonwebtoken-google](https://github.com/avkviring/jsonwebtoken-google) which can't target WASM due to a `rand` crate dependency.
use jsonwebtoken::{Algorithm, Validation};
use serde::de::DeserializeOwned;
use thiserror::Error;

mod keys;
#[cfg(any(test, feature = "test-helper"))]
pub mod test_helper;

pub struct Parser {
    client_id: String,
    key_provider: Box<dyn KeyProvider>,
}

#[derive(Error, Debug)]
pub enum ParserError {
    #[error("Wrong header.")]
    WrongHeader,
    #[error("Wrong provider.")]
    KeyProvider(ProviderError),
    #[error("Unknown kid.")]
    UnknownKid,
    #[error("Wrong token format - {0}.")]
    WrongToken(jsonwebtoken::errors::Error),
}

impl Parser {
    pub const GOOGLE_CERT_URL: &'static str = "https://www.googleapis.com/oauth2/v3/certs";

    pub fn new(client_id: &str) -> Self {
        Parser::new_with_custom_cert_url(client_id, Parser::GOOGLE_CERT_URL)
    }

    pub fn new_with_custom_cert_url(client_id: &str, public_key_url: &str) -> Self {
        Self {
            client_id: client_id.to_owned(),
            key_provider: Box::new(GooglePublicKeyProvider::new(public_key_url)),
        }
    }
    /// Parses and validates the provided token. Validation is done against the associated DecodingKey for the kid of the token, fetched from the key_provider (google).
    pub async fn parse<T: DeserializeOwned>(&mut self, token: &str) -> Result<T, ParserError> {
        let header = jsonwebtoken::decode_header(token).map_err(|_| ParserError::WrongHeader)?;
        let kid = header.kid.ok_or(ParserError::UnknownKid)?;
        let key = self
            .key_provider
            .as_mut()
            .get_key(kid.as_str())
            .await
            .map_err(ParserError::KeyProvider)?;

        let aud = vec![self.client_id.to_owned()];
        let mut validation = Validation::new(Algorithm::RS256);
        validation.set_audience(&aud);
        validation.set_issuer(&[
            "https://accounts.google.com".to_string(),
            "accounts.google.com".to_string(),
        ]);
        validation.validate_exp = true;
        validation.validate_nbf = false;
        let data =
            jsonwebtoken::decode::<T>(token, &key, &validation).map_err(ParserError::WrongToken)?;
        Ok(data.claims)
    }
}

#[cfg(test)]
mod tests {
    use jsonwebtoken::errors::ErrorKind;

    use crate::test_helper::{setup, TokenClaims};
    use crate::ParserError;

    #[tokio::test]
    async fn should_email_parsed_correct() {
        let claims = TokenClaims::new();
        let (token, mut parser, _server) = setup(&claims);
        let result = parser.parse::<TokenClaims>(token.as_str()).await;
        let result = result.unwrap();
        assert_eq!(result.email, claims.email);

        let result = parser.parse::<TokenClaims>(token.as_str()).await;
        let result = result.unwrap();
        assert_eq!(result.email, claims.email);
    }

    #[tokio::test]
    async fn should_validate_exp() {
        let claims = TokenClaims::new_expired();
        let (token, mut validator, _server) = setup(&claims);
        let result = validator.parse::<TokenClaims>(token.as_str()).await;

        assert!(
            if let ParserError::WrongToken(error) = result.err().unwrap() {
                matches!(error.into_kind(), ErrorKind::ExpiredSignature)
            } else {
                false
            }
        );
    }

    #[tokio::test]
    async fn should_validate_iss() {
        let mut claims = TokenClaims::new();
        claims.iss = "https://some.com".to_owned();
        let (token, mut validator, _server) = setup(&claims);
        let result = validator.parse::<TokenClaims>(token.as_str()).await;
        assert!(
            if let ParserError::WrongToken(error) = result.err().unwrap() {
                matches!(error.into_kind(), ErrorKind::InvalidIssuer)
            } else {
                false
            }
        );
    }

    #[tokio::test]
    async fn should_validate_aud() {
        let mut claims = TokenClaims::new();
        claims.aud = "other-id".to_owned();
        let (token, mut validator, _server) = setup(&claims);
        let result = validator.parse::<TokenClaims>(token.as_str()).await;
        assert!(
            if let ParserError::WrongToken(error) = result.err().unwrap() {
                matches!(error.into_kind(), ErrorKind::InvalidAudience)
            } else {
                false
            }
        );
    }
}