use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
#[derive(Clone, Debug)]
pub struct EncryptedFrame {
pub nonce: [u8; 12],
pub ciphertext: Vec<u8>,
}
pub struct SessionCipher {
cipher: Aes256Gcm,
nonce_prefix: [u8; 8],
counter: AtomicU32,
last_recv_counter: AtomicU64,
}
const MAX_NONCE_COUNTER: u32 = u32::MAX - 1;
impl SessionCipher {
pub fn new(key: &[u8; 32], nonce_prefix: [u8; 8]) -> Self {
let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key is valid for AES-256");
Self {
cipher,
nonce_prefix,
counter: AtomicU32::new(0),
last_recv_counter: AtomicU64::new(u64::MAX),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedFrame, SessionError> {
let count = self.counter.fetch_add(1, Ordering::SeqCst);
if count >= MAX_NONCE_COUNTER {
return Err(SessionError::NonceExhausted);
}
let nonce_bytes = self.build_nonce(count);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|_| SessionError::EncryptionFailed)?;
Ok(EncryptedFrame {
nonce: nonce_bytes,
ciphertext,
})
}
pub fn decrypt(&self, frame: &EncryptedFrame) -> Result<Vec<u8>, SessionError> {
let counter = u32::from_be_bytes(
frame
.nonce
.get(8..12)
.ok_or(SessionError::InvalidNonce)?
.try_into()
.map_err(|_| SessionError::InvalidNonce)?,
) as u64;
let last = self.last_recv_counter.load(Ordering::SeqCst);
if last != u64::MAX && counter <= last {
return Err(SessionError::ReplayedNonce);
}
let nonce = Nonce::from_slice(&frame.nonce);
let plaintext = self
.cipher
.decrypt(nonce, frame.ciphertext.as_ref())
.map_err(|_| SessionError::DecryptionFailed)?;
self.last_recv_counter.store(counter, Ordering::SeqCst);
Ok(plaintext)
}
pub fn encrypt_in_place_detached(
&self,
payload: &mut [u8],
) -> Result<([u8; 12], [u8; 16]), SessionError> {
let count = self.counter.fetch_add(1, Ordering::SeqCst);
if count >= MAX_NONCE_COUNTER {
return Err(SessionError::NonceExhausted);
}
let nonce_bytes = self.build_nonce(count);
let nonce = Nonce::from_slice(&nonce_bytes);
let tag = aes_gcm::aead::AeadInPlace::encrypt_in_place_detached(
&self.cipher,
nonce,
b"",
payload,
)
.map_err(|_| SessionError::EncryptionFailed)?;
let mut tag_bytes = [0u8; 16];
tag_bytes.copy_from_slice(&tag);
Ok((nonce_bytes, tag_bytes))
}
pub fn decrypt_in_place_detached(
&self,
nonce_bytes: &[u8; 12],
payload: &mut [u8],
tag_bytes: &[u8; 16],
) -> Result<(), SessionError> {
let counter = u32::from_be_bytes(
nonce_bytes
.get(8..12)
.ok_or(SessionError::InvalidNonce)?
.try_into()
.map_err(|_| SessionError::InvalidNonce)?,
) as u64;
let last = self.last_recv_counter.load(Ordering::SeqCst);
if last != u64::MAX && counter <= last {
return Err(SessionError::ReplayedNonce);
}
let nonce = Nonce::from_slice(nonce_bytes);
let tag = aes_gcm::aead::Tag::<aes_gcm::Aes256Gcm>::from_slice(tag_bytes);
aes_gcm::aead::AeadInPlace::decrypt_in_place_detached(
&self.cipher,
nonce,
b"",
payload,
tag,
)
.map_err(|_| SessionError::DecryptionFailed)?;
self.last_recv_counter.store(counter, Ordering::SeqCst);
Ok(())
}
fn build_nonce(&self, counter: u32) -> [u8; 12] {
let mut nonce = [0u8; 12];
nonce[..8].copy_from_slice(&self.nonce_prefix);
nonce[8..12].copy_from_slice(&counter.to_be_bytes());
nonce
}
}
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("invalid nonce format")]
InvalidNonce,
#[error("nonce counter exhausted — session must be rekeyed")]
NonceExhausted,
#[error("encryption failed")]
EncryptionFailed,
#[error("decryption failed (invalid ciphertext or wrong key)")]
DecryptionFailed,
#[error("replayed nonce: frame counter has already been accepted")]
ReplayedNonce,
}
#[cfg(test)]
mod tests;