use std::borrow::Cow;
use std::sync::Arc;
use bon::bon;
use google_cloud_kms_v1::{
client::KeyManagementService, model::crypto_key_version::CryptoKeyVersionAlgorithm,
};
use huskarl_core::crypto::KeyMatchStrength;
use huskarl_core::crypto::cipher::{
AeadCipherSelector, AeadDecryptor, AeadEncryptor, AeadOutput, CipherMatch, DecryptError,
MultiKeyDecryptor,
};
use huskarl_core::platform::MaybeSendBoxFuture;
use snafu::prelude::*;
use super::super::version::{self, VersionStrategy};
use super::setup;
use super::{
GetCryptoKeyVersionSnafu, ListCryptoKeyVersionsSnafu, NoEnabledCryptoKeyVersionsSnafu,
ResolveVersionSnafu, UnsupportedAlgorithmSnafu,
};
pub use super::{KeyError, SetupError};
type KidMapper = Arc<dyn Fn(&str) -> String + Send + Sync>;
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum EncryptionError {
RawEncrypt {
source: google_cloud_kms_v1::Error,
},
MismatchedKeyInfo,
}
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum DecryptionError {
RawDecrypt {
source: google_cloud_kms_v1::Error,
},
}
impl EncryptionError {
#[must_use]
pub fn is_retryable(&self) -> bool {
match self {
EncryptionError::RawEncrypt { source } => source.is_timeout() || source.is_exhausted(),
EncryptionError::MismatchedKeyInfo => false,
}
}
}
impl From<EncryptionError> for huskarl_core::Error {
fn from(err: EncryptionError) -> Self {
let kind = if err.is_retryable() {
huskarl_core::ErrorKind::Transport { retryable: true }
} else {
huskarl_core::ErrorKind::Crypto
};
huskarl_core::Error::new(kind, err)
}
}
impl DecryptionError {
#[must_use]
pub fn is_retryable(&self) -> bool {
match self {
DecryptionError::RawDecrypt { source } => source.is_timeout() || source.is_exhausted(),
}
}
}
impl From<DecryptionError> for huskarl_core::Error {
fn from(err: DecryptionError) -> Self {
let kind = if err.is_retryable() {
huskarl_core::ErrorKind::Transport { retryable: true }
} else {
huskarl_core::ErrorKind::Crypto
};
huskarl_core::Error::new(kind, err)
}
}
#[derive(Debug, Clone)]
pub struct KeyVersion {
kms_client: KeyManagementService,
resource_name: String,
enc_algorithm: String,
key_id: Option<String>,
}
#[bon]
impl KeyVersion {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
resource_name: String,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
) -> Result<Self, SetupError> {
build_key_version(resource_name, kms_client, with_kid_from_key_version).await
}
}
impl AeadCipherSelector for KeyVersion {
fn select_cipher(&self) -> Arc<dyn AeadEncryptor> {
Arc::new(self.clone())
}
}
impl AeadEncryptor for KeyVersion {
fn enc_algorithm(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.enc_algorithm)
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.key_id.as_deref().map(Cow::Borrowed)
}
fn encrypt<'a>(
&'a self,
plaintext: &'a [u8],
aad: &'a [u8],
) -> MaybeSendBoxFuture<'a, Result<AeadOutput, huskarl_core::Error>> {
Box::pin(async move {
let response = self
.kms_client
.raw_encrypt()
.set_name(&self.resource_name)
.set_plaintext(plaintext.to_vec())
.set_additional_authenticated_data(aad.to_vec())
.send()
.await
.context(RawEncryptSnafu)?;
if response.name != self.resource_name {
return Err(EncryptionError::MismatchedKeyInfo.into());
}
let tag_length = usize::try_from(response.tag_length).unwrap_or(0);
let ct_with_tag = response.ciphertext;
let split_at = ct_with_tag.len().saturating_sub(tag_length);
let ciphertext = ct_with_tag[..split_at].to_vec();
let tag = ct_with_tag[split_at..].to_vec();
let nonce = response.initialization_vector.to_vec();
Ok(AeadOutput {
nonce,
ciphertext,
tag,
})
})
}
}
impl AeadDecryptor for KeyVersion {
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
if let Some(enc) = m.enc
&& enc != self.enc_algorithm
{
return None;
}
match (m.kid, self.key_id.as_deref()) {
(Some(jwt_kid), Some(my_kid)) if jwt_kid != my_kid => None,
(Some(_), Some(_)) => Some(KeyMatchStrength::ByKeyId),
_ => Some(KeyMatchStrength::ByAlgorithm),
}
}
fn decrypt<'a>(
&'a self,
_cipher_match: Option<&'a CipherMatch<'a>>,
nonce: &'a [u8],
ciphertext: &'a [u8],
tag: &'a [u8],
aad: &'a [u8],
) -> MaybeSendBoxFuture<'a, Result<Vec<u8>, DecryptError>> {
Box::pin(async move {
let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + tag.len());
ct_with_tag.extend_from_slice(ciphertext);
ct_with_tag.extend_from_slice(tag);
let response = self
.kms_client
.raw_decrypt()
.set_name(&self.resource_name)
.set_ciphertext(ct_with_tag)
.set_initialization_vector(nonce.to_vec())
.set_additional_authenticated_data(aad.to_vec())
.set_tag_length(i32::try_from(tag.len()).unwrap_or(16))
.send()
.await
.context(RawDecryptSnafu)
.map_err(huskarl_core::Error::from)?;
Ok(response.plaintext.to_vec())
})
}
}
#[derive(Debug, Clone)]
pub struct EncryptionKey {
key_version: KeyVersion,
}
#[bon]
impl EncryptionKey {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(default)]
strategy: VersionStrategy,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
) -> Result<Self, KeyError> {
let key_version = resolve_encryption_key_version(
&key_name,
&kms_client,
&strategy,
with_kid_from_key_version.as_ref(),
)
.await?;
Ok(Self { key_version })
}
}
impl AeadCipherSelector for EncryptionKey {
fn select_cipher(&self) -> Arc<dyn AeadEncryptor> {
Arc::new(self.key_version.clone())
}
}
impl AeadEncryptor for EncryptionKey {
fn enc_algorithm(&self) -> Cow<'_, str> {
self.key_version.enc_algorithm()
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.key_version.key_id()
}
fn encrypt<'a>(
&'a self,
plaintext: &'a [u8],
aad: &'a [u8],
) -> MaybeSendBoxFuture<'a, Result<AeadOutput, huskarl_core::Error>> {
self.key_version.encrypt(plaintext, aad)
}
}
#[derive(Clone)]
pub struct DecryptionKey {
decryptor: Arc<MultiKeyDecryptor>,
}
impl std::fmt::Debug for DecryptionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecryptionKey").finish_non_exhaustive()
}
}
#[bon]
impl DecryptionKey {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
max_versions: Option<usize>,
) -> Result<Self, KeyError> {
let versions = resolve_decryption_key_versions(
&key_name,
&kms_client,
with_kid_from_key_version.as_ref(),
max_versions,
)
.await?;
let decryptor = Arc::new(MultiKeyDecryptor::new(
versions
.into_iter()
.map(|kv| Arc::new(kv) as Arc<dyn AeadDecryptor>)
.collect(),
));
Ok(Self { decryptor })
}
}
impl AeadDecryptor for DecryptionKey {
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
self.decryptor.cipher_match(m)
}
fn decrypt<'a>(
&'a self,
cipher_match: Option<&'a CipherMatch<'a>>,
nonce: &'a [u8],
ciphertext: &'a [u8],
tag: &'a [u8],
aad: &'a [u8],
) -> MaybeSendBoxFuture<'a, Result<Vec<u8>, DecryptError>> {
self.decryptor
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
}
}
#[derive(Clone)]
pub struct CipherKey {
encryption: EncryptionKey,
decryption: DecryptionKey,
}
impl std::fmt::Debug for CipherKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CipherKey")
.field("encryption", &self.encryption)
.finish_non_exhaustive()
}
}
#[bon]
impl CipherKey {
#[builder(finish_fn = build)]
pub async fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(default)]
strategy: VersionStrategy,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<KidMapper>,
max_versions: Option<usize>,
) -> Result<Self, KeyError> {
let kid_mapper = with_kid_from_key_version.as_ref();
let (enc_kv, dec_kvs) = futures_util::try_join!(
resolve_encryption_key_version(&key_name, &kms_client, &strategy, kid_mapper),
resolve_decryption_key_versions(&key_name, &kms_client, kid_mapper, max_versions),
)?;
let encryption = EncryptionKey {
key_version: enc_kv,
};
let decryption = DecryptionKey {
decryptor: Arc::new(MultiKeyDecryptor::new(
dec_kvs
.into_iter()
.map(|kv| Arc::new(kv) as Arc<dyn AeadDecryptor>)
.collect(),
)),
};
Ok(Self {
encryption,
decryption,
})
}
}
impl AeadCipherSelector for CipherKey {
fn select_cipher(&self) -> Arc<dyn AeadEncryptor> {
Arc::new(self.encryption.key_version.clone())
}
}
impl AeadEncryptor for CipherKey {
fn enc_algorithm(&self) -> Cow<'_, str> {
self.encryption.enc_algorithm()
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.encryption.key_id()
}
fn encrypt<'a>(
&'a self,
plaintext: &'a [u8],
aad: &'a [u8],
) -> MaybeSendBoxFuture<'a, Result<AeadOutput, huskarl_core::Error>> {
self.encryption.encrypt(plaintext, aad)
}
}
impl AeadDecryptor for CipherKey {
fn cipher_match(&self, m: &CipherMatch<'_>) -> Option<KeyMatchStrength> {
self.decryption.cipher_match(m)
}
fn decrypt<'a>(
&'a self,
cipher_match: Option<&'a CipherMatch<'a>>,
nonce: &'a [u8],
ciphertext: &'a [u8],
tag: &'a [u8],
aad: &'a [u8],
) -> MaybeSendBoxFuture<'a, Result<Vec<u8>, DecryptError>> {
self.decryption
.decrypt(cipher_match, nonce, ciphertext, tag, aad)
}
}
async fn resolve_encryption_key_version(
key_name: &str,
kms_client: &KeyManagementService,
strategy: &VersionStrategy,
kid_mapper: Option<&KidMapper>,
) -> Result<KeyVersion, KeyError> {
let version_id = version::resolve_version(key_name, strategy, kms_client)
.await
.context(ResolveVersionSnafu)?;
let resource_name = format!("{key_name}/cryptoKeyVersions/{version_id}");
let kv_response = kms_client
.get_crypto_key_version()
.set_name(&resource_name)
.send()
.await
.context(GetCryptoKeyVersionSnafu)?;
let resolved_name = if kv_response.name.is_empty() {
resource_name
} else {
kv_response.name
};
let vid = version::version_id_from_resource_name(&resolved_name);
let key_id = kid_mapper.map(|f| f(vid));
let enc_algorithm = get_enc_algorithm(&kv_response.algorithm).ok_or_else(|| {
UnsupportedAlgorithmSnafu {
algorithm: kv_response.algorithm,
}
.build()
})?;
Ok(KeyVersion {
kms_client: kms_client.clone(),
resource_name: resolved_name,
enc_algorithm: enc_algorithm.to_string(),
key_id,
})
}
async fn resolve_decryption_key_versions(
key_name: &str,
kms_client: &KeyManagementService,
kid_mapper: Option<&KidMapper>,
max_versions: Option<usize>,
) -> Result<Vec<KeyVersion>, KeyError> {
let raw =
version::list_enabled_kms_versions(kms_client, key_name, max_versions, Some("name desc"))
.await
.context(ListCryptoKeyVersionsSnafu)?;
ensure!(!raw.is_empty(), NoEnabledCryptoKeyVersionsSnafu);
Ok(raw
.iter()
.filter_map(|v| {
let enc_algorithm = get_enc_algorithm(&v.algorithm)?;
let vid = version::version_id_from_resource_name(&v.name);
let key_id = kid_mapper.map(|f| f(vid));
Some(KeyVersion {
kms_client: kms_client.clone(),
resource_name: v.name.clone(),
enc_algorithm: enc_algorithm.to_string(),
key_id,
})
})
.collect())
}
async fn build_key_version(
resource_name: String,
kms_client: KeyManagementService,
with_kid_from_key_version: Option<KidMapper>,
) -> Result<KeyVersion, SetupError> {
let kv_response = kms_client
.get_crypto_key_version()
.set_name(&resource_name)
.send()
.await
.context(setup::GetCryptoKeyVersionSnafu)?;
let resolved_name = if kv_response.name.is_empty() {
resource_name
} else {
kv_response.name
};
let version_id = version::version_id_from_resource_name(&resolved_name);
let key_id = with_kid_from_key_version.map(|f| f(version_id));
let enc_algorithm =
get_enc_algorithm(&kv_response.algorithm).context(setup::UnsupportedAlgorithmSnafu {
algorithm: kv_response.algorithm,
})?;
Ok(KeyVersion {
kms_client,
resource_name: resolved_name,
enc_algorithm: enc_algorithm.to_string(),
key_id,
})
}
fn get_enc_algorithm(algorithm: &CryptoKeyVersionAlgorithm) -> Option<&'static str> {
use CryptoKeyVersionAlgorithm::{Aes128Gcm, Aes256Gcm};
match algorithm {
Aes128Gcm => Some("A128GCM"),
Aes256Gcm => Some("A256GCM"),
_ => None,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::future::Future;
use google_cloud_gax::Result as GaxResult;
use google_cloud_gax::options::RequestOptions;
use google_cloud_gax::response::Response;
use google_cloud_kms_v1::model::{
RawDecryptRequest, RawDecryptResponse, RawEncryptRequest, RawEncryptResponse,
};
use google_cloud_kms_v1::stub::KeyManagementService as KmsStub;
use huskarl_core::ErrorKind;
use rstest::rstest;
use super::*;
#[derive(Debug, Clone, Default)]
struct MockKms {
response_name: String,
ciphertext: Vec<u8>,
tag_length: i32,
iv: Vec<u8>,
}
impl KmsStub for MockKms {
fn raw_encrypt(
&self,
_req: RawEncryptRequest,
_options: RequestOptions,
) -> impl Future<Output = GaxResult<Response<RawEncryptResponse>>> + Send {
let resp = RawEncryptResponse::default()
.set_name(self.response_name.clone())
.set_ciphertext(self.ciphertext.clone())
.set_tag_length(self.tag_length)
.set_initialization_vector(self.iv.clone());
async move { Ok(Response::from(resp)) }
}
fn raw_decrypt(
&self,
req: RawDecryptRequest,
_options: RequestOptions,
) -> impl Future<Output = GaxResult<Response<RawDecryptResponse>>> + Send {
let resp = RawDecryptResponse::default().set_plaintext(req.ciphertext);
async move { Ok(Response::from(resp)) }
}
}
fn key_version(mock: MockKms, enc_algorithm: &str, key_id: Option<&str>) -> KeyVersion {
KeyVersion {
kms_client: KeyManagementService::from_stub(mock),
resource_name: "projects/p/.../cryptoKeyVersions/1".to_owned(),
enc_algorithm: enc_algorithm.to_owned(),
key_id: key_id.map(str::to_owned),
}
}
#[rstest]
#[case(CryptoKeyVersionAlgorithm::Aes128Gcm, Some("A128GCM"))]
#[case(CryptoKeyVersionAlgorithm::Aes256Gcm, Some("A256GCM"))]
#[case(CryptoKeyVersionAlgorithm::HmacSha256, None)]
fn get_enc_algorithm_maps_aead_algorithms(
#[case] algorithm: CryptoKeyVersionAlgorithm,
#[case] expected: Option<&str>,
) {
assert_eq!(get_enc_algorithm(&algorithm), expected);
}
#[test]
fn encryption_error_classifies_as_crypto() {
let err = EncryptionError::MismatchedKeyInfo;
assert!(!err.is_retryable());
assert_eq!(huskarl_core::Error::from(err).kind(), ErrorKind::Crypto);
}
#[rstest]
#[case(Some("A256GCM"), Some("k1"), Some("k1"), Some(KeyMatchStrength::ByKeyId))]
#[case(Some("A256GCM"), None, Some("k1"), Some(KeyMatchStrength::ByAlgorithm))]
#[case(Some("A256GCM"), Some("k2"), Some("k1"), None)]
#[case(Some("A128GCM"), None, None, None)]
#[case(None, None, None, Some(KeyMatchStrength::ByAlgorithm))]
fn cipher_match_applies_alg_and_kid_rules(
#[case] req_enc: Option<&str>,
#[case] req_kid: Option<&str>,
#[case] registered_kid: Option<&str>,
#[case] expected: Option<KeyMatchStrength>,
) {
let kv = key_version(MockKms::default(), "A256GCM", registered_kid);
let m = CipherMatch::builder()
.maybe_enc(req_enc)
.maybe_kid(req_kid)
.build();
assert_eq!(kv.cipher_match(&m), expected);
}
#[tokio::test]
async fn encrypt_splits_tag_off_the_ciphertext() {
let mock = MockKms {
response_name: "projects/p/.../cryptoKeyVersions/1".to_owned(),
ciphertext: vec![0xC0, 0xC1, 0xC2, 0xD0, 0xD1, 0xD2, 0xD3],
tag_length: 4,
iv: vec![0x9A, 0x9B],
};
let kv = key_version(mock, "A256GCM", None);
let out = kv.encrypt(b"plaintext", b"aad").await.unwrap();
assert_eq!(out.nonce, vec![0x9A, 0x9B]);
assert_eq!(out.ciphertext, vec![0xC0, 0xC1, 0xC2]);
assert_eq!(out.tag, vec![0xD0, 0xD1, 0xD2, 0xD3]);
}
#[tokio::test]
async fn encrypt_rejects_mismatched_key_name() {
let mock = MockKms {
response_name: "projects/p/.../cryptoKeyVersions/2".to_owned(), ciphertext: vec![1, 2, 3, 4],
tag_length: 1,
iv: vec![0],
};
let kv = key_version(mock, "A256GCM", None);
let err = kv.encrypt(b"plaintext", b"aad").await.err().unwrap();
assert_eq!(err.kind(), ErrorKind::Crypto);
}
#[tokio::test]
async fn decrypt_appends_tag_to_ciphertext_for_kms() {
let kv = key_version(MockKms::default(), "A256GCM", None);
let echoed = kv
.decrypt(None, b"nonce", &[0xAA, 0xBB], &[0xCC, 0xDD], b"aad")
.await
.unwrap();
assert_eq!(echoed, vec![0xAA, 0xBB, 0xCC, 0xDD]);
}
}