use std::cmp;
use rust_sodium::randombytes::randombytes;
use crate::errors::{SignalingError, SignalingResult};
use crate::helpers::libsodium_init_or_panic;
#[derive(Debug, Hash, PartialEq, Eq)]
pub(crate) struct CombinedSequence {
overflow: u16,
sequence: u32,
}
impl CombinedSequence {
pub(crate) fn new(overflow: u16, sequence: u32) -> Self {
CombinedSequence { overflow, sequence }
}
pub(crate) fn random() -> Self {
libsodium_init_or_panic();
let rand = randombytes(4);
let overflow = 0u16;
let sequence = (u32::from(rand[0]) << 24)
+ (u32::from(rand[1]) << 16)
+ (u32::from(rand[2]) << 8)
+ u32::from(rand[3]);
CombinedSequence { overflow, sequence }
}
pub(crate) fn combined_sequence_number(&self) -> u64 {
(u64::from(self.overflow) << 32) + u64::from(self.sequence)
}
pub(crate) fn increment(&mut self) -> SignalingResult<CombinedSequenceSnapshot> {
let next = match self.sequence.checked_add(1) {
Some(incremented) => CombinedSequence::new(self.overflow, incremented),
None => match self.overflow.checked_add(1) {
Some(incremented) => CombinedSequence::new(incremented, 0),
None => return Err(SignalingError::CsnOverflow),
}
};
let snapshot = (&next).into();
*self = next;
Ok(snapshot)
}
}
impl<'a> From<&'a CombinedSequenceSnapshot> for CombinedSequence {
fn from(val: &'a CombinedSequenceSnapshot) -> Self {
Self {
overflow: val.overflow,
sequence: val.sequence,
}
}
}
impl cmp::PartialEq<CombinedSequenceSnapshot> for CombinedSequence {
fn eq(&self, other: &CombinedSequenceSnapshot) -> bool {
self.combined_sequence_number().eq(&other.combined_sequence_number())
}
}
impl cmp::PartialOrd<CombinedSequenceSnapshot> for CombinedSequence {
fn partial_cmp(&self, other: &CombinedSequenceSnapshot) -> Option<cmp::Ordering> {
Some(self.combined_sequence_number().cmp(&other.combined_sequence_number()))
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct CombinedSequenceSnapshot {
overflow: u16,
sequence: u32,
}
impl CombinedSequenceSnapshot {
pub fn new(overflow: u16, sequence: u32) -> Self {
CombinedSequenceSnapshot { overflow, sequence }
}
#[cfg(test)]
pub fn random() -> Self {
let cs = CombinedSequence::random();
CombinedSequenceSnapshot {
sequence: cs.sequence,
overflow: cs.overflow,
}
}
pub fn overflow_number(&self) -> u16 {
self.overflow
}
pub fn sequence_number(&self) -> u32 {
self.sequence
}
pub fn combined_sequence_number(&self) -> u64 {
(u64::from(self.overflow) << 32) + u64::from(self.sequence)
}
}
impl<'a> From<&'a CombinedSequence> for CombinedSequenceSnapshot {
fn from(val: &'a CombinedSequence) -> Self {
Self {
overflow: val.overflow,
sequence: val.sequence,
}
}
}
impl cmp::Ord for CombinedSequenceSnapshot {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.combined_sequence_number().cmp(&other.combined_sequence_number())
}
}
impl cmp::PartialOrd for CombinedSequenceSnapshot {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl cmp::PartialEq<CombinedSequence> for CombinedSequenceSnapshot {
fn eq(&self, other: &CombinedSequence) -> bool {
self.combined_sequence_number().eq(&other.combined_sequence_number())
}
}
impl cmp::PartialOrd<CombinedSequence> for CombinedSequenceSnapshot {
fn partial_cmp(&self, other: &CombinedSequence) -> Option<cmp::Ordering> {
Some(self.combined_sequence_number().cmp(&other.combined_sequence_number()))
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct CombinedSequencePair {
pub(crate) ours: CombinedSequence,
pub(crate) theirs: Option<CombinedSequenceSnapshot>,
}
impl CombinedSequencePair {
pub(crate) fn new() -> Self {
CombinedSequencePair {
ours: CombinedSequence::random(),
theirs: None,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct PeerSequenceNumbers {
pub incoming: u64,
pub outgoing: u64,
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
#[test]
fn random_distinct() {
let mut numbers = HashSet::new();
for _ in 0..100 {
let csn = CombinedSequence::random();
numbers.insert(csn);
}
assert_eq!(numbers.len(), 100);
}
#[test]
fn combined_value_range() {
for _ in 0..100 {
let csn = CombinedSequence::random();
let number = csn.combined_sequence_number();
assert!(number < (1 << 48));
}
}
#[test]
fn increment_without_overflow() {
let mut old = CombinedSequence::random();
while old.sequence == ::std::u32::MAX {
old = CombinedSequence::random();
}
let old_sequence = old.sequence;
let old_overflow = old.overflow;
let old_combined_sequence = old.combined_sequence_number();
let new = old.increment().unwrap();
assert_eq!(old_sequence + 1, new.sequence_number());
assert_eq!(old_overflow, new.overflow_number());
assert_eq!(old_combined_sequence + 1, new.combined_sequence_number());
}
#[test]
fn increment_with_sequence_overflow() {
let mut old = CombinedSequence::new(0, ::std::u32::MAX);
let new = old.increment().unwrap();
assert_eq!(new.sequence_number(), 0);
assert_eq!(new.overflow_number(), 1);
assert_eq!(new.combined_sequence_number(), (::std::u32::MAX as u64) + 1);
}
#[test]
fn increment_with_overflow_overflow() {
let mut old = CombinedSequence::new(::std::u16::MAX, ::std::u32::MAX);
let new = old.increment();
assert!(new.is_err());
match new.unwrap_err() {
SignalingError::CsnOverflow => {},
ref other => panic!("Wrong error type: {:?}", other),
};
}
}