ant_quic/crypto/pqc/
ml_dsa_impl.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! ML-DSA-65 implementation using aws-lc-rs
9//!
10//! This module provides the implementation of Module Lattice-based Digital Signature
11//! Algorithm (ML-DSA) as specified in FIPS 204, using aws-lc-rs.
12
13use crate::crypto::pqc::MlDsaOperations;
14use crate::crypto::pqc::types::*;
15
16#[cfg(feature = "aws-lc-rs")]
17use aws_lc_rs::{
18    encoding::AsDer,
19    signature::{KeyPair, UnparsedPublicKey},
20    unstable::signature::{
21        ML_DSA_65, ML_DSA_65_SIGNING, PqdsaKeyPair, PqdsaSigningAlgorithm,
22        PqdsaVerificationAlgorithm,
23    },
24};
25
26#[cfg(feature = "aws-lc-rs")]
27use std::collections::HashMap;
28#[cfg(feature = "aws-lc-rs")]
29use std::sync::{Arc, Mutex};
30
31/// Cached key entry for ML-DSA
32#[cfg(feature = "aws-lc-rs")]
33struct CachedDsaKey {
34    key_pair: Arc<PqdsaKeyPair>,
35    public_key_der: Vec<u8>,
36}
37
38/// ML-DSA-65 implementation using aws-lc-rs
39pub struct MlDsa65Impl {
40    #[cfg(feature = "aws-lc-rs")]
41    signing_alg: &'static PqdsaSigningAlgorithm,
42    #[cfg(feature = "aws-lc-rs")]
43    verification_alg: &'static PqdsaVerificationAlgorithm,
44    /// Key cache - maps public key bytes to full key pair
45    /// This is needed because aws-lc-rs doesn't expose private key serialization
46    #[cfg(feature = "aws-lc-rs")]
47    key_cache: Arc<Mutex<HashMap<Vec<u8>, CachedDsaKey>>>,
48}
49
50impl MlDsa65Impl {
51    /// Create a new ML-DSA-65 implementation
52    pub fn new() -> Self {
53        Self {
54            #[cfg(feature = "aws-lc-rs")]
55            signing_alg: &ML_DSA_65_SIGNING,
56            #[cfg(feature = "aws-lc-rs")]
57            verification_alg: &ML_DSA_65,
58            #[cfg(feature = "aws-lc-rs")]
59            key_cache: Arc::new(Mutex::new(HashMap::new())),
60        }
61    }
62
63    /// Clear the key cache
64    #[cfg(feature = "aws-lc-rs")]
65    pub fn clear_cache(&self) {
66        if let Ok(mut cache) = self.key_cache.lock() {
67            cache.clear();
68        }
69    }
70}
71
72impl Clone for MlDsa65Impl {
73    fn clone(&self) -> Self {
74        Self {
75            #[cfg(feature = "aws-lc-rs")]
76            signing_alg: self.signing_alg,
77            #[cfg(feature = "aws-lc-rs")]
78            verification_alg: self.verification_alg,
79            #[cfg(feature = "aws-lc-rs")]
80            key_cache: Arc::clone(&self.key_cache),
81        }
82    }
83}
84
85#[cfg(feature = "aws-lc-rs")]
86impl MlDsaOperations for MlDsa65Impl {
87    fn generate_keypair(&self) -> PqcResult<(MlDsaPublicKey, MlDsaSecretKey)> {
88        // Generate a new key pair
89        let key_pair = PqdsaKeyPair::generate(self.signing_alg)
90            .map_err(|e| PqcError::KeyGenerationFailed(e.to_string()))?;
91
92        // Extract public key bytes
93        let public_key_der = key_pair
94            .public_key()
95            .as_der()
96            .map_err(|e| PqcError::KeyGenerationFailed(e.to_string()))?;
97
98        let public_key_bytes = public_key_der.as_ref().to_vec();
99
100        // Store the key pair in cache
101        {
102            let mut cache = self.key_cache.lock().map_err(|_| {
103                PqcError::KeyGenerationFailed("Failed to acquire key cache lock".to_string())
104            })?;
105
106            // Use first ML_DSA_65_PUBLIC_KEY_SIZE bytes as the key for consistency
107            let key_id = if public_key_bytes.len() >= ML_DSA_65_PUBLIC_KEY_SIZE {
108                public_key_bytes[..ML_DSA_65_PUBLIC_KEY_SIZE].to_vec()
109            } else {
110                let mut padded = vec![0u8; ML_DSA_65_PUBLIC_KEY_SIZE];
111                padded[..public_key_bytes.len()].copy_from_slice(&public_key_bytes);
112                padded
113            };
114
115            cache.insert(
116                key_id,
117                CachedDsaKey {
118                    key_pair: Arc::new(key_pair),
119                    public_key_der: public_key_bytes.clone(),
120                },
121            );
122        }
123
124        // Create our key types
125        let mut public_key = Box::new([0u8; ML_DSA_65_PUBLIC_KEY_SIZE]);
126        let mut secret_key = Box::new([0u8; ML_DSA_65_SECRET_KEY_SIZE]);
127
128        // For public key, we'll store the actual DER-encoded public key
129        let pub_copy_len = public_key_bytes.len().min(ML_DSA_65_PUBLIC_KEY_SIZE);
130        public_key[..pub_copy_len].copy_from_slice(&public_key_bytes[..pub_copy_len]);
131
132        // For secret key, we store the public key as an identifier (same as ML-KEM)
133        secret_key[..pub_copy_len].copy_from_slice(&public_key_bytes[..pub_copy_len]);
134
135        Ok((MlDsaPublicKey(public_key), MlDsaSecretKey(secret_key)))
136    }
137
138    fn sign(&self, secret_key: &MlDsaSecretKey, message: &[u8]) -> PqcResult<MlDsaSignature> {
139        // Extract the public key identifier from the secret key
140        let key_id = secret_key.0[..ML_DSA_65_PUBLIC_KEY_SIZE].to_vec();
141
142        // Retrieve the key pair from cache
143        let key_pair = {
144            let cache = self.key_cache.lock().map_err(|_| {
145                PqcError::SigningFailed("Failed to acquire key cache lock".to_string())
146            })?;
147
148            cache
149                .get(&key_id)
150                .map(|entry| entry.key_pair.clone())
151                .ok_or(PqcError::InvalidSecretKey)?
152        };
153
154        // Sign the message
155        let mut signature_bytes = vec![0u8; self.signing_alg.signature_len()];
156        let sig_len = key_pair
157            .sign(message, &mut signature_bytes)
158            .map_err(|e| PqcError::SigningFailed(e.to_string()))?;
159
160        signature_bytes.truncate(sig_len);
161
162        // Ensure correct size
163        if signature_bytes.len() > ML_DSA_65_SIGNATURE_SIZE {
164            return Err(PqcError::InvalidSignature);
165        }
166
167        // Create our signature type
168        let mut sig = Box::new([0u8; ML_DSA_65_SIGNATURE_SIZE]);
169        sig[..signature_bytes.len()].copy_from_slice(&signature_bytes);
170
171        Ok(MlDsaSignature(sig))
172    }
173
174    fn verify(
175        &self,
176        public_key: &MlDsaPublicKey,
177        message: &[u8],
178        signature: &MlDsaSignature,
179    ) -> PqcResult<bool> {
180        // The public key identifier is stored in the public key
181        let key_id = public_key.0[..ML_DSA_65_PUBLIC_KEY_SIZE].to_vec();
182
183        // Find the cached entry to get the original DER-encoded public key
184        let public_key_der = {
185            let cache = self.key_cache.lock().map_err(|_| {
186                PqcError::VerificationFailed("Failed to acquire key cache lock".to_string())
187            })?;
188
189            cache
190                .get(&key_id)
191                .map(|entry| entry.public_key_der.clone())
192                .ok_or(PqcError::VerificationFailed(
193                    "Public key not found in cache".to_string(),
194                ))?
195        };
196
197        // Create unparsed public key for verification
198        let unparsed_public_key = UnparsedPublicKey::new(self.verification_alg, &public_key_der);
199
200        // Find the actual signature length (non-zero bytes from the end)
201        let mut sig_len = ML_DSA_65_SIGNATURE_SIZE;
202        for i in (0..ML_DSA_65_SIGNATURE_SIZE).rev() {
203            if signature.0[i] != 0 {
204                sig_len = i + 1;
205                break;
206            }
207        }
208
209        if sig_len == 0 {
210            return Ok(false);
211        }
212
213        // Verify the signature
214        match unparsed_public_key.verify(message, &signature.0[..sig_len]) {
215            Ok(()) => Ok(true),
216            Err(_) => Ok(false),
217        }
218    }
219}
220
221// Fallback implementation when aws-lc-rs is not available
222#[cfg(not(feature = "aws-lc-rs"))]
223impl MlDsaOperations for MlDsa65Impl {
224    fn generate_keypair(&self) -> PqcResult<(MlDsaPublicKey, MlDsaSecretKey)> {
225        // Without aws-lc-rs, we can't provide real ML-DSA
226        // This is just a placeholder that generates random bytes
227        use rand::RngCore;
228        let mut rng = rand::thread_rng();
229
230        let mut pub_key = Box::new([0u8; ML_DSA_65_PUBLIC_KEY_SIZE]);
231        let mut sec_key = Box::new([0u8; ML_DSA_65_SECRET_KEY_SIZE]);
232
233        rng.fill_bytes(&mut pub_key[..]);
234        rng.fill_bytes(&mut sec_key[..]);
235
236        Ok((MlDsaPublicKey(pub_key), MlDsaSecretKey(sec_key)))
237    }
238
239    fn sign(&self, _secret_key: &MlDsaSecretKey, _message: &[u8]) -> PqcResult<MlDsaSignature> {
240        // Without aws-lc-rs, we can't provide real ML-DSA
241        Err(PqcError::FeatureNotAvailable)
242    }
243
244    fn verify(
245        &self,
246        _public_key: &MlDsaPublicKey,
247        _message: &[u8],
248        _signature: &MlDsaSignature,
249    ) -> PqcResult<bool> {
250        // Without aws-lc-rs, we can't provide real ML-DSA
251        Err(PqcError::FeatureNotAvailable)
252    }
253}
254
255#[cfg(all(test, feature = "pqc"))]
256mod tests {
257    use super::*;
258
259    #[test]
260    #[cfg(feature = "aws-lc-rs")]
261    fn test_ml_dsa_65_key_generation() {
262        let ml_dsa = MlDsa65Impl::new();
263        let result = ml_dsa.generate_keypair();
264
265        assert!(result.is_ok());
266        let (pub_key, sec_key) = result.unwrap();
267
268        assert_eq!(pub_key.0.len(), ML_DSA_65_PUBLIC_KEY_SIZE);
269        assert_eq!(sec_key.0.len(), ML_DSA_65_SECRET_KEY_SIZE);
270    }
271
272    #[test]
273    #[cfg(feature = "aws-lc-rs")]
274    fn test_ml_dsa_65_sign_verify() {
275        let ml_dsa = MlDsa65Impl::new();
276
277        // Generate keypair
278        let (pub_key, sec_key) = ml_dsa.generate_keypair().unwrap();
279
280        // Sign message
281        let message = b"Test message for ML-DSA-65";
282        let signature = ml_dsa.sign(&sec_key, message).unwrap();
283
284        assert_eq!(signature.0.len(), ML_DSA_65_SIGNATURE_SIZE);
285
286        // Verify signature
287        let valid = ml_dsa.verify(&pub_key, message, &signature).unwrap();
288        assert!(valid, "Signature should be valid");
289
290        // Verify with wrong message
291        let wrong_message = b"Different message";
292        let invalid = ml_dsa.verify(&pub_key, wrong_message, &signature).unwrap();
293        assert!(!invalid, "Signature should be invalid for wrong message");
294    }
295
296    #[test]
297    #[cfg(feature = "aws-lc-rs")]
298    fn test_ml_dsa_65_verify_with_different_key() {
299        let ml_dsa = MlDsa65Impl::new();
300
301        // Generate two different keypairs
302        let (pub_key1, sec_key1) = ml_dsa.generate_keypair().unwrap();
303        let (pub_key2, _sec_key2) = ml_dsa.generate_keypair().unwrap();
304
305        // Sign with first key
306        let message = b"Test message";
307        let signature = ml_dsa.sign(&sec_key1, message).unwrap();
308
309        // Verify with correct key
310        let valid = ml_dsa.verify(&pub_key1, message, &signature).unwrap();
311        assert!(valid);
312
313        // Verify with wrong key should fail
314        let invalid = ml_dsa.verify(&pub_key2, message, &signature).unwrap();
315        assert!(!invalid);
316    }
317
318    #[test]
319    #[cfg(not(feature = "aws-lc-rs"))]
320    fn test_ml_dsa_without_feature() {
321        let ml_dsa = MlDsa65Impl::new();
322
323        // Key generation should work (returns random bytes)
324        let keypair_result = ml_dsa.generate_keypair();
325        assert!(keypair_result.is_ok());
326
327        let (pub_key, sec_key) = keypair_result.unwrap();
328
329        // Signing should fail without the feature
330        let message = b"test";
331        let sign_result = ml_dsa.sign(&sec_key, message);
332        assert!(sign_result.is_err());
333        assert!(matches!(sign_result, Err(PqcError::FeatureNotAvailable)));
334
335        // Verification should also fail
336        let sig = MlDsaSignature(Box::new([0u8; ML_DSA_65_SIGNATURE_SIZE]));
337        let verify_result = ml_dsa.verify(&pub_key, message, &sig);
338        assert!(verify_result.is_err());
339        assert!(matches!(verify_result, Err(PqcError::FeatureNotAvailable)));
340    }
341}