use std::borrow::Cow;
use std::sync::Arc;
use bon::bon;
use google_cloud_kms_v1::{
client::KeyManagementService, model::crypto_key_version::CryptoKeyVersionAlgorithm,
};
use huskarl_core::crypto::KeyMatchStrength;
use huskarl_core::crypto::cipher::{
AeadCipherSelector, AeadDecryptor, AeadEncryptor, AeadOutput, BoxedAeadDecryptor, CipherMatch,
MultiKeyDecryptor, MultiKeyDecryptorError,
};
use snafu::prelude::*;
use super::super::version::{self, VersionStrategy};
use super::setup;
use super::{
GetCryptoKeyVersionSnafu, ListCryptoKeyVersionsSnafu, NoEnabledCryptoKeyVersionsSnafu,
ResolveVersionSnafu, UnsupportedAlgorithmSnafu,
};
pub use super::{KeyError, SetupError};
type KidMapper = Arc<dyn Fn(&str) -> String + Send + Sync>;
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum EncryptionError {
RawEncrypt {
source: google_cloud_kms_v1::Error,
},
MismatchedKeyInfo,
}
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum DecryptionError {
RawDecrypt {
source: google_cloud_kms_v1::Error,
},
}
impl huskarl_core::Error for EncryptionError {
fn is_retryable(&self) -> bool {
match self {
EncryptionError::RawEncrypt { source } => source.is_timeout() || source.is_exhausted(),
EncryptionError::MismatchedKeyInfo => false,
}
}
}
impl huskarl_core::Error for DecryptionError {
fn is_retryable(&self) -> bool {
match self {
DecryptionError::RawDecrypt { source } => source.is_timeout() || source.is_exhausted(),
}
}
}
#[derive(Debug, Clone)]
pub struct KeyVersion {
kms_client: KeyManagementService,
resource_name: String,
enc_algorithm: String,
key_id: Option<String>,
}
#[bon]
impl KeyVersion {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
resource_name: String,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
) -> Result<Self, SetupError> {
build_key_version(resource_name, kms_client, with_kid_from_key_version).await
}
}
impl AeadCipherSelector for KeyVersion {
type Encryptor = Self;
fn select_cipher(&self) -> Self::Encryptor {
self.clone()
}
}
impl AeadEncryptor for KeyVersion {
type Error = EncryptionError;
fn enc_algorithm(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.enc_algorithm)
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.key_id.as_deref().map(Cow::Borrowed)
}
async fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<AeadOutput, Self::Error> {
let response = self
.kms_client
.raw_encrypt()
.set_name(&self.resource_name)
.set_plaintext(plaintext.to_vec())
.set_additional_authenticated_data(aad.to_vec())
.send()
.await
.context(RawEncryptSnafu)?;
ensure!(response.name == self.resource_name, MismatchedKeyInfoSnafu);
let tag_length = usize::try_from(response.tag_length).unwrap_or(0);
let ct_with_tag = response.ciphertext;
let split_at = ct_with_tag.len().saturating_sub(tag_length);
let ciphertext = ct_with_tag[..split_at].to_vec();
let tag = ct_with_tag[split_at..].to_vec();
let nonce = response.initialization_vector.to_vec();
Ok(AeadOutput {
nonce,
ciphertext,
tag,
})
}
}
impl AeadDecryptor for KeyVersion {
type Error = DecryptionError;
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
if let Some(enc) = m.enc
&& enc != self.enc_algorithm
{
return None;
}
match (m.kid, self.key_id.as_deref()) {
(Some(jwt_kid), Some(my_kid)) if jwt_kid != my_kid => None,
(Some(_), Some(_)) => Some(KeyMatchStrength::ByKeyId),
_ => Some(KeyMatchStrength::ByAlgorithm),
}
}
async fn decrypt(
&self,
_cipher_match: Option<&CipherMatch<'_>>,
nonce: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, Self::Error> {
let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + tag.len());
ct_with_tag.extend_from_slice(ciphertext);
ct_with_tag.extend_from_slice(tag);
let response = self
.kms_client
.raw_decrypt()
.set_name(&self.resource_name)
.set_ciphertext(ct_with_tag)
.set_initialization_vector(nonce.to_vec())
.set_additional_authenticated_data(aad.to_vec())
.set_tag_length(i32::try_from(tag.len()).unwrap_or(16))
.send()
.await
.context(RawDecryptSnafu)?;
Ok(response.plaintext.to_vec())
}
}
#[derive(Debug, Clone)]
pub struct EncryptionKey {
key_version: KeyVersion,
}
#[bon]
impl EncryptionKey {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(default)]
strategy: VersionStrategy,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
) -> Result<Self, KeyError> {
let key_version = resolve_encryption_key_version(
&key_name,
&kms_client,
&strategy,
with_kid_from_key_version.as_ref(),
)
.await?;
Ok(Self { key_version })
}
}
impl AeadCipherSelector for EncryptionKey {
type Encryptor = KeyVersion;
fn select_cipher(&self) -> KeyVersion {
self.key_version.clone()
}
}
impl AeadEncryptor for EncryptionKey {
type Error = EncryptionError;
fn enc_algorithm(&self) -> Cow<'_, str> {
self.key_version.enc_algorithm()
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.key_version.key_id()
}
async fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<AeadOutput, Self::Error> {
self.key_version.encrypt(plaintext, aad).await
}
}
#[derive(Clone)]
pub struct DecryptionKey {
decryptor: Arc<MultiKeyDecryptor>,
}
impl std::fmt::Debug for DecryptionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecryptionKey").finish_non_exhaustive()
}
}
#[bon]
impl DecryptionKey {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
max_versions: Option<usize>,
) -> Result<Self, KeyError> {
let versions = resolve_decryption_key_versions(
&key_name,
&kms_client,
with_kid_from_key_version.as_ref(),
max_versions,
)
.await?;
let decryptor = Arc::new(MultiKeyDecryptor::new(
versions.into_iter().map(BoxedAeadDecryptor::new).collect(),
));
Ok(Self { decryptor })
}
}
impl AeadDecryptor for DecryptionKey {
type Error = MultiKeyDecryptorError;
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
self.decryptor.cipher_match(m)
}
async fn decrypt(
&self,
cipher_match: Option<&CipherMatch<'_>>,
nonce: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, Self::Error> {
self.decryptor
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
.await
}
}
#[derive(Clone)]
pub struct CipherKey {
encryption: EncryptionKey,
decryption: DecryptionKey,
}
impl std::fmt::Debug for CipherKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CipherKey")
.field("encryption", &self.encryption)
.finish_non_exhaustive()
}
}
#[bon]
impl CipherKey {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(default)]
strategy: VersionStrategy,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
max_versions: Option<usize>,
) -> Result<Self, KeyError> {
let kid_mapper = with_kid_from_key_version.as_ref();
let (enc_kv, dec_kvs) = futures_util::try_join!(
resolve_encryption_key_version(&key_name, &kms_client, &strategy, kid_mapper),
resolve_decryption_key_versions(&key_name, &kms_client, kid_mapper, max_versions),
)?;
let encryption = EncryptionKey {
key_version: enc_kv,
};
let decryption = DecryptionKey {
decryptor: Arc::new(MultiKeyDecryptor::new(
dec_kvs.into_iter().map(BoxedAeadDecryptor::new).collect(),
)),
};
Ok(Self {
encryption,
decryption,
})
}
}
impl AeadCipherSelector for CipherKey {
type Encryptor = KeyVersion;
fn select_cipher(&self) -> KeyVersion {
self.encryption.key_version.clone()
}
}
impl AeadEncryptor for CipherKey {
type Error = EncryptionError;
fn enc_algorithm(&self) -> Cow<'_, str> {
self.encryption.enc_algorithm()
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.encryption.key_id()
}
async fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<AeadOutput, Self::Error> {
self.encryption.encrypt(plaintext, aad).await
}
}
impl AeadDecryptor for CipherKey {
type Error = MultiKeyDecryptorError;
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
self.decryption.cipher_match(m)
}
async fn decrypt(
&self,
cipher_match: Option<&CipherMatch<'_>>,
nonce: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, Self::Error> {
self.decryption
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
.await
}
}
async fn resolve_encryption_key_version(
key_name: &str,
kms_client: &KeyManagementService,
strategy: &VersionStrategy,
kid_mapper: Option<&KidMapper>,
) -> Result<KeyVersion, KeyError> {
let version_id = version::resolve_version(key_name, strategy, kms_client)
.await
.context(ResolveVersionSnafu)?;
let resource_name = format!("{key_name}/cryptoKeyVersions/{version_id}");
let kv_response = kms_client
.get_crypto_key_version()
.set_name(&resource_name)
.send()
.await
.context(GetCryptoKeyVersionSnafu)?;
let resolved_name = if kv_response.name.is_empty() {
resource_name
} else {
kv_response.name
};
let vid = version::version_id_from_resource_name(&resolved_name);
let key_id = kid_mapper.map(|f| f(vid));
let enc_algorithm = get_enc_algorithm(&kv_response.algorithm).ok_or_else(|| {
UnsupportedAlgorithmSnafu {
algorithm: kv_response.algorithm,
}
.build()
})?;
Ok(KeyVersion {
kms_client: kms_client.clone(),
resource_name: resolved_name,
enc_algorithm: enc_algorithm.to_string(),
key_id,
})
}
async fn resolve_decryption_key_versions(
key_name: &str,
kms_client: &KeyManagementService,
kid_mapper: Option<&KidMapper>,
max_versions: Option<usize>,
) -> Result<Vec<KeyVersion>, KeyError> {
let raw =
version::list_enabled_kms_versions(kms_client, key_name, max_versions, Some("name desc"))
.await
.context(ListCryptoKeyVersionsSnafu)?;
ensure!(!raw.is_empty(), NoEnabledCryptoKeyVersionsSnafu);
Ok(raw
.iter()
.filter_map(|v| {
let enc_algorithm = get_enc_algorithm(&v.algorithm)?;
let vid = version::version_id_from_resource_name(&v.name);
let key_id = kid_mapper.map(|f| f(vid));
Some(KeyVersion {
kms_client: kms_client.clone(),
resource_name: v.name.clone(),
enc_algorithm: enc_algorithm.to_string(),
key_id,
})
})
.collect())
}
async fn build_key_version(
resource_name: String,
kms_client: KeyManagementService,
with_kid_from_key_version: Option<KidMapper>,
) -> Result<KeyVersion, SetupError> {
let kv_response = kms_client
.get_crypto_key_version()
.set_name(&resource_name)
.send()
.await
.context(setup::GetCryptoKeyVersionSnafu)?;
let resolved_name = if kv_response.name.is_empty() {
resource_name
} else {
kv_response.name
};
let version_id = version::version_id_from_resource_name(&resolved_name);
let key_id = with_kid_from_key_version.map(|f| f(version_id));
let enc_algorithm =
get_enc_algorithm(&kv_response.algorithm).context(setup::UnsupportedAlgorithmSnafu {
algorithm: kv_response.algorithm,
})?;
Ok(KeyVersion {
kms_client,
resource_name: resolved_name,
enc_algorithm: enc_algorithm.to_string(),
key_id,
})
}
fn get_enc_algorithm(algorithm: &CryptoKeyVersionAlgorithm) -> Option<&'static str> {
use CryptoKeyVersionAlgorithm::{Aes128Gcm, Aes256Gcm};
match algorithm {
Aes128Gcm => Some("A128GCM"),
Aes256Gcm => Some("A256GCM"),
_ => None,
}
}