1use crate::errors::CoreError;
17#[cfg(feature = "fips")]
18use aws_lc_rs::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM, CHACHA20_POLY1305};
19#[cfg(not(feature = "fips"))]
20use ring::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM, CHACHA20_POLY1305};
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::sync::Arc;
23
24pub const AEAD_OVERHEAD: usize = 16;
26
27pub const AEAD_MAX_INVOCATIONS: u64 = 1u64 << 48;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55#[repr(u8)]
56pub enum CipherSuite {
57 Aes256Gcm = 1,
60 ChaCha20Poly1305 = 2,
69}
70
71impl CipherSuite {
72 pub fn to_byte(self) -> u8 {
74 self as u8
75 }
76
77 pub fn from_byte(b: u8) -> Option<Self> {
79 match b {
80 1 => Some(Self::Aes256Gcm),
81 2 => Some(Self::ChaCha20Poly1305),
82 _ => None,
83 }
84 }
85
86 fn algorithm(&self) -> &'static aead::Algorithm {
88 match self {
89 Self::Aes256Gcm => &AES_256_GCM,
90 Self::ChaCha20Poly1305 => &CHACHA20_POLY1305,
91 }
92 }
93}
94
95#[derive(Debug, Clone, Copy)]
97pub struct HwCaps {
98 pub has_hw_aes: bool,
99}
100
101impl HwCaps {
102 pub fn detect() -> Self {
104 Self {
105 has_hw_aes: Self::detect_hw_aes(),
106 }
107 }
108
109 #[cfg(target_arch = "aarch64")]
110 fn detect_hw_aes() -> bool {
111 std::arch::is_aarch64_feature_detected!("aes")
112 }
113
114 #[cfg(target_arch = "x86_64")]
115 fn detect_hw_aes() -> bool {
116 std::is_x86_feature_detected!("aes")
117 }
118
119 #[cfg(target_arch = "x86")]
120 fn detect_hw_aes() -> bool {
121 std::is_x86_feature_detected!("aes")
122 }
123
124 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
125 fn detect_hw_aes() -> bool {
126 false }
128
129 pub fn recommended_cipher(&self) -> CipherSuite {
136 #[cfg(feature = "fips")]
137 {
138 let _ = self.has_hw_aes;
139 CipherSuite::Aes256Gcm
140 }
141 #[cfg(not(feature = "fips"))]
142 {
143 if self.has_hw_aes {
144 CipherSuite::Aes256Gcm
145 } else {
146 CipherSuite::ChaCha20Poly1305
147 }
148 }
149 }
150}
151
152pub fn negotiate_cipher(
159 client_preferred: &[CipherSuite],
160 server_caps: &HwCaps,
161) -> Result<CipherSuite, CoreError> {
162 #[cfg(feature = "fips")]
163 {
164 let _ = server_caps;
165 if client_preferred.contains(&CipherSuite::Aes256Gcm) {
166 Ok(CipherSuite::Aes256Gcm)
167 } else {
168 Err(CoreError::CipherSuiteUnavailable(
169 "no FIPS-approved cipher suite in client offer (only AES-256-GCM is approved under fips)"
170 .into(),
171 ))
172 }
173 }
174 #[cfg(not(feature = "fips"))]
175 {
176 let server_pref = server_caps.recommended_cipher();
177 if client_preferred.contains(&server_pref) {
179 return Ok(server_pref);
180 }
181 Ok(client_preferred
183 .first()
184 .copied()
185 .unwrap_or(CipherSuite::ChaCha20Poly1305))
186 }
187}
188
189#[derive(Clone)]
196pub struct CryptoSession {
197 inner: Arc<CryptoSessionInner>,
198}
199
200struct CryptoSessionInner {
201 suite: CipherSuite,
202 send_key: LessSafeKey,
203 recv_key: LessSafeKey,
204 send_counter: AtomicU64,
205 recv_counter: AtomicU64,
206 nonce_prefix: [u8; 4],
207}
208
209impl CryptoSession {
210 pub fn from_shared_secret(shared_secret: &[u8; 32]) -> Result<Self, CoreError> {
213 let suite = HwCaps::detect().recommended_cipher();
214 Self::build(shared_secret, suite, false)
215 }
216
217 pub fn from_shared_secret_peer(shared_secret: &[u8; 32]) -> Result<Self, CoreError> {
219 let suite = HwCaps::detect().recommended_cipher();
220 Self::build(shared_secret, suite, true)
221 }
222
223 pub fn with_suite(shared_secret: &[u8; 32], suite: CipherSuite) -> Result<Self, CoreError> {
231 Self::guard_suite_under_fips(suite)?;
232 Self::build(shared_secret, suite, false)
233 }
234
235 pub fn with_suite_peer(
239 shared_secret: &[u8; 32],
240 suite: CipherSuite,
241 ) -> Result<Self, CoreError> {
242 Self::guard_suite_under_fips(suite)?;
243 Self::build(shared_secret, suite, true)
244 }
245
246 #[inline]
247 fn guard_suite_under_fips(suite: CipherSuite) -> Result<(), CoreError> {
248 #[cfg(feature = "fips")]
249 {
250 if suite == CipherSuite::ChaCha20Poly1305 {
251 return Err(CoreError::CipherSuiteUnavailable(
252 "ChaCha20-Poly1305 is not FIPS-approved; only AES-256-GCM is permitted under --features fips"
253 .into(),
254 ));
255 }
256 }
257 #[cfg(not(feature = "fips"))]
258 {
259 let _ = suite;
260 }
261 Ok(())
262 }
263
264 fn build(shared_secret: &[u8; 32], suite: CipherSuite, swap: bool) -> Result<Self, CoreError> {
265 let ctx = match suite {
266 CipherSuite::Aes256Gcm => "phantom-aes-",
267 CipherSuite::ChaCha20Poly1305 => "phantom-cc20-",
268 };
269 let send_label = format!("{}send-v1", ctx);
270 let recv_label = format!("{}recv-v1", ctx);
271
272 let key_a = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
279 &send_label,
280 shared_secret,
281 ));
282 let key_b = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
283 &recv_label,
284 shared_secret,
285 ));
286
287 let (send_bytes, recv_bytes) = if swap { (key_b, key_a) } else { (key_a, key_b) };
288
289 let algo = suite.algorithm();
290 let send_unbound = UnboundKey::new(algo, &*send_bytes)
291 .map_err(|_| CoreError::CryptoError("Failed to create send key".into()))?;
292 let recv_unbound = UnboundKey::new(algo, &*recv_bytes)
293 .map_err(|_| CoreError::CryptoError("Failed to create recv key".into()))?;
294
295 let prefix_bytes = crate::crypto::kdf::derive_key_32("phantom-nonce-pfx-v1", shared_secret);
296 let mut nonce_prefix = [0u8; 4];
297 nonce_prefix.copy_from_slice(&prefix_bytes[..4]);
298
299 Ok(Self {
300 inner: Arc::new(CryptoSessionInner {
301 suite,
302 send_key: LessSafeKey::new(send_unbound),
303 recv_key: LessSafeKey::new(recv_unbound),
304 send_counter: AtomicU64::new(0),
305 recv_counter: AtomicU64::new(0),
306 nonce_prefix,
307 }),
308 })
309 }
310
311 #[inline]
313 pub fn cipher_suite(&self) -> CipherSuite {
314 self.inner.suite
315 }
316
317 #[inline]
319 pub fn encrypt_in_place(&self, aad: &[u8], buf: &mut Vec<u8>) -> Result<(), CryptoError> {
320 let counter = self.inner.send_counter.fetch_add(1, Ordering::Relaxed);
321 if counter >= AEAD_MAX_INVOCATIONS {
322 return Err(CryptoError::NonceExhausted);
323 }
324 let nonce = self.make_nonce(counter);
325 self.inner
326 .send_key
327 .seal_in_place_append_tag(nonce, Aad::from(aad), buf)
328 .map_err(|_| CryptoError::EncryptionFailed)?;
329 Ok(())
330 }
331
332 #[inline]
336 pub fn encrypt_in_place_offset(
337 &self,
338 aad: &[u8],
339 buf: &mut Vec<u8>,
340 offset: usize,
341 ) -> Result<usize, CryptoError> {
342 let counter = self.inner.send_counter.fetch_add(1, Ordering::Relaxed);
343 if counter >= AEAD_MAX_INVOCATIONS {
344 return Err(CryptoError::NonceExhausted);
345 }
346 let nonce = self.make_nonce(counter);
347 let tag = self
349 .inner
350 .send_key
351 .seal_in_place_separate_tag(nonce, Aad::from(aad), &mut buf[offset..])
352 .map_err(|_| CryptoError::EncryptionFailed)?;
353 buf.extend_from_slice(tag.as_ref());
355 Ok(buf.len() - offset)
356 }
357
358 #[inline]
360 pub fn encrypt(&self, aad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
361 let mut buf = Vec::with_capacity(plaintext.len() + AEAD_OVERHEAD);
362 buf.extend_from_slice(plaintext);
363 self.encrypt_in_place(aad, &mut buf)?;
364 Ok(buf)
365 }
366
367 #[inline]
369 pub fn decrypt_in_place<'a>(
370 &self,
371 aad: &[u8],
372 buf: &'a mut [u8],
373 ) -> Result<&'a mut [u8], CryptoError> {
374 let counter = self.inner.recv_counter.fetch_add(1, Ordering::Relaxed);
375 if counter >= AEAD_MAX_INVOCATIONS {
376 return Err(CryptoError::NonceExhausted);
377 }
378 let nonce = self.make_nonce(counter);
379 self.inner
380 .recv_key
381 .open_in_place(nonce, Aad::from(aad), buf)
382 .map_err(|_| CryptoError::DecryptionFailed)
383 }
384
385 #[inline]
389 pub fn send_invocations(&self) -> u64 {
390 self.inner.send_counter.load(Ordering::Relaxed)
391 }
392
393 #[inline]
395 pub fn recv_invocations(&self) -> u64 {
396 self.inner.recv_counter.load(Ordering::Relaxed)
397 }
398
399 #[inline]
401 pub fn decrypt(&self, aad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
402 let mut buf = ciphertext.to_vec();
403 let plaintext = self.decrypt_in_place(aad, &mut buf)?;
404 let len = plaintext.len();
405 buf.truncate(len);
406 Ok(buf)
407 }
408
409 #[inline]
427 pub fn encrypt_with_nonce(
428 &self,
429 nonce_bytes: [u8; 12],
430 aad: &[u8],
431 plaintext: &[u8],
432 ) -> Result<Vec<u8>, CryptoError> {
433 let counter = self.inner.send_counter.fetch_add(1, Ordering::Relaxed);
434 if counter >= AEAD_MAX_INVOCATIONS {
435 return Err(CryptoError::NonceExhausted);
436 }
437 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
438 let mut buf = Vec::with_capacity(plaintext.len() + AEAD_OVERHEAD);
439 buf.extend_from_slice(plaintext);
440 self.inner
441 .send_key
442 .seal_in_place_append_tag(nonce, Aad::from(aad), &mut buf)
443 .map_err(|_| CryptoError::EncryptionFailed)?;
444 Ok(buf)
445 }
446
447 #[inline]
451 pub fn decrypt_with_nonce(
452 &self,
453 nonce_bytes: [u8; 12],
454 aad: &[u8],
455 ciphertext: &[u8],
456 ) -> Result<Vec<u8>, CryptoError> {
457 let counter = self.inner.recv_counter.fetch_add(1, Ordering::Relaxed);
458 if counter >= AEAD_MAX_INVOCATIONS {
459 return Err(CryptoError::NonceExhausted);
460 }
461 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
462 let mut buf = ciphertext.to_vec();
463 let plaintext_slice = self
464 .inner
465 .recv_key
466 .open_in_place(nonce, Aad::from(aad), &mut buf)
467 .map_err(|_| CryptoError::DecryptionFailed)?;
468 let len = plaintext_slice.len();
469 buf.truncate(len);
470 Ok(buf)
471 }
472
473 #[inline]
476 pub fn nonce_prefix(&self) -> [u8; 4] {
477 self.inner.nonce_prefix
478 }
479
480 #[inline(always)]
481 fn make_nonce(&self, counter: u64) -> Nonce {
482 let mut n = [0u8; 12];
483 n[..4].copy_from_slice(&self.inner.nonce_prefix);
484 n[4..12].copy_from_slice(&counter.to_be_bytes());
485 Nonce::assume_unique_for_key(n)
486 }
487}
488
489#[derive(Debug, Clone, Copy)]
491pub enum CryptoError {
492 EncryptionFailed,
493 DecryptionFailed,
494 NonceExhausted,
497}
498
499impl std::fmt::Display for CryptoError {
500 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501 match self {
502 Self::EncryptionFailed => write!(f, "Encryption failed"),
503 Self::DecryptionFailed => write!(f, "Decryption / authentication failed"),
504 Self::NonceExhausted => write!(
505 f,
506 "AEAD nonce exhausted: per-direction counter exceeded {} invocations \
507 (rotate keys before reusing this session)",
508 AEAD_MAX_INVOCATIONS
509 ),
510 }
511 }
512}
513
514impl std::error::Error for CryptoError {}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn hw_detection() {
522 let caps = HwCaps::detect();
523 let suite = caps.recommended_cipher();
524 eprintln!("HW AES: {}, Recommended: {:?}", caps.has_hw_aes, suite);
525 }
528
529 #[test]
530 fn round_trip_aes() {
531 let secret = [0xABu8; 32];
532 let a = CryptoSession::with_suite(&secret, CipherSuite::Aes256Gcm).unwrap();
533 let b = CryptoSession::with_suite_peer(&secret, CipherSuite::Aes256Gcm).unwrap();
534
535 let msg = b"Hello, PQ AES world!";
536 let ct = a.encrypt(&[], msg).unwrap();
537 let pt = b.decrypt(&[], &ct).unwrap();
538 assert_eq!(&pt, msg);
539 }
540
541 #[cfg(feature = "fips")]
545 #[test]
546 fn round_trip_aes_aws_lc_rs() {
547 let secret = [0xCEu8; 32];
548 let a = CryptoSession::with_suite(&secret, CipherSuite::Aes256Gcm).unwrap();
549 let b = CryptoSession::with_suite_peer(&secret, CipherSuite::Aes256Gcm).unwrap();
550
551 let msg = b"Hello, FIPS-mode AES world!";
552 let ct = a.encrypt(&[], msg).unwrap();
553 let pt = b.decrypt(&[], &ct).unwrap();
554 assert_eq!(&pt, msg);
555 }
556
557 #[cfg(not(feature = "fips"))]
560 #[test]
561 fn round_trip_chacha() {
562 let secret = [0xCDu8; 32];
563 let a = CryptoSession::with_suite(&secret, CipherSuite::ChaCha20Poly1305).unwrap();
564 let b = CryptoSession::with_suite_peer(&secret, CipherSuite::ChaCha20Poly1305).unwrap();
565
566 let msg = b"Hello, PQ ChaCha world!";
567 let ct = a.encrypt(&[], msg).unwrap();
568 let pt = b.decrypt(&[], &ct).unwrap();
569 assert_eq!(&pt, msg);
570 }
571
572 #[cfg(feature = "fips")]
575 #[test]
576 fn chacha_rejected_under_fips() {
577 let secret = [0xCDu8; 32];
578
579 match CryptoSession::with_suite(&secret, CipherSuite::ChaCha20Poly1305) {
580 Err(CoreError::CipherSuiteUnavailable(_)) => {}
581 Err(e) => panic!("expected CipherSuiteUnavailable, got {e:?}"),
582 Ok(_) => panic!("expected error, got ok"),
583 }
584
585 match CryptoSession::with_suite_peer(&secret, CipherSuite::ChaCha20Poly1305) {
586 Err(CoreError::CipherSuiteUnavailable(_)) => {}
587 Err(e) => panic!("expected CipherSuiteUnavailable, got {e:?}"),
588 Ok(_) => panic!("expected error, got ok"),
589 }
590 }
591
592 #[test]
593 fn round_trip_auto() {
594 let secret = [0xEFu8; 32];
595 let a = CryptoSession::from_shared_secret(&secret).unwrap();
596 let b = CryptoSession::from_shared_secret_peer(&secret).unwrap();
597
598 assert_eq!(a.cipher_suite(), b.cipher_suite());
599 let msg = b"Auto-detected cipher!";
600 let ct = a.encrypt(&[], msg).unwrap();
601 let pt = b.decrypt(&[], &ct).unwrap();
602 assert_eq!(&pt, msg);
603 }
604
605 #[test]
606 fn in_place_with_offset() {
607 let secret = [0xAB; 32];
608 let session = CryptoSession::with_suite(&secret, CipherSuite::Aes256Gcm).unwrap();
609 let peer = CryptoSession::with_suite_peer(&secret, CipherSuite::Aes256Gcm).unwrap();
610
611 let data = b"Payload after header";
612 let header_len = 4usize;
613 let mut buf = Vec::with_capacity(header_len + data.len() + AEAD_OVERHEAD);
614 buf.extend_from_slice(&[0u8; 4]); buf.extend_from_slice(data);
616
617 let ct_len = session
618 .encrypt_in_place_offset(&[0u8; 4], &mut buf, header_len)
619 .unwrap();
620
621 buf[..4].copy_from_slice(&(ct_len as u32).to_be_bytes());
623
624 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
626 let (_header, payload) = buf.split_at_mut(4);
627 let pt = peer
628 .decrypt_in_place(&[0u8; 4], &mut payload[..len])
629 .unwrap();
630 assert_eq!(pt, data);
631 }
632
633 #[cfg(not(feature = "fips"))]
634 #[test]
635 fn negotiation() {
636 let server_aes = HwCaps { has_hw_aes: true };
637 let server_no_aes = HwCaps { has_hw_aes: false };
638
639 let result = negotiate_cipher(
641 &[CipherSuite::Aes256Gcm, CipherSuite::ChaCha20Poly1305],
642 &server_aes,
643 )
644 .unwrap();
645 assert_eq!(result, CipherSuite::Aes256Gcm);
646
647 let result = negotiate_cipher(
649 &[CipherSuite::Aes256Gcm, CipherSuite::ChaCha20Poly1305],
650 &server_no_aes,
651 )
652 .unwrap();
653 assert_eq!(result, CipherSuite::ChaCha20Poly1305);
654
655 let result = negotiate_cipher(&[CipherSuite::ChaCha20Poly1305], &server_aes).unwrap();
657 assert_eq!(result, CipherSuite::ChaCha20Poly1305);
658 }
659
660 #[cfg(feature = "fips")]
662 #[test]
663 fn negotiation_rejects_chacha_only_under_fips() {
664 let server_aes = HwCaps { has_hw_aes: true };
665
666 let suite = negotiate_cipher(
668 &[CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm],
669 &server_aes,
670 )
671 .unwrap();
672 assert_eq!(suite, CipherSuite::Aes256Gcm);
673
674 let err = negotiate_cipher(&[CipherSuite::ChaCha20Poly1305], &server_aes).unwrap_err();
676 assert!(
677 matches!(err, CoreError::CipherSuiteUnavailable(_)),
678 "expected CipherSuiteUnavailable, got {err:?}"
679 );
680 }
681
682 #[cfg(not(feature = "fips"))]
685 #[test]
686 fn throughput_comparison() {
687 use std::time::Instant;
688
689 let secret = [0xAB; 32];
690 let data = vec![0u8; 16 * 1024]; let iters = 50_000;
692
693 for suite in [CipherSuite::Aes256Gcm, CipherSuite::ChaCha20Poly1305] {
694 let session = CryptoSession::with_suite(&secret, suite).unwrap();
695 let start = Instant::now();
696 for _ in 0..iters {
697 let e = session.encrypt(&[], &data).unwrap();
698 std::hint::black_box(e);
699 }
700 let elapsed = start.elapsed();
701 let tput = (data.len() * iters) as f64 / 1_048_576.0 / elapsed.as_secs_f64();
702 eprintln!("{:?}: {:.0} MiB/s", suite, tput);
703 }
704 }
705}