use crate::errors::CoreError;
#[cfg(feature = "fips")]
use aws_lc_rs::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM, CHACHA20_POLY1305};
#[cfg(not(feature = "fips"))]
use ring::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM, CHACHA20_POLY1305};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub const AEAD_OVERHEAD: usize = 16;
pub const AEAD_MAX_INVOCATIONS: u64 = 1u64 << 48;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CipherSuite {
Aes256Gcm = 1,
ChaCha20Poly1305 = 2,
}
impl CipherSuite {
pub fn to_byte(self) -> u8 {
self as u8
}
pub fn from_byte(b: u8) -> Option<Self> {
match b {
1 => Some(Self::Aes256Gcm),
2 => Some(Self::ChaCha20Poly1305),
_ => None,
}
}
fn algorithm(&self) -> &'static aead::Algorithm {
match self {
Self::Aes256Gcm => &AES_256_GCM,
Self::ChaCha20Poly1305 => &CHACHA20_POLY1305,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct HwCaps {
pub has_hw_aes: bool,
}
impl HwCaps {
pub fn detect() -> Self {
Self {
has_hw_aes: Self::detect_hw_aes(),
}
}
#[cfg(target_arch = "aarch64")]
fn detect_hw_aes() -> bool {
std::arch::is_aarch64_feature_detected!("aes")
}
#[cfg(target_arch = "x86_64")]
fn detect_hw_aes() -> bool {
std::is_x86_feature_detected!("aes")
}
#[cfg(target_arch = "x86")]
fn detect_hw_aes() -> bool {
std::is_x86_feature_detected!("aes")
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")))]
fn detect_hw_aes() -> bool {
false }
pub fn recommended_cipher(&self) -> CipherSuite {
#[cfg(feature = "fips")]
{
let _ = self.has_hw_aes;
CipherSuite::Aes256Gcm
}
#[cfg(not(feature = "fips"))]
{
if self.has_hw_aes {
CipherSuite::Aes256Gcm
} else {
CipherSuite::ChaCha20Poly1305
}
}
}
}
pub fn negotiate_cipher(
client_preferred: &[CipherSuite],
server_caps: &HwCaps,
) -> Result<CipherSuite, CoreError> {
#[cfg(feature = "fips")]
{
let _ = server_caps;
if client_preferred.contains(&CipherSuite::Aes256Gcm) {
Ok(CipherSuite::Aes256Gcm)
} else {
Err(CoreError::CipherSuiteUnavailable(
"no FIPS-approved cipher suite in client offer (only AES-256-GCM is approved under fips)"
.into(),
))
}
}
#[cfg(not(feature = "fips"))]
{
let server_pref = server_caps.recommended_cipher();
if client_preferred.contains(&server_pref) {
return Ok(server_pref);
}
Ok(client_preferred
.first()
.copied()
.unwrap_or(CipherSuite::ChaCha20Poly1305))
}
}
#[derive(Clone)]
pub struct CryptoSession {
inner: Arc<CryptoSessionInner>,
}
struct CryptoSessionInner {
suite: CipherSuite,
send_key: LessSafeKey,
recv_key: LessSafeKey,
send_counter: AtomicU64,
recv_counter: AtomicU64,
nonce_prefix: [u8; 4],
}
impl CryptoSession {
pub fn from_shared_secret(shared_secret: &[u8; 32]) -> Result<Self, CoreError> {
let suite = HwCaps::detect().recommended_cipher();
Self::build(shared_secret, suite, false)
}
pub fn from_shared_secret_peer(shared_secret: &[u8; 32]) -> Result<Self, CoreError> {
let suite = HwCaps::detect().recommended_cipher();
Self::build(shared_secret, suite, true)
}
pub fn with_suite(shared_secret: &[u8; 32], suite: CipherSuite) -> Result<Self, CoreError> {
Self::guard_suite_under_fips(suite)?;
Self::build(shared_secret, suite, false)
}
pub fn with_suite_peer(
shared_secret: &[u8; 32],
suite: CipherSuite,
) -> Result<Self, CoreError> {
Self::guard_suite_under_fips(suite)?;
Self::build(shared_secret, suite, true)
}
#[inline]
fn guard_suite_under_fips(suite: CipherSuite) -> Result<(), CoreError> {
#[cfg(feature = "fips")]
{
if suite == CipherSuite::ChaCha20Poly1305 {
return Err(CoreError::CipherSuiteUnavailable(
"ChaCha20-Poly1305 is not FIPS-approved; only AES-256-GCM is permitted under --features fips"
.into(),
));
}
}
#[cfg(not(feature = "fips"))]
{
let _ = suite;
}
Ok(())
}
fn build(shared_secret: &[u8; 32], suite: CipherSuite, swap: bool) -> Result<Self, CoreError> {
let ctx = match suite {
CipherSuite::Aes256Gcm => "phantom-aes-",
CipherSuite::ChaCha20Poly1305 => "phantom-cc20-",
};
let send_label = format!("{}send-v1", ctx);
let recv_label = format!("{}recv-v1", ctx);
let key_a = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
&send_label,
shared_secret,
));
let key_b = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
&recv_label,
shared_secret,
));
let (send_bytes, recv_bytes) = if swap { (key_b, key_a) } else { (key_a, key_b) };
let algo = suite.algorithm();
let send_unbound = UnboundKey::new(algo, &*send_bytes)
.map_err(|_| CoreError::CryptoError("Failed to create send key".into()))?;
let recv_unbound = UnboundKey::new(algo, &*recv_bytes)
.map_err(|_| CoreError::CryptoError("Failed to create recv key".into()))?;
let prefix_bytes = crate::crypto::kdf::derive_key_32("phantom-nonce-pfx-v1", shared_secret);
let mut nonce_prefix = [0u8; 4];
nonce_prefix.copy_from_slice(&prefix_bytes[..4]);
Ok(Self {
inner: Arc::new(CryptoSessionInner {
suite,
send_key: LessSafeKey::new(send_unbound),
recv_key: LessSafeKey::new(recv_unbound),
send_counter: AtomicU64::new(0),
recv_counter: AtomicU64::new(0),
nonce_prefix,
}),
})
}
#[inline]
pub fn cipher_suite(&self) -> CipherSuite {
self.inner.suite
}
#[inline]
pub fn encrypt_in_place(&self, aad: &[u8], buf: &mut Vec<u8>) -> Result<(), CryptoError> {
let counter = self.inner.send_counter.fetch_add(1, Ordering::Relaxed);
if counter >= AEAD_MAX_INVOCATIONS {
return Err(CryptoError::NonceExhausted);
}
let nonce = self.make_nonce(counter);
self.inner
.send_key
.seal_in_place_append_tag(nonce, Aad::from(aad), buf)
.map_err(|_| CryptoError::EncryptionFailed)?;
Ok(())
}
#[inline]
pub fn encrypt_in_place_offset(
&self,
aad: &[u8],
buf: &mut Vec<u8>,
offset: usize,
) -> Result<usize, CryptoError> {
let counter = self.inner.send_counter.fetch_add(1, Ordering::Relaxed);
if counter >= AEAD_MAX_INVOCATIONS {
return Err(CryptoError::NonceExhausted);
}
let nonce = self.make_nonce(counter);
let tag = self
.inner
.send_key
.seal_in_place_separate_tag(nonce, Aad::from(aad), &mut buf[offset..])
.map_err(|_| CryptoError::EncryptionFailed)?;
buf.extend_from_slice(tag.as_ref());
Ok(buf.len() - offset)
}
#[inline]
pub fn encrypt(&self, aad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut buf = Vec::with_capacity(plaintext.len() + AEAD_OVERHEAD);
buf.extend_from_slice(plaintext);
self.encrypt_in_place(aad, &mut buf)?;
Ok(buf)
}
#[inline]
pub fn decrypt_in_place<'a>(
&self,
aad: &[u8],
buf: &'a mut [u8],
) -> Result<&'a mut [u8], CryptoError> {
let counter = self.inner.recv_counter.fetch_add(1, Ordering::Relaxed);
if counter >= AEAD_MAX_INVOCATIONS {
return Err(CryptoError::NonceExhausted);
}
let nonce = self.make_nonce(counter);
self.inner
.recv_key
.open_in_place(nonce, Aad::from(aad), buf)
.map_err(|_| CryptoError::DecryptionFailed)
}
#[inline]
pub fn send_invocations(&self) -> u64 {
self.inner.send_counter.load(Ordering::Relaxed)
}
#[inline]
pub fn recv_invocations(&self) -> u64 {
self.inner.recv_counter.load(Ordering::Relaxed)
}
#[inline]
pub fn decrypt(&self, aad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut buf = ciphertext.to_vec();
let plaintext = self.decrypt_in_place(aad, &mut buf)?;
let len = plaintext.len();
buf.truncate(len);
Ok(buf)
}
#[inline]
pub fn encrypt_with_nonce(
&self,
nonce_bytes: [u8; 12],
aad: &[u8],
plaintext: &[u8],
) -> Result<Vec<u8>, CryptoError> {
let counter = self.inner.send_counter.fetch_add(1, Ordering::Relaxed);
if counter >= AEAD_MAX_INVOCATIONS {
return Err(CryptoError::NonceExhausted);
}
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let mut buf = Vec::with_capacity(plaintext.len() + AEAD_OVERHEAD);
buf.extend_from_slice(plaintext);
self.inner
.send_key
.seal_in_place_append_tag(nonce, Aad::from(aad), &mut buf)
.map_err(|_| CryptoError::EncryptionFailed)?;
Ok(buf)
}
#[inline]
pub fn decrypt_with_nonce(
&self,
nonce_bytes: [u8; 12],
aad: &[u8],
ciphertext: &[u8],
) -> Result<Vec<u8>, CryptoError> {
let counter = self.inner.recv_counter.fetch_add(1, Ordering::Relaxed);
if counter >= AEAD_MAX_INVOCATIONS {
return Err(CryptoError::NonceExhausted);
}
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let mut buf = ciphertext.to_vec();
let plaintext_slice = self
.inner
.recv_key
.open_in_place(nonce, Aad::from(aad), &mut buf)
.map_err(|_| CryptoError::DecryptionFailed)?;
let len = plaintext_slice.len();
buf.truncate(len);
Ok(buf)
}
#[inline]
pub fn nonce_prefix(&self) -> [u8; 4] {
self.inner.nonce_prefix
}
#[inline(always)]
fn make_nonce(&self, counter: u64) -> Nonce {
let mut n = [0u8; 12];
n[..4].copy_from_slice(&self.inner.nonce_prefix);
n[4..12].copy_from_slice(&counter.to_be_bytes());
Nonce::assume_unique_for_key(n)
}
}
#[derive(Debug, Clone, Copy)]
pub enum CryptoError {
EncryptionFailed,
DecryptionFailed,
NonceExhausted,
}
impl std::fmt::Display for CryptoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EncryptionFailed => write!(f, "Encryption failed"),
Self::DecryptionFailed => write!(f, "Decryption / authentication failed"),
Self::NonceExhausted => write!(
f,
"AEAD nonce exhausted: per-direction counter exceeded {} invocations \
(rotate keys before reusing this session)",
AEAD_MAX_INVOCATIONS
),
}
}
}
impl std::error::Error for CryptoError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hw_detection() {
let caps = HwCaps::detect();
let suite = caps.recommended_cipher();
eprintln!("HW AES: {}, Recommended: {:?}", caps.has_hw_aes, suite);
}
#[test]
fn round_trip_aes() {
let secret = [0xABu8; 32];
let a = CryptoSession::with_suite(&secret, CipherSuite::Aes256Gcm).unwrap();
let b = CryptoSession::with_suite_peer(&secret, CipherSuite::Aes256Gcm).unwrap();
let msg = b"Hello, PQ AES world!";
let ct = a.encrypt(&[], msg).unwrap();
let pt = b.decrypt(&[], &ct).unwrap();
assert_eq!(&pt, msg);
}
#[cfg(feature = "fips")]
#[test]
fn round_trip_aes_aws_lc_rs() {
let secret = [0xCEu8; 32];
let a = CryptoSession::with_suite(&secret, CipherSuite::Aes256Gcm).unwrap();
let b = CryptoSession::with_suite_peer(&secret, CipherSuite::Aes256Gcm).unwrap();
let msg = b"Hello, FIPS-mode AES world!";
let ct = a.encrypt(&[], msg).unwrap();
let pt = b.decrypt(&[], &ct).unwrap();
assert_eq!(&pt, msg);
}
#[cfg(not(feature = "fips"))]
#[test]
fn round_trip_chacha() {
let secret = [0xCDu8; 32];
let a = CryptoSession::with_suite(&secret, CipherSuite::ChaCha20Poly1305).unwrap();
let b = CryptoSession::with_suite_peer(&secret, CipherSuite::ChaCha20Poly1305).unwrap();
let msg = b"Hello, PQ ChaCha world!";
let ct = a.encrypt(&[], msg).unwrap();
let pt = b.decrypt(&[], &ct).unwrap();
assert_eq!(&pt, msg);
}
#[cfg(feature = "fips")]
#[test]
fn chacha_rejected_under_fips() {
let secret = [0xCDu8; 32];
match CryptoSession::with_suite(&secret, CipherSuite::ChaCha20Poly1305) {
Err(CoreError::CipherSuiteUnavailable(_)) => {}
Err(e) => panic!("expected CipherSuiteUnavailable, got {e:?}"),
Ok(_) => panic!("expected error, got ok"),
}
match CryptoSession::with_suite_peer(&secret, CipherSuite::ChaCha20Poly1305) {
Err(CoreError::CipherSuiteUnavailable(_)) => {}
Err(e) => panic!("expected CipherSuiteUnavailable, got {e:?}"),
Ok(_) => panic!("expected error, got ok"),
}
}
#[test]
fn round_trip_auto() {
let secret = [0xEFu8; 32];
let a = CryptoSession::from_shared_secret(&secret).unwrap();
let b = CryptoSession::from_shared_secret_peer(&secret).unwrap();
assert_eq!(a.cipher_suite(), b.cipher_suite());
let msg = b"Auto-detected cipher!";
let ct = a.encrypt(&[], msg).unwrap();
let pt = b.decrypt(&[], &ct).unwrap();
assert_eq!(&pt, msg);
}
#[test]
fn in_place_with_offset() {
let secret = [0xAB; 32];
let session = CryptoSession::with_suite(&secret, CipherSuite::Aes256Gcm).unwrap();
let peer = CryptoSession::with_suite_peer(&secret, CipherSuite::Aes256Gcm).unwrap();
let data = b"Payload after header";
let header_len = 4usize;
let mut buf = Vec::with_capacity(header_len + data.len() + AEAD_OVERHEAD);
buf.extend_from_slice(&[0u8; 4]); buf.extend_from_slice(data);
let ct_len = session
.encrypt_in_place_offset(&[0u8; 4], &mut buf, header_len)
.unwrap();
buf[..4].copy_from_slice(&(ct_len as u32).to_be_bytes());
let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
let (_header, payload) = buf.split_at_mut(4);
let pt = peer
.decrypt_in_place(&[0u8; 4], &mut payload[..len])
.unwrap();
assert_eq!(pt, data);
}
#[cfg(not(feature = "fips"))]
#[test]
fn negotiation() {
let server_aes = HwCaps { has_hw_aes: true };
let server_no_aes = HwCaps { has_hw_aes: false };
let result = negotiate_cipher(
&[CipherSuite::Aes256Gcm, CipherSuite::ChaCha20Poly1305],
&server_aes,
)
.unwrap();
assert_eq!(result, CipherSuite::Aes256Gcm);
let result = negotiate_cipher(
&[CipherSuite::Aes256Gcm, CipherSuite::ChaCha20Poly1305],
&server_no_aes,
)
.unwrap();
assert_eq!(result, CipherSuite::ChaCha20Poly1305);
let result = negotiate_cipher(&[CipherSuite::ChaCha20Poly1305], &server_aes).unwrap();
assert_eq!(result, CipherSuite::ChaCha20Poly1305);
}
#[cfg(feature = "fips")]
#[test]
fn negotiation_rejects_chacha_only_under_fips() {
let server_aes = HwCaps { has_hw_aes: true };
let suite = negotiate_cipher(
&[CipherSuite::ChaCha20Poly1305, CipherSuite::Aes256Gcm],
&server_aes,
)
.unwrap();
assert_eq!(suite, CipherSuite::Aes256Gcm);
let err = negotiate_cipher(&[CipherSuite::ChaCha20Poly1305], &server_aes).unwrap_err();
assert!(
matches!(err, CoreError::CipherSuiteUnavailable(_)),
"expected CipherSuiteUnavailable, got {err:?}"
);
}
#[cfg(not(feature = "fips"))]
#[test]
fn throughput_comparison() {
use std::time::Instant;
let secret = [0xAB; 32];
let data = vec![0u8; 16 * 1024]; let iters = 50_000;
for suite in [CipherSuite::Aes256Gcm, CipherSuite::ChaCha20Poly1305] {
let session = CryptoSession::with_suite(&secret, suite).unwrap();
let start = Instant::now();
for _ in 0..iters {
let e = session.encrypt(&[], &data).unwrap();
std::hint::black_box(e);
}
let elapsed = start.elapsed();
let tput = (data.len() * iters) as f64 / 1_048_576.0 / elapsed.as_secs_f64();
eprintln!("{:?}: {:.0} MiB/s", suite, tput);
}
}
}