use futures_util::future::join_all;
use snafu::prelude::*;
use crate::{
BoxedError,
crypto::{
KeyMatchStrength,
verifier::{
BoxedJwsVerifier, CreateVerifierError, JwsVerifier, JwsVerifierPlatform, KeyMatch,
VerifyError,
},
},
error::Error as _,
jwk::PublicJwks,
};
#[derive(Debug)]
pub struct MultiKeyVerifier {
verifiers: Vec<BoxedJwsVerifier>,
try_all_on_ambiguous_match: bool,
}
#[derive(Debug, Snafu)]
pub enum MultiKeyVerifierError {
#[snafu(display("Failed to create verifier from JWK"))]
CreateVerifier {
source: CreateVerifierError,
},
}
impl crate::Error for MultiKeyVerifierError {
fn is_retryable(&self) -> bool {
match self {
Self::CreateVerifier { source } => source.is_retryable(),
}
}
}
enum GetVerifierResult {
ByKeyId(BoxedJwsVerifier),
ByAlgorithm(Vec<BoxedJwsVerifier>),
None,
}
impl GetVerifierResult {
fn key_match_strength(&self) -> Option<KeyMatchStrength> {
match self {
Self::ByKeyId(_) => Some(KeyMatchStrength::ByKeyId),
Self::ByAlgorithm(_) => Some(KeyMatchStrength::ByAlgorithm),
Self::None => None,
}
}
}
impl MultiKeyVerifier {
#[must_use]
pub fn new(verifiers: Vec<BoxedJwsVerifier>) -> Self {
Self {
verifiers,
try_all_on_ambiguous_match: false,
}
}
pub async fn from_jwks(
jwks: &PublicJwks,
platform: &dyn JwsVerifierPlatform,
) -> Result<Self, MultiKeyVerifierError> {
let verifiers: Vec<BoxedJwsVerifier> = join_all(
jwks.keys
.iter()
.map(|jwk| platform.create_verifier_from_jwk(jwk.clone())),
)
.await
.into_iter()
.filter_map(|result| match result {
Ok(v) => Some(Ok(v)),
Err(CreateVerifierError::UnsupportedKey) => None,
Err(e) => Some(Err(e)),
})
.collect::<Result<_, _>>()
.context(CreateVerifierSnafu)?;
Ok(Self {
verifiers,
try_all_on_ambiguous_match: false,
})
}
#[must_use]
pub fn try_all_on_ambiguous_match(mut self, value: bool) -> Self {
self.try_all_on_ambiguous_match = value;
self
}
async fn dispatch_verify(
&self,
input: &[u8],
signature: &[u8],
key_match: &KeyMatch<'_>,
) -> Result<(), VerifyError<BoxedError>> {
let by_algorithm_verifiers = match self.get_verifier(key_match) {
GetVerifierResult::ByKeyId(verifier) => {
return verifier.verify(input, signature, key_match).await;
}
GetVerifierResult::ByAlgorithm(verifiers) => verifiers,
GetVerifierResult::None => return Err(VerifyError::NoMatchingKey),
};
if by_algorithm_verifiers.len() > 1 && !self.try_all_on_ambiguous_match {
return Err(VerifyError::AmbiguousKeyMatch);
}
let mut last_retryable = None;
let mut last_non_retryable = None;
for verifier in by_algorithm_verifiers {
match verifier.verify(input, signature, key_match).await {
Ok(()) => return Ok(()),
Err(VerifyError::NoMatchingKey) => {}
Err(e) => {
if e.is_retryable() {
last_retryable = Some(e);
} else {
last_non_retryable = Some(e);
}
}
}
}
Err(last_non_retryable
.or(last_retryable)
.unwrap_or(VerifyError::NoMatchingKey))
}
fn get_verifier(&self, key_match: &KeyMatch) -> GetVerifierResult {
let mut by_algorithm_verifiers: Vec<BoxedJwsVerifier> = Vec::new();
for verifier in &self.verifiers {
match verifier.key_match(key_match) {
Some(KeyMatchStrength::ByKeyId) => {
return GetVerifierResult::ByKeyId(verifier.clone());
}
Some(KeyMatchStrength::ByAlgorithm) => {
by_algorithm_verifiers.push(verifier.clone());
}
None => {}
}
}
if by_algorithm_verifiers.is_empty() {
GetVerifierResult::None
} else {
GetVerifierResult::ByAlgorithm(by_algorithm_verifiers)
}
}
}
impl JwsVerifier for MultiKeyVerifier {
type Error = BoxedError;
fn key_match(&self, key_match: &KeyMatch<'_>) -> Option<KeyMatchStrength> {
self.get_verifier(key_match).key_match_strength()
}
async fn verify(
&self,
input: &[u8],
signature: &[u8],
key_match: &KeyMatch<'_>,
) -> Result<(), VerifyError<Self::Error>> {
self.dispatch_verify(input, signature, key_match).await
}
async fn try_refresh(&self) -> bool {
join_all(self.verifiers.iter().map(JwsVerifier::try_refresh))
.await
.into_iter()
.any(|b| b)
}
}