#[cfg(feature = "ml-dsa")]
use std::sync::OnceLock;
use affinidi_secrets_resolver::secrets::KeyType;
use async_trait::async_trait;
use crate::DataIntegrityError;
use crate::crypto_suites::CryptoSuite;
use crate::signer::Signer;
pub struct CachingSigner<S: Signer> {
inner: S,
#[cfg(feature = "ml-dsa")]
cached_expanded: OnceLock<affinidi_crypto::ml_dsa::MlDsaExpandedKey>,
#[cfg(not(feature = "ml-dsa"))]
_phantom: std::marker::PhantomData<()>,
}
impl<S: Signer> CachingSigner<S> {
#[must_use]
pub fn new(inner: S) -> Self {
Self {
inner,
#[cfg(feature = "ml-dsa")]
cached_expanded: OnceLock::new(),
#[cfg(not(feature = "ml-dsa"))]
_phantom: std::marker::PhantomData,
}
}
pub fn into_inner(self) -> S {
self.inner
}
}
#[async_trait]
impl<S> Signer for CachingSigner<S>
where
S: Signer + GetPrivateBytes,
{
fn key_type(&self) -> KeyType {
self.inner.key_type()
}
fn verification_method(&self) -> &str {
self.inner.verification_method()
}
fn cryptosuite(&self) -> CryptoSuite {
self.inner.cryptosuite()
}
async fn sign(&self, data: &[u8]) -> Result<Vec<u8>, DataIntegrityError> {
#[cfg(feature = "ml-dsa")]
{
match self.inner.key_type() {
KeyType::MlDsa44 | KeyType::MlDsa65 | KeyType::MlDsa87 => {
let expanded = self.cached_expanded.get_or_init(|| {
affinidi_crypto::ml_dsa::MlDsaExpandedKey::from_seed(
self.inner.key_type(),
self.inner.private_bytes(),
)
.expect("ML-DSA seed is always 32 bytes for ML-DSA key types")
});
return Ok(expanded.sign(data));
}
_ => {}
}
}
self.inner.sign(data).await
}
}
pub trait GetPrivateBytes {
fn private_bytes(&self) -> &[u8];
}
impl GetPrivateBytes for affinidi_secrets_resolver::secrets::Secret {
fn private_bytes(&self) -> &[u8] {
self.get_private_bytes()
}
}
#[cfg(test)]
#[cfg(feature = "ml-dsa")]
mod tests {
use super::*;
use crate::{DataIntegrityProof, SignOptions, VerifyOptions};
use affinidi_secrets_resolver::secrets::Secret;
use serde_json::json;
#[tokio::test]
async fn caching_signer_produces_same_signature_as_plain() {
let secret = Secret::generate_ml_dsa_44(None, Some(&[7u8; 32]));
let caching = CachingSigner::new(secret.clone());
let doc = json!({ "cache": "test" });
let created = chrono::Utc::now();
let opts = || SignOptions::new().with_created(created);
let plain_proof = DataIntegrityProof::sign(&doc, &secret, opts())
.await
.unwrap();
let cached_proof = DataIntegrityProof::sign(&doc, &caching, opts())
.await
.unwrap();
assert_eq!(plain_proof.proof_value, cached_proof.proof_value);
plain_proof
.verify_with_public_key(&doc, secret.get_public_bytes(), VerifyOptions::new())
.unwrap();
cached_proof
.verify_with_public_key(&doc, secret.get_public_bytes(), VerifyOptions::new())
.unwrap();
}
#[tokio::test]
async fn caching_signer_second_call_is_faster() {
let secret = Secret::generate_ml_dsa_44(None, Some(&[3u8; 32]));
let caching = CachingSigner::new(secret.clone());
let _ = caching.sign(b"warm").await.unwrap();
let n = 50;
let t0 = std::time::Instant::now();
for _ in 0..n {
let _ = secret.sign(b"x").await.unwrap();
}
let without_cache = t0.elapsed();
let t0 = std::time::Instant::now();
for _ in 0..n {
let _ = caching.sign(b"x").await.unwrap();
}
let with_cache = t0.elapsed();
assert!(
with_cache < without_cache.saturating_sub(without_cache / 7),
"caching did not speed up: without={without_cache:?}, with={with_cache:?}"
);
}
}