amqp_api_server/api/input/
token_validator.rs

1use std::collections::HashMap;
2
3use crate::api::input::token::Token;
4use crate::config::openid_connect_config::OpenIdConnectConfig;
5use jsonwebtoken::{decode, decode_header, jwk::AlgorithmParameters, DecodingKey, Validation};
6use serde_json::Value;
7
8use crate::config::token_validator_config;
9use crate::config::token_validator_config::TokenValidatorConfig;
10use crate::error::{Error, ErrorKind};
11
12pub struct TokenValidator {
13    config: TokenValidatorConfig,
14}
15
16impl TokenValidator {
17    pub fn new(config: TokenValidatorConfig) -> TokenValidator {
18        TokenValidator { config }
19    }
20
21    pub fn validate(&self, token: &str) -> Result<Token, Error> {
22        let header = match decode_header(token) {
23            Ok(header) => header,
24            Err(error) => {
25                return Err(Error::new(
26                    ErrorKind::TokenDecodingFailure,
27                    format!("failure to decode token's header: {}", error),
28                ));
29            }
30        };
31
32        let kid = match header.kid {
33            Some(k) => k,
34            None => {
35                return Err(Error::new(
36                    ErrorKind::MalformedToken,
37                    "failed to find token header's kid",
38                ));
39            }
40        };
41
42        let jwk = match self.config.jwks().find(&kid) {
43            Some(jwk) => jwk,
44            None => {
45                return Err(Error::new(
46                    ErrorKind::MalformedToken,
47                    format!("failed to find jwk for kid '{}'", kid),
48                ));
49            }
50        };
51
52        let rsa = match jwk.algorithm {
53            AlgorithmParameters::RSA(ref rsa) => rsa,
54            _ => {
55                return Err(Error::new(
56                    ErrorKind::MalformedToken,
57                    format!("expected 'RSA' algorithm got '{:?}'", jwk.algorithm),
58                ));
59            }
60        };
61
62        let decoding_key = match DecodingKey::from_rsa_components(&rsa.n, &rsa.e) {
63            Ok(decoding_key) => decoding_key,
64            Err(error) => {
65                return Err(Error::new(
66                    ErrorKind::TokenDecodingFailure,
67                    format!("failed to get decoding key: {}", error),
68                ));
69            }
70        };
71
72        let algorithm = match jwk.common.algorithm {
73            Some(algorithm) => algorithm,
74            None => {
75                return Err(Error::new(
76                    ErrorKind::TokenDecodingFailure,
77                    "jwk is missing algorithm",
78                ));
79            }
80        };
81
82        let mut validation = Validation::new(algorithm);
83        validation.validate_exp = true;
84        validation.set_audience(self.config.open_id_connect().audience());
85        validation.set_issuer(self.config.open_id_connect().issuers());
86
87        let decoded_token =
88            match decode::<HashMap<String, Value>>(token, &decoding_key, &validation) {
89                Ok(decoded_token) => decoded_token,
90                Err(error) => {
91                    return Err(Error::new(
92                        ErrorKind::InvalidToken,
93                        format!("invalid token detected: {}", error),
94                    ));
95                }
96            };
97
98        let wrapped_token = Token::try_new(decoded_token)?;
99        Ok(wrapped_token)
100    }
101}
102
103pub async fn try_generate_token_validator(
104    openid_connect: OpenIdConnectConfig,
105) -> Result<TokenValidator, Error> {
106    let token_validator_config =
107        token_validator_config::try_generate_config(openid_connect).await?;
108
109    Ok(TokenValidator::new(token_validator_config))
110}