1use aes_gcm::aead::{Aead, KeyInit};
4use aes_gcm::{Aes256Gcm, Nonce};
5use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
6
7#[derive(Clone, Debug)]
9pub struct EncryptedFrame {
10 pub nonce: [u8; 12],
12 pub ciphertext: Vec<u8>,
14}
15
16pub struct SessionCipher {
24 cipher: Aes256Gcm,
25 nonce_prefix: [u8; 8],
26 counter: AtomicU32,
27 last_recv_counter: AtomicU64,
30}
31
32const MAX_NONCE_COUNTER: u32 = u32::MAX - 1;
34
35impl SessionCipher {
36 pub fn new(key: &[u8; 32], nonce_prefix: [u8; 8]) -> Self {
38 let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is valid for AES-256");
39 Self {
40 cipher,
41 nonce_prefix,
42 counter: AtomicU32::new(0),
43 last_recv_counter: AtomicU64::new(u64::MAX),
44 }
45 }
46
47 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedFrame, SessionError> {
49 let count = self.counter.fetch_add(1, Ordering::SeqCst);
50 if count >= MAX_NONCE_COUNTER {
51 return Err(SessionError::NonceExhausted);
52 }
53
54 let nonce_bytes = self.build_nonce(count);
55 let nonce = Nonce::from_slice(&nonce_bytes);
56
57 let ciphertext = self
58 .cipher
59 .encrypt(nonce, plaintext)
60 .map_err(|_| SessionError::EncryptionFailed)?;
61
62 Ok(EncryptedFrame {
63 nonce: nonce_bytes,
64 ciphertext,
65 })
66 }
67
68 pub fn decrypt(&self, frame: &EncryptedFrame) -> Result<Vec<u8>, SessionError> {
73 let counter = u32::from_be_bytes(
74 frame
75 .nonce
76 .get(8..12)
77 .ok_or(SessionError::InvalidNonce)?
78 .try_into()
79 .map_err(|_| SessionError::InvalidNonce)?,
80 ) as u64;
81 let last = self.last_recv_counter.load(Ordering::SeqCst);
82 if last != u64::MAX && counter <= last {
83 return Err(SessionError::ReplayedNonce);
84 }
85
86 let nonce = Nonce::from_slice(&frame.nonce);
87 let plaintext = self
88 .cipher
89 .decrypt(nonce, frame.ciphertext.as_ref())
90 .map_err(|_| SessionError::DecryptionFailed)?;
91
92 self.last_recv_counter.store(counter, Ordering::SeqCst);
93 Ok(plaintext)
94 }
95
96 pub fn encrypt_in_place_detached(
99 &self,
100 payload: &mut [u8],
101 ) -> Result<([u8; 12], [u8; 16]), SessionError> {
102 let count = self.counter.fetch_add(1, Ordering::SeqCst);
103 if count >= MAX_NONCE_COUNTER {
104 return Err(SessionError::NonceExhausted);
105 }
106
107 let nonce_bytes = self.build_nonce(count);
108 let nonce = Nonce::from_slice(&nonce_bytes);
109
110 let tag = aes_gcm::aead::AeadInPlace::encrypt_in_place_detached(
111 &self.cipher,
112 nonce,
113 b"",
114 payload,
115 )
116 .map_err(|_| SessionError::EncryptionFailed)?;
117
118 let mut tag_bytes = [0u8; 16];
119 tag_bytes.copy_from_slice(&tag);
120
121 Ok((nonce_bytes, tag_bytes))
122 }
123
124 pub fn decrypt_in_place_detached(
127 &self,
128 nonce_bytes: &[u8; 12],
129 payload: &mut [u8],
130 tag_bytes: &[u8; 16],
131 ) -> Result<(), SessionError> {
132 let counter = u32::from_be_bytes(
133 nonce_bytes
134 .get(8..12)
135 .ok_or(SessionError::InvalidNonce)?
136 .try_into()
137 .map_err(|_| SessionError::InvalidNonce)?,
138 ) as u64;
139 let last = self.last_recv_counter.load(Ordering::SeqCst);
140 if last != u64::MAX && counter <= last {
141 return Err(SessionError::ReplayedNonce);
142 }
143
144 let nonce = Nonce::from_slice(nonce_bytes);
145 let tag = aes_gcm::aead::Tag::<aes_gcm::Aes256Gcm>::from_slice(tag_bytes);
146 aes_gcm::aead::AeadInPlace::decrypt_in_place_detached(
147 &self.cipher,
148 nonce,
149 b"",
150 payload,
151 tag,
152 )
153 .map_err(|_| SessionError::DecryptionFailed)?;
154
155 self.last_recv_counter.store(counter, Ordering::SeqCst);
156 Ok(())
157 }
158
159 fn build_nonce(&self, counter: u32) -> [u8; 12] {
161 let mut nonce = [0u8; 12];
162 nonce[..8].copy_from_slice(&self.nonce_prefix);
163 nonce[8..12].copy_from_slice(&counter.to_be_bytes());
164 nonce
165 }
166}
167
168#[derive(Debug, thiserror::Error)]
169pub enum SessionError {
171 #[error("invalid nonce format")]
173 InvalidNonce,
174 #[error("nonce counter exhausted — session must be rekeyed")]
176 NonceExhausted,
177 #[error("encryption failed")]
179 EncryptionFailed,
180 #[error("decryption failed (invalid ciphertext or wrong key)")]
182 DecryptionFailed,
183 #[error("replayed nonce: frame counter has already been accepted")]
185 ReplayedNonce,
186}
187
188#[cfg(test)]
189mod tests;