use bytes::{Bytes, BytesMut};
use parking_lot::Mutex;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
use snow::{params::NoiseParams, Builder, HandshakeState};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use super::protocol::{NONCE_SIZE, TAG_SIZE};
const NOISE_PATTERN: &str = "Noise_NKpsk0_25519_ChaChaPoly_BLAKE2s";
pub fn handshake_prologue(src_node_id: u64, dest_node_id: u64) -> [u8; 32] {
let mut buf = [0u8; 32];
buf[0..16].copy_from_slice(b"net-handshake-v1");
buf[16..24].copy_from_slice(&src_node_id.to_le_bytes());
buf[24..32].copy_from_slice(&dest_node_id.to_le_bytes());
buf
}
#[derive(Debug, Clone)]
pub enum CryptoError {
Handshake(String),
Encryption(String),
Decryption(String),
InvalidKey(String),
InvalidNonce,
}
impl std::fmt::Display for CryptoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Handshake(msg) => write!(f, "handshake error: {}", msg),
Self::Encryption(msg) => write!(f, "encryption error: {}", msg),
Self::Decryption(msg) => write!(f, "decryption error: {}", msg),
Self::InvalidKey(msg) => write!(f, "invalid key: {}", msg),
Self::InvalidNonce => write!(f, "invalid nonce"),
}
}
}
impl std::error::Error for CryptoError {}
#[derive(Clone)]
pub struct SessionKeys {
pub tx_key: [u8; 32],
pub rx_key: [u8; 32],
pub session_id: u64,
pub remote_static_pub: [u8; 32],
}
impl std::fmt::Debug for SessionKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionKeys")
.field("session_id", &self.session_id)
.field("tx_key", &"[REDACTED]")
.field("rx_key", &"[REDACTED]")
.field(
"remote_static_pub",
&format_args!(
"{:02x}{:02x}{:02x}{:02x}…",
self.remote_static_pub[0],
self.remote_static_pub[1],
self.remote_static_pub[2],
self.remote_static_pub[3],
),
)
.finish()
}
}
#[derive(Clone)]
pub struct StaticKeypair {
pub private: [u8; 32],
pub public: [u8; 32],
}
impl StaticKeypair {
#[expect(
clippy::expect_used,
reason = "NOISE_PATTERN is a compile-time-constant string, parses infallibly; the Noise builder generates keypairs deterministically from valid patterns"
)]
pub fn generate() -> Self {
let builder = Builder::new(
NOISE_PATTERN
.parse()
.expect("static noise pattern is valid"),
);
let keypair = builder
.generate_keypair()
.expect("keypair generation from valid pattern");
let mut private = [0u8; 32];
let mut public = [0u8; 32];
private.copy_from_slice(&keypair.private);
public.copy_from_slice(&keypair.public);
Self { private, public }
}
pub fn from_keys(private: [u8; 32], public: [u8; 32]) -> Self {
Self { private, public }
}
#[inline]
pub fn public_key(&self) -> &[u8; 32] {
&self.public
}
#[inline]
pub fn secret_key(&self) -> &[u8; 32] {
&self.private
}
}
impl std::fmt::Debug for StaticKeypair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StaticKeypair")
.field("public", &hex_string(&self.public))
.field("private", &"[REDACTED]")
.finish()
}
}
pub struct NoiseHandshake {
state: HandshakeState,
is_initiator: bool,
}
impl NoiseHandshake {
pub fn initiator(psk: &[u8; 32], responder_static: &[u8; 32]) -> Result<Self, CryptoError> {
Self::initiator_with_prologue(psk, responder_static, &[])
}
pub fn initiator_with_prologue(
psk: &[u8; 32],
responder_static: &[u8; 32],
prologue: &[u8],
) -> Result<Self, CryptoError> {
let params: NoiseParams = NOISE_PATTERN
.parse()
.map_err(|e| CryptoError::Handshake(format!("invalid noise params: {}", e)))?;
let state = Builder::new(params)
.psk(0, psk)
.map_err(|e| CryptoError::Handshake(format!("failed to set psk: {}", e)))?
.prologue(prologue)
.map_err(|e| CryptoError::Handshake(format!("failed to set prologue: {}", e)))?
.remote_public_key(responder_static)
.map_err(|e| CryptoError::Handshake(format!("failed to set remote key: {}", e)))?
.build_initiator()
.map_err(|e| CryptoError::Handshake(format!("failed to build initiator: {}", e)))?;
Ok(Self {
state,
is_initiator: true,
})
}
pub fn responder(psk: &[u8; 32], static_keypair: &StaticKeypair) -> Result<Self, CryptoError> {
Self::responder_with_prologue(psk, static_keypair, &[])
}
pub fn responder_with_prologue(
psk: &[u8; 32],
static_keypair: &StaticKeypair,
prologue: &[u8],
) -> Result<Self, CryptoError> {
let params: NoiseParams = NOISE_PATTERN
.parse()
.map_err(|e| CryptoError::Handshake(format!("invalid noise params: {}", e)))?;
let state = Builder::new(params)
.psk(0, psk)
.map_err(|e| CryptoError::Handshake(format!("failed to set psk: {}", e)))?
.prologue(prologue)
.map_err(|e| CryptoError::Handshake(format!("failed to set prologue: {}", e)))?
.local_private_key(&static_keypair.private)
.map_err(|e| CryptoError::Handshake(format!("failed to set local key: {}", e)))?
.build_responder()
.map_err(|e| CryptoError::Handshake(format!("failed to build responder: {}", e)))?;
Ok(Self {
state,
is_initiator: false,
})
}
#[inline]
pub fn is_finished(&self) -> bool {
self.state.is_handshake_finished()
}
#[inline]
#[allow(dead_code)]
pub fn is_initiator(&self) -> bool {
self.is_initiator
}
pub fn write_message(&mut self, payload: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut buf = vec![0u8; 65535];
let len = self
.state
.write_message(payload, &mut buf)
.map_err(|e| CryptoError::Handshake(format!("write_message failed: {}", e)))?;
buf.truncate(len);
Ok(buf)
}
pub fn read_message(&mut self, message: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut buf = vec![0u8; 65535];
let len = self
.state
.read_message(message, &mut buf)
.map_err(|e| CryptoError::Handshake(format!("read_message failed: {}", e)))?;
buf.truncate(len);
Ok(buf)
}
pub fn into_session_keys(self) -> Result<SessionKeys, CryptoError> {
if !self.is_finished() {
return Err(CryptoError::Handshake("handshake not finished".to_string()));
}
let is_initiator = self.is_initiator;
let handshake_hash: [u8; 32] = {
let hash_slice = self.state.get_handshake_hash();
let mut arr = [0u8; 32];
let len = hash_slice.len().min(32);
arr[..len].copy_from_slice(&hash_slice[..len]);
arr
};
let mut remote_static_pub = [0u8; 32];
if let Some(rs) = self.state.get_remote_static() {
let len = rs.len().min(32);
remote_static_pub[..len].copy_from_slice(&rs[..len]);
}
let _transport = self
.state
.into_transport_mode()
.map_err(|e| CryptoError::Handshake(format!("transport mode failed: {}", e)))?;
#[expect(
clippy::unwrap_used,
reason = "handshake_hash typed as [u8; 32] above; [0..8].try_into::<[u8; 8]>() is infallible"
)]
let session_id = u64::from_le_bytes(handshake_hash[0..8].try_into().unwrap());
let mut tx_key = [0u8; 32];
let mut rx_key = [0u8; 32];
if is_initiator {
derive_key(&handshake_hash, b"initiator-tx", &mut tx_key);
derive_key(&handshake_hash, b"initiator-rx", &mut rx_key);
} else {
derive_key(&handshake_hash, b"initiator-rx", &mut tx_key);
derive_key(&handshake_hash, b"initiator-tx", &mut rx_key);
}
Ok(SessionKeys {
tx_key,
rx_key,
session_id,
remote_static_pub,
})
}
}
#[expect(
clippy::expect_used,
reason = "UnboundKey::new fails only on key-length mismatch; the [u8; 32] parameter makes that unrepresentable for CHACHA20_POLY1305"
)]
fn packet_key(key: &[u8; 32]) -> Box<LessSafeKey> {
Box::new(LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, key).expect("32-byte ChaCha20-Poly1305 key"),
))
}
pub struct PacketCipher {
cipher: Box<LessSafeKey>,
nonce_template: [u8; NONCE_SIZE],
tx_counter: Arc<AtomicU64>,
rx_window: Mutex<ReplayWindow>,
}
#[derive(Debug)]
struct ReplayWindow {
rx_counter: u64,
bitmap: [u64; Self::BITMAP_WORDS],
}
impl ReplayWindow {
const WINDOW_SIZE: u64 = 1024;
const MAX_FORWARD: u64 = Self::WINDOW_SIZE;
const BITMAP_WORDS: usize = 16;
const fn new() -> Self {
Self {
rx_counter: 0,
bitmap: [0; Self::BITMAP_WORDS],
}
}
fn is_valid(&self, received: u64) -> bool {
if received == u64::MAX {
return false;
}
if received >= self.rx_counter {
received.saturating_sub(self.rx_counter) <= Self::MAX_FORWARD
} else {
let age = self.rx_counter - 1 - received;
if age >= Self::WINDOW_SIZE {
return false;
}
let word = (age / 64) as usize;
let bit = age % 64;
self.bitmap[word] & (1u64 << bit) == 0
}
}
fn commit(&mut self, received: u64) -> bool {
if received == u64::MAX {
return false;
}
if self.rx_counter == u64::MAX {
return false;
}
if received >= self.rx_counter {
let shift = (received - self.rx_counter).saturating_add(1);
self.shift_bitmap_up(shift);
self.rx_counter = received.saturating_add(1);
self.bitmap[0] |= 1u64;
true
} else {
let age = self.rx_counter - 1 - received;
if age >= Self::WINDOW_SIZE {
return false;
}
let word = (age / 64) as usize;
let bit = age % 64;
let mask = 1u64 << bit;
let was_set = self.bitmap[word] & mask != 0;
self.bitmap[word] |= mask;
!was_set
}
}
fn shift_bitmap_up(&mut self, shift: u64) {
if shift == 0 {
return;
}
if shift >= (Self::BITMAP_WORDS as u64) * 64 {
tracing::warn!(
shift,
window_size = Self::WINDOW_SIZE,
max_forward = Self::MAX_FORWARD,
"anti-replay bitmap reset on large forward jump; \
prior {} counters lost replay tracking",
Self::WINDOW_SIZE,
);
self.bitmap = [0; Self::BITMAP_WORDS];
return;
}
let word_shift = (shift / 64) as usize;
let bit_shift = (shift % 64) as u32;
if bit_shift == 0 {
for i in (0..Self::BITMAP_WORDS).rev() {
self.bitmap[i] = if i >= word_shift {
self.bitmap[i - word_shift]
} else {
0
};
}
} else {
for i in (0..Self::BITMAP_WORDS).rev() {
let hi = if i >= word_shift {
self.bitmap[i - word_shift] << bit_shift
} else {
0
};
let lo = if i > word_shift {
self.bitmap[i - word_shift - 1] >> (64 - bit_shift)
} else {
0
};
self.bitmap[i] = hi | lo;
}
}
}
}
#[inline]
pub(crate) fn session_prefix_from_id(session_id: u64) -> [u8; 4] {
let lo = session_id as u32;
let hi = (session_id >> 32) as u32;
(lo ^ hi).to_le_bytes()
}
impl PacketCipher {
pub fn new(key: &[u8; 32], session_id: u64) -> Self {
let mut nonce_template = [0u8; NONCE_SIZE];
nonce_template[0..4].copy_from_slice(&session_prefix_from_id(session_id));
Self {
cipher: packet_key(key),
nonce_template,
tx_counter: Arc::new(AtomicU64::new(0)),
rx_window: Mutex::new(ReplayWindow::new()),
}
}
pub fn with_shared_tx_counter(
key: &[u8; 32],
session_id: u64,
tx_counter: Arc<AtomicU64>,
) -> Self {
let mut nonce_template = [0u8; NONCE_SIZE];
nonce_template[0..4].copy_from_slice(&session_prefix_from_id(session_id));
Self {
cipher: packet_key(key),
nonce_template,
tx_counter,
rx_window: Mutex::new(ReplayWindow::new()),
}
}
#[inline]
#[allow(dead_code)]
fn next_tx_nonce(&self) -> [u8; NONCE_SIZE] {
let counter = self.tx_counter.fetch_add(1, Ordering::Relaxed);
let mut nonce = self.nonce_template;
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
nonce
}
#[inline]
fn nonce_from_counter(&self, counter: u64) -> [u8; NONCE_SIZE] {
let mut nonce = self.nonce_template;
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
nonce
}
#[inline]
pub fn current_tx_counter(&self) -> u64 {
self.tx_counter.load(Ordering::Relaxed)
}
#[inline]
pub fn encrypt_in_place(&self, aad: &[u8], buffer: &mut BytesMut) -> Result<u64, CryptoError> {
let counter = self.tx_counter.fetch_add(1, Ordering::Relaxed);
let nonce = self.nonce_from_counter(counter);
let tag = self
.cipher
.seal_in_place_separate_tag(
Nonce::assume_unique_for_key(nonce),
Aad::from(aad),
buffer.as_mut(),
)
.map_err(|_| CryptoError::Encryption("encryption failed".to_string()))?;
buffer.extend_from_slice(tag.as_ref());
Ok(counter)
}
#[inline]
pub fn encrypt_in_place_detached(
&self,
aad: &[u8],
buffer: &mut [u8],
) -> Result<(u64, [u8; 16]), CryptoError> {
let counter = self.tx_counter.fetch_add(1, Ordering::Relaxed);
let nonce = self.nonce_from_counter(counter);
let tag = self
.cipher
.seal_in_place_separate_tag(Nonce::assume_unique_for_key(nonce), Aad::from(aad), buffer)
.map_err(|_| CryptoError::Encryption("encryption failed".to_string()))?;
let mut tag_bytes = [0u8; TAG_SIZE];
tag_bytes.copy_from_slice(tag.as_ref());
Ok((counter, tag_bytes))
}
#[inline]
pub fn encrypt(&self, aad: &[u8], plaintext: &[u8]) -> Result<(Vec<u8>, u64), CryptoError> {
let counter = self.tx_counter.fetch_add(1, Ordering::Relaxed);
let nonce = self.nonce_from_counter(counter);
let mut ciphertext = Vec::with_capacity(plaintext.len() + TAG_SIZE);
ciphertext.extend_from_slice(plaintext);
self.cipher
.seal_in_place_append_tag(
Nonce::assume_unique_for_key(nonce),
Aad::from(aad),
&mut ciphertext,
)
.map_err(|_| CryptoError::Encryption("encryption failed".to_string()))?;
Ok((ciphertext, counter))
}
#[inline]
pub fn decrypt(
&self,
nonce_counter: u64,
aad: &[u8],
ciphertext: &[u8],
) -> Result<Vec<u8>, CryptoError> {
let nonce = self.nonce_from_counter(nonce_counter);
let mut buf = ciphertext.to_vec();
let plaintext_len = self
.cipher
.open_in_place(
Nonce::assume_unique_for_key(nonce),
Aad::from(aad),
&mut buf,
)
.map_err(|_| CryptoError::Decryption("decryption failed".to_string()))?
.len();
buf.truncate(plaintext_len);
Ok(buf)
}
#[inline]
pub fn decrypt_in_place(
&self,
nonce_counter: u64,
aad: &[u8],
buffer: &mut [u8],
) -> Result<usize, CryptoError> {
if buffer.len() < TAG_SIZE {
return Err(CryptoError::Decryption("buffer too small".to_string()));
}
let nonce = self.nonce_from_counter(nonce_counter);
let plaintext_len = self
.cipher
.open_in_place(Nonce::assume_unique_for_key(nonce), Aad::from(aad), buffer)
.map_err(|_| CryptoError::Decryption("decryption failed".to_string()))?
.len();
Ok(plaintext_len)
}
#[inline]
pub fn decrypt_to_bytes(
&self,
nonce_counter: u64,
aad: &[u8],
ciphertext: Bytes,
) -> Result<Bytes, CryptoError> {
match ciphertext.try_into_mut() {
Ok(mut buf) => {
let plaintext_len = self.decrypt_in_place(nonce_counter, aad, &mut buf)?;
buf.truncate(plaintext_len);
Ok(buf.freeze())
}
Err(shared) => {
self.decrypt(nonce_counter, aad, &shared).map(Bytes::from)
}
}
}
#[inline]
pub fn verify(
&self,
nonce_counter: u64,
aad: &[u8],
ciphertext: &[u8],
) -> Result<(), CryptoError> {
if ciphertext.len() < TAG_SIZE {
return Err(CryptoError::Decryption("buffer too small".to_string()));
}
let mut buf = BytesMut::with_capacity(ciphertext.len());
buf.extend_from_slice(ciphertext);
self.decrypt_in_place(nonce_counter, aad, &mut buf)?;
Ok(())
}
#[inline]
pub fn update_rx_counter(&self, received: u64) -> bool {
let mut w = self.rx_window.lock();
w.commit(received)
}
#[inline]
pub fn try_admit_rx_counter(&self, received: u64) -> bool {
let mut w = self.rx_window.lock();
w.commit(received)
}
#[inline]
pub fn is_valid_rx_counter(&self, received: u64) -> bool {
let w = self.rx_window.lock();
w.is_valid(received)
}
}
impl std::fmt::Debug for PacketCipher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let rx_counter = self.rx_window.try_lock().map(|w| w.rx_counter).unwrap_or(0);
f.debug_struct("PacketCipher")
.field("algorithm", &"ChaCha20-Poly1305")
.field("tx_counter", &self.tx_counter.load(Ordering::Relaxed))
.field("rx_counter", &rx_counter)
.finish()
}
}
#[expect(
clippy::expect_used,
reason = "Blake2sMac::new_from_slice rejects only keys longer than 32 bytes; BLAKE2s output (32 bytes) and arbitrary IKM slices are both within the allowed length"
)]
fn derive_key(ikm: &[u8], info: &[u8], out: &mut [u8; 32]) {
use blake2::{
digest::{consts::U32, Mac},
Blake2sMac,
};
let mut extractor = <Blake2sMac<U32> as Mac>::new_from_slice(ikm)
.expect("BLAKE2s accepts variable-length keys");
Mac::update(&mut extractor, b"net-kdf-v1");
let prk = extractor.finalize().into_bytes();
let mut expander =
<Blake2sMac<U32> as Mac>::new_from_slice(&prk).expect("BLAKE2s accepts 32-byte key");
Mac::update(&mut expander, info);
let okm = expander.finalize().into_bytes();
out.copy_from_slice(&okm);
}
fn hex_string(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ring_seal_opens_with_rustcrypto() {
use chacha20poly1305::{
aead::{Aead, Payload},
ChaCha20Poly1305, KeyInit,
};
let key = [0x42u8; 32];
let session_id = 0xABCD_EF01_2345_6789u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = [0x24u8; 56];
let plaintext: &[u8] = b"wire-format compat across AEAD implementations";
let (ct, counter) = cipher.encrypt(&aad, plaintext).unwrap();
let mut nonce = [0u8; NONCE_SIZE];
nonce[0..4].copy_from_slice(&session_prefix_from_id(session_id));
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
let rustcrypto = ChaCha20Poly1305::new((&key).into());
let opened = rustcrypto
.decrypt(
(&nonce).into(),
Payload {
msg: &ct,
aad: &aad,
},
)
.expect("RustCrypto must open ring's seal byte-for-byte");
assert_eq!(opened, plaintext);
}
#[test]
fn rustcrypto_seal_opens_with_ring() {
use chacha20poly1305::{
aead::{Aead, Payload},
ChaCha20Poly1305, KeyInit,
};
let key = [0x42u8; 32];
let session_id = 0xABCD_EF01_2345_6789u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = [0x24u8; 56];
let plaintext: &[u8] = b"wire-format compat across AEAD implementations";
let counter = 7u64;
let mut nonce = [0u8; NONCE_SIZE];
nonce[0..4].copy_from_slice(&session_prefix_from_id(session_id));
nonce[4..12].copy_from_slice(&counter.to_le_bytes());
let rustcrypto = ChaCha20Poly1305::new((&key).into());
let ct = rustcrypto
.encrypt(
(&nonce).into(),
Payload {
msg: plaintext,
aad: &aad,
},
)
.unwrap();
let opened = cipher
.decrypt(counter, &aad, &ct)
.expect("ring must open RustCrypto's seal byte-for-byte");
assert_eq!(opened, plaintext);
}
#[test]
fn session_prefix_uses_high_bits_of_session_id() {
let a: u64 = 0x0000_0001_1234_5678;
let b: u64 = 0xFFFF_FFFF_1234_5678;
let pa = session_prefix_from_id(a);
let pb = session_prefix_from_id(b);
assert_ne!(
pa, pb,
"prefixes that only differ in high 32 bits of session_id must not collide"
);
}
#[test]
fn session_prefix_stable_for_same_id() {
let id = 0xDEAD_BEEF_CAFE_F00D_u64;
assert_eq!(session_prefix_from_id(id), session_prefix_from_id(id));
}
#[test]
fn test_static_keypair_generate() {
let keypair1 = StaticKeypair::generate();
let keypair2 = StaticKeypair::generate();
assert_ne!(keypair1.public, keypair2.public);
assert_ne!(keypair1.private, keypair2.private);
}
#[test]
fn test_noise_handshake() {
let psk = [0x42u8; 32];
let responder_keypair = StaticKeypair::generate();
let mut initiator = NoiseHandshake::initiator(&psk, &responder_keypair.public).unwrap();
let mut responder = NoiseHandshake::responder(&psk, &responder_keypair).unwrap();
let msg1 = initiator.write_message(b"").unwrap();
responder.read_message(&msg1).unwrap();
let msg2 = responder.write_message(b"").unwrap();
initiator.read_message(&msg2).unwrap();
assert!(initiator.is_finished());
assert!(responder.is_finished());
let init_keys = initiator.into_session_keys().unwrap();
let resp_keys = responder.into_session_keys().unwrap();
assert_eq!(init_keys.session_id, resp_keys.session_id);
assert_eq!(init_keys.tx_key, resp_keys.rx_key);
assert_eq!(init_keys.rx_key, resp_keys.tx_key);
}
#[test]
fn test_fast_cipher_roundtrip() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"additional data";
let plaintext = b"hello, world!";
let (ciphertext, counter) = cipher.encrypt(aad, plaintext).unwrap();
let rx_cipher = PacketCipher::new(&key, session_id);
let decrypted = rx_cipher.decrypt(counter, aad, &ciphertext).unwrap();
assert_eq!(&decrypted, plaintext);
}
#[test]
fn test_fast_cipher_in_place() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"additional data";
let plaintext = b"hello, world!";
let mut buffer = BytesMut::from(&plaintext[..]);
let counter = cipher.encrypt_in_place(aad, &mut buffer).unwrap();
assert_eq!(buffer.len(), plaintext.len() + TAG_SIZE);
let rx_cipher = PacketCipher::new(&key, session_id);
let len = rx_cipher
.decrypt_in_place(counter, aad, &mut buffer[..])
.unwrap();
assert_eq!(len, plaintext.len());
assert_eq!(&buffer[..len], plaintext);
}
#[test]
fn nonce_template_carries_session_prefix_with_zero_counter() {
let key = [0x55u8; 32];
let session_id = 0x1234_5678_9ABC_DEF0_u64;
let cipher = PacketCipher::new(&key, session_id);
let expected_prefix = session_prefix_from_id(session_id);
assert_eq!(
&cipher.nonce_template[0..4],
&expected_prefix,
"template's prefix bytes must match `session_prefix_from_id`",
);
assert_eq!(
&cipher.nonce_template[4..12],
&[0u8; 8],
"template's counter bytes start at zero — each per-packet \
nonce overwrites them",
);
let nonce = cipher.nonce_from_counter(0xCAFE_BABE_DEAD_BEEF);
assert_eq!(&nonce[0..4], &expected_prefix);
assert_eq!(&nonce[4..12], &0xCAFE_BABE_DEAD_BEEF_u64.to_le_bytes(),);
}
#[test]
fn decrypt_to_bytes_in_place_when_refcount_is_one() {
let key = [0x77u8; 32];
let session_id = 0xAABB_CCDD_EEFF_0011_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"aad";
let plaintext = b"per-packet alloc gone";
let mut tx_buf = BytesMut::from(&plaintext[..]);
let counter = cipher.encrypt_in_place(aad, &mut tx_buf).unwrap();
let inbound: Bytes = tx_buf.freeze();
let inbound_ptr = inbound.as_ptr();
let inbound_len = inbound.len();
let rx_cipher = PacketCipher::new(&key, session_id);
let plaintext_bytes = rx_cipher
.decrypt_to_bytes(counter, aad, inbound)
.expect("decrypt must succeed");
assert_eq!(plaintext_bytes.as_ref(), plaintext);
assert_eq!(
plaintext_bytes.len() + TAG_SIZE,
inbound_len,
"plaintext shrinks by TAG_SIZE",
);
assert_eq!(
plaintext_bytes.as_ptr(),
inbound_ptr,
"in-place decrypt: backing pointer must be unchanged",
);
}
#[test]
fn decrypt_to_bytes_falls_back_on_shared_buffer() {
let key = [0x77u8; 32];
let session_id = 0xAABB_CCDD_EEFF_0011_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"aad";
let plaintext = b"shared inbound";
let mut tx_buf = BytesMut::from(&plaintext[..]);
let counter = cipher.encrypt_in_place(aad, &mut tx_buf).unwrap();
let inbound = tx_buf.freeze();
let _other_holder = inbound.clone();
let rx_cipher = PacketCipher::new(&key, session_id);
let plaintext_bytes = rx_cipher
.decrypt_to_bytes(counter, aad, inbound)
.expect("decrypt must still succeed via the fallback path");
assert_eq!(plaintext_bytes.as_ref(), plaintext);
}
#[test]
fn verify_admits_valid_tag_and_rejects_tampered() {
let key = [0x77u8; 32];
let session_id = 0xAABB_CCDD_EEFF_0011_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"heartbeat-aad";
let plaintext: &[u8] = b"";
let mut tx_buf = BytesMut::from(plaintext);
let counter = cipher.encrypt_in_place(aad, &mut tx_buf).unwrap();
assert_eq!(tx_buf.len(), TAG_SIZE);
let inbound = tx_buf.freeze();
let rx_cipher = PacketCipher::new(&key, session_id);
rx_cipher
.verify(counter, aad, &inbound)
.expect("genuine heartbeat must verify");
let mut tampered = inbound.to_vec();
tampered[0] ^= 0xAA;
let err = rx_cipher
.verify(counter, aad, &tampered)
.expect_err("tampered heartbeat must fail verify");
assert!(matches!(err, CryptoError::Decryption(_)));
}
#[test]
fn test_fast_cipher_counter_increments() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"aad";
let plaintext = b"test";
let (_, counter1) = cipher.encrypt(aad, plaintext).unwrap();
let (_, counter2) = cipher.encrypt(aad, plaintext).unwrap();
let (_, counter3) = cipher.encrypt(aad, plaintext).unwrap();
assert_eq!(counter1, 0);
assert_eq!(counter2, 1);
assert_eq!(counter3, 2);
}
#[test]
fn test_fast_cipher_different_sessions() {
let key = [0x42u8; 32];
let cipher1 = PacketCipher::new(&key, 0x1111);
let cipher2 = PacketCipher::new(&key, 0x2222);
let aad = b"aad";
let plaintext = b"test";
let (ct1, c1) = cipher1.encrypt(aad, plaintext).unwrap();
let (ct2, c2) = cipher2.encrypt(aad, plaintext).unwrap();
assert_eq!(c1, c2); assert_ne!(ct1, ct2); }
#[test]
fn test_fast_cipher_tamper_detection() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"additional data";
let plaintext = b"hello, world!";
let (mut ciphertext, counter) = cipher.encrypt(aad, plaintext).unwrap();
ciphertext[0] ^= 0xFF;
let rx_cipher = PacketCipher::new(&key, session_id);
let result = rx_cipher.decrypt(counter, aad, &ciphertext);
assert!(result.is_err());
}
#[test]
fn test_fast_cipher_wrong_counter() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF_u64;
let cipher = PacketCipher::new(&key, session_id);
let aad = b"additional data";
let plaintext = b"hello, world!";
let (ciphertext, _counter) = cipher.encrypt(aad, plaintext).unwrap();
let rx_cipher = PacketCipher::new(&key, session_id);
let result = rx_cipher.decrypt(999, aad, &ciphertext);
assert!(result.is_err());
}
#[test]
fn test_fast_cipher_replay_protection() {
let key = [0x42u8; 32];
let session_id = 0x1234567890ABCDEF_u64;
let cipher = PacketCipher::new(&key, session_id);
assert!(cipher.is_valid_rx_counter(0));
cipher.update_rx_counter(100);
assert!(cipher.is_valid_rx_counter(101));
assert!(cipher.is_valid_rx_counter(200));
assert!(cipher.is_valid_rx_counter(50));
cipher.update_rx_counter(2000);
assert!(!cipher.is_valid_rx_counter(0));
assert!(
!cipher.is_valid_rx_counter(u64::MAX),
"counter far beyond MAX_FORWARD should be rejected"
);
assert!(
cipher.is_valid_rx_counter(3025),
"counter at MAX_FORWARD boundary should be accepted"
);
assert!(
!cipher.is_valid_rx_counter(3026),
"counter just past MAX_FORWARD should be rejected"
);
}
#[test]
fn replay_window_rejects_jump_beyond_window_size() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0xCAFEu64);
for c in 0..100u64 {
cipher.update_rx_counter(c);
}
assert!(
cipher.is_valid_rx_counter(1124),
"counter at WINDOW_SIZE boundary must still be accepted"
);
assert!(
!cipher.is_valid_rx_counter(1125),
"counter past WINDOW_SIZE must be rejected — accepting it \
would zero the bitmap and re-open the prior {} counters \
to replay",
1024,
);
cipher.update_rx_counter(1124);
assert!(
!cipher.is_valid_rx_counter(1124),
"just-committed counter must remain non-replayable"
);
assert!(
!cipher.is_valid_rx_counter(99),
"counter from before the jump must reject as too-old"
);
}
#[test]
fn replay_window_ceiling_counter_does_not_poison_receive_path() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0xC0FFEEu64);
{
let mut w = cipher.rx_window.lock();
w.rx_counter = u64::MAX - ReplayWindow::MAX_FORWARD;
}
assert!(
!cipher.is_valid_rx_counter(u64::MAX),
"u64::MAX must be rejected by is_valid even when in MAX_FORWARD range"
);
assert!(
!cipher.update_rx_counter(u64::MAX),
"commit on u64::MAX must reject — accepting it saturates rx_counter and poisons the receive path"
);
let post = cipher.rx_window.lock().rx_counter;
assert_eq!(
post,
u64::MAX - ReplayWindow::MAX_FORWARD,
"rx_counter must not have been mutated by the rejected u64::MAX commit"
);
let safe = u64::MAX - 1;
assert!(
cipher.is_valid_rx_counter(safe),
"u64::MAX - 1 must still be acceptable when in MAX_FORWARD range"
);
assert!(
cipher.update_rx_counter(safe),
"u64::MAX - 1 must still commit when in MAX_FORWARD range"
);
}
#[test]
fn test_fast_cipher_session_keys_integration() {
let psk = [0x42u8; 32];
let responder_keypair = StaticKeypair::generate();
let mut initiator = NoiseHandshake::initiator(&psk, &responder_keypair.public).unwrap();
let mut responder = NoiseHandshake::responder(&psk, &responder_keypair).unwrap();
let msg1 = initiator.write_message(b"").unwrap();
responder.read_message(&msg1).unwrap();
let msg2 = responder.write_message(b"").unwrap();
initiator.read_message(&msg2).unwrap();
let init_keys = initiator.into_session_keys().unwrap();
let resp_keys = responder.into_session_keys().unwrap();
let init_cipher = PacketCipher::new(&init_keys.tx_key, init_keys.session_id);
let resp_cipher = PacketCipher::new(&resp_keys.rx_key, resp_keys.session_id);
let aad = b"test aad";
let plaintext = b"secret message via fast cipher";
let (ciphertext, counter) = init_cipher.encrypt(aad, plaintext).unwrap();
let decrypted = resp_cipher.decrypt(counter, aad, &ciphertext).unwrap();
assert_eq!(&decrypted, plaintext);
}
#[test]
fn test_fast_cipher_not_clone() {
fn _assert_not_clone<T>() {
}
_assert_not_clone::<PacketCipher>();
let key = [0x42u8; 32];
let cipher1 = PacketCipher::new(&key, 0x1111);
let cipher2 = PacketCipher::new(&key, 0x1111);
let aad = b"test";
let (ct1, c1) = cipher1.encrypt(aad, b"hello").unwrap();
let (ct2, c2) = cipher2.encrypt(aad, b"hello").unwrap();
assert_eq!(c1, c2, "both start at counter 0");
assert_eq!(ct1, ct2, "same nonce produces same ciphertext — Clone removal prevents this from happening accidentally");
}
#[test]
fn test_derive_key_uses_cryptographic_prf() {
let ikm = [0xABu8; 32];
let mut key1 = [0u8; 32];
let mut key2 = [0u8; 32];
derive_key(&ikm, b"label-a", &mut key1);
derive_key(&ikm, b"label-b", &mut key2);
assert_ne!(key1, key2);
let mut key1_again = [0u8; 32];
derive_key(&ikm, b"label-a", &mut key1_again);
assert_eq!(key1, key1_again);
assert_ne!(key1, [0u8; 32]);
assert_ne!(
key1[..8],
key1[8..16],
"output should not be trivially repeating"
);
}
#[test]
fn test_regression_rx_counter_u64_max_no_wrap() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0x1234);
assert!(cipher.update_rx_counter(1000));
assert!(
!cipher.update_rx_counter(u64::MAX),
"u64::MAX must be rejected at commit; pre-fix it was \
accepted-then-saturated, poisoning the receive path"
);
let counter = cipher.rx_window.lock().rx_counter;
assert_eq!(
counter, 1001,
"rx_counter must remain at the post-1000-commit value; \
a rejected u64::MAX commit must not mutate state"
);
}
#[test]
fn test_replay_bitmap_rejects_duplicate_counter() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0x1234);
assert!(cipher.is_valid_rx_counter(100));
assert!(cipher.update_rx_counter(100));
assert!(
!cipher.is_valid_rx_counter(100),
"replayed counter must fail the validity check"
);
assert!(
!cipher.update_rx_counter(100),
"replayed counter must fail the commit-time check too, \
closing the TOCTOU race between check and commit"
);
assert!(cipher.is_valid_rx_counter(50));
assert!(cipher.update_rx_counter(50));
assert!(!cipher.is_valid_rx_counter(50));
assert!(!cipher.update_rx_counter(50));
assert!(cipher.update_rx_counter(10_000));
assert!(
!cipher.is_valid_rx_counter(100),
"counter that has slid out of the window is no longer valid"
);
}
#[test]
fn test_replay_window_tracks_bits_across_slide() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0x1234);
assert!(cipher.update_rx_counter(10));
assert!(cipher.update_rx_counter(20));
assert!(cipher.update_rx_counter(30));
assert!(cipher.update_rx_counter(530));
for c in [10u64, 20, 30, 530] {
assert!(
!cipher.is_valid_rx_counter(c),
"counter {c} should remain marked as seen after window slide"
);
assert!(
!cipher.update_rx_counter(c),
"commit of already-seen counter {c} must return false"
);
}
assert!(cipher.is_valid_rx_counter(25));
assert!(cipher.update_rx_counter(25));
}
#[test]
fn test_replay_commit_rejects_out_of_window_counter() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0x1234);
assert!(cipher.update_rx_counter(5_000));
assert!(
!cipher.is_valid_rx_counter(100),
"out-of-window counter must fail validity"
);
assert!(
!cipher.update_rx_counter(100),
"out-of-window counter must also fail at commit time"
);
}
#[test]
fn test_regression_replay_rejected_at_u64_max_boundary() {
let key = [0x42u8; 32];
let cipher = PacketCipher::new(&key, 0x1234);
assert!(
!cipher.update_rx_counter(u64::MAX),
"u64::MAX must be rejected at the gate — accepting it \
saturates rx_counter and dead-ends the receive path"
);
assert!(
!cipher.update_rx_counter(u64::MAX),
"second commit of u64::MAX must also be rejected"
);
assert!(
!cipher.is_valid_rx_counter(u64::MAX - 1),
"u64::MAX - 1 from rx_counter=0 must reject at is_valid (past MAX_FORWARD)"
);
}
#[test]
fn try_admit_rx_counter_matches_update_rx_counter_semantics() {
let key = [0x9Au8; 32];
let scenarios: &[(&[u64], u64, bool)] = &[
(&[], 100, true),
(&[100], 100, false),
(&[100], 200, true),
(&[100], 50, true),
(&[100], u64::MAX, false),
(&[10_000], 100, false),
(&[10_000, 9_999], 9_999, false),
];
for (i, (primer, probe, expected)) in scenarios.iter().enumerate() {
let admit_cipher = PacketCipher::new(&key, 0x4242);
let update_cipher = PacketCipher::new(&key, 0x4242);
for &p in *primer {
assert!(
admit_cipher.try_admit_rx_counter(p),
"scenario {i}: priming admit({p}) must succeed"
);
assert!(
update_cipher.update_rx_counter(p),
"scenario {i}: priming update({p}) must succeed"
);
}
assert_eq!(
admit_cipher.try_admit_rx_counter(*probe),
*expected,
"scenario {i}: try_admit_rx_counter({probe}) verdict differs from spec",
);
assert_eq!(
update_cipher.update_rx_counter(*probe),
*expected,
"scenario {i}: update_rx_counter({probe}) verdict differs from spec",
);
let a = PacketCipher::new(&key, 0x4242);
let b = PacketCipher::new(&key, 0x4242);
for &p in *primer {
a.update_rx_counter(p);
b.update_rx_counter(p);
}
assert_eq!(
a.try_admit_rx_counter(*probe),
b.update_rx_counter(*probe),
"scenario {i}: try_admit_rx_counter and update_rx_counter must agree",
);
}
}
fn hex_of(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
fn assert_no_leak(dbg_output: &str, secret: &[u8], label: &str) {
let hex = hex_of(secret).to_lowercase();
assert!(
!dbg_output.to_lowercase().contains(&hex),
"{label} bytes leaked into Debug output: {dbg_output}",
);
}
#[test]
fn session_keys_debug_redacts_tx_and_rx_keys() {
let tx_secret = [0xAB; 32];
let rx_secret = [0xCD; 32];
let keys = SessionKeys {
tx_key: tx_secret,
rx_key: rx_secret,
session_id: 0x1234_5678_DEAD_BEEF,
remote_static_pub: [0x11; 32],
};
let s = format!("{:?}", keys);
assert!(
s.contains("[REDACTED]"),
"SessionKeys Debug must include [REDACTED]; got: {s}",
);
assert_no_leak(&s, &tx_secret, "tx_key");
assert_no_leak(&s, &rx_secret, "rx_key");
assert!(s.contains("session_id"));
}
#[test]
fn static_keypair_debug_redacts_private_key() {
let private = [0x77; 32];
let public = [0x22; 32];
let kp = StaticKeypair { private, public };
let s = format!("{:?}", kp);
assert!(
s.contains("[REDACTED]"),
"StaticKeypair Debug must include [REDACTED]; got: {s}",
);
assert_no_leak(&s, &private, "private key");
assert!(
s.to_lowercase().contains(&hex_of(&public).to_lowercase()),
"public key should be visible in Debug; got: {s}",
);
}
#[test]
fn crypto_error_display_covers_every_variant() {
assert_eq!(
format!("{}", CryptoError::Handshake("bad msg1".into())),
"handshake error: bad msg1"
);
assert_eq!(
format!("{}", CryptoError::Encryption("tag mismatch".into())),
"encryption error: tag mismatch"
);
assert_eq!(
format!("{}", CryptoError::Decryption("auth fail".into())),
"decryption error: auth fail"
);
assert_eq!(
format!("{}", CryptoError::InvalidKey("zero key".into())),
"invalid key: zero key"
);
assert_eq!(format!("{}", CryptoError::InvalidNonce), "invalid nonce");
}
}