pub struct SegmentTx {
base_id: u32,
next_id: u32,
size: u32,
transit_buf: Box<[usize]>,
transit_buf_mask: u32,
nonce_buf: Box<[u64]>,
nonce_buf_mask: u32,
transit_total: usize,
ack_history: u32,
nonce_history: u32,
}
impl SegmentTx {
pub fn new(base_id: u32, size: u32) -> Self {
assert!(
size > 0 && (size - 1) & size == 0,
"size must be a power of two"
);
assert!(size >= 64, "size must be 64 or greater");
Self {
base_id,
next_id: base_id,
size,
transit_buf: vec![0; size as usize].into(),
transit_buf_mask: size - 1,
transit_total: 0,
nonce_buf: vec![0; (size / 64) as usize].into(),
nonce_buf_mask: (size / 64) - 1,
ack_history: u32::MAX,
nonce_history: 0,
}
}
pub fn next_id(&self) -> u32 {
self.next_id
}
pub fn compute_next_nonce(&mut self) -> bool {
let nonce = rand::random::<bool>();
let idx = self.next_id / 64;
let bit_idx = self.next_id % 64;
let bit_mask = 1u64 << (63 - bit_idx);
let bitfield = &mut self.nonce_buf[(idx & self.nonce_buf_mask) as usize];
if nonce {
*bitfield |= bit_mask;
} else {
*bitfield &= !bit_mask;
}
nonce
}
pub fn can_send(&self) -> bool {
self.next_id.wrapping_sub(self.base_id) < self.size
}
pub fn mark_sent(&mut self, size: usize) {
debug_assert!(self.can_send());
let idx = (self.next_id & self.transit_buf_mask) as usize;
self.transit_buf[idx] = size;
self.transit_total += size;
self.next_id = self.next_id.wrapping_add(1);
}
pub fn bytes_in_transit(&self) -> usize {
self.transit_total
}
fn update_nonce_history(&mut self, ack_delta: u32, base_id_new: u32) {
let new_bits = ack_delta.min(32);
let p1 = base_id_new;
let p0 = base_id_new.wrapping_sub(new_bits);
let p0_idx = p0 / 64;
let p1_bits = p1 % 64;
let p1_idx = p1 / 64;
let mut x_new = if new_bits == 32 {
0
} else {
self.nonce_history << new_bits
};
if p0_idx == p1_idx {
let bitfield = self.nonce_buf[(p1_idx & self.nonce_buf_mask) as usize];
let shift = 64 - p1_bits;
let mask = (1u64 << new_bits) - 1;
x_new |= ((bitfield >> shift) & mask) as u32;
} else if p1_bits == 0 {
let bitfield = self.nonce_buf[(p0_idx & self.nonce_buf_mask) as usize];
let mask = (1u64 << new_bits) - 1;
x_new |= (bitfield & mask) as u32;
} else {
let bitfield_a = self.nonce_buf[(p0_idx & self.nonce_buf_mask) as usize];
let bitfield_b = self.nonce_buf[(p1_idx & self.nonce_buf_mask) as usize];
let mask_a = (1u64 << (new_bits - p1_bits)) - 1;
let mask_b = (1u64 << p1_bits) - 1;
let shift_a = p1_bits;
let shift_b = 64 - p1_bits;
x_new |= ((bitfield_b >> shift_b) & mask_b) as u32;
x_new |= ((bitfield_a & mask_a) << shift_a) as u32;
}
self.nonce_history = x_new;
}
fn expected_checksum(nonce_history: u32, rx_history: u32) -> bool {
let mut x = nonce_history & rx_history;
x ^= x >> 16;
x ^= x >> 8;
x ^= x >> 4;
x ^= x >> 2;
x ^= x >> 1;
x & 0b1 == 0b1
}
fn update_ack_history(&mut self, ack_delta: u32, rx_history: u32, rx_checksum: bool) -> bool {
if ack_delta > self.ack_history.leading_ones() {
self.ack_history = u32::MAX;
return true;
}
if ack_delta == 32 {
self.ack_history = 0;
} else {
self.ack_history = self.ack_history.wrapping_shl(ack_delta);
}
if Self::expected_checksum(self.nonce_history, rx_history) == rx_checksum {
self.ack_history |= rx_history;
} else {
}
if self.ack_history == u32::MAX {
return false;
}
if self.ack_history.wrapping_shr(31) == 0 {
self.ack_history = u32::MAX;
return true;
}
let leading_ones = self.ack_history.leading_ones();
debug_assert!(leading_ones < 32);
let ones_after_zero = self
.ack_history
.wrapping_shl(leading_ones)
.wrapping_shr(leading_ones)
.count_ones();
if ones_after_zero >= 3 {
self.ack_history = u32::MAX;
true
} else {
false
}
}
pub fn acknowledge(&mut self, rx_base_id: u32, rx_history: u32, rx_checksum: bool) -> bool {
let ack_delta = rx_base_id.wrapping_sub(self.base_id);
if ack_delta <= self.next_id.wrapping_sub(self.base_id) {
if ack_delta > 0 {
for i in 0..ack_delta {
let idx = (self.base_id.wrapping_add(i) & self.transit_buf_mask) as usize;
let size = self.transit_buf[idx];
self.transit_total -= size;
}
self.update_nonce_history(ack_delta, rx_base_id);
self.base_id = rx_base_id;
}
return self.update_ack_history(ack_delta, rx_history, rx_checksum);
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn send_receive_window() {
let mut tx = SegmentTx::new(0, 128);
for i in 0..128 {
assert_eq!(tx.can_send(), true);
assert_eq!(tx.next_id(), i);
assert_eq!(tx.bytes_in_transit(), (i * 5) as usize);
tx.mark_sent(5);
}
assert_eq!(tx.can_send(), false);
for i in 0..128 {
assert_eq!(tx.acknowledge(i + 1, 0x1, false), false);
assert_eq!(tx.bytes_in_transit(), ((128 - i - 1) * 5) as usize);
assert_eq!(tx.can_send(), true);
}
}
#[test]
fn ideal_nonce_verification() {
for n in 1..7 {
let initial_id = u32::MAX - 1024;
let mut tx = SegmentTx::new(initial_id, 1024);
let mut i = 0;
while i < 2048 {
let id = initial_id.wrapping_add(i);
assert_eq!(tx.next_id(), id);
let mut nonces = [false; 7];
for j in 0..n {
nonces[j] = tx.compute_next_nonce();
tx.mark_sent(0);
}
let rx_history = (1 << n) - 1;
let mut rx_checksum = false;
for j in 0..n {
rx_checksum ^= nonces[j];
}
let rx_base_id = id.wrapping_add(n as u32);
assert_eq!(
tx.acknowledge(rx_base_id, rx_history as u32, rx_checksum),
false
);
i = rx_base_id;
}
}
}
#[test]
fn bad_nonce_verification() {
let mut tx = SegmentTx::new(0, 1024);
let mut nonces = [false; 24];
for i in 0..24 {
nonces[i] = tx.compute_next_nonce();
tx.mark_sent(0);
}
let checksum_0_2_bad = !(nonces[0] ^ nonces[1] ^ nonces[2]);
assert_eq!(tx.acknowledge(3, 0b111, checksum_0_2_bad), false);
assert_eq!(tx.acknowledge(4, 0b1, nonces[3]), false);
assert_eq!(tx.acknowledge(5, 0b1, nonces[4]), false);
assert_eq!(tx.acknowledge(6, 0b1, nonces[5]), true);
}
#[test]
fn full_window_nonce_verification() {
let mut tx = SegmentTx::new(0, 1024);
let mut nonces = [false; 9];
for i in 0..1024 {
if i < nonces.len() {
nonces[i] = tx.compute_next_nonce();
}
tx.mark_sent(0);
}
let checksum_0_2 = nonces[0] ^ nonces[1] ^ nonces[2];
assert_eq!(tx.acknowledge(3, 0b111, checksum_0_2), false);
for _ in 0..3 {
tx.compute_next_nonce();
tx.mark_sent(0);
}
let checksum_0_5 = checksum_0_2 ^ nonces[3] ^ nonces[4] ^ nonces[5];
assert_eq!(tx.acknowledge(6, 0b111111, checksum_0_5), false);
let checksum_6_8 = nonces[6] ^ nonces[7] ^ nonces[8];
assert_eq!(tx.acknowledge(9, 0b111, checksum_6_8), false);
}
fn new_filled_window(size: usize) -> SegmentTx {
let mut tx = SegmentTx::new(0, 128);
for _ in 0..size {
tx.mark_sent(0);
}
tx
}
#[test]
fn drop_detection() {
const WINDOW_SIZE: usize = 128;
const DUMMY_NONCE: bool = false;
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b01001, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b01001, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b01101, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b10011, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b10111, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b10011, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(5, 0b11011, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(7, 0b11101, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(7, 0b11111, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(5, 0b10011, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(7, 0b11111, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(32, (0b1 << 31) | 0b001, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(32, (0b1 << 31) | 0b011, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(32, (0b1 << 31) | 0b111, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(32, u32::MAX, DUMMY_NONCE), false);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(33, u32::MAX, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(31, 0b0, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(32, 0b0, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(32, 0b1, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(16, 0b1, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(32, 0b1, DUMMY_NONCE), true);
let mut tx = new_filled_window(WINDOW_SIZE);
assert_eq!(tx.acknowledge(16, 0b1, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(31, 0b1, DUMMY_NONCE), false);
assert_eq!(tx.acknowledge(32, 0b0, DUMMY_NONCE), true);
}
}