use std::borrow::Cow;
use std::pin::Pin;
use std::sync::Arc;
use bon::Builder;
use snafu::prelude::*;
use crate::crypto::signer::error::{MismatchedKeyMetadataSnafu, UnderlyingSnafu};
use crate::error::BoxedError;
use crate::jwk::PublicJwk;
use crate::platform::{MaybeSend, MaybeSendSync};
use crate::{Error, platform::MaybeSendFuture};
#[derive(Debug, Clone)]
pub struct BoxedJwsSigningKey {
inner: Arc<dyn DynJwsSigningKey>,
}
impl BoxedJwsSigningKey {
pub fn new<Sgn: JwsSigningKey + 'static>(signer: Sgn) -> Self {
Self {
inner: Arc::new(signer),
}
}
}
trait DynJwsSigningKey: std::fmt::Debug + MaybeSendSync {
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata>;
fn sign_unchecked<'a>(
&'a self,
input: &'a [u8],
) -> Pin<Box<dyn MaybeSendFuture<Output = Result<Vec<u8>, BoxedError>> + 'a>>;
}
impl<Sgn: JwsSigningKey> DynJwsSigningKey for Sgn {
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata> {
self.key_metadata()
}
fn sign_unchecked<'a>(
&'a self,
input: &'a [u8],
) -> Pin<Box<dyn MaybeSendFuture<Output = Result<Vec<u8>, BoxedError>> + 'a>> {
Box::pin(async {
self.sign_unchecked(input)
.await
.map_err(BoxedError::from_err)
})
}
}
impl JwsSigningKey for BoxedJwsSigningKey {
type Error = BoxedError;
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata> {
self.inner.key_metadata()
}
async fn sign_unchecked(&self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
self.inner.sign_unchecked(input).await
}
}
#[derive(Debug, Clone)]
pub struct BoxedAsymmetricJwsSigningKey {
inner: Arc<dyn DynAsymmetricJwsSigningKey>,
}
impl BoxedAsymmetricJwsSigningKey {
pub fn new<Sgn: JwsSigningKey + HasPublicKey + std::fmt::Debug + 'static>(signer: Sgn) -> Self {
Self {
inner: Arc::new(signer),
}
}
}
trait DynAsymmetricJwsSigningKey: std::fmt::Debug + MaybeSendSync {
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata>;
fn sign_unchecked<'a>(
&'a self,
input: &'a [u8],
) -> Pin<Box<dyn MaybeSendFuture<Output = Result<Vec<u8>, BoxedError>> + 'a>>;
fn public_key_jwk(&self) -> &PublicJwk;
}
impl<Sgn: std::fmt::Debug + JwsSigningKey + HasPublicKey> DynAsymmetricJwsSigningKey for Sgn {
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata> {
self.key_metadata()
}
fn sign_unchecked<'a>(
&'a self,
input: &'a [u8],
) -> Pin<Box<dyn MaybeSendFuture<Output = Result<Vec<u8>, BoxedError>> + 'a>> {
Box::pin(async {
self.sign_unchecked(input)
.await
.map_err(BoxedError::from_err)
})
}
fn public_key_jwk(&self) -> &PublicJwk {
self.public_key_jwk()
}
}
impl JwsSigningKey for BoxedAsymmetricJwsSigningKey {
type Error = BoxedError;
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata> {
self.inner.key_metadata()
}
async fn sign_unchecked(&self, input: &[u8]) -> Result<Vec<u8>, Self::Error> {
self.inner.sign_unchecked(input).await
}
}
impl HasPublicKey for BoxedAsymmetricJwsSigningKey {
fn public_key_jwk(&self) -> &PublicJwk {
self.inner.public_key_jwk()
}
}
#[derive(Debug, Clone, Builder, PartialEq)]
pub struct SigningKeyMetadata {
#[builder(into)]
pub jws_algorithm: String,
#[builder(into)]
pub key_id: Option<String>,
}
pub trait JwsSigningKey: std::fmt::Debug + Clone + MaybeSendSync {
type Error: Error + 'static;
fn key_metadata(&self) -> Cow<'_, SigningKeyMetadata>;
fn sign_unchecked(
&self,
input: &[u8],
) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + MaybeSend;
fn sign(
&self,
input: &[u8],
key_metadata: &SigningKeyMetadata,
) -> impl Future<Output = Result<Vec<u8>, super::JwsSignerError<Self::Error>>> + MaybeSend {
async move {
if &*self.key_metadata() == key_metadata {
self.sign_unchecked(input).await.context(UnderlyingSnafu)
} else {
MismatchedKeyMetadataSnafu.fail()
}
}
}
}
pub trait HasPublicKey: MaybeSendSync {
fn public_key_jwk(&self) -> &PublicJwk;
}
#[cfg(all(
test,
not(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none")))
))]
mod tests {
use std::{borrow::Cow, convert::Infallible};
use super::*;
use crate::crypto::signer::JwsSignerError;
#[derive(Debug, Clone)]
struct MockSigningKey {
key_metadata: SigningKeyMetadata,
}
impl MockSigningKey {
pub fn new() -> Self {
Self {
key_metadata: SigningKeyMetadata::builder().jws_algorithm("ALG").build(),
}
}
}
impl JwsSigningKey for MockSigningKey {
type Error = Infallible;
fn key_metadata(&self) -> std::borrow::Cow<'_, SigningKeyMetadata> {
Cow::Borrowed(&self.key_metadata)
}
async fn sign_unchecked(&self, _input: &[u8]) -> Result<Vec<u8>, Self::Error> {
Ok(vec![])
}
}
#[tokio::test]
async fn test_metadata_no_mismatch_succeeds() {
MockSigningKey::new()
.sign(
&[],
&SigningKeyMetadata {
jws_algorithm: "ALG".into(),
key_id: None,
},
)
.await
.expect("no mismatch");
}
#[tokio::test]
async fn test_metadata_different_alg_fails() {
let result = MockSigningKey::new()
.sign(
&[],
&SigningKeyMetadata::builder().jws_algorithm("ALG2").build(),
)
.await;
assert!(matches!(result, Err(JwsSignerError::MismatchedKeyMetadata)));
}
#[tokio::test]
async fn test_metadata_different_kid_fails() {
let result = MockSigningKey::new()
.sign(
&[],
&SigningKeyMetadata::builder()
.jws_algorithm("ALG")
.key_id("key-id")
.build(),
)
.await;
assert!(matches!(result, Err(JwsSignerError::MismatchedKeyMetadata)));
}
}