use std::borrow::Cow;
use bon::bon;
use google_cloud_kms_v1::{
client::KeyManagementService, model::crypto_key_version::CryptoKeyVersionAlgorithm,
};
use huskarl_core::crypto::signer::{
AsymmetricJwsSigner, AsymmetricJwsSignerSelector, JwsSigner, JwsSignerSelector,
};
use huskarl_core::jwk::{self, PublicJwk};
use p256::ecdsa::signature;
use p256::elliptic_curve::pkcs8::DecodePublicKey as _;
use p256::elliptic_curve::sec1::ToSec1Point as _;
use snafu::prelude::*;
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum SetupError {
GetCryptoKey {
source: google_cloud_kms_v1::Error,
},
ListCryptoKeyVersions {
source: google_cloud_kms_v1::Error,
},
NoEnabledCryptoKeyVersions,
UnsupportedAlgorithm {
algorithm: CryptoKeyVersionAlgorithm,
},
InvalidKeyVersionName,
GetPublicKey {
source: google_cloud_kms_v1::Error,
},
#[snafu(display("Failed to parse public key PEM into JWK"))]
PublicKeyParse,
}
#[derive(Debug, Snafu)]
#[non_exhaustive]
pub enum SigningError {
AsymmetricSign {
source: google_cloud_kms_v1::Error,
},
SignatureConversion {
source: signature::Error,
},
MismatchedAlgorithmInfo,
}
impl huskarl_core::Error for SigningError {
fn is_retryable(&self) -> bool {
match self {
SigningError::AsymmetricSign { source } => source.is_timeout() || source.is_exhausted(),
SigningError::SignatureConversion { .. } | SigningError::MismatchedAlgorithmInfo => {
false
}
}
}
}
#[derive(Debug, Clone)]
pub struct AsymmetricJwsKey {
kms_client: KeyManagementService,
resource_name: String,
jws_algorithm: String,
key_id: Option<String>,
}
#[bon]
impl AsymmetricJwsKey {
async fn resolve_resource_name(
key_name: &str,
key_version: Option<String>,
kms_client: &KeyManagementService,
) -> Result<String, SetupError> {
if let Some(supplied_version) = key_version {
Ok(supplied_version)
} else {
Ok(kms_client
.list_crypto_key_versions()
.set_parent(key_name)
.set_page_size(1)
.set_filter("state=ENABLED")
.set_order_by("name desc")
.send()
.await
.context(ListCryptoKeyVersionsSnafu)?
.crypto_key_versions
.into_iter()
.next()
.ok_or(NoEnabledCryptoKeyVersionsSnafu.build())?
.name
.rsplit('/')
.next()
.ok_or(InvalidKeyVersionNameSnafu.build())?
.to_string())
}
}
#[builder(finish_fn = build)]
#[allow(clippy::type_complexity)]
pub async fn builder(
#[builder(into)]
key_name: String,
#[builder(into)]
key_version: Option<String>,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + 'static| Box::new(f))]
with_kid_from_key_version: Option<Box<dyn FnOnce(&str) -> String>>,
) -> Result<Self, SetupError> {
let resolved_key_version =
Self::resolve_resource_name(&key_name, key_version, &kms_client).await?;
let resolved_key_version_name =
format!("{key_name}/cryptoKeyVersions/{resolved_key_version}");
let key_id = with_kid_from_key_version.map(|f| f(&resolved_key_version));
let jws_algorithm =
get_jws_algorithm_for_resource(&kms_client, &resolved_key_version_name).await?;
Ok(Self {
kms_client,
resource_name: resolved_key_version_name,
jws_algorithm,
key_id,
})
}
}
impl JwsSignerSelector for AsymmetricJwsKey {
type Signer = Self;
fn select_signer(&self) -> Self::Signer {
self.clone()
}
}
impl JwsSigner for AsymmetricJwsKey {
type Error = SigningError;
fn jws_algorithm(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.jws_algorithm)
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.key_id.as_deref().map(Cow::Borrowed)
}
async fn sign(&self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
let response = self
.kms_client
.asymmetric_sign()
.set_name(&self.resource_name)
.set_data(input.to_vec())
.send()
.await
.context(AsymmetricSignSnafu)?;
ensure!(
response.name == self.resource_name,
MismatchedAlgorithmInfoSnafu
);
let signature = response.signature.to_vec();
match self.jws_algorithm.as_str() {
"ES256" => convert_ecdsa_der_to_fixed(&signature, EcDsaVariant::P256)
.context(SignatureConversionSnafu),
"ES384" => convert_ecdsa_der_to_fixed(&signature, EcDsaVariant::P384)
.context(SignatureConversionSnafu),
_ => Ok(signature),
}
}
}
#[derive(Debug, Clone)]
pub struct AsymmetricJwsKeyPair {
inner: AsymmetricJwsKey,
public_key_jwk: PublicJwk,
thumbprint: String,
}
#[bon]
impl AsymmetricJwsKeyPair {
#[builder(finish_fn = build)]
#[allow(clippy::type_complexity)]
pub async fn builder(
#[builder(into)]
key_name: String,
#[builder(into)]
key_version: Option<String>,
kms_client: KeyManagementService,
#[builder(with = |f: impl Fn(&str) -> String + 'static| Box::new(f))]
with_kid_from_key_version: Option<Box<dyn FnOnce(&str) -> String>>,
) -> Result<Self, SetupError> {
let resolved_key_version =
AsymmetricJwsKey::resolve_resource_name(&key_name, key_version, &kms_client).await?;
let resolved_key_version_name =
format!("{key_name}/cryptoKeyVersions/{resolved_key_version}");
let key_id = with_kid_from_key_version.map(|f| f(&resolved_key_version));
let public_key_response = kms_client
.get_public_key()
.set_name(&resolved_key_version_name)
.send()
.await
.context(GetPublicKeySnafu)?;
let jws_algorithm =
get_jws_algorithm(&public_key_response.algorithm).with_context(|| {
UnsupportedAlgorithmSnafu {
algorithm: public_key_response.algorithm,
}
})?;
let public_key_jwk =
parse_ec_public_key_pem(&public_key_response.pem, jws_algorithm, key_id.as_deref())
.context(PublicKeyParseSnafu)?;
let thumbprint = public_key_jwk.thumbprint().context(PublicKeyParseSnafu)?;
Ok(Self {
inner: AsymmetricJwsKey {
kms_client,
resource_name: resolved_key_version_name,
jws_algorithm: jws_algorithm.to_string(),
key_id,
},
public_key_jwk,
thumbprint,
})
}
}
impl AsymmetricJwsSignerSelector for AsymmetricJwsKeyPair {
type AsymmetricSigner = Self;
fn select_asymmetric_signer(&self) -> Self::AsymmetricSigner {
self.clone()
}
fn select_asymmetric_signer_by_thumbprint(
&self,
thumbprint: &str,
) -> Option<Self::AsymmetricSigner> {
if self.thumbprint == thumbprint {
Some(self.clone())
} else {
None
}
}
}
impl AsymmetricJwsSigner for AsymmetricJwsKeyPair {
fn public_key_jwk(&self) -> Cow<'_, PublicJwk> {
Cow::Borrowed(&self.public_key_jwk)
}
}
impl JwsSignerSelector for AsymmetricJwsKeyPair {
type Signer = Self;
fn select_signer(&self) -> Self::Signer {
self.clone()
}
}
impl JwsSigner for AsymmetricJwsKeyPair {
type Error = SigningError;
fn jws_algorithm(&self) -> Cow<'_, str> {
self.inner.jws_algorithm()
}
fn key_id(&self) -> Option<Cow<'_, str>> {
self.inner.key_id()
}
async fn sign(&self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
self.inner.sign(input).await
}
}
async fn get_jws_algorithm_for_resource(
kms_client: &KeyManagementService,
resource_name: &str,
) -> Result<String, SetupError> {
let key_version = kms_client
.get_crypto_key_version()
.set_name(resource_name)
.send()
.await
.context(GetCryptoKeySnafu)?;
get_jws_algorithm(&key_version.algorithm)
.map(String::from)
.with_context(|| UnsupportedAlgorithmSnafu {
algorithm: key_version.algorithm,
})
}
fn get_jws_algorithm(algorithm: &CryptoKeyVersionAlgorithm) -> Option<&'static str> {
use CryptoKeyVersionAlgorithm::{
EcSignEd25519, EcSignP256Sha256, EcSignP384Sha384, RsaSignPkcs12048Sha256,
RsaSignPkcs13072Sha256, RsaSignPkcs14096Sha256, RsaSignPkcs14096Sha512,
RsaSignPss2048Sha256, RsaSignPss3072Sha256, RsaSignPss4096Sha256, RsaSignPss4096Sha512,
};
match algorithm {
RsaSignPss2048Sha256 | RsaSignPss3072Sha256 | RsaSignPss4096Sha256 => Some("PS256"),
RsaSignPss4096Sha512 => Some("PS512"),
RsaSignPkcs12048Sha256 | RsaSignPkcs13072Sha256 | RsaSignPkcs14096Sha256 => Some("RS256"),
RsaSignPkcs14096Sha512 => Some("RS512"),
EcSignP256Sha256 => Some("ES256"),
EcSignP384Sha384 => Some("ES384"),
EcSignEd25519 => Some("Ed25519"),
_ => None,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum EcDsaVariant {
P256,
P384,
}
fn parse_ec_public_key_pem(pem: &str, jws_algorithm: &str, kid: Option<&str>) -> Option<PublicJwk> {
match jws_algorithm {
"ES256" => {
let pk = p256::PublicKey::from_public_key_pem(pem).ok()?;
let point = pk.to_sec1_point(false);
Some(
PublicJwk::builder()
.algorithm("ES256")
.maybe_kid(kid)
.key_use(jwk::KeyUse::Sign)
.key(
jwk::EcPublicKey::builder()
.crv("P-256")
.x(point.x()?.to_vec())
.y(point.y()?.to_vec()),
)
.build(),
)
}
"ES384" => {
let pk = p384::PublicKey::from_public_key_pem(pem).ok()?;
let point = pk.to_sec1_point(false);
Some(
PublicJwk::builder()
.algorithm("ES384")
.maybe_kid(kid)
.key_use(jwk::KeyUse::Sign)
.key(
jwk::EcPublicKey::builder()
.crv("P-384")
.x(point.x()?.to_vec())
.y(point.y()?.to_vec()),
)
.build(),
)
}
_ => None,
}
}
fn convert_ecdsa_der_to_fixed(
der_sig: &[u8],
variant: EcDsaVariant,
) -> Result<Vec<u8>, signature::Error> {
match variant {
EcDsaVariant::P256 => {
let sig = p256::ecdsa::Signature::from_der(der_sig)?;
Ok(sig.to_bytes().to_vec())
}
EcDsaVariant::P384 => {
let sig = p384::ecdsa::Signature::from_der(der_sig)?;
Ok(sig.to_bytes().to_vec())
}
}
}