use super::PacketData;
use crate::noise::errors::WireGuardError;
use parking_lot::Mutex;
use portable_atomic::{AtomicU64, Ordering};
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
pub struct Session {
pub(crate) receiving_index: u32,
sending_index: u32,
receiver: LessSafeKey,
sender: LessSafeKey,
sending_key_counter: AtomicU64,
receiving_key_counter: Mutex<ReceivingKeyCounterValidator>,
}
impl std::fmt::Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"Session: {}<- ->{}",
self.receiving_index, self.sending_index
)
}
}
const DATA_OFFSET: usize = 16;
const AEAD_SIZE: usize = 16;
const WORD_SIZE: u64 = 64;
const N_WORDS: u64 = 16; const N_BITS: u64 = WORD_SIZE * N_WORDS;
#[derive(Debug, Clone, Default)]
struct ReceivingKeyCounterValidator {
next: u64,
receive_cnt: u64,
bitmap: [u64; N_WORDS as usize],
}
impl ReceivingKeyCounterValidator {
#[inline(always)]
fn set_bit(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
self.bitmap[word] |= 1 << bit;
}
#[inline(always)]
fn clear_bit(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
self.bitmap[word] &= !(1u64 << bit);
}
#[inline(always)]
fn clear_word(&mut self, idx: u64) {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
self.bitmap[word] = 0;
}
#[inline(always)]
fn check_bit(&self, idx: u64) -> bool {
let bit_idx = idx % N_BITS;
let word = (bit_idx / WORD_SIZE) as usize;
let bit = (bit_idx % WORD_SIZE) as usize;
((self.bitmap[word] >> bit) & 1) == 1
}
#[inline(always)]
fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> {
if counter >= self.next {
return Ok(());
}
if counter + N_BITS < self.next {
return Err(WireGuardError::InvalidCounter);
}
if !self.check_bit(counter) {
Ok(())
} else {
Err(WireGuardError::DuplicateCounter)
}
}
#[inline(always)]
fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> {
if counter + N_BITS < self.next {
return Err(WireGuardError::InvalidCounter);
}
if counter == self.next {
self.set_bit(counter);
self.next += 1;
return Ok(());
}
if counter < self.next {
if self.check_bit(counter) {
return Err(WireGuardError::InvalidCounter);
}
self.set_bit(counter);
return Ok(());
}
if counter - self.next >= N_BITS {
for c in self.bitmap.iter_mut() {
*c = 0;
}
} else {
let mut i = self.next;
while i % WORD_SIZE != 0 && i < counter {
self.clear_bit(i);
i += 1;
}
while i + WORD_SIZE < counter {
self.clear_word(i);
i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE);
}
while i < counter {
self.clear_bit(i);
i += 1;
}
}
self.set_bit(counter);
self.next = counter + 1;
Ok(())
}
}
impl Session {
pub(super) fn new(
local_index: u32,
peer_index: u32,
receiving_key: [u8; 32],
sending_key: [u8; 32],
) -> Session {
Session {
receiving_index: local_index,
sending_index: peer_index,
receiver: LessSafeKey::new(
UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(),
),
sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()),
sending_key_counter: AtomicU64::new(0),
receiving_key_counter: Mutex::new(Default::default()),
}
}
pub(super) fn local_index(&self) -> usize {
self.receiving_index as usize
}
fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> {
let counter_validator = self.receiving_key_counter.lock();
counter_validator.will_accept(counter)
}
fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> {
let mut counter_validator = self.receiving_key_counter.lock();
let ret = counter_validator.mark_did_receive(counter);
if ret.is_ok() {
counter_validator.receive_cnt += 1;
}
ret
}
pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] {
if dst.len() < src.len() + super::DATA_OVERHEAD_SZ {
panic!("The destination buffer is too small");
}
let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64;
let (message_type, rest) = dst.split_at_mut(4);
let (receiver_index, rest) = rest.split_at_mut(4);
let (counter, data) = rest.split_at_mut(8);
message_type.copy_from_slice(&super::DATA.to_le_bytes());
receiver_index.copy_from_slice(&self.sending_index.to_le_bytes());
counter.copy_from_slice(&sending_key_counter.to_le_bytes());
let n = {
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes());
data[..src.len()].copy_from_slice(src);
self.sender
.seal_in_place_separate_tag(
Nonce::assume_unique_for_key(nonce),
Aad::from(&[]),
&mut data[..src.len()],
)
.map(|tag| {
data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref());
src.len() + AEAD_SIZE
})
.unwrap()
};
&mut dst[..DATA_OFFSET + n]
}
pub(super) fn receive_packet_data<'a>(
&self,
packet: PacketData,
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
let ct_len = packet.encrypted_encapsulated_packet.len();
if dst.len() < ct_len {
panic!("The destination buffer is too small");
}
if packet.receiver_idx != self.receiving_index {
return Err(WireGuardError::WrongIndex);
}
self.receiving_counter_quick_check(packet.counter)?;
let ret = {
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes());
dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet);
self.receiver
.open_in_place(
Nonce::assume_unique_for_key(nonce),
Aad::from(&[]),
&mut dst[..ct_len],
)
.map_err(|_| WireGuardError::InvalidAeadTag)?
};
self.receiving_counter_mark(packet.counter)?;
Ok(ret)
}
pub(super) fn current_packet_cnt(&self) -> (u64, u64) {
let counter_validator = self.receiving_key_counter.lock();
(counter_validator.next, counter_validator.receive_cnt)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_counter() {
let mut c: ReceivingKeyCounterValidator = Default::default();
assert!(c.mark_did_receive(0).is_ok());
assert!(c.mark_did_receive(0).is_err());
assert!(c.mark_did_receive(1).is_ok());
assert!(c.mark_did_receive(1).is_err());
assert!(c.mark_did_receive(63).is_ok());
assert!(c.mark_did_receive(63).is_err());
assert!(c.mark_did_receive(15).is_ok());
assert!(c.mark_did_receive(15).is_err());
for i in 64..N_BITS + 128 {
assert!(c.mark_did_receive(i).is_ok());
assert!(c.mark_did_receive(i).is_err());
}
assert!(c.mark_did_receive(N_BITS * 3).is_ok());
for i in 0..=N_BITS * 2 {
assert!(matches!(
c.will_accept(i),
Err(WireGuardError::InvalidCounter)
));
assert!(c.mark_did_receive(i).is_err());
}
for i in N_BITS * 2 + 1..N_BITS * 3 {
assert!(c.will_accept(i).is_ok());
}
assert!(matches!(
c.will_accept(N_BITS * 3),
Err(WireGuardError::DuplicateCounter)
));
for i in (N_BITS * 2 + 1..N_BITS * 3).rev() {
assert!(c.mark_did_receive(i).is_ok());
assert!(c.mark_did_receive(i).is_err());
}
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok());
assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err());
assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err());
assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err());
}
}