use crate::{
bare_key::{SerializedKey, SerializedKeys},
key::*,
Any, KeyError, OpaqueValidationFailureReason, SignatureError, SigningContext,
ValidationContext, ValidationError,
};
use std::{
collections::{BTreeSet, HashMap, HashSet},
fmt::Debug,
};
pub(crate) trait IsKey {
type Inner: std::hash::Hash + Eq + Debug + Clone;
fn inner(&self) -> &Self::Inner;
fn key_type(inner: &Self::Inner) -> KeyType;
fn from_inner(kid: Option<String>, inner: Self::Inner) -> Self;
fn into_inner(self) -> (Option<String>, Self::Inner);
fn get_serialized_key(key: SerializedKey) -> Option<Self>
where
Self: Sized;
fn to_serialized_key(kid: Option<&str>, inner: &Self::Inner) -> SerializedKey;
fn from_pem(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError>
where
Self: Sized;
fn to_pem(inner: &Self::Inner) -> String;
fn encoding_key(inner: &Self::Inner) -> Option<&jsonwebtoken::EncodingKey>;
fn decoding_key(inner: &Self::Inner) -> &jsonwebtoken::DecodingKey;
}
#[allow(private_bounds)]
pub struct KeyRegistry<K: IsKey> {
named_keys: HashMap<String, usize>,
unnamed_keys: HashSet<usize>,
key_to_ordinal: HashMap<K::Inner, (usize, Option<String>)>,
active_keys: BTreeSet<usize>,
next: usize,
}
impl<K: IsKey> Default for KeyRegistry<K> {
fn default() -> Self {
Self {
named_keys: HashMap::default(),
unnamed_keys: HashSet::default(),
key_to_ordinal: HashMap::default(),
active_keys: BTreeSet::default(),
next: 0,
}
}
}
impl KeyRegistry<PrivateKey> {
pub fn private() -> Self {
Self::default()
}
}
impl KeyRegistry<PublicKey> {
pub fn public() -> Self {
Self::default()
}
}
impl KeyRegistry<Key> {
pub fn new() -> Self {
Self::default()
}
}
#[allow(private_bounds)]
impl<K: IsKey> KeyRegistry<K> {
pub fn clear(&mut self) {
*self = Self::default();
}
pub fn into_keys(self) -> impl Iterator<Item = K> {
self.key_to_ordinal
.into_iter()
.map(|(key, (_, kid))| K::from_inner(kid, key))
}
pub fn add_key(&mut self, key: K) {
self.remove_key(&key);
let (kid, inner) = key.into_inner();
if let Some(kid) = &kid {
if self.named_keys.contains_key(kid) {
self.remove_kid(kid);
}
}
let ordinal = self.next;
self.next += 1;
self.key_to_ordinal.insert(inner, (ordinal, kid.clone()));
self.active_keys.insert(ordinal);
if let Some(kid) = kid {
self.named_keys.insert(kid, ordinal);
} else {
self.unnamed_keys.insert(ordinal);
}
}
pub fn remove_key(&mut self, key: &K) {
let inner = key.inner();
if let Some((ordinal, kid)) = self.key_to_ordinal.remove(inner) {
if let Some(kid) = kid {
self.named_keys.remove(&kid);
} else {
self.unnamed_keys.remove(&ordinal);
}
self.active_keys.remove(&ordinal);
}
}
pub fn remove_kid(&mut self, kid: &str) -> bool {
if let Some(ordinal) = self.named_keys.remove(kid) {
self.active_keys.remove(&ordinal);
self.key_to_ordinal.retain(|_, &mut (v, _)| v != ordinal);
true
} else {
false
}
}
pub fn len(&self) -> usize {
self.key_to_ordinal.len()
}
pub fn is_empty(&self) -> bool {
self.key_to_ordinal.is_empty()
}
pub fn add_from_jwkset(&mut self, jwkset: &str) -> Result<usize, KeyError> {
let loaded: SerializedKeys =
serde_json::from_str(jwkset).map_err(|_| KeyError::InvalidJson)?;
let mut added = 0;
for key in loaded.keys {
if let Some(key) = K::get_serialized_key(key) {
self.add_key(key);
added += 1;
} else {
}
}
Ok(added)
}
pub fn add_from_pem(&mut self, pem: &str) -> Result<usize, KeyError> {
let keys = K::from_pem(pem)?;
let mut added = 0;
for key in keys {
if let Ok(key) = key {
self.add_key(key);
added += 1;
} else {
}
}
Ok(added)
}
pub fn add_from_any(&mut self, source: &str) -> Result<usize, KeyError> {
let source = source.trim();
if source.is_empty() {
return Ok(0);
}
let first_char = source.chars().next().unwrap_or_default();
if first_char == '{' {
self.add_from_jwkset(source)
} else if first_char == '-' {
self.add_from_pem(source)
} else {
Err(KeyError::UnsupportedKeyType(format!(
"Expected JWK set or PEM file, got {first_char}"
)))
}
}
pub fn to_pem(&self) -> String {
let mut pem = String::new();
for (k, (_, _)) in &self.key_to_ordinal {
pem.push_str(&K::to_pem(k));
}
pem
}
pub fn to_json(&self) -> Result<String, KeyError> {
serde_json::to_string(&SerializedKeys {
keys: self
.key_to_ordinal
.iter()
.map(|(k, (_, kid))| K::to_serialized_key(kid.as_deref(), k))
.collect(),
})
.map_err(|_| KeyError::EncodeError)
}
fn active_key(&self) -> Option<(Option<&str>, &K::Inner)> {
if let Some(&i) = self.active_keys.last() {
for (k, &(v, ref kid)) in &self.key_to_ordinal {
if v == i {
if let Some(kid) = kid {
return Some((Some(kid.as_str()), k));
} else {
return Some((None, k));
}
}
}
}
None
}
pub fn unsafely_decode_without_validation(
&self,
token: &str,
) -> Result<HashMap<String, Any>, ValidationError> {
let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::default());
validation.insecure_disable_signature_validation();
validation.required_spec_claims.clear();
validation.validate_exp = false;
validation.validate_nbf = false;
validation.validate_aud = false;
let decoding_key = jsonwebtoken::DecodingKey::from_secret(b"");
let decoded =
jsonwebtoken::decode::<HashMap<String, Any>>(token, &decoding_key, &validation)
.map_err(|e| {
ValidationError::Invalid(OpaqueValidationFailureReason::Failure(format!(
"{:?}",
e.kind()
)))
})?;
Ok(decoded.claims)
}
pub fn validate(
&self,
token: &str,
ctx: &ValidationContext,
) -> Result<HashMap<String, Any>, ValidationError> {
if !self.named_keys.is_empty() {
if let Ok(header) = jsonwebtoken::decode_header(token) {
if let Some(header_kid) = header.kid {
for (key, (_, kid)) in &self.key_to_ordinal {
if kid.as_deref() == Some(header_kid.as_str()) {
return validate_token(
K::key_type(key),
K::decoding_key(key),
None,
token,
ctx,
);
}
}
}
}
}
let mut result = None;
for (key, _) in self.key_to_ordinal.iter() {
let last_result =
validate_token(K::key_type(key), K::decoding_key(key), None, token, ctx);
match last_result {
Ok(result) => return Ok(result),
Err(e) => result = Some(e),
}
}
Err(result.unwrap_or(OpaqueValidationFailureReason::NoAppropriateKey.into()))
}
pub fn sign(
&self,
claims: HashMap<String, Any>,
ctx: &SigningContext,
) -> Result<String, SignatureError> {
let (kid, key) = self.active_key().ok_or(SignatureError::NoAppropriateKey)?;
let encoding_key = K::encoding_key(key).ok_or(SignatureError::NoAppropriateKey)?;
sign_token(K::key_type(key), encoding_key, kid, claims, ctx)
}
}
impl KeyRegistry<PrivateKey> {
pub fn can_sign(&self) -> bool {
self.has_private_keys() || self.has_symmetric_keys()
}
pub fn can_validate(&self) -> bool {
self.has_public_keys() || self.has_symmetric_keys()
}
pub fn has_private_keys(&self) -> bool {
!self.is_empty()
}
pub fn has_public_keys(&self) -> bool {
self.key_to_ordinal
.iter()
.any(|(k, _)| k.bare_key.key_type() != KeyType::HS256)
}
pub fn has_symmetric_keys(&self) -> bool {
self.key_to_ordinal
.iter()
.any(|(k, _)| k.bare_key.key_type() == KeyType::HS256)
}
#[cfg(feature = "keygen")]
pub fn generate_key(&mut self, kid: Option<String>, key_type: KeyType) -> Result<(), KeyError> {
let key = PrivateKey::generate(kid, key_type)?;
self.add_key(key);
Ok(())
}
}
impl KeyRegistry<PublicKey> {
pub fn can_sign(&self) -> bool {
self.has_private_keys() || self.has_symmetric_keys()
}
pub fn can_validate(&self) -> bool {
self.has_public_keys() || self.has_symmetric_keys()
}
pub fn has_public_keys(&self) -> bool {
!self.is_empty()
}
pub fn has_private_keys(&self) -> bool {
false
}
pub fn has_symmetric_keys(&self) -> bool {
false
}
}
impl KeyRegistry<Key> {
pub fn can_sign(&self) -> bool {
self.has_private_keys() || self.has_symmetric_keys()
}
pub fn can_validate(&self) -> bool {
self.has_public_keys() || self.has_symmetric_keys()
}
pub fn has_private_keys(&self) -> bool {
for k in self.key_to_ordinal.keys() {
if let KeyInner::Private(_) = k {
return true;
}
}
false
}
pub fn has_public_keys(&self) -> bool {
for k in self.key_to_ordinal.keys() {
if let KeyInner::Public(_) = k {
return true;
}
if let KeyInner::Private(k) = k {
if k.bare_key.key_type() != KeyType::HS256 {
return true;
}
}
}
false
}
pub fn has_symmetric_keys(&self) -> bool {
for k in self.key_to_ordinal.keys() {
if let KeyInner::Private(k) = k {
if k.bare_key.key_type() == KeyType::HS256 {
return true;
}
}
}
false
}
pub fn to_pem_public(&self) -> Result<String, KeyError> {
let mut pem = String::new();
for (k, (_, _)) in &self.key_to_ordinal {
match k {
KeyInner::Private(k) => {
pem.push_str(&k.bare_key.to_pem_public()?);
}
KeyInner::Public(k) => {
pem.push_str(&k.bare_key.to_pem());
}
}
}
Ok(pem)
}
pub fn to_json_public(&self) -> Result<String, KeyError> {
let mut keys = Vec::new();
for (k, (_, kid)) in &self.key_to_ordinal {
match k {
KeyInner::Private(k) => {
keys.push(SerializedKey::Public(
kid.clone(),
k.bare_key.to_public()?.clone_key(),
));
}
KeyInner::Public(k) => {
keys.push(SerializedKey::Public(kid.clone(), k.bare_key.clone_key()));
}
}
}
serde_json::to_string(&SerializedKeys { keys }).map_err(|_| KeyError::EncodeError)
}
#[cfg(feature = "keygen")]
pub fn generate_key(&mut self, kid: Option<String>, key_type: KeyType) -> Result<(), KeyError> {
let key = PrivateKey::generate(kid, key_type)?;
self.add_key(key.into());
Ok(())
}
}