ant_quic/crypto/pqc/
ml_dsa_impl.rs1use crate::crypto::pqc::MlDsaOperations;
14use crate::crypto::pqc::types::*;
15
16#[cfg(feature = "aws-lc-rs")]
17use aws_lc_rs::{
18 encoding::AsDer,
19 signature::{KeyPair, UnparsedPublicKey},
20 unstable::signature::{
21 ML_DSA_65, ML_DSA_65_SIGNING, PqdsaKeyPair, PqdsaSigningAlgorithm,
22 PqdsaVerificationAlgorithm,
23 },
24};
25
26#[cfg(feature = "aws-lc-rs")]
27use std::collections::HashMap;
28#[cfg(feature = "aws-lc-rs")]
29use std::sync::{Arc, Mutex};
30
31#[cfg(feature = "aws-lc-rs")]
33struct CachedDsaKey {
34 key_pair: Arc<PqdsaKeyPair>,
35 public_key_der: Vec<u8>,
36}
37
38pub struct MlDsa65Impl {
40 #[cfg(feature = "aws-lc-rs")]
41 signing_alg: &'static PqdsaSigningAlgorithm,
42 #[cfg(feature = "aws-lc-rs")]
43 verification_alg: &'static PqdsaVerificationAlgorithm,
44 #[cfg(feature = "aws-lc-rs")]
47 key_cache: Arc<Mutex<HashMap<Vec<u8>, CachedDsaKey>>>,
48}
49
50impl MlDsa65Impl {
51 pub fn new() -> Self {
53 Self {
54 #[cfg(feature = "aws-lc-rs")]
55 signing_alg: &ML_DSA_65_SIGNING,
56 #[cfg(feature = "aws-lc-rs")]
57 verification_alg: &ML_DSA_65,
58 #[cfg(feature = "aws-lc-rs")]
59 key_cache: Arc::new(Mutex::new(HashMap::new())),
60 }
61 }
62
63 #[cfg(feature = "aws-lc-rs")]
65 pub fn clear_cache(&self) {
66 if let Ok(mut cache) = self.key_cache.lock() {
67 cache.clear();
68 }
69 }
70}
71
72impl Clone for MlDsa65Impl {
73 fn clone(&self) -> Self {
74 Self {
75 #[cfg(feature = "aws-lc-rs")]
76 signing_alg: self.signing_alg,
77 #[cfg(feature = "aws-lc-rs")]
78 verification_alg: self.verification_alg,
79 #[cfg(feature = "aws-lc-rs")]
80 key_cache: Arc::clone(&self.key_cache),
81 }
82 }
83}
84
85#[cfg(feature = "aws-lc-rs")]
86impl MlDsaOperations for MlDsa65Impl {
87 fn generate_keypair(&self) -> PqcResult<(MlDsaPublicKey, MlDsaSecretKey)> {
88 let key_pair = PqdsaKeyPair::generate(self.signing_alg)
90 .map_err(|e| PqcError::KeyGenerationFailed(e.to_string()))?;
91
92 let public_key_der = key_pair
94 .public_key()
95 .as_der()
96 .map_err(|e| PqcError::KeyGenerationFailed(e.to_string()))?;
97
98 let public_key_bytes = public_key_der.as_ref().to_vec();
99
100 {
102 let mut cache = self.key_cache.lock().map_err(|_| {
103 PqcError::KeyGenerationFailed("Failed to acquire key cache lock".to_string())
104 })?;
105
106 let key_id = if public_key_bytes.len() >= ML_DSA_65_PUBLIC_KEY_SIZE {
108 public_key_bytes[..ML_DSA_65_PUBLIC_KEY_SIZE].to_vec()
109 } else {
110 let mut padded = vec![0u8; ML_DSA_65_PUBLIC_KEY_SIZE];
111 padded[..public_key_bytes.len()].copy_from_slice(&public_key_bytes);
112 padded
113 };
114
115 cache.insert(
116 key_id,
117 CachedDsaKey {
118 key_pair: Arc::new(key_pair),
119 public_key_der: public_key_bytes.clone(),
120 },
121 );
122 }
123
124 let mut public_key = Box::new([0u8; ML_DSA_65_PUBLIC_KEY_SIZE]);
126 let mut secret_key = Box::new([0u8; ML_DSA_65_SECRET_KEY_SIZE]);
127
128 let pub_copy_len = public_key_bytes.len().min(ML_DSA_65_PUBLIC_KEY_SIZE);
130 public_key[..pub_copy_len].copy_from_slice(&public_key_bytes[..pub_copy_len]);
131
132 secret_key[..pub_copy_len].copy_from_slice(&public_key_bytes[..pub_copy_len]);
134
135 Ok((MlDsaPublicKey(public_key), MlDsaSecretKey(secret_key)))
136 }
137
138 fn sign(&self, secret_key: &MlDsaSecretKey, message: &[u8]) -> PqcResult<MlDsaSignature> {
139 let key_id = secret_key.0[..ML_DSA_65_PUBLIC_KEY_SIZE].to_vec();
141
142 let key_pair = {
144 let cache = self.key_cache.lock().map_err(|_| {
145 PqcError::SigningFailed("Failed to acquire key cache lock".to_string())
146 })?;
147
148 cache
149 .get(&key_id)
150 .map(|entry| entry.key_pair.clone())
151 .ok_or(PqcError::InvalidSecretKey)?
152 };
153
154 let mut signature_bytes = vec![0u8; self.signing_alg.signature_len()];
156 let sig_len = key_pair
157 .sign(message, &mut signature_bytes)
158 .map_err(|e| PqcError::SigningFailed(e.to_string()))?;
159
160 signature_bytes.truncate(sig_len);
161
162 if signature_bytes.len() > ML_DSA_65_SIGNATURE_SIZE {
164 return Err(PqcError::InvalidSignature);
165 }
166
167 let mut sig = Box::new([0u8; ML_DSA_65_SIGNATURE_SIZE]);
169 sig[..signature_bytes.len()].copy_from_slice(&signature_bytes);
170
171 Ok(MlDsaSignature(sig))
172 }
173
174 fn verify(
175 &self,
176 public_key: &MlDsaPublicKey,
177 message: &[u8],
178 signature: &MlDsaSignature,
179 ) -> PqcResult<bool> {
180 let key_id = public_key.0[..ML_DSA_65_PUBLIC_KEY_SIZE].to_vec();
182
183 let public_key_der = {
185 let cache = self.key_cache.lock().map_err(|_| {
186 PqcError::VerificationFailed("Failed to acquire key cache lock".to_string())
187 })?;
188
189 cache
190 .get(&key_id)
191 .map(|entry| entry.public_key_der.clone())
192 .ok_or(PqcError::VerificationFailed(
193 "Public key not found in cache".to_string(),
194 ))?
195 };
196
197 let unparsed_public_key = UnparsedPublicKey::new(self.verification_alg, &public_key_der);
199
200 let mut sig_len = ML_DSA_65_SIGNATURE_SIZE;
202 for i in (0..ML_DSA_65_SIGNATURE_SIZE).rev() {
203 if signature.0[i] != 0 {
204 sig_len = i + 1;
205 break;
206 }
207 }
208
209 if sig_len == 0 {
210 return Ok(false);
211 }
212
213 match unparsed_public_key.verify(message, &signature.0[..sig_len]) {
215 Ok(()) => Ok(true),
216 Err(_) => Ok(false),
217 }
218 }
219}
220
221#[cfg(not(feature = "aws-lc-rs"))]
223impl MlDsaOperations for MlDsa65Impl {
224 fn generate_keypair(&self) -> PqcResult<(MlDsaPublicKey, MlDsaSecretKey)> {
225 use rand::RngCore;
228 let mut rng = rand::thread_rng();
229
230 let mut pub_key = Box::new([0u8; ML_DSA_65_PUBLIC_KEY_SIZE]);
231 let mut sec_key = Box::new([0u8; ML_DSA_65_SECRET_KEY_SIZE]);
232
233 rng.fill_bytes(&mut pub_key[..]);
234 rng.fill_bytes(&mut sec_key[..]);
235
236 Ok((MlDsaPublicKey(pub_key), MlDsaSecretKey(sec_key)))
237 }
238
239 fn sign(&self, _secret_key: &MlDsaSecretKey, _message: &[u8]) -> PqcResult<MlDsaSignature> {
240 Err(PqcError::FeatureNotAvailable)
242 }
243
244 fn verify(
245 &self,
246 _public_key: &MlDsaPublicKey,
247 _message: &[u8],
248 _signature: &MlDsaSignature,
249 ) -> PqcResult<bool> {
250 Err(PqcError::FeatureNotAvailable)
252 }
253}
254
255#[cfg(all(test, feature = "pqc"))]
256mod tests {
257 use super::*;
258
259 #[test]
260 #[cfg(feature = "aws-lc-rs")]
261 fn test_ml_dsa_65_key_generation() {
262 let ml_dsa = MlDsa65Impl::new();
263 let result = ml_dsa.generate_keypair();
264
265 assert!(result.is_ok());
266 let (pub_key, sec_key) = result.unwrap();
267
268 assert_eq!(pub_key.0.len(), ML_DSA_65_PUBLIC_KEY_SIZE);
269 assert_eq!(sec_key.0.len(), ML_DSA_65_SECRET_KEY_SIZE);
270 }
271
272 #[test]
273 #[cfg(feature = "aws-lc-rs")]
274 fn test_ml_dsa_65_sign_verify() {
275 let ml_dsa = MlDsa65Impl::new();
276
277 let (pub_key, sec_key) = ml_dsa.generate_keypair().unwrap();
279
280 let message = b"Test message for ML-DSA-65";
282 let signature = ml_dsa.sign(&sec_key, message).unwrap();
283
284 assert_eq!(signature.0.len(), ML_DSA_65_SIGNATURE_SIZE);
285
286 let valid = ml_dsa.verify(&pub_key, message, &signature).unwrap();
288 assert!(valid, "Signature should be valid");
289
290 let wrong_message = b"Different message";
292 let invalid = ml_dsa.verify(&pub_key, wrong_message, &signature).unwrap();
293 assert!(!invalid, "Signature should be invalid for wrong message");
294 }
295
296 #[test]
297 #[cfg(feature = "aws-lc-rs")]
298 fn test_ml_dsa_65_verify_with_different_key() {
299 let ml_dsa = MlDsa65Impl::new();
300
301 let (pub_key1, sec_key1) = ml_dsa.generate_keypair().unwrap();
303 let (pub_key2, _sec_key2) = ml_dsa.generate_keypair().unwrap();
304
305 let message = b"Test message";
307 let signature = ml_dsa.sign(&sec_key1, message).unwrap();
308
309 let valid = ml_dsa.verify(&pub_key1, message, &signature).unwrap();
311 assert!(valid);
312
313 let invalid = ml_dsa.verify(&pub_key2, message, &signature).unwrap();
315 assert!(!invalid);
316 }
317
318 #[test]
319 #[cfg(not(feature = "aws-lc-rs"))]
320 fn test_ml_dsa_without_feature() {
321 let ml_dsa = MlDsa65Impl::new();
322
323 let keypair_result = ml_dsa.generate_keypair();
325 assert!(keypair_result.is_ok());
326
327 let (pub_key, sec_key) = keypair_result.unwrap();
328
329 let message = b"test";
331 let sign_result = ml_dsa.sign(&sec_key, message);
332 assert!(sign_result.is_err());
333 assert!(matches!(sign_result, Err(PqcError::FeatureNotAvailable)));
334
335 let sig = MlDsaSignature(Box::new([0u8; ML_DSA_65_SIGNATURE_SIZE]));
337 let verify_result = ml_dsa.verify(&pub_key, message, &sig);
338 assert!(verify_result.is_err());
339 assert!(matches!(verify_result, Err(PqcError::FeatureNotAvailable)));
340 }
341}