Skip to main content

phantom_protocol/crypto/
aes_session.rs

1//! High-Performance AES-GCM Session Encryption
2//!
3//! Uses `ring` crate for AES-256-GCM with hardware acceleration.
4//! On Apple Silicon M1: ARM FEAT_AES intrinsics (~4-8 GB/s)
5//! On x86_64: AES-NI instructions (~4-8 GB/s)
6//!
7//! `ring` uses hand-optimized assembly for both ARM64 and x86_64,
8//! ensuring maximum throughput compared to pure-Rust crates.
9
10use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13/// Overhead bytes added by AES-256-GCM (16-byte auth tag)
14pub const AES_GCM_OVERHEAD: usize = 16;
15
16/// High-performance session encryption using ring's AES-256-GCM
17/// (hardware accelerated on ARM64/x86_64)
18pub struct AesSession {
19    /// Send cipher key
20    send_key: LessSafeKey,
21    /// Receive cipher key
22    recv_key: LessSafeKey,
23    /// Send nonce counter
24    send_counter: AtomicU64,
25    /// Receive nonce counter
26    recv_counter: AtomicU64,
27    /// Nonce prefix (4 bytes, set per session)
28    nonce_prefix: [u8; 4],
29}
30
31impl AesSession {
32    /// Create from a 32-byte shared secret (derived from PQC handshake).
33    /// This is the "initiator" side.
34    /// Create from a 32-byte shared secret (derived from PQC handshake).
35    /// This is the "initiator" side.
36    pub fn from_shared_secret(shared_secret: &[u8; 32]) -> Result<Self, crate::CoreError> {
37        Self::build(shared_secret, false)
38    }
39
40    /// Create the "peer" (responder) side — send/recv keys are swapped so that
41    /// initiator's encrypt can be decrypted by peer's decrypt, and vice versa.
42    /// Create the "peer" (responder) side — send/recv keys are swapped so that
43    /// initiator's encrypt can be decrypted by peer's decrypt, and vice versa.
44    pub fn from_shared_secret_peer(shared_secret: &[u8; 32]) -> Result<Self, crate::CoreError> {
45        Self::build(shared_secret, true)
46    }
47
48    fn build(shared_secret: &[u8; 32], swap: bool) -> Result<Self, crate::CoreError> {
49        // `crypto::kdf::derive_key_32` cfg-dispatches between
50        // `blake3::derive_key` (default) and HKDF-SHA256 (`--features
51        // fips`). API shape and 32-byte output are identical.
52        // CRYPTO-3: wipe the per-direction AEAD key bytes on every exit path
53        // once copied into ring's opaque `UnboundKey`.
54        let key_a = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
55            "phantom-aes-send-v1",
56            shared_secret,
57        ));
58        let key_b = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
59            "phantom-aes-recv-v1",
60            shared_secret,
61        ));
62
63        let (send_bytes, recv_bytes) = if swap { (key_b, key_a) } else { (key_a, key_b) };
64
65        let send_unbound = UnboundKey::new(&AES_256_GCM, &*send_bytes)
66            .map_err(|_| crate::CoreError::CryptoError("Invalid key".into()))?;
67        let recv_unbound = UnboundKey::new(&AES_256_GCM, &*recv_bytes)
68            .map_err(|_| crate::CoreError::CryptoError("Invalid key".into()))?;
69
70        let prefix_bytes = crate::crypto::kdf::derive_key_32("phantom-nonce-pfx-v1", shared_secret);
71        let mut nonce_prefix = [0u8; 4];
72        nonce_prefix.copy_from_slice(&prefix_bytes[..4]);
73
74        Ok(Self {
75            send_key: LessSafeKey::new(send_unbound),
76            recv_key: LessSafeKey::new(recv_unbound),
77            send_counter: AtomicU64::new(0),
78            recv_counter: AtomicU64::new(0),
79            nonce_prefix,
80        })
81    }
82
83    /// Encrypt in place: appends 16-byte tag. Returns total ciphertext length.
84    #[inline]
85    pub fn encrypt_in_place(&self, aad: &[u8], buf: &mut Vec<u8>) -> Result<(), EncryptError> {
86        let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
87        let nonce = self.make_nonce(counter);
88        self.send_key
89            .seal_in_place_append_tag(nonce, Aad::from(aad), buf)
90            .map_err(|_| EncryptError::EncryptionFailed)?;
91        Ok(())
92    }
93
94    /// Encrypt: allocates a new Vec with ciphertext.
95    #[inline]
96    pub fn encrypt(&self, aad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, EncryptError> {
97        let mut buf = Vec::with_capacity(plaintext.len() + AES_GCM_OVERHEAD);
98        buf.extend_from_slice(plaintext);
99        self.encrypt_in_place(aad, &mut buf)?;
100        Ok(buf)
101    }
102
103    /// Decrypt in place: verifies tag and truncates. Returns plaintext slice.
104    #[inline]
105    pub fn decrypt_in_place<'a>(
106        &self,
107        aad: &[u8],
108        buf: &'a mut [u8],
109    ) -> Result<&'a mut [u8], EncryptError> {
110        let counter = self.recv_counter.fetch_add(1, Ordering::Relaxed);
111        let nonce = self.make_nonce(counter);
112        self.recv_key
113            .open_in_place(nonce, Aad::from(aad), buf)
114            .map_err(|_| EncryptError::DecryptionFailed)
115    }
116
117    /// Decrypt: allocates a new Vec with plaintext.
118    #[inline]
119    pub fn decrypt(&self, aad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>, EncryptError> {
120        let mut buf = ciphertext.to_vec();
121        let plaintext = self.decrypt_in_place(aad, &mut buf)?;
122        let len = plaintext.len();
123        buf.truncate(len);
124        Ok(buf)
125    }
126
127    #[inline(always)]
128    fn make_nonce(&self, counter: u64) -> Nonce {
129        let mut n = [0u8; 12];
130        n[..4].copy_from_slice(&self.nonce_prefix);
131        n[4..12].copy_from_slice(&counter.to_be_bytes());
132        Nonce::assume_unique_for_key(n)
133    }
134}
135
136/// Encryption errors
137#[derive(Debug, Clone, Copy)]
138pub enum EncryptError {
139    EncryptionFailed,
140    DecryptionFailed,
141}
142
143impl std::fmt::Display for EncryptError {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            Self::EncryptionFailed => write!(f, "Encryption failed"),
147            Self::DecryptionFailed => write!(f, "Decryption / authentication failed"),
148        }
149    }
150}
151
152impl std::error::Error for EncryptError {}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn round_trip() {
160        let secret = [0xABu8; 32];
161        // Two "peers" derived from the same secret, but with swapped keys
162        // Two "peers" derived from the same secret, but with swapped keys
163        let session_a = AesSession::from_shared_secret(&secret).unwrap();
164        let session_b = AesSession::from_shared_secret_peer(&secret).unwrap();
165
166        let msg = b"Hello, PQC world!";
167        let ct = session_a.encrypt(&[], msg).expect("Encryption failed");
168
169        // session_b decrypt uses recv_key which matches session_a's send_key
170        let pt = session_b.decrypt(&[], &ct).expect("Decryption failed");
171        assert_eq!(&pt, msg);
172    }
173
174    #[test]
175    fn throughput_smoke() {
176        use std::time::Instant;
177
178        let session = AesSession::from_shared_secret(&[0xAB; 32]).unwrap();
179        let data = vec![0u8; 64 * 1024];
180        let iters = 50_000;
181
182        let start = Instant::now();
183        for _ in 0..iters {
184            let enc = session.encrypt(&[], &data).expect("Encryption failed");
185            std::hint::black_box(enc);
186        }
187        let elapsed = start.elapsed();
188
189        let total_mb = (data.len() * iters) as f64 / 1024.0 / 1024.0;
190        let throughput = total_mb / elapsed.as_secs_f64();
191        eprintln!("ring AES-256-GCM: {:.0} MiB/s", throughput);
192        // With HW AES should be well above 1 GiB/s
193    }
194}