phantom_protocol/crypto/
aes_session.rs1use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13pub const AES_GCM_OVERHEAD: usize = 16;
15
16pub struct AesSession {
19 send_key: LessSafeKey,
21 recv_key: LessSafeKey,
23 send_counter: AtomicU64,
25 recv_counter: AtomicU64,
27 nonce_prefix: [u8; 4],
29}
30
31impl AesSession {
32 pub fn from_shared_secret(shared_secret: &[u8; 32]) -> Result<Self, crate::CoreError> {
37 Self::build(shared_secret, false)
38 }
39
40 pub fn from_shared_secret_peer(shared_secret: &[u8; 32]) -> Result<Self, crate::CoreError> {
45 Self::build(shared_secret, true)
46 }
47
48 fn build(shared_secret: &[u8; 32], swap: bool) -> Result<Self, crate::CoreError> {
49 let key_a = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
55 "phantom-aes-send-v1",
56 shared_secret,
57 ));
58 let key_b = zeroize::Zeroizing::new(crate::crypto::kdf::derive_key_32(
59 "phantom-aes-recv-v1",
60 shared_secret,
61 ));
62
63 let (send_bytes, recv_bytes) = if swap { (key_b, key_a) } else { (key_a, key_b) };
64
65 let send_unbound = UnboundKey::new(&AES_256_GCM, &*send_bytes)
66 .map_err(|_| crate::CoreError::CryptoError("Invalid key".into()))?;
67 let recv_unbound = UnboundKey::new(&AES_256_GCM, &*recv_bytes)
68 .map_err(|_| crate::CoreError::CryptoError("Invalid key".into()))?;
69
70 let prefix_bytes = crate::crypto::kdf::derive_key_32("phantom-nonce-pfx-v1", shared_secret);
71 let mut nonce_prefix = [0u8; 4];
72 nonce_prefix.copy_from_slice(&prefix_bytes[..4]);
73
74 Ok(Self {
75 send_key: LessSafeKey::new(send_unbound),
76 recv_key: LessSafeKey::new(recv_unbound),
77 send_counter: AtomicU64::new(0),
78 recv_counter: AtomicU64::new(0),
79 nonce_prefix,
80 })
81 }
82
83 #[inline]
85 pub fn encrypt_in_place(&self, aad: &[u8], buf: &mut Vec<u8>) -> Result<(), EncryptError> {
86 let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
87 let nonce = self.make_nonce(counter);
88 self.send_key
89 .seal_in_place_append_tag(nonce, Aad::from(aad), buf)
90 .map_err(|_| EncryptError::EncryptionFailed)?;
91 Ok(())
92 }
93
94 #[inline]
96 pub fn encrypt(&self, aad: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, EncryptError> {
97 let mut buf = Vec::with_capacity(plaintext.len() + AES_GCM_OVERHEAD);
98 buf.extend_from_slice(plaintext);
99 self.encrypt_in_place(aad, &mut buf)?;
100 Ok(buf)
101 }
102
103 #[inline]
105 pub fn decrypt_in_place<'a>(
106 &self,
107 aad: &[u8],
108 buf: &'a mut [u8],
109 ) -> Result<&'a mut [u8], EncryptError> {
110 let counter = self.recv_counter.fetch_add(1, Ordering::Relaxed);
111 let nonce = self.make_nonce(counter);
112 self.recv_key
113 .open_in_place(nonce, Aad::from(aad), buf)
114 .map_err(|_| EncryptError::DecryptionFailed)
115 }
116
117 #[inline]
119 pub fn decrypt(&self, aad: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>, EncryptError> {
120 let mut buf = ciphertext.to_vec();
121 let plaintext = self.decrypt_in_place(aad, &mut buf)?;
122 let len = plaintext.len();
123 buf.truncate(len);
124 Ok(buf)
125 }
126
127 #[inline(always)]
128 fn make_nonce(&self, counter: u64) -> Nonce {
129 let mut n = [0u8; 12];
130 n[..4].copy_from_slice(&self.nonce_prefix);
131 n[4..12].copy_from_slice(&counter.to_be_bytes());
132 Nonce::assume_unique_for_key(n)
133 }
134}
135
136#[derive(Debug, Clone, Copy)]
138pub enum EncryptError {
139 EncryptionFailed,
140 DecryptionFailed,
141}
142
143impl std::fmt::Display for EncryptError {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 Self::EncryptionFailed => write!(f, "Encryption failed"),
147 Self::DecryptionFailed => write!(f, "Decryption / authentication failed"),
148 }
149 }
150}
151
152impl std::error::Error for EncryptError {}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn round_trip() {
160 let secret = [0xABu8; 32];
161 let session_a = AesSession::from_shared_secret(&secret).unwrap();
164 let session_b = AesSession::from_shared_secret_peer(&secret).unwrap();
165
166 let msg = b"Hello, PQC world!";
167 let ct = session_a.encrypt(&[], msg).expect("Encryption failed");
168
169 let pt = session_b.decrypt(&[], &ct).expect("Decryption failed");
171 assert_eq!(&pt, msg);
172 }
173
174 #[test]
175 fn throughput_smoke() {
176 use std::time::Instant;
177
178 let session = AesSession::from_shared_secret(&[0xAB; 32]).unwrap();
179 let data = vec![0u8; 64 * 1024];
180 let iters = 50_000;
181
182 let start = Instant::now();
183 for _ in 0..iters {
184 let enc = session.encrypt(&[], &data).expect("Encryption failed");
185 std::hint::black_box(enc);
186 }
187 let elapsed = start.elapsed();
188
189 let total_mb = (data.len() * iters) as f64 / 1024.0 / 1024.0;
190 let throughput = total_mb / elapsed.as_secs_f64();
191 eprintln!("ring AES-256-GCM: {:.0} MiB/s", throughput);
192 }
194}