use std::sync::Arc;
use bon::bon;
use google_cloud_kms_v1::client::KeyManagementService;
use huskarl_core::jwk::PublicJwks;
use snafu::prelude::*;
use huskarl_core::jwk;
use super::signer::{
PublicKeyParseError, get_jwe_algorithm, get_jws_algorithm, parse_public_key_pem,
};
#[derive(Clone)]
#[allow(clippy::type_complexity)]
pub struct Jwks {
kms_client: KeyManagementService,
key_name: String,
with_kid_from_key_version: Option<Arc<dyn Fn(&str) -> String + Send + Sync>>,
max_versions: Option<usize>,
}
impl std::fmt::Debug for Jwks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jwks")
.field("key_name", &self.key_name)
.finish_non_exhaustive()
}
}
#[bon]
impl Jwks {
#[builder(finish_fn = build)]
#[allow(clippy::type_complexity)]
pub fn builder(
#[builder(into)]
key_name: String,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + Send + Sync + 'static| Arc::new(f))]
with_kid_from_key_version: Option<Arc<dyn Fn(&str) -> String + Send + Sync>>,
max_versions: Option<usize>,
) -> Self {
Self {
kms_client,
key_name,
with_kid_from_key_version,
max_versions,
}
}
pub async fn fetch(&self) -> Result<PublicJwks, JwksError> {
let versions = self.list_enabled_versions().await?;
ensure!(!versions.is_empty(), NoEnabledCryptoKeyVersionsSnafu);
let futures: Vec<_> = versions
.iter()
.filter_map(|version| {
let (algorithm, key_use) = if let Some(alg) = get_jws_algorithm(&version.algorithm)
{
(alg, jwk::KeyUse::Sign)
} else {
let alg = get_jwe_algorithm(&version.algorithm)?;
(alg, jwk::KeyUse::Encrypt)
};
let version_id =
super::super::version::version_id_from_resource_name(&version.name);
let kid = self
.with_kid_from_key_version
.as_ref()
.map(|f| f(version_id));
let name = &version.name;
let kms_client = &self.kms_client;
Some(async move {
let public_key_response = kms_client
.get_public_key()
.set_name(name)
.send()
.await
.context(GetPublicKeySnafu)?;
parse_public_key_pem(
&public_key_response.pem,
algorithm,
kid.as_deref(),
key_use,
)
.context(PublicKeyParseSnafu)
})
})
.collect();
let keys = futures_util::future::try_join_all(futures).await?;
Ok(PublicJwks::new(keys))
}
async fn list_enabled_versions(
&self,
) -> Result<Vec<google_cloud_kms_v1::model::CryptoKeyVersion>, JwksError> {
super::super::version::list_enabled_kms_versions(
&self.kms_client,
&self.key_name,
self.max_versions,
None,
)
.await
.context(ListCryptoKeyVersionsSnafu)
}
}
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum JwksError {
ListCryptoKeyVersions {
source: google_cloud_kms_v1::Error,
},
GetPublicKey {
source: google_cloud_kms_v1::Error,
},
PublicKeyParse {
source: PublicKeyParseError,
},
NoEnabledCryptoKeyVersions,
}
impl JwksError {
#[must_use]
pub fn is_retryable(&self) -> bool {
match self {
JwksError::ListCryptoKeyVersions { source } => {
source.is_timeout() || source.is_exhausted()
}
JwksError::GetPublicKey { source } => source.is_timeout() || source.is_exhausted(),
JwksError::PublicKeyParse { .. } | JwksError::NoEnabledCryptoKeyVersions => false,
}
}
}