use std::borrow::Cow;
use snafu::prelude::*;
use crate::{
BoxedError,
crypto::{
KeyMatchStrength,
cipher::{AeadDecryptor, AeadEncryptor, AeadOutput, BoxedAeadDecryptor, CipherMatch},
},
};
pub struct MultiKeyDecryptor {
decryptors: Vec<BoxedAeadDecryptor>,
}
impl MultiKeyDecryptor {
#[must_use]
pub fn new(decryptors: Vec<BoxedAeadDecryptor>) -> Self {
Self { decryptors }
}
}
#[derive(Debug, Snafu)]
pub enum MultiKeyDecryptorError {
#[snafu(display("no matching key"))]
NoMatchingKey,
#[snafu(display("decryption failed"))]
DecryptionFailed {
source: BoxedError,
},
}
impl crate::Error for MultiKeyDecryptorError {
fn is_retryable(&self) -> bool {
match self {
Self::NoMatchingKey => false,
Self::DecryptionFailed { source } => source.is_retryable(),
}
}
}
enum SelectedDecryptor<'a> {
ByKeyId(&'a BoxedAeadDecryptor),
ByAlgorithm(Vec<&'a BoxedAeadDecryptor>),
None,
}
impl MultiKeyDecryptor {
fn select<'a>(&'a self, m: &CipherMatch<'_>) -> SelectedDecryptor<'a> {
let mut by_algorithm: Vec<&'a BoxedAeadDecryptor> = Vec::new();
for decryptor in &self.decryptors {
match decryptor.cipher_match(m) {
Some(KeyMatchStrength::ByKeyId) => {
return SelectedDecryptor::ByKeyId(decryptor);
}
Some(KeyMatchStrength::ByAlgorithm) => {
by_algorithm.push(decryptor);
}
None => {}
}
}
if by_algorithm.is_empty() {
SelectedDecryptor::None
} else {
SelectedDecryptor::ByAlgorithm(by_algorithm)
}
}
async fn try_decrypt(
decryptors: impl Iterator<Item = &BoxedAeadDecryptor>,
count: usize,
cipher_match: Option<&CipherMatch<'_>>,
nonce: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, MultiKeyDecryptorError> {
let mut last_error = None;
for decryptor in decryptors {
match decryptor
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
.await
{
Ok(plaintext) => return Ok(plaintext),
Err(e) => last_error = Some(e),
}
}
match last_error {
Some(source) if count == 1 => Err(MultiKeyDecryptorError::DecryptionFailed { source }),
_ => NoMatchingKeySnafu.fail(),
}
}
}
impl AeadDecryptor for MultiKeyDecryptor {
type Error = MultiKeyDecryptorError;
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
let mut by_algorithm = false;
for decryptor in &self.decryptors {
match decryptor.cipher_match(m) {
Some(KeyMatchStrength::ByKeyId) => return Some(KeyMatchStrength::ByKeyId),
Some(KeyMatchStrength::ByAlgorithm) => by_algorithm = true,
None => {}
}
}
by_algorithm.then_some(KeyMatchStrength::ByAlgorithm)
}
async fn decrypt(
&self,
cipher_match: Option<&CipherMatch<'_>>,
nonce: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, Self::Error> {
if let Some(m) = cipher_match {
match self.select(m) {
SelectedDecryptor::ByKeyId(decryptor) => {
return decryptor
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
.await
.map_err(|source| MultiKeyDecryptorError::DecryptionFailed { source });
}
SelectedDecryptor::ByAlgorithm(decryptors) => {
let count = decryptors.len();
return Self::try_decrypt(
decryptors.into_iter(),
count,
cipher_match,
nonce,
ciphertext,
tag,
aad,
)
.await;
}
SelectedDecryptor::None => return NoMatchingKeySnafu.fail(),
}
}
Self::try_decrypt(
self.decryptors.iter(),
self.decryptors.len(),
None,
nonce,
ciphertext,
tag,
aad,
)
.await
}
}
pub struct MultiKeyCipher<E> {
encryptor: E,
decryptor: MultiKeyDecryptor,
}
impl<E> MultiKeyCipher<E> {
pub fn new(encryptor: E, decryptor: MultiKeyDecryptor) -> Self {
Self {
encryptor,
decryptor,
}
}
}
impl<E: AeadEncryptor> AeadEncryptor for MultiKeyCipher<E> {
type Error = E::Error;
fn enc_algorithm(&self) -> Cow<'_, str> {
self.encryptor.enc_algorithm()
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.encryptor.key_id()
}
async fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<AeadOutput, Self::Error> {
self.encryptor.encrypt(plaintext, aad).await
}
}
impl<E: AeadEncryptor> AeadDecryptor for MultiKeyCipher<E> {
type Error = MultiKeyDecryptorError;
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
self.decryptor.cipher_match(m)
}
async fn decrypt(
&self,
cipher_match: Option<&CipherMatch<'_>>,
nonce: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, Self::Error> {
self.decryptor
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
.await
}
}