use std::time::{Duration, Instant};
use soe_protocol::OpCode;
use soe_protocol::channel::{
InputConfig, OutputConfig, ReliableDataInputChannel, ReliableDataOutputChannel,
};
const MAX_DATA_LENGTH: usize = 508;
const WINDOW: usize = 32;
fn generate_packet(size: usize) -> Vec<u8> {
let mut state: u64 = 23445;
let mut out = Vec::with_capacity(size);
while out.len() < size {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^= z >> 31;
out.extend_from_slice(&z.to_le_bytes());
}
out.truncate(size);
if size > 2 && out[0] == 0x00 && out[1] == 0x19 {
out[0] ^= 0xFF;
}
out
}
fn assert_roundtrip(packets: &[Vec<u8>]) {
let start = Instant::now();
let mut output = ReliableDataOutputChannel::new(
OutputConfig {
max_data_length: MAX_DATA_LENGTH,
max_queued_outgoing: WINDOW,
ack_wait: Duration::from_millis(500),
},
None,
start,
);
let mut input = ReliableDataInputChannel::new(
InputConfig {
max_queued_incoming: WINDOW as u16,
acknowledge_all_data: false,
data_ack_window: WINDOW as u16,
max_ack_delay: Duration::ZERO,
},
None,
start,
);
for packet in packets {
output.enqueue_data(packet);
}
let mut received: Vec<Vec<u8>> = Vec::new();
let mut now = start;
let tick = Duration::from_millis(1);
let total_fragments: usize = packets.iter().map(|p| p.len() / MAX_DATA_LENGTH + 1).sum();
let max_iters = total_fragments * 4 + 1000;
let mut iters = 0;
loop {
now += tick;
output.run_tick(now);
for pkt in output.take_outgoing() {
match pkt.op_code {
OpCode::ReliableData => input
.handle_reliable_data(pkt.payload, now)
.expect("valid reliable data"),
OpCode::ReliableDataFragment => input
.handle_reliable_data_fragment(pkt.payload, now)
.expect("valid reliable data fragment"),
other => panic!("unexpected output op code: {other:?}"),
}
}
for data in input.take_app_data() {
received.push(data.to_vec());
}
input.run_tick(now);
for ack in input.take_outgoing() {
match ack.op_code {
OpCode::Acknowledge => output.notify_of_acknowledge(ack.sequence, now),
OpCode::AcknowledgeAll => output.notify_of_acknowledge_all(ack.sequence, now),
other => panic!("unexpected acknowledgement op code: {other:?}"),
}
}
if output.queued_len() == 0 {
for data in input.take_app_data() {
received.push(data.to_vec());
}
break;
}
iters += 1;
assert!(iters < max_iters, "channels did not converge");
}
assert_eq!(
received.len(),
packets.len(),
"received packet count mismatch"
);
for (i, (got, expected)) in received.iter().zip(packets).enumerate() {
assert_eq!(
got, expected,
"recomposed packet {i} differs from the original"
);
}
}
#[test]
fn single_small_packet() {
assert_roundtrip(&[generate_packet(5)]);
}
#[test]
fn multiple_small_packets() {
assert_roundtrip(&[
generate_packet(3),
generate_packet(45),
generate_packet(1),
generate_packet(214),
]);
}
#[test]
fn multiple_small_packets_requiring_fragmentation() {
assert_roundtrip(&[
generate_packet(3),
generate_packet(45),
generate_packet(1),
generate_packet(214),
generate_packet(214),
generate_packet(214),
]);
}
#[test]
fn largest_single_data_packet() {
assert_roundtrip(&[generate_packet(MAX_DATA_LENGTH - 2)]);
}
#[test]
fn single_large_packet() {
assert_roundtrip(&[generate_packet(MAX_DATA_LENGTH - 1)]);
}
#[test]
fn multiple_large_packets() {
assert_roundtrip(&[
generate_packet(512),
generate_packet(512 + 7),
generate_packet(512 + 54),
generate_packet(512 * 2),
]);
}
#[test]
fn single_huge_packet_fragmentation() {
assert_roundtrip(&[generate_packet(400 * 1024)]);
}
#[test]
fn all_the_packets() {
let packets: Vec<Vec<u8>> = (1..=256).map(|i| generate_packet(i * 256)).collect();
assert_roundtrip(&packets);
}