use crate::error::{Error, Result};
use crate::random::OsRng;
use lru::LruCache;
use parking_lot::Mutex;
use rand::RngCore;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use zeroize::Zeroize;
#[derive(Clone, Zeroize)]
pub struct Nonce<const N: usize> {
bytes: [u8; N],
}
impl<const N: usize> Nonce<N> {
pub fn new(bytes: [u8; N]) -> Self {
Self { bytes }
}
pub fn from_slice(slice: &[u8]) -> Result<Self> {
if slice.len() != N {
return Err(Error::InvalidNonceLength {
expected: N,
actual: slice.len(),
});
}
let mut bytes = [0u8; N];
bytes.copy_from_slice(slice);
Ok(Self { bytes })
}
pub fn random() -> Self {
let mut bytes = [0u8; N];
OsRng.fill_bytes(&mut bytes);
Self { bytes }
}
pub fn zero() -> Self {
Self { bytes: [0u8; N] }
}
pub fn as_bytes(&self) -> &[u8; N] {
&self.bytes
}
pub fn as_slice(&self) -> &[u8] {
&self.bytes
}
pub const fn len() -> usize {
N
}
pub fn increment(&mut self) -> Result<()> {
for byte in self.bytes.iter_mut().rev() {
if *byte == 255 {
*byte = 0;
} else {
*byte += 1;
return Ok(());
}
}
Err(Error::NonceExhausted)
}
pub fn from_counter(counter: u64) -> Self {
assert!(N >= 8, "Nonce must be at least 8 bytes to use from_counter");
let mut bytes = [0u8; N];
bytes[N - 8..].copy_from_slice(&counter.to_be_bytes());
Self { bytes }
}
}
impl<const N: usize> AsRef<[u8]> for Nonce<N> {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl<const N: usize> std::fmt::Debug for Nonce<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Nonce<{}>({})", N, hex::encode(self.bytes))
}
}
impl<const N: usize> std::fmt::Display for Nonce<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", hex::encode(self.bytes))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NonceStrategy {
Random,
Counter,
Hybrid,
}
pub struct NonceGenerator<const N: usize> {
strategy: NonceStrategy,
counter: AtomicU64,
random_prefix: [u8; 4],
generated_count: AtomicU64,
max_nonces: Option<u64>,
}
impl<const N: usize> NonceGenerator<N> {
pub fn random() -> Self {
Self {
strategy: NonceStrategy::Random,
counter: AtomicU64::new(0),
random_prefix: [0; 4],
generated_count: AtomicU64::new(0),
max_nonces: None,
}
}
pub fn counter(start: u64) -> Self {
Self {
strategy: NonceStrategy::Counter,
counter: AtomicU64::new(start),
random_prefix: [0; 4],
generated_count: AtomicU64::new(0),
max_nonces: None,
}
}
pub fn hybrid() -> Self {
let mut prefix = [0u8; 4];
OsRng.fill_bytes(&mut prefix);
Self {
strategy: NonceStrategy::Hybrid,
counter: AtomicU64::new(0),
random_prefix: prefix,
generated_count: AtomicU64::new(0),
max_nonces: None,
}
}
pub fn with_limit(mut self, max: u64) -> Self {
self.max_nonces = Some(max);
self
}
pub fn generate(&self) -> Result<Nonce<N>> {
let count = self.generated_count.fetch_add(1, Ordering::SeqCst);
if let Some(max) = self.max_nonces
&& count >= max
{
self.generated_count.fetch_sub(1, Ordering::SeqCst);
return Err(Error::NonceExhausted);
}
match self.strategy {
NonceStrategy::Random => Ok(Nonce::random()),
NonceStrategy::Counter => {
let counter = self.counter.fetch_add(1, Ordering::SeqCst);
if counter == u64::MAX {
return Err(Error::NonceExhausted);
}
let mut bytes = [0u8; N];
let counter_bytes = counter.to_be_bytes();
let start = N.saturating_sub(8);
bytes[start..].copy_from_slice(&counter_bytes[8 - (N - start)..]);
Ok(Nonce::new(bytes))
}
NonceStrategy::Hybrid => {
let counter = self.counter.fetch_add(1, Ordering::SeqCst);
if counter == u64::MAX {
return Err(Error::NonceExhausted);
}
let mut bytes = [0u8; N];
let prefix_len = 4.min(N);
bytes[..prefix_len].copy_from_slice(&self.random_prefix[..prefix_len]);
if N > 4 {
let counter_bytes = counter.to_be_bytes();
let counter_space = N - 4;
let counter_start = 8usize.saturating_sub(counter_space);
bytes[4..].copy_from_slice(&counter_bytes[counter_start..]);
}
Ok(Nonce::new(bytes))
}
}
}
pub fn count(&self) -> u64 {
self.generated_count.load(Ordering::SeqCst)
}
pub fn current_counter(&self) -> u64 {
self.counter.load(Ordering::SeqCst)
}
pub fn reset_dangerous_nonce_reuse_possible(&self) {
self.counter.store(0, Ordering::SeqCst);
self.generated_count.store(0, Ordering::SeqCst);
}
}
impl<const N: usize> Default for NonceGenerator<N> {
fn default() -> Self {
Self::random()
}
}
pub struct NonceTracker<const N: usize> {
cache: Mutex<LruCache<[u8; N], ()>>,
max_entries: NonZeroUsize,
eviction_count: AtomicU64,
}
impl<const N: usize> NonceTracker<N> {
pub fn new(max_entries: usize) -> Self {
let max_entries =
NonZeroUsize::new(max_entries).expect("NonceTracker capacity must be greater than 0");
Self {
cache: Mutex::new(LruCache::new(max_entries)),
max_entries,
eviction_count: AtomicU64::new(0),
}
}
#[must_use = "nonce reuse check must be verified - reuse is catastrophic"]
pub fn check(&self, nonce: &Nonce<N>) -> Result<()> {
let mut cache = self.cache.lock();
if cache.peek(nonce.as_bytes()).is_some() {
return Err(Error::NonceReuse);
}
if cache.len() >= self.max_entries.get() {
self.eviction_count.fetch_add(1, Ordering::Relaxed);
}
cache.put(*nonce.as_bytes(), ());
Ok(())
}
#[must_use = "nonce reuse check must be verified - reuse is catastrophic"]
pub fn check_and_touch(&self, nonce: &Nonce<N>) -> Result<()> {
let mut cache = self.cache.lock();
if cache.get(nonce.as_bytes()).is_some() {
return Err(Error::NonceReuse);
}
if cache.len() >= self.max_entries.get() {
self.eviction_count.fetch_add(1, Ordering::Relaxed);
}
cache.put(*nonce.as_bytes(), ());
Ok(())
}
pub fn clear(&self) {
self.cache.lock().clear();
}
pub fn len(&self) -> usize {
self.cache.lock().len()
}
pub fn is_empty(&self) -> bool {
self.cache.lock().is_empty()
}
pub fn capacity(&self) -> usize {
self.max_entries.get()
}
pub fn eviction_count(&self) -> u64 {
self.eviction_count.load(Ordering::Relaxed)
}
pub fn reset_eviction_count(&self) {
self.eviction_count.store(0, Ordering::Relaxed);
}
}
pub type Nonce96 = Nonce<12>;
pub type Nonce192 = Nonce<24>;
pub type Nonce128 = Nonce<16>;
pub type Nonce64 = Nonce<8>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nonce_random() {
let n1 = Nonce96::random();
let n2 = Nonce96::random();
assert_ne!(n1.as_bytes(), n2.as_bytes());
}
#[test]
fn test_nonce_increment() {
let mut nonce = Nonce96::zero();
for i in 1..=256 {
nonce.increment().unwrap();
assert_eq!(nonce.as_bytes()[11], (i & 0xFF) as u8);
}
}
#[test]
fn test_nonce_generator_counter() {
let generator = NonceGenerator::<12>::counter(0);
let n1 = generator.generate().unwrap();
let n2 = generator.generate().unwrap();
assert_ne!(n1.as_bytes(), n2.as_bytes());
assert_eq!(generator.count(), 2);
}
#[test]
fn test_nonce_generator_limit() {
let generator = NonceGenerator::<12>::counter(0).with_limit(2);
assert!(generator.generate().is_ok());
assert!(generator.generate().is_ok());
assert!(generator.generate().is_err()); }
#[test]
fn test_nonce_tracker() {
let tracker = NonceTracker::<12>::new(100);
let nonce = Nonce96::random();
assert!(tracker.check(&nonce).is_ok());
assert!(tracker.check(&nonce).is_err());
let nonce2 = Nonce96::random();
assert!(tracker.check(&nonce2).is_ok());
}
#[test]
fn test_nonce_tracker_lru_eviction() {
let tracker = NonceTracker::<12>::new(3);
let n1 = Nonce96::random();
let n2 = Nonce96::random();
let n3 = Nonce96::random();
let n4 = Nonce96::random();
assert!(tracker.check(&n1).is_ok());
assert!(tracker.check(&n2).is_ok());
assert!(tracker.check(&n3).is_ok());
assert_eq!(tracker.len(), 3);
assert_eq!(tracker.eviction_count(), 0);
assert!(tracker.check(&n4).is_ok());
assert_eq!(tracker.len(), 3);
assert_eq!(tracker.eviction_count(), 1);
assert!(tracker.check(&n1).is_ok());
assert_eq!(tracker.eviction_count(), 2);
assert!(tracker.check(&n3).is_err());
assert!(tracker.check(&n4).is_err());
}
#[test]
fn test_nonce_tracker_capacity() {
let tracker = NonceTracker::<12>::new(50);
assert_eq!(tracker.capacity(), 50);
assert!(tracker.is_empty());
}
}