use crate::error::{MrvbError, MrvbResult};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use pqcrypto_mldsa::mldsa65;
use pqcrypto_traits::sign::{PublicKey as PqPublicKey, SecretKey as PqSecretKey, SignedMessage};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dilithium3PublicKey {
bytes: Vec<u8>,
}
impl Dilithium3PublicKey {
pub fn from_bytes(bytes: &[u8]) -> MrvbResult<Self> {
if bytes.len() != mldsa65::public_key_bytes() {
return Err(MrvbError::InvalidKeyLength {
expected: mldsa65::public_key_bytes(),
actual: bytes.len(),
});
}
Ok(Self {
bytes: bytes.to_vec(),
})
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn to_base64(&self) -> String {
BASE64.encode(&self.bytes)
}
pub fn from_base64(s: &str) -> MrvbResult<Self> {
let bytes = BASE64
.decode(s)
.map_err(|e| MrvbError::EncodingError(e.to_string()))?;
Self::from_bytes(&bytes)
}
}
#[derive(Clone)]
pub struct Dilithium3SecretKey {
bytes: Vec<u8>,
}
impl Dilithium3SecretKey {
pub fn from_bytes(bytes: &[u8]) -> MrvbResult<Self> {
if bytes.len() != mldsa65::secret_key_bytes() {
return Err(MrvbError::InvalidKeyLength {
expected: mldsa65::secret_key_bytes(),
actual: bytes.len(),
});
}
Ok(Self {
bytes: bytes.to_vec(),
})
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn to_base64(&self) -> String {
BASE64.encode(&self.bytes)
}
pub fn from_base64(s: &str) -> MrvbResult<Self> {
let bytes = BASE64
.decode(s)
.map_err(|e| MrvbError::EncodingError(e.to_string()))?;
Self::from_bytes(&bytes)
}
}
impl std::fmt::Debug for Dilithium3SecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dilithium3SecretKey")
.field("len", &self.bytes.len())
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dilithium3Signature {
bytes: Vec<u8>,
}
impl Dilithium3Signature {
pub fn from_bytes(bytes: &[u8]) -> Self {
Self {
bytes: bytes.to_vec(),
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn to_base64(&self) -> String {
BASE64.encode(&self.bytes)
}
pub fn from_base64(s: &str) -> MrvbResult<Self> {
let bytes = BASE64
.decode(s)
.map_err(|e| MrvbError::EncodingError(e.to_string()))?;
Ok(Self { bytes })
}
}
#[derive(Debug, Clone)]
pub struct Dilithium3KeyPair {
pub public_key: Dilithium3PublicKey,
pub secret_key: Dilithium3SecretKey,
}
impl Dilithium3KeyPair {
pub fn generate() -> Self {
let (pk, sk) = mldsa65::keypair();
Self {
public_key: Dilithium3PublicKey {
bytes: pk.as_bytes().to_vec(),
},
secret_key: Dilithium3SecretKey {
bytes: sk.as_bytes().to_vec(),
},
}
}
}
pub struct Dilithium3Signer {
secret_key: Dilithium3SecretKey,
public_key: Dilithium3PublicKey,
}
impl Dilithium3Signer {
pub fn new(keypair: Dilithium3KeyPair) -> Self {
Self {
secret_key: keypair.secret_key,
public_key: keypair.public_key,
}
}
pub fn from_secret_key_bytes(secret_key_bytes: &[u8]) -> MrvbResult<Self> {
let secret_key = Dilithium3SecretKey::from_bytes(secret_key_bytes)?;
let _sk = mldsa65::SecretKey::from_bytes(secret_key_bytes)
.map_err(|e| MrvbError::CryptoError(format!("Invalid secret key: {:?}", e)))?;
let (pk, _) = mldsa65::keypair();
Ok(Self {
secret_key,
public_key: Dilithium3PublicKey {
bytes: pk.as_bytes().to_vec(),
},
})
}
pub fn public_key(&self) -> &Dilithium3PublicKey {
&self.public_key
}
pub fn sign(&self, message: &[u8]) -> MrvbResult<Dilithium3Signature> {
let sk = mldsa65::SecretKey::from_bytes(&self.secret_key.bytes)
.map_err(|e| MrvbError::CryptoError(format!("Invalid secret key: {:?}", e)))?;
let signed_msg = mldsa65::sign(message, &sk);
let sig_bytes = signed_msg.as_bytes();
Ok(Dilithium3Signature::from_bytes(sig_bytes))
}
pub fn sign_base64(&self, message: &[u8]) -> MrvbResult<String> {
let sig = self.sign(message)?;
Ok(sig.to_base64())
}
}
impl std::fmt::Debug for Dilithium3Signer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dilithium3Signer")
.field("public_key", &self.public_key)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct Dilithium3Verifier {
public_key: Dilithium3PublicKey,
}
impl Dilithium3Verifier {
pub fn new(public_key: Dilithium3PublicKey) -> Self {
Self { public_key }
}
pub fn from_public_key_bytes(public_key_bytes: &[u8]) -> MrvbResult<Self> {
let public_key = Dilithium3PublicKey::from_bytes(public_key_bytes)?;
Ok(Self { public_key })
}
pub fn public_key(&self) -> &Dilithium3PublicKey {
&self.public_key
}
pub fn verify(&self, message: &[u8], signature: &Dilithium3Signature) -> MrvbResult<bool> {
let pk = mldsa65::PublicKey::from_bytes(&self.public_key.bytes)
.map_err(|e| MrvbError::CryptoError(format!("Invalid public key: {:?}", e)))?;
let signed_msg = mldsa65::SignedMessage::from_bytes(&signature.bytes)
.map_err(|e| MrvbError::CryptoError(format!("Invalid signature format: {:?}", e)))?;
match mldsa65::open(&signed_msg, &pk) {
Ok(opened_msg) => {
Ok(opened_msg == message)
}
Err(_) => Ok(false),
}
}
pub fn verify_base64(&self, message: &[u8], signature_b64: &str) -> MrvbResult<bool> {
let signature = Dilithium3Signature::from_base64(signature_b64)?;
self.verify(message, &signature)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_generation() {
let keypair = Dilithium3KeyPair::generate();
assert_eq!(
keypair.public_key.as_bytes().len(),
mldsa65::public_key_bytes()
);
assert_eq!(
keypair.secret_key.as_bytes().len(),
mldsa65::secret_key_bytes()
);
}
#[test]
fn test_sign_and_verify() {
let keypair = Dilithium3KeyPair::generate();
let signer = Dilithium3Signer::new(keypair.clone());
let verifier = Dilithium3Verifier::new(keypair.public_key);
let message = b"Hello, quantum-resistant world!";
let signature = signer.sign(message).unwrap();
assert!(verifier.verify(message, &signature).unwrap());
let wrong_message = b"Wrong message";
assert!(!verifier.verify(wrong_message, &signature).unwrap());
}
#[test]
fn test_base64_encoding() {
let keypair = Dilithium3KeyPair::generate();
let signer = Dilithium3Signer::new(keypair.clone());
let message = b"Test message";
let sig_b64 = signer.sign_base64(message).unwrap();
let verifier = Dilithium3Verifier::new(keypair.public_key);
assert!(verifier.verify_base64(message, &sig_b64).unwrap());
}
#[test]
fn test_key_serialization() {
let keypair = Dilithium3KeyPair::generate();
let pk_b64 = keypair.public_key.to_base64();
let pk_restored = Dilithium3PublicKey::from_base64(&pk_b64).unwrap();
assert_eq!(keypair.public_key.as_bytes(), pk_restored.as_bytes());
let sk_b64 = keypair.secret_key.to_base64();
let sk_restored = Dilithium3SecretKey::from_base64(&sk_b64).unwrap();
assert_eq!(keypair.secret_key.as_bytes(), sk_restored.as_bytes());
}
#[test]
fn test_invalid_key_length() {
let invalid_bytes = vec![0u8; 10];
let result = Dilithium3PublicKey::from_bytes(&invalid_bytes);
assert!(result.is_err());
match result {
Err(MrvbError::InvalidKeyLength { expected, actual }) => {
assert_eq!(expected, mldsa65::public_key_bytes());
assert_eq!(actual, 10);
}
_ => panic!("Expected InvalidKeyLength error"),
}
}
}