1use aes_gcm::aead::{Aead, KeyInit, Payload};
2use aes_gcm::{Aes256Gcm, Key, Nonce};
3use ring::{
4 hkdf::{self as ring_hkdf, KeyType},
5 hmac,
6};
7
8pub use argon2::Params as Argon2Params;
9
10const AES_GCM_KEY_LEN: usize = 32;
11const AES_GCM_NONCE_LEN: usize = 12;
12const AES_GCM_TAG_LEN: usize = 16;
13
14struct HkdfOutputLen(usize);
15
16impl KeyType for HkdfOutputLen {
17 fn len(&self) -> usize {
18 self.0
19 }
20}
21
22fn validate_length(label: &str, actual: usize, expected: usize) -> Result<(), String> {
23 if actual == expected {
24 Ok(())
25 } else {
26 Err(format!(
27 "{} length invalid: expected {} bytes, got {}",
28 label, expected, actual
29 ))
30 }
31}
32
33pub fn aes_gcm_encrypt(
34 key: &[u8],
35 nonce: &[u8],
36 plaintext: &[u8],
37 aad: &[u8],
38) -> Result<(Vec<u8>, Vec<u8>), String> {
39 validate_length("AES-GCM key", key.len(), AES_GCM_KEY_LEN)?;
40 validate_length("AES-GCM nonce", nonce.len(), AES_GCM_NONCE_LEN)?;
41
42 let key = Key::<Aes256Gcm>::from_slice(key);
43 let cipher = Aes256Gcm::new(key);
44 let nonce = Nonce::from_slice(nonce);
45 let payload = Payload {
46 msg: plaintext,
47 aad,
48 };
49
50 let ciphertext_with_tag = cipher
52 .encrypt(nonce, payload)
53 .map_err(|e| format!("encryption failure: {}", e))?;
54
55 if ciphertext_with_tag.len() < AES_GCM_TAG_LEN {
56 return Err("encryption failure: ciphertext too short".to_string());
57 }
58
59 let split_idx = ciphertext_with_tag.len() - AES_GCM_TAG_LEN;
61 let (cipher, tag) = ciphertext_with_tag.split_at(split_idx);
62
63 Ok((cipher.to_vec(), tag.to_vec()))
64}
65
66pub fn aes_gcm_decrypt(
67 key: &[u8],
68 nonce: &[u8],
69 ciphertext: &[u8],
70 aad: &[u8],
71 tag: &[u8],
72) -> Result<Vec<u8>, String> {
73 validate_length("AES-GCM key", key.len(), AES_GCM_KEY_LEN)?;
74 validate_length("AES-GCM nonce", nonce.len(), AES_GCM_NONCE_LEN)?;
75 validate_length("AES-GCM tag", tag.len(), AES_GCM_TAG_LEN)?;
76
77 let key = Key::<Aes256Gcm>::from_slice(key);
78 let cipher = Aes256Gcm::new(key);
79 let nonce = Nonce::from_slice(nonce);
80
81 let mut payload_vec = Vec::with_capacity(ciphertext.len() + tag.len());
84 payload_vec.extend_from_slice(ciphertext);
85 payload_vec.extend_from_slice(tag);
86
87 let payload = Payload {
88 msg: &payload_vec,
89 aad,
90 };
91
92 cipher
93 .decrypt(nonce, payload)
94 .map_err(|e| format!("decryption failed: {}", e))
95}
96
97pub fn argon2id_hash(
98 password: &[u8],
99 salt: &[u8],
100 params: &Argon2Params,
101) -> Result<Vec<u8>, String> {
102 let argon2 = argon2::Argon2::new(
103 argon2::Algorithm::Argon2id,
104 argon2::Version::V0x13,
105 params.clone(),
106 );
107
108 let mut output = vec![0u8; params.output_len().unwrap_or(32)];
109 argon2
110 .hash_password_into(password, salt, &mut output)
111 .map_err(|e| format!("argon2 hashing failed: {}", e))?;
112 Ok(output)
113}
114
115#[tracing::instrument(
123 name = "hkdf_extract_expand",
124 skip_all,
125 fields(
126 operation = "hkdf_extract_expand",
127 ikm_len = ikm.len(),
128 salt_len = salt.len(),
129 info_len = info.len(),
130 info_label,
131 output_len = len,
132 )
133)]
134pub fn hkdf_extract_expand(
135 ikm: &[u8],
136 salt: &[u8],
137 info: &[u8],
138 len: usize,
139) -> Result<Vec<u8>, String> {
140 let span = tracing::Span::current();
144 let label_safe = info.len() <= 64
145 && std::str::from_utf8(info)
146 .map(|s| {
147 s.chars()
148 .all(|c| c.is_ascii_graphic() || c == ' ' || c == '-' || c == '_' || c == '.')
149 })
150 .unwrap_or(false);
151 if label_safe {
152 if let Ok(s) = std::str::from_utf8(info) {
155 span.record("info_label", s);
156 }
157 } else {
158 span.record("info_label", "<binary or oversized; redacted>");
159 }
160 tracing::debug!(
161 target: "cass::encryption",
162 operation = "hkdf_extract_expand",
163 ikm_len = ikm.len(),
164 salt_len = salt.len(),
165 info_len = info.len(),
166 output_len = len,
167 "hkdf_extract_expand: entering"
168 );
169 let start = std::time::Instant::now();
170
171 let salt_obj = ring_hkdf::Salt::new(ring_hkdf::HKDF_SHA256, salt);
172 let prk = salt_obj.extract(ikm);
173 let info_components = [info];
174 let okm = prk
175 .expand(&info_components, HkdfOutputLen(len))
176 .map_err(|_| "hkdf expand failed: invalid output length".to_string())?;
177 let mut output = vec![0u8; len];
178 okm.fill(&mut output)
179 .map_err(|_| "hkdf expand failed: unable to fill output buffer".to_string())?;
180
181 let elapsed_us = start.elapsed().as_micros() as u64;
182 tracing::debug!(
183 target: "cass::encryption",
184 operation = "hkdf_extract_expand",
185 elapsed_us = elapsed_us,
186 "hkdf_extract_expand: ok"
187 );
188 Ok(output)
189}
190
191#[tracing::instrument(
194 name = "hkdf_extract",
195 skip_all,
196 fields(operation = "hkdf_extract", salt_len = salt.len(), ikm_len = ikm.len()),
197)]
198pub fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> Vec<u8> {
199 let key = hmac::Key::new(hmac::HMAC_SHA256, salt);
200 let result = hmac::sign(&key, ikm).as_ref().to_vec();
201 tracing::debug!(
202 target: "cass::encryption",
203 operation = "hkdf_extract",
204 output_len = result.len(),
205 "hkdf_extract: ok"
206 );
207 result
208}
209
210#[cfg(test)]
215mod tests {
216 use super::*;
217
218 fn assert_err_contains<T>(result: Result<T, String>, expected: &str) {
219 let err = result.err().expect("operation should fail");
220 assert!(
221 err.contains(expected),
222 "expected error containing {expected:?}, got {err:?}"
223 );
224 }
225
226 #[test]
231 fn aes_gcm_encrypt_decrypt_round_trip() {
232 let key = [0u8; 32];
233 let nonce = [0u8; 12];
234 let plaintext = b"Hello, world!";
235 let aad = b"additional data";
236
237 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
238
239 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
240
241 assert_eq!(decrypted, plaintext);
242 }
243
244 #[test]
245 fn aes_gcm_round_trip_empty_plaintext() {
246 let key = [0u8; 32];
247 let nonce = [0u8; 12];
248 let plaintext = b"";
249 let aad = b"";
250
251 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
252
253 assert!(ciphertext.is_empty());
254 assert_eq!(tag.len(), 16);
255
256 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
257 assert!(decrypted.is_empty());
258 }
259
260 #[test]
261 fn aes_gcm_round_trip_large_data() {
262 let key = [0xab; 32];
263 let nonce = [0xcd; 12];
264 let plaintext: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
265 let aad = b"large data test";
266
267 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, &plaintext, aad).unwrap();
268
269 assert_eq!(ciphertext.len(), plaintext.len());
270
271 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
272 assert_eq!(decrypted, plaintext);
273 }
274
275 #[test]
276 fn aes_gcm_round_trip_unicode_data() {
277 let key = [0x42; 32];
278 let nonce = [0x13; 12];
279 let plaintext = "日本語テスト 🦀 Rust".as_bytes();
280 let aad = "unicode AAD: émojis 🎉".as_bytes();
281
282 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
283 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
284
285 assert_eq!(decrypted, plaintext);
286 }
287
288 #[test]
289 fn aes_gcm_encrypt_invalid_key_length() {
290 let key = [0u8; 16]; let nonce = [0u8; 12];
292 let plaintext = b"test";
293 let aad = b"";
294
295 let result = aes_gcm_encrypt(&key, &nonce, plaintext, aad);
296 assert_err_contains(result, "key length invalid");
297 }
298
299 #[test]
300 fn aes_gcm_encrypt_invalid_nonce_length() {
301 let key = [0u8; 32];
302 let nonce = [0u8; 16]; let plaintext = b"test";
304 let aad = b"";
305
306 let result = aes_gcm_encrypt(&key, &nonce, plaintext, aad);
307 assert_err_contains(result, "nonce length invalid");
308 }
309
310 #[test]
311 fn aes_gcm_decrypt_invalid_key_length() {
312 let key = [0u8; 31]; let nonce = [0u8; 12];
314 let ciphertext = b"ciphertext";
315 let aad = b"";
316 let tag = [0u8; 16];
317
318 let result = aes_gcm_decrypt(&key, &nonce, ciphertext, aad, &tag);
319 assert_err_contains(result, "key length invalid");
320 }
321
322 #[test]
323 fn aes_gcm_decrypt_invalid_nonce_length() {
324 let key = [0u8; 32];
325 let nonce = [0u8; 8]; let ciphertext = b"ciphertext";
327 let aad = b"";
328 let tag = [0u8; 16];
329
330 let result = aes_gcm_decrypt(&key, &nonce, ciphertext, aad, &tag);
331 assert_err_contains(result, "nonce length invalid");
332 }
333
334 #[test]
335 fn aes_gcm_decrypt_invalid_tag_length() {
336 let key = [0u8; 32];
337 let nonce = [0u8; 12];
338 let ciphertext = b"ciphertext";
339 let aad = b"";
340 let tag = [0u8; 8]; let result = aes_gcm_decrypt(&key, &nonce, ciphertext, aad, &tag);
343 assert_err_contains(result, "tag length invalid");
344 }
345
346 #[test]
347 fn aes_gcm_decrypt_wrong_key_fails() {
348 let key = [0u8; 32];
349 let nonce = [0u8; 12];
350 let plaintext = b"secret message";
351 let aad = b"aad";
352
353 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
354
355 let wrong_key = [1u8; 32];
357 let result = aes_gcm_decrypt(&wrong_key, &nonce, &ciphertext, aad, &tag);
358 assert_err_contains(result, "decryption failed");
359 }
360
361 #[test]
362 fn aes_gcm_decrypt_wrong_aad_fails() {
363 let key = [0u8; 32];
364 let nonce = [0u8; 12];
365 let plaintext = b"secret message";
366 let aad = b"correct aad";
367
368 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
369
370 let wrong_aad = b"wrong aad";
372 let result = aes_gcm_decrypt(&key, &nonce, &ciphertext, wrong_aad, &tag);
373 assert_err_contains(result, "decryption failed");
374 }
375
376 #[test]
377 fn aes_gcm_decrypt_tampered_ciphertext_fails() {
378 let key = [0u8; 32];
379 let nonce = [0u8; 12];
380 let plaintext = b"secret message";
381 let aad = b"aad";
382
383 let (mut ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
384
385 ciphertext[0] ^= 0xff;
387 let result = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag);
388 assert_err_contains(result, "decryption failed");
389 }
390
391 #[test]
392 fn aes_gcm_decrypt_tampered_tag_fails() {
393 let key = [0u8; 32];
394 let nonce = [0u8; 12];
395 let plaintext = b"secret message";
396 let aad = b"aad";
397
398 let (ciphertext, mut tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
399
400 tag[0] ^= 0xff;
402 let result = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag);
403 assert_err_contains(result, "decryption failed");
404 }
405
406 #[test]
407 fn aes_gcm_tag_is_correct_size() {
408 let key = [0u8; 32];
409 let nonce = [0u8; 12];
410 let plaintext = b"test";
411 let aad = b"";
412
413 let (_, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
414 assert_eq!(tag.len(), 16);
415 }
416
417 #[test]
418 fn aes_gcm_different_nonces_produce_different_ciphertext() {
419 let key = [0u8; 32];
420 let plaintext = b"same plaintext";
421 let aad = b"same aad";
422
423 let nonce1 = [0u8; 12];
424 let nonce2 = [1u8; 12];
425
426 let (ciphertext1, _) = aes_gcm_encrypt(&key, &nonce1, plaintext, aad).unwrap();
427 let (ciphertext2, _) = aes_gcm_encrypt(&key, &nonce2, plaintext, aad).unwrap();
428
429 assert_ne!(ciphertext1, ciphertext2);
430 }
431
432 #[test]
437 fn argon2id_hash_produces_deterministic_output() {
438 let password = b"password123";
439 let salt = b"randomsalt123456"; let params = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
441
442 let hash1 = argon2id_hash(password, salt, ¶ms).unwrap();
443 let hash2 = argon2id_hash(password, salt, ¶ms).unwrap();
444
445 assert_eq!(hash1, hash2);
446 assert_eq!(hash1.len(), 32);
447 }
448
449 #[test]
450 fn argon2id_hash_different_passwords_produce_different_hashes() {
451 let salt = b"randomsalt123456";
452 let params = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
453
454 let hash1 = argon2id_hash(b"password1", salt, ¶ms).unwrap();
455 let hash2 = argon2id_hash(b"password2", salt, ¶ms).unwrap();
456
457 assert_ne!(hash1, hash2);
458 }
459
460 #[test]
461 fn argon2id_hash_different_salts_produce_different_hashes() {
462 let password = b"samepassword";
463 let params = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
464
465 let hash1 = argon2id_hash(password, b"salt1234567890ab", ¶ms).unwrap();
466 let hash2 = argon2id_hash(password, b"salt0987654321xy", ¶ms).unwrap();
467
468 assert_ne!(hash1, hash2);
469 }
470
471 #[test]
472 fn argon2id_hash_respects_output_length() {
473 let password = b"password";
474 let salt = b"salt1234567890ab";
475
476 let params_32 = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
477 let params_64 = Argon2Params::new(1024, 1, 1, Some(64)).unwrap();
478
479 let hash_32 = argon2id_hash(password, salt, ¶ms_32).unwrap();
480 let hash_64 = argon2id_hash(password, salt, ¶ms_64).unwrap();
481
482 assert_eq!(hash_32.len(), 32);
483 assert_eq!(hash_64.len(), 64);
484 }
485
486 #[test]
487 fn argon2id_hash_empty_password() {
488 let password = b"";
489 let salt = b"randomsalt123456";
490 let params = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
491
492 let result = argon2id_hash(password, salt, ¶ms);
493 assert!(result.is_ok());
494 assert_eq!(result.unwrap().len(), 32);
495 }
496
497 #[test]
498 fn argon2id_hash_unicode_password() {
499 let password = "日本語パスワード🔐".as_bytes();
500 let salt = b"randomsalt123456";
501 let params = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
502
503 let result = argon2id_hash(password, salt, ¶ms);
504 assert!(result.is_ok());
505 assert_eq!(result.unwrap().len(), 32);
506 }
507
508 #[test]
513 fn hkdf_extract_expand_produces_deterministic_output() {
514 let ikm = b"input key material";
515 let salt = b"salt value";
516 let info = b"context info";
517
518 let okm1 = hkdf_extract_expand(ikm, salt, info, 32).unwrap();
519 let okm2 = hkdf_extract_expand(ikm, salt, info, 32).unwrap();
520
521 assert_eq!(okm1, okm2);
522 assert_eq!(okm1.len(), 32);
523 }
524
525 #[test]
526 fn hkdf_extract_expand_respects_output_length() {
527 let ikm = b"input key material";
528 let salt = b"salt value";
529 let info = b"context info";
530
531 let okm_16 = hkdf_extract_expand(ikm, salt, info, 16).unwrap();
532 let okm_64 = hkdf_extract_expand(ikm, salt, info, 64).unwrap();
533
534 assert_eq!(okm_16.len(), 16);
535 assert_eq!(okm_64.len(), 64);
536 }
537
538 #[test]
539 fn hkdf_extract_expand_different_info_produces_different_output() {
540 let ikm = b"input key material";
541 let salt = b"salt value";
542
543 let okm1 = hkdf_extract_expand(ikm, salt, b"info1", 32).unwrap();
544 let okm2 = hkdf_extract_expand(ikm, salt, b"info2", 32).unwrap();
545
546 assert_ne!(okm1, okm2);
547 }
548
549 #[test]
550 fn hkdf_extract_expand_different_salt_produces_different_output() {
551 let ikm = b"input key material";
552 let info = b"context info";
553
554 let okm1 = hkdf_extract_expand(ikm, b"salt1", info, 32).unwrap();
555 let okm2 = hkdf_extract_expand(ikm, b"salt2", info, 32).unwrap();
556
557 assert_ne!(okm1, okm2);
558 }
559
560 #[test]
561 fn hkdf_extract_expand_empty_inputs() {
562 let ikm = b"input key material";
563
564 let okm1 = hkdf_extract_expand(ikm, b"", b"info", 32).unwrap();
566 assert_eq!(okm1.len(), 32);
567
568 let okm2 = hkdf_extract_expand(ikm, b"salt", b"", 32).unwrap();
570 assert_eq!(okm2.len(), 32);
571 }
572
573 #[test]
574 fn hkdf_extract_expand_too_long_output_fails() {
575 let ikm = b"input key material";
576 let salt = b"salt";
577 let info = b"info";
578
579 let result = hkdf_extract_expand(ikm, salt, info, 8161);
581 assert!(result.is_err());
582 }
583
584 #[test]
585 fn hkdf_extract_produces_deterministic_output() {
586 let salt = b"salt value";
587 let ikm = b"input key material";
588
589 let prk1 = hkdf_extract(salt, ikm);
590 let prk2 = hkdf_extract(salt, ikm);
591
592 assert_eq!(prk1, prk2);
593 assert_eq!(prk1.len(), 32);
595 }
596
597 #[test]
598 fn hkdf_extract_different_ikm_produces_different_output() {
599 let salt = b"salt value";
600
601 let prk1 = hkdf_extract(salt, b"ikm1");
602 let prk2 = hkdf_extract(salt, b"ikm2");
603
604 assert_ne!(prk1, prk2);
605 }
606
607 #[test]
608 fn hkdf_extract_different_salt_produces_different_output() {
609 let ikm = b"input key material";
610
611 let prk1 = hkdf_extract(b"salt1", ikm);
612 let prk2 = hkdf_extract(b"salt2", ikm);
613
614 assert_ne!(prk1, prk2);
615 }
616
617 #[test]
618 fn hkdf_extract_empty_salt() {
619 let ikm = b"input key material";
620
621 let prk = hkdf_extract(b"", ikm);
622 assert_eq!(prk.len(), 32);
623 }
624
625 #[test]
626 fn hkdf_extract_empty_ikm() {
627 let salt = b"salt value";
628
629 let prk = hkdf_extract(salt, b"");
630 assert_eq!(prk.len(), 32);
631 }
632
633 #[test]
638 fn integration_argon2_derived_key_for_aes_gcm() {
639 let password = b"user_password";
640 let salt = b"randomsalt123456";
641 let params = Argon2Params::new(1024, 1, 1, Some(32)).unwrap();
642
643 let key = argon2id_hash(password, salt, ¶ms).unwrap();
645 assert_eq!(key.len(), 32);
646
647 let nonce = [0u8; 12];
649 let plaintext = b"sensitive data";
650 let aad = b"";
651
652 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
653 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
654
655 assert_eq!(decrypted, plaintext);
656 }
657
658 #[test]
659 fn integration_hkdf_derived_key_for_aes_gcm() {
660 let master_secret = b"master_secret";
661 let salt = b"application_salt";
662 let info = b"encryption_key";
663
664 let key = hkdf_extract_expand(master_secret, salt, info, 32).unwrap();
666 assert_eq!(key.len(), 32);
667
668 let nonce = [0u8; 12];
670 let plaintext = b"sensitive data";
671 let aad = b"";
672
673 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
674 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
675
676 assert_eq!(decrypted, plaintext);
677 }
678
679 #[test]
680 fn integration_extract_then_expand() {
681 let salt = b"random_salt";
682 let ikm = b"initial_key_material";
683 let info = b"derived_key";
684
685 let prk = hkdf_extract(salt, ikm);
687 let key = hkdf_extract_expand(&prk, b"", info, 32).unwrap();
688
689 assert_eq!(key.len(), 32);
690
691 let nonce = [0u8; 12];
693 let plaintext = b"test data";
694 let aad = b"";
695
696 let (ciphertext, tag) = aes_gcm_encrypt(&key, &nonce, plaintext, aad).unwrap();
697 let decrypted = aes_gcm_decrypt(&key, &nonce, &ciphertext, aad, &tag).unwrap();
698
699 assert_eq!(decrypted, plaintext);
700 }
701}