jwt_authorizer/jwks/
mod.rs1use std::{str::FromStr, sync::Arc};
2
3use jsonwebtoken::{
4 jwk::{AlgorithmParameters, Jwk},
5 Algorithm, DecodingKey, Header,
6};
7
8use crate::error::AuthError;
9
10use self::key_store_manager::KeyStoreManager;
11
12pub mod key_store_manager;
13
14#[derive(Clone)]
15pub enum KeySource {
16 KeyStoreSource(KeyStoreManager),
18 MultiKeySource(KeySet),
20 SingleKeySource(Arc<KeyData>),
22}
23
24#[derive(Clone)]
25pub struct KeyData {
26 pub kid: Option<String>,
27 pub algs: Vec<Algorithm>,
29 pub key: DecodingKey,
30}
31
32fn get_valid_algs(key: &Jwk) -> Vec<Algorithm> {
33 if let Some(key_alg) = key.common.key_algorithm {
34 Algorithm::from_str(key_alg.to_string().as_str()).map_or(vec![], |a| vec![a])
36 } else {
37 match key.algorithm {
39 AlgorithmParameters::EllipticCurve(_) => {
40 vec![Algorithm::ES256, Algorithm::ES384]
41 }
42 AlgorithmParameters::RSA(_) => vec![
43 Algorithm::RS256,
44 Algorithm::RS384,
45 Algorithm::RS512,
46 Algorithm::PS256,
47 Algorithm::PS384,
48 Algorithm::PS512,
49 ],
50 AlgorithmParameters::OctetKey(_) => vec![Algorithm::EdDSA],
51 AlgorithmParameters::OctetKeyPair(_) => vec![Algorithm::HS256, Algorithm::HS384, Algorithm::HS512],
52 }
53 }
54}
55
56impl KeyData {
57 pub fn from_jwk(key: &Jwk) -> Result<KeyData, jsonwebtoken::errors::Error> {
58 Ok(KeyData {
59 kid: key.common.key_id.clone(),
60 algs: get_valid_algs(key),
61 key: DecodingKey::from_jwk(key)?,
62 })
63 }
64}
65
66#[derive(Clone, Default)]
67pub struct KeySet(Vec<Arc<KeyData>>);
68
69impl From<Vec<Arc<KeyData>>> for KeySet {
70 fn from(value: Vec<Arc<KeyData>>) -> Self {
71 KeySet(value)
72 }
73}
74
75impl KeySet {
76 pub fn find_kid(&self, kid: &str) -> Option<&Arc<KeyData>> {
78 self.0.iter().find(|k| match &k.kid {
79 Some(k) => k == kid,
80 None => false,
81 })
82 }
83
84 pub fn find_alg(&self, alg: &Algorithm) -> Option<&Arc<KeyData>> {
86 self.0.iter().find(|k| k.algs.contains(alg))
87 }
88
89 pub fn first(&self) -> Option<&Arc<KeyData>> {
91 self.0.first()
92 }
93
94 pub(crate) fn get_key(&self, header: &Header) -> Result<&Arc<KeyData>, AuthError> {
95 let key = if let Some(ref kid) = header.kid {
96 self.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))?
97 } else {
98 self.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))?
99 };
100 Ok(key)
101 }
102}
103
104impl KeySource {
105 pub async fn get_key(&self, header: Header) -> Result<Arc<KeyData>, AuthError> {
106 match self {
107 KeySource::KeyStoreSource(kstore) => kstore.get_key(&header).await,
108 KeySource::MultiKeySource(keys) => keys.get_key(&header).cloned(),
109 KeySource::SingleKeySource(key) => Ok(key.clone()),
110 }
111 }
112}