use crate::{
noise::errors::WireGuardError,
noise::index_table::Index,
packet::{Packet, WgData, WgDataHeader, WgKind},
};
use bytes::{Buf, BytesMut};
use parking_lot::Mutex;
use ring::aead::{Aad, CHACHA20_POLY1305, LessSafeKey, Nonce, UnboundKey};
use std::sync::atomic::{AtomicUsize, Ordering};
use zerocopy::FromBytes;
pub struct Session {
pub(crate) receiving_index: Index,
sending_index: u32,
receiver: LessSafeKey,
sender: LessSafeKey,
sending_key_counter: AtomicUsize,
receiving_key_counter: Mutex<ReceivingKeyCounterValidator>,
}
impl std::fmt::Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"Session: {}<- ->{}",
self.receiving_index, self.sending_index
)
}
}
const WORD_SIZE: u64 = 64;
const N_WORDS: u64 = 16; const N_BITS: u64 = WORD_SIZE * N_WORDS;
#[derive(Debug, Clone, Default)]
struct ReceivingKeyCounterValidator {
next: u64,
receive_cnt: u64,
bitmap: [u64; N_WORDS as usize],
}
impl ReceivingKeyCounterValidator {
#[inline(always)]
fn set_bit(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
self.bitmap[word] |= 1 << bit;
}
#[inline(always)]
fn clear_bit(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
self.bitmap[word] &= !(1u64 << bit);
}
#[inline(always)]
fn clear_word(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
self.bitmap[word] = 0;
}
#[inline(always)]
fn check_bit(&self, idx: u64) -> bool {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
((self.bitmap[word] >> bit) & 1) == 1
}
#[inline(always)]
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
if counter >= self.next {
return Ok(());
}
if counter + N_BITS < self.next {
return Err(WireGuardError::InvalidCounter);
}
if self.check_bit(counter) {
Err(WireGuardError::DuplicateCounter)
} else {
Ok(())
}
}
#[inline(always)]
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
if counter + N_BITS < self.next {
return Err(WireGuardError::InvalidCounter);
}
if counter == self.next {
self.set_bit(counter);
self.next += 1;
return Ok(());
}
if counter < self.next {
if self.check_bit(counter) {
return Err(WireGuardError::InvalidCounter);
}
self.set_bit(counter);
return Ok(());
}
if counter - self.next >= N_BITS {
for c in self.bitmap.iter_mut() {
*c = 0;
}
} else {
let mut i = self.next;
while !i.is_multiple_of(WORD_SIZE) && i < counter {
self.clear_bit(i);
i += 1;
}
while i + WORD_SIZE < counter {
self.clear_word(i);
i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE);
}
while i < counter {
self.clear_bit(i);
i += 1;
}
}
self.set_bit(counter);
self.next = counter + 1;
Ok(())
}
}
impl Session {
pub(super) fn new(
local_index: Index,
sending_index: u32,
receiving_key: [u8; 32],
sending_key: [u8; 32],
) -> Session {
Session {
receiving_index: local_index,
sending_index,
receiver: LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(),
),
sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()),
sending_key_counter: AtomicUsize::new(0),
receiving_key_counter: Mutex::new(Default::default()),
}
}
fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> {
let counter_validator = self.receiving_key_counter.lock();
counter_validator.will_accept(counter)
}
fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> {
let mut counter_validator = self.receiving_key_counter.lock();
let ret = counter_validator.mark_did_receive(counter);
if ret.is_ok() {
counter_validator.receive_cnt += 1;
}
ret
}
pub(super) fn format_packet_data(&self, packet: Packet) -> Packet<WgData> {
let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64;
let len = WgData::OVERHEAD + packet.len();
let mut buf = Packet::from_bytes(BytesMut::zeroed(len));
let data = WgData::mut_from_bytes(buf.buf_mut())
.expect("buffer size is at least WgData::OVERHEAD");
data.header = WgDataHeader::new()
.with_receiver_idx(self.sending_index)
.with_counter(sending_key_counter);
debug_assert_eq!(packet.len(), data.encrypted_encapsulated_packet_mut().len());
data.encrypted_encapsulated_packet_mut()
.copy_from_slice(&packet);
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
let tag = self
.sender
.seal_in_place_separate_tag(
Nonce::assume_unique_for_key(nonce),
Aad::from(&[]),
data.encrypted_encapsulated_packet_mut(),
)
.expect("encryption must succeed");
data.tag_mut().copy_from_slice(tag.as_ref());
let packet = buf.try_into_wg().expect("is a wireguard packet");
let WgKind::Data(packet) = packet else {
unreachable!("is a wireguard data packet");
};
packet
}
pub(super) fn receive_packet_data(
&self,
mut packet: Packet<WgData>,
) -> Result<Packet, WireGuardError> {
if packet.header.receiver_idx.get() != self.receiving_index.value() {
return Err(WireGuardError::WrongIndex);
}
let counter = packet.header.counter.get();
self.receiving_counter_quick_check(counter)?;
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&packet.header.counter.to_bytes());
let decrypted_len = self
.receiver
.open_in_place(
Nonce::assume_unique_for_key(nonce),
Aad::from(&[]),
&mut packet.encrypted_encapsulated_packet_and_tag,
)
.map_err(|_| WireGuardError::InvalidAeadTag)?
.len();
let mut packet = packet.into_bytes();
let buf = packet.buf_mut();
buf.advance(WgDataHeader::LEN);
buf.truncate(decrypted_len);
self.receiving_counter_mark(counter)?;
Ok(packet)
}
pub(super) fn current_packet_cnt(&self) -> (u64, u64) {
let counter_validator = self.receiving_key_counter.lock();
(counter_validator.next, counter_validator.receive_cnt)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_counter() {
let mut c: ReceivingKeyCounterValidator = Default::default();
assert!(c.mark_did_receive(0).is_ok());
assert!(c.mark_did_receive(0).is_err());
assert!(c.mark_did_receive(1).is_ok());
assert!(c.mark_did_receive(1).is_err());
assert!(c.mark_did_receive(63).is_ok());
assert!(c.mark_did_receive(63).is_err());
assert!(c.mark_did_receive(15).is_ok());
assert!(c.mark_did_receive(15).is_err());
for i in 64..N_BITS + 128 {
assert!(c.mark_did_receive(i).is_ok());
assert!(c.mark_did_receive(i).is_err());
}
assert!(c.mark_did_receive(N_BITS * 3).is_ok());
for i in 0..=N_BITS * 2 {
assert!(matches!(
c.will_accept(i),
Err(WireGuardError::InvalidCounter)
));
assert!(c.mark_did_receive(i).is_err());
}
for i in N_BITS * 2 + 1..N_BITS * 3 {
assert!(c.will_accept(i).is_ok());
}
assert!(matches!(
c.will_accept(N_BITS * 3),
Err(WireGuardError::DuplicateCounter)
));
for i in (N_BITS * 2 + 1..N_BITS * 3).rev() {
assert!(c.mark_did_receive(i).is_ok());
assert!(c.mark_did_receive(i).is_err());
}
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err());
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err());
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err());
}
}