jwt_authorizer/jwks/
mod.rs

1use std::{str::FromStr, sync::Arc};
2
3use jsonwebtoken::{
4    jwk::{AlgorithmParameters, Jwk},
5    Algorithm, DecodingKey, Header,
6};
7
8use crate::error::AuthError;
9
10use self::key_store_manager::KeyStoreManager;
11
12pub mod key_store_manager;
13
14#[derive(Clone)]
15pub enum KeySource {
16    /// KeyDataSource managing a refreshable key sets
17    KeyStoreSource(KeyStoreManager),
18    /// Manages public key sets, initialized on startup
19    MultiKeySource(KeySet),
20    /// Manages one public key, initialized on startup
21    SingleKeySource(Arc<KeyData>),
22}
23
24#[derive(Clone)]
25pub struct KeyData {
26    pub kid: Option<String>,
27    /// valid algorithms
28    pub algs: Vec<Algorithm>,
29    pub key: DecodingKey,
30}
31
32fn get_valid_algs(key: &Jwk) -> Vec<Algorithm> {
33    if let Some(key_alg) = key.common.key_algorithm {
34        // if alg is not correct => no valid algs => empty array
35        Algorithm::from_str(key_alg.to_string().as_str()).map_or(vec![], |a| vec![a])
36    } else {
37        // guessing valid algs from key structure
38        match key.algorithm {
39            AlgorithmParameters::EllipticCurve(_) => {
40                vec![Algorithm::ES256, Algorithm::ES384]
41            }
42            AlgorithmParameters::RSA(_) => vec![
43                Algorithm::RS256,
44                Algorithm::RS384,
45                Algorithm::RS512,
46                Algorithm::PS256,
47                Algorithm::PS384,
48                Algorithm::PS512,
49            ],
50            AlgorithmParameters::OctetKey(_) => vec![Algorithm::EdDSA],
51            AlgorithmParameters::OctetKeyPair(_) => vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
52        }
53    }
54}
55
56impl KeyData {
57    pub fn from_jwk(key: &Jwk) -> Result<KeyData, jsonwebtoken::errors::Error> {
58        Ok(KeyData {
59            kid: key.common.key_id.clone(),
60            algs: get_valid_algs(key),
61            key: DecodingKey::from_jwk(key)?,
62        })
63    }
64}
65
66#[derive(Clone, Default)]
67pub struct KeySet(Vec<Arc<KeyData>>);
68
69impl From<Vec<Arc<KeyData>>> for KeySet {
70    fn from(value: Vec<Arc<KeyData>>) -> Self {
71        KeySet(value)
72    }
73}
74
75impl KeySet {
76    /// Find the key in the set that matches the given key id, if any.
77    pub fn find_kid(&self, kid: &str) -> Option<&Arc<KeyData>> {
78        self.0.iter().find(|k| match &k.kid {
79            Some(k) => k == kid,
80            None => false,
81        })
82    }
83
84    /// Find the key in the set that matches the given key id, if any.
85    pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc<KeyData>> {
86        self.0.iter().find(|k| k.algs.contains(alg))
87    }
88
89    /// Find first key.
90    pub fn first(&self) -> Option<&Arc<KeyData>> {
91        self.0.first()
92    }
93
94    pub(crate) fn get_key(&self, header: &Header) -> Result<&Arc<KeyData>, AuthError> {
95        let key = if let Some(ref kid) = header.kid {
96            self.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
97        } else {
98            self.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))?
99        };
100        Ok(key)
101    }
102}
103
104impl KeySource {
105    pub async fn get_key(&self, header: Header) -> Result<Arc<KeyData>, AuthError> {
106        match self {
107            KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
108            KeySource::MultiKeySource(keys) => keys.get_key(&header).cloned(),
109            KeySource::SingleKeySource(key) => Ok(key.clone()),
110        }
111    }
112}