use alloc::borrow::ToOwned;
use alloc::vec;
use alloc::vec::Vec;
use std::println;
use super::*;
use crate::association::Event;
use crate::config::generate_snap_token;
use crate::error::{Error, Result};
use crate::association::state::{AckMode, AssociationState};
use crate::association::stream::{ReliabilityType, Stream};
use crate::chunk::chunk_abort::ChunkAbort;
use crate::chunk::chunk_cookie_echo::ChunkCookieEcho;
use crate::chunk::chunk_error::ChunkError;
use crate::chunk::chunk_forward_tsn::ChunkForwardTsn;
use crate::chunk::chunk_heartbeat::ChunkHeartbeat;
use crate::chunk::chunk_init::ChunkInit;
use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
use crate::chunk::chunk_reconfig::ChunkReconfig;
use crate::chunk::chunk_selective_ack::{ChunkSelectiveAck, GapAckBlock};
use crate::chunk::chunk_shutdown::ChunkShutdown;
use crate::chunk::chunk_shutdown_ack::ChunkShutdownAck;
use crate::chunk::chunk_shutdown_complete::ChunkShutdownComplete;
use crate::chunk::{ErrorCauseProtocolViolation, PROTOCOL_VIOLATION};
use crate::packet::{CommonHeader, Packet};
use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest;
use crate::param::param_reconfig_response::ParamReconfigResponse;
use assert_matches::assert_matches;
use core::net::Ipv6Addr;
use core::ops::RangeFrom;
use core::str::FromStr;
use core::time::Duration;
use core::{cmp, mem};
use log::{info, trace};
use std::net::UdpSocket;
use std::sync::{LazyLock, Mutex};
use std::time::Instant;
pub static SERVER_PORTS: LazyLock<Mutex<RangeFrom<u16>>> = LazyLock::new(|| Mutex::new(4433..));
pub static CLIENT_PORTS: LazyLock<Mutex<RangeFrom<u16>>> = LazyLock::new(|| Mutex::new(44433..));
fn min_opt<T: Ord>(x: Option<T>, y: Option<T>) -> Option<T> {
match (x, y) {
(Some(x), Some(y)) => Some(cmp::min(x, y)),
(Some(x), _) => Some(x),
(_, Some(y)) => Some(y),
_ => None,
}
}
const MAX_DATAGRAMS: usize = 10;
fn split_transmit(transmit: Transmit) -> Vec<Transmit> {
let mut transmits = Vec::new();
if let Payload::RawEncode(contents) = transmit.payload {
for content in contents {
transmits.push(Transmit {
now: transmit.now,
remote: transmit.remote,
payload: Payload::RawEncode(vec![content]),
ecn: transmit.ecn,
local_ip: transmit.local_ip,
});
}
}
transmits
}
pub fn client_config() -> ClientConfig {
ClientConfig::new()
}
pub fn server_config() -> ServerConfig {
ServerConfig::new()
}
struct TestEndpoint {
endpoint: Endpoint,
addr: SocketAddr,
socket: Option<UdpSocket>,
timeout: Option<Instant>,
outbound: VecDeque<Transmit>,
delayed: VecDeque<Transmit>,
inbound: VecDeque<(Instant, Option<EcnCodepoint>, Bytes)>,
accepted: Option<AssociationHandle>,
associations: HashMap<AssociationHandle, Association>,
conn_events: HashMap<AssociationHandle, VecDeque<AssociationEvent>>,
}
impl TestEndpoint {
fn new(endpoint: Endpoint, addr: SocketAddr) -> Self {
let socket = UdpSocket::bind(addr).expect("failed to bind UDP socket");
socket
.set_read_timeout(Some(Duration::new(0, 10_000_000)))
.unwrap();
Self {
endpoint,
addr,
socket: Some(socket),
timeout: None,
outbound: VecDeque::new(),
delayed: VecDeque::new(),
inbound: VecDeque::new(),
accepted: None,
associations: HashMap::default(),
conn_events: HashMap::default(),
}
}
pub fn drive(&mut self, now: Instant, remote: SocketAddr) {
if let Some(ref socket) = self.socket {
loop {
let mut buf = [0; 8192];
if socket.recv_from(&mut buf).is_err() {
break;
}
}
}
while self.inbound.front().is_some_and(|x| x.0 <= now) {
let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap();
if let Some((ch, event)) = self.endpoint.handle(recv_time, remote, None, ecn, packet) {
match event {
DatagramEvent::NewAssociation(conn) => {
self.associations.insert(ch, conn);
self.accepted = Some(ch);
}
DatagramEvent::AssociationEvent(event) => {
self.conn_events.entry(ch).or_default().push_back(event);
}
}
}
}
while let Some(x) = self.poll_transmit() {
self.outbound.extend(split_transmit(x));
}
let mut endpoint_events: Vec<(AssociationHandle, EndpointEvent)> = vec![];
for (ch, conn) in self.associations.iter_mut() {
if self.timeout.is_some_and(|x| x <= now) {
self.timeout = None;
conn.handle_timeout(now);
}
for (_, mut events) in self.conn_events.drain() {
for event in events.drain(..) {
conn.handle_event(event);
}
}
while let Some(event) = conn.poll_endpoint_event() {
endpoint_events.push((*ch, event));
}
while let Some(x) = conn.poll_transmit(now) {
self.outbound.extend(split_transmit(x));
}
self.timeout = conn.poll_timeout();
}
for (ch, event) in endpoint_events {
if let Some(event) = self.handle_event(ch, event) {
if let Some(conn) = self.associations.get_mut(&ch) {
conn.handle_event(event);
}
}
}
}
pub fn next_wakeup(&self) -> Option<Instant> {
let next_inbound = self.inbound.front().map(|x| x.0);
min_opt(self.timeout, next_inbound)
}
fn is_idle(&self) -> bool {
self.associations.values().all(|x| x.is_idle())
}
pub fn delay_outbound(&mut self) {
assert!(self.delayed.is_empty());
mem::swap(&mut self.delayed, &mut self.outbound);
}
pub fn finish_delay(&mut self) {
self.outbound.extend(self.delayed.drain(..));
}
pub fn assert_accept(&mut self) -> AssociationHandle {
self.accepted.take().expect("server didn't connect")
}
}
impl ::core::ops::Deref for TestEndpoint {
type Target = Endpoint;
fn deref(&self) -> &Endpoint {
&self.endpoint
}
}
impl ::core::ops::DerefMut for TestEndpoint {
fn deref_mut(&mut self) -> &mut Endpoint {
&mut self.endpoint
}
}
struct Pair {
server: TestEndpoint,
client: TestEndpoint,
time: Instant,
latency: Duration, }
impl Pair {
pub fn new(endpoint_config: Arc<EndpointConfig>, server_config: ServerConfig) -> Self {
let server = Endpoint::new(endpoint_config.clone(), Some(Arc::new(server_config)));
let client = Endpoint::new(endpoint_config, None);
Pair::new_from_endpoint(client, server)
}
pub fn new_from_endpoint(client: Endpoint, server: Endpoint) -> Self {
let server_addr = SocketAddr::new(
Ipv6Addr::LOCALHOST.into(),
SERVER_PORTS.lock().unwrap().next().unwrap(),
);
let client_addr = SocketAddr::new(
Ipv6Addr::LOCALHOST.into(),
CLIENT_PORTS.lock().unwrap().next().unwrap(),
);
Self {
server: TestEndpoint::new(server, server_addr),
client: TestEndpoint::new(client, client_addr),
time: Instant::now(),
latency: Duration::new(0, 0),
}
}
pub fn step(&mut self) -> bool {
self.drive_client();
self.drive_server();
if self.client.is_idle() && self.server.is_idle() {
return false;
}
let client_t = self.client.next_wakeup();
let server_t = self.server.next_wakeup();
match min_opt(client_t, server_t) {
Some(t) if Some(t) == client_t => {
if t != self.time {
self.time = self.time.max(t);
trace!("advancing to {:?} for client", self.time);
}
true
}
Some(t) if Some(t) == server_t => {
if t != self.time {
self.time = self.time.max(t);
trace!("advancing to {:?} for server", self.time);
}
true
}
Some(_) => unreachable!(),
None => false,
}
}
pub fn drive(&mut self) {
while self.step() {}
}
pub fn drive_client(&mut self) {
self.client.drive(self.time, self.server.addr);
for x in self.client.outbound.drain(..) {
if let Payload::RawEncode(contents) = x.payload {
for content in contents {
if let Some(ref socket) = self.client.socket {
socket.send_to(&content, x.remote).unwrap();
}
if self.server.addr == x.remote {
self.server
.inbound
.push_back((self.time + self.latency, x.ecn, content));
}
}
}
}
}
pub fn drive_server(&mut self) {
self.server.drive(self.time, self.client.addr);
for x in self.server.outbound.drain(..) {
if let Payload::RawEncode(contents) = x.payload {
for content in contents {
if let Some(ref socket) = self.server.socket {
socket.send_to(&content, x.remote).unwrap();
}
if self.client.addr == x.remote {
self.client
.inbound
.push_back((self.time + self.latency, x.ecn, content));
}
}
}
}
}
pub fn connect(&mut self) -> (AssociationHandle, AssociationHandle) {
self.connect_with(client_config())
}
pub fn connect_with(&mut self, config: ClientConfig) -> (AssociationHandle, AssociationHandle) {
info!("connecting");
let client_ch = self.begin_connect(config);
self.drive();
let server_ch = self.server.assert_accept();
self.finish_connect(client_ch, server_ch);
(client_ch, server_ch)
}
pub fn begin_connect(&mut self, config: ClientConfig) -> AssociationHandle {
let (client_ch, client_conn) = self.client.connect(config, self.server.addr).unwrap();
self.client.associations.insert(client_ch, client_conn);
client_ch
}
fn finish_connect(&mut self, client_ch: AssociationHandle, server_ch: AssociationHandle) {
assert_matches!(
self.client_conn_mut(client_ch).poll(),
Some(Event::Connected)
);
assert_matches!(
self.server_conn_mut(server_ch).poll(),
Some(Event::Connected)
);
}
pub fn client_conn_mut(&mut self, ch: AssociationHandle) -> &mut Association {
self.client.associations.get_mut(&ch).unwrap()
}
pub fn client_stream(&mut self, ch: AssociationHandle, si: u16) -> Result<Stream<'_>> {
self.client_conn_mut(ch).stream(si)
}
pub fn server_conn_mut(&mut self, ch: AssociationHandle) -> &mut Association {
self.server.associations.get_mut(&ch).unwrap()
}
pub fn server_stream(&mut self, ch: AssociationHandle, si: u16) -> Result<Stream<'_>> {
self.server_conn_mut(ch).stream(si)
}
}
impl Default for Pair {
fn default() -> Self {
Pair::new(Default::default(), server_config())
}
}
fn create_association_pair(
ack_mode: AckMode,
recv_buf_size: u32,
) -> Result<(Pair, AssociationHandle, AssociationHandle)> {
let mut pair = Pair::new(
Arc::new(EndpointConfig::default()),
ServerConfig {
transport: Arc::new(if recv_buf_size > 0 {
TransportConfig::default().with_max_receive_buffer_size(recv_buf_size)
} else {
TransportConfig::default()
}),
..Default::default()
},
);
let (client_ch, server_ch) = pair.connect_with(ClientConfig {
transport: Arc::new(if recv_buf_size > 0 {
TransportConfig::default().with_max_receive_buffer_size(recv_buf_size)
} else {
TransportConfig::default()
}),
..Default::default()
});
pair.client_conn_mut(client_ch).ack_mode = ack_mode;
pair.server_conn_mut(server_ch).ack_mode = ack_mode;
Ok((pair, client_ch, server_ch))
}
fn establish_session_pair(
pair: &mut Pair,
client_ch: AssociationHandle,
server_ch: AssociationHandle,
si: u16,
) -> Result<()> {
let hello_msg = Bytes::from_static(b"Hello");
let _ = pair
.client_conn_mut(client_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)?;
let _ = pair
.client_stream(client_ch, si)?
.write_sctp(&hello_msg, PayloadProtocolIdentifier::Dcep)?;
pair.drive();
{
let s1 = pair.server_conn_mut(server_ch).accept_stream().unwrap();
if si != s1.stream_identifier {
return Err(Error::Other("si should match".to_owned()));
}
}
pair.drive();
let mut buf = vec![0u8; 1024];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let n = chunks.read(&mut buf)?;
if n != hello_msg.len() {
return Err(Error::Other("received data must by 3 bytes".to_owned()));
}
if chunks.ppi != PayloadProtocolIdentifier::Dcep {
return Err(Error::Other("unexpected ppi".to_owned()));
}
if buf[..n] != hello_msg {
return Err(Error::Other("received data mismatch".to_owned()));
}
pair.drive();
Ok(())
}
fn close_association_pair(
_pair: &mut Pair,
_client_ch: AssociationHandle,
_server_ch: AssociationHandle,
_si: u16,
) {
}
#[test]
fn test_assoc_reliable_simple() -> Result<()> {
let si: u16 = 1;
let msg: Bytes = Bytes::from_static(b"ABC");
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
let n = pair
.client_stream(client_ch, si)?
.write_sctp(&msg, PayloadProtocolIdentifier::Binary)?;
assert_eq!(msg.len(), n, "unexpected length of received data");
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(msg.len(), a.buffered_amount(), "incorrect bufferedAmount");
}
pair.drive();
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
assert_eq!(n, msg.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reliable_ordered_reordered() -> Result<()> {
let si: u16 = 2;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.delay_outbound();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.finish_delay();
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
0,
"unexpected received data"
);
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reliable_ordered_fragmented_then_defragmented() -> Result<()> {
let si: u16 = 3;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let mut sbufl = vec![0u8; 2000];
for (i, b) in sbufl.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?.set_reliability_params(
false,
ReliabilityType::Reliable,
0,
)?;
pair.server_stream(server_ch, si)?.set_reliability_params(
false,
ReliabilityType::Reliable,
0,
)?;
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbufl.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbufl.len(), n, "unexpected length of received data");
pair.drive();
let mut rbuf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut rbuf)?;
assert_eq!(n, sbufl.len(), "unexpected length of received data");
assert_eq!(&rbuf[..n], &sbufl, "unexpected received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reliable_unordered_fragmented_then_defragmented() -> Result<()> {
let si: u16 = 4;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let sbufl = vec![0u8; 2000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?.set_reliability_params(
true,
ReliabilityType::Reliable,
0,
)?;
pair.server_stream(server_ch, si)?.set_reliability_params(
true,
ReliabilityType::Reliable,
0,
)?;
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbufl.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbufl.len(), n, "unexpected length of received data");
pair.drive();
let mut rbuf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut rbuf)?;
assert_eq!(n, sbufl.len(), "unexpected length of received data");
assert_eq!(&rbuf[..n], &sbufl, "unexpected received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reliable_unordered_ordered() -> Result<()> {
let si: u16 = 5;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?.set_reliability_params(
true,
ReliabilityType::Reliable,
0,
)?;
pair.server_stream(server_ch, si)?.set_reliability_params(
true,
ReliabilityType::Reliable,
0,
)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.delay_outbound();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.finish_delay();
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
0,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reliable_retransmission() -> Result<()> {
let si: u16 = 6;
let msg1: Bytes = Bytes::from_static(b"ABC");
let msg2: Bytes = Bytes::from_static(b"DEFG");
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
{
let a = pair.client_conn_mut(client_ch);
a.rto_mgr.set_rto(100, true);
}
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
let n = pair
.client_stream(client_ch, si)?
.write_sctp(&msg1, PayloadProtocolIdentifier::Binary)?;
assert_eq!(msg1.len(), n, "unexpected length of received data");
pair.drive_client(); pair.server.inbound.clear(); debug!("dropping packet");
let n = pair
.client_stream(client_ch, si)?
.write_sctp(&msg2, PayloadProtocolIdentifier::Binary)?;
assert_eq!(msg2.len(), n, "unexpected length of received data");
pair.drive();
let mut buf = vec![0u8; 32];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, msg1.len(), "unexpected length of received data");
assert_eq!(&buf[..n], &msg1, "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, msg2.len(), "unexpected length of received data");
assert_eq!(&buf[..n], &msg2, "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reliable_short_buffer() -> Result<()> {
let si: u16 = 1;
let msg: Bytes = Bytes::from_static(b"Hello");
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
let n = pair
.client_stream(client_ch, si)?
.write_sctp(&msg, PayloadProtocolIdentifier::Binary)?;
assert_eq!(msg.len(), n, "unexpected length of received data");
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(msg.len(), a.buffered_amount(), "incorrect bufferedAmount");
}
pair.drive();
let mut buf = vec![0u8; 3];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let result = chunks.read(&mut buf);
assert!(result.is_err(), "expected error to be io.ErrShortBuffer");
if let Err(err) = result {
assert_eq!(
Error::ErrShortBuffer,
err,
"expected error to be io.ErrShortBuffer"
);
}
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_unreliable_rexmit_ordered_no_fragment() -> Result<()> {
let si: u16 = 1;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?
.set_reliability_params(false, ReliabilityType::Rexmit, 0)?;
pair.server_stream(server_ch, si)?
.set_reliability_params(false, ReliabilityType::Rexmit, 0)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive_client(); pair.server.inbound.clear(); debug!("dropping packet");
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
debug!("flush_buffers");
pair.drive();
let mut buf = vec![0u8; 2000];
debug!("read_sctp");
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
debug!("process");
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_unreliable_rexmit_ordered_fragment() -> Result<()> {
let si: u16 = 1;
let mut sbuf = vec![0u8; 2000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
let a = pair.client_conn_mut(client_ch);
a.rto_mgr.set_rto(100, true);
}
pair.client_stream(client_ch, si)?
.set_reliability_params(false, ReliabilityType::Rexmit, 0)?;
pair.server_stream(server_ch, si)?
.set_reliability_params(false, ReliabilityType::Rexmit, 0)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive_client(); pair.server.inbound.clear();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_unreliable_rexmit_unordered_no_fragment() -> Result<()> {
let si: u16 = 2;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?
.set_reliability_params(true, ReliabilityType::Rexmit, 0)?;
pair.server_stream(server_ch, si)?
.set_reliability_params(true, ReliabilityType::Rexmit, 0)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive_client(); pair.server.inbound.clear();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_unreliable_rexmit_unordered_fragment() -> Result<()> {
let si: u16 = 1;
let mut sbuf = vec![0u8; 2000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?
.set_reliability_params(true, ReliabilityType::Rexmit, 0)?;
pair.server_stream(server_ch, si)?
.set_reliability_params(true, ReliabilityType::Rexmit, 0)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.outbound.clear();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
assert_eq!(
0,
q.unordered.len(),
"should be nothing in the unordered queue"
);
assert_eq!(
0,
q.unordered_chunks.len(),
"should be nothing in the unorderedChunks list"
);
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_unreliable_rexmit_timed_ordered() -> Result<()> {
let si: u16 = 3;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?
.set_reliability_params(false, ReliabilityType::Timed, 0)?;
pair.server_stream(server_ch, si)?
.set_reliability_params(false, ReliabilityType::Timed, 0)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.outbound.clear();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_unreliable_rexmit_timed_unordered() -> Result<()> {
let si: u16 = 3;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?
.set_reliability_params(true, ReliabilityType::Timed, 0)?;
pair.server_stream(server_ch, si)?
.set_reliability_params(true, ReliabilityType::Timed, 0)?;
sbuf[0..4].copy_from_slice(&0u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
pair.client.outbound.clear();
sbuf[0..4].copy_from_slice(&1u32.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive();
let mut buf = vec![0u8; 2000];
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
1,
"unexpected received data"
);
pair.drive();
{
let q = &pair
.client_conn_mut(client_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
assert!(!q.is_readable(), "should no longer be readable");
assert_eq!(
0,
q.unordered.len(),
"should be nothing in the unordered queue"
);
assert_eq!(
0,
q.unordered_chunks.len(),
"should be nothing in the unorderedChunks list"
);
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_congestion_control_congestion_avoidance() -> Result<()> {
let max_receive_buffer_size: u32 = 64 * 1024;
let si: u16 = 6;
let n_packets_to_send: u32 = 2000;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) =
create_association_pair(AckMode::Normal, max_receive_buffer_size)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
pair.client_conn_mut(client_ch).stats.reset();
pair.server_conn_mut(server_ch).stats.reset();
}
for i in 0..n_packets_to_send {
sbuf[0..4].copy_from_slice(&i.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
}
pair.drive_client();
let mut rbuf = vec![0u8; 3000];
let mut n_packets_received = 0u32;
while pair.client_conn_mut(client_ch).buffered_amount() > 0
&& n_packets_received < n_packets_to_send
{
pair.step();
while let Some(chunks) = pair.server_stream(server_ch, si)?.read_sctp()? {
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut rbuf)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
assert_eq!(
n_packets_received,
u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]),
"unexpected length of received data"
);
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
n_packets_received += 1;
}
}
pair.drive();
assert_eq!(
n_packets_received, n_packets_to_send,
"unexpected num of packets received"
);
{
let a = pair.client_conn_mut(client_ch);
assert!(!a.in_fast_recovery, "should not be in fast-recovery");
assert!(
a.cwnd > a.ssthresh,
"should be in congestion avoidance mode"
);
assert!(
a.ssthresh >= max_receive_buffer_size,
"{} should not be less than the initial size of 128KB {}",
a.ssthresh,
max_receive_buffer_size
);
debug!("nSACKs : {}", a.stats.get_num_sacks());
debug!("nT3Timeouts: {}", a.stats.get_num_t3timeouts());
assert!(
a.stats.get_num_sacks() <= n_packets_to_send as u64 / 2,
"too many sacks"
);
assert_eq!(0, a.stats.get_num_t3timeouts(), "should be no retransmit");
}
{
assert_eq!(
0,
pair.server_conn_mut(server_ch)
.streams
.get(&si)
.unwrap()
.get_num_bytes_in_reassembly_queue(),
"reassembly queue should be empty"
);
let b = pair.server_conn_mut(server_ch);
debug!("nDATAs : {}", b.stats.get_num_datas());
assert_eq!(
n_packets_to_send as u64,
b.stats.get_num_datas(),
"packet count mismatch"
);
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_congestion_control_slow_reader() -> Result<()> {
let max_receive_buffer_size: u32 = 64 * 1024;
let si: u16 = 6;
let n_packets_to_send: u32 = 130;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) =
create_association_pair(AckMode::Normal, max_receive_buffer_size)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
for i in 0..n_packets_to_send {
sbuf[0..4].copy_from_slice(&i.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
}
pair.drive_client();
let mut rbuf = vec![0u8; 3000];
let mut n_packets_received = 0u32;
let mut has_rtoed = false;
while pair.client_conn_mut(client_ch).buffered_amount() > 0
&& n_packets_received < n_packets_to_send
{
if !has_rtoed {
let rwnd = pair
.server_conn_mut(server_ch)
.get_my_receiver_window_credit();
let cwnd = pair.client_conn_mut(client_ch).cwnd;
let cmtu = pair.client_conn_mut(client_ch).mtu;
if cwnd > cmtu || rwnd > 0 {
pair.step();
continue;
}
has_rtoed = true;
}
while let Some(chunks) = pair.server_stream(server_ch, si)?.read_sctp()? {
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut rbuf)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
assert_eq!(
n_packets_received,
u32::from_be_bytes([rbuf[0], rbuf[1], rbuf[2], rbuf[3]]),
"unexpected length of received data"
);
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
n_packets_received += 1;
}
pair.step();
}
pair.drive();
assert_eq!(
n_packets_received, n_packets_to_send,
"unexpected num of packets received"
);
assert_eq!(
0,
pair.server_conn_mut(server_ch)
.streams
.get(&si)
.unwrap()
.get_num_bytes_in_reassembly_queue(),
"reassembly queue should be empty"
);
{
let a = pair.client_conn_mut(client_ch);
debug!("nSACKs : {}", a.stats.get_num_sacks());
}
{
let b = pair.server_conn_mut(server_ch);
debug!("nDATAs : {}", b.stats.get_num_datas());
debug!("nAckTimeouts: {}", b.stats.get_num_ack_timeouts());
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_delayed_ack() -> Result<()> {
let si: u16 = 6;
let mut sbuf = vec![0u8; 1000];
let mut rbuf = vec![0u8; 1500];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::AlwaysDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
pair.client_conn_mut(client_ch).stats.reset();
pair.server_conn_mut(server_ch).stats.reset();
}
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.drive_client();
let since = pair.time;
let mut n_packets_received = 0;
while pair.client_conn_mut(client_ch).buffered_amount() > 0 {
pair.step();
while let Some(chunks) = pair.server_stream(server_ch, si)?.read_sctp()? {
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut rbuf)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
n_packets_received += 1;
}
}
let delay = (pair.time.duration_since(since).as_millis() as f64) / 1000.0;
debug!("received in {} seconds", delay);
assert!(delay >= 0.2, "should be >= 200msec");
pair.drive();
assert_eq!(n_packets_received, 1, "unexpected num of packets received");
assert_eq!(
0,
pair.server_conn_mut(server_ch)
.streams
.get(&si)
.unwrap()
.get_num_bytes_in_reassembly_queue(),
"reassembly queue should be empty"
);
let a_num_sacks = {
let a = pair.client_conn_mut(client_ch);
debug!("nSACKs : {}", a.stats.get_num_sacks());
assert_eq!(0, a.stats.get_num_t3timeouts(), "should be no retransmit");
a.stats.get_num_sacks()
};
{
let b = pair.server_conn_mut(server_ch);
debug!("nDATAs : {}", b.stats.get_num_datas());
debug!("nAckTimeouts: {}", b.stats.get_num_ack_timeouts());
assert_eq!(1, b.stats.get_num_datas(), "DATA chunk count mismatch");
assert_eq!(
a_num_sacks,
b.stats.get_num_datas(),
"sack count should be equal to the number of data chunks"
);
assert_eq!(
1,
b.stats.get_num_ack_timeouts(),
"ackTimeout count mismatch"
);
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reset_close_one_way() -> Result<()> {
let si: u16 = 1;
let msg: Bytes = Bytes::from_static(b"ABC");
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
let n = pair
.client_stream(client_ch, si)?
.write_sctp(&msg, PayloadProtocolIdentifier::Binary)?;
assert_eq!(msg.len(), n, "unexpected length of received data");
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(msg.len(), a.buffered_amount(), "incorrect bufferedAmount");
}
pair.step();
let mut buf = vec![0u8; 32];
while pair.server_stream(server_ch, si).is_ok() {
debug!("s1.read_sctp begin");
match pair.server_stream(server_ch, si)?.read_sctp() {
Ok(chunks_opt) => {
if let Some(chunks) = chunks_opt {
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
debug!("s1.read_sctp done with {:?}", &buf[..n]);
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(n, msg.len(), "unexpected length of received data");
}
debug!("s0.close");
pair.client_stream(client_ch, si)?.stop()?;
pair.step();
}
Err(err) => {
debug!("s1.read_sctp err {:?}", err);
break;
}
}
}
pair.drive();
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reset_close_both_ways() -> Result<()> {
let si: u16 = 1;
let msg: Bytes = Bytes::from_static(b"ABC");
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(0, a.buffered_amount(), "incorrect bufferedAmount");
}
let n = pair
.client_stream(client_ch, si)?
.write_sctp(&msg, PayloadProtocolIdentifier::Binary)?;
assert_eq!(msg.len(), n, "unexpected length of received data");
{
let a = pair.client_conn_mut(client_ch);
assert_eq!(msg.len(), a.buffered_amount(), "incorrect bufferedAmount");
}
pair.step();
let mut buf = vec![0u8; 32];
while pair.server_stream(server_ch, si).is_ok() || pair.client_stream(client_ch, si).is_ok() {
if pair.server_stream(server_ch, si).is_ok() {
debug!("s1.read_sctp begin");
match pair.server_stream(server_ch, si)?.read_sctp() {
Ok(chunks_opt) => {
if let Some(chunks) = chunks_opt {
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
debug!("s1.read_sctp done with {:?}", &buf[..n]);
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(n, msg.len(), "unexpected length of received data");
}
}
Err(err) => {
debug!("s1.read_sctp err {:?}", err);
break;
}
}
}
if pair.client_stream(client_ch, si).is_ok() {
debug!("s0.read_sctp begin");
match pair.client_stream(client_ch, si)?.read_sctp() {
Ok(chunks_opt) => {
if let Some(chunks) = chunks_opt {
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
debug!("s0.read_sctp done with {:?}", &buf[..n]);
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(n, msg.len(), "unexpected length of received data");
}
}
Err(err) => {
debug!("s0.read_sctp err {:?}", err);
break;
}
}
}
if pair.client_stream(client_ch, si).is_ok() {
pair.client_stream(client_ch, si)?.stop()?; }
if pair.server_stream(server_ch, si).is_ok() {
pair.server_stream(server_ch, si)?.stop()?; }
pair.step();
}
pair.drive();
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_abort() -> Result<()> {
let si: u16 = 1;
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
let transmit = {
let abort = ChunkAbort {
error_causes: vec![ErrorCauseProtocolViolation {
code: PROTOCOL_VIOLATION,
..Default::default()
}],
};
let packet = pair
.client_conn_mut(client_ch)
.create_packet(vec![Box::new(abort)])
.marshal()?;
Transmit {
now: pair.time,
remote: pair.server.addr,
ecn: None,
local_ip: None,
payload: Payload::RawEncode(vec![packet]),
}
};
assert_eq!(
AssociationState::Established,
pair.client_conn_mut(client_ch).state()
);
assert_eq!(
AssociationState::Established,
pair.server_conn_mut(server_ch).state()
);
debug!("send ChunkAbort");
pair.client.outbound.push_back(transmit);
pair.drive();
assert_eq!(
AssociationState::Established,
pair.client_conn_mut(client_ch).state()
);
assert_eq!(
AssociationState::Closed,
pair.server_conn_mut(server_ch).state()
);
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_association_handle_packet_before_init() -> Result<()> {
let tests = vec![
(
"InitAck",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::new(ChunkInit {
is_ack: true,
initiate_tag: 1,
num_inbound_streams: 1,
num_outbound_streams: 1,
advertised_receiver_window_credit: 1500,
..Default::default()
})],
},
),
(
"Abort",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::<ChunkAbort>::default()],
},
),
(
"CoockeEcho",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::<ChunkCookieEcho>::default()],
},
),
(
"HeartBeat",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::<ChunkHeartbeat>::default()],
},
),
(
"PayloadData",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::<ChunkPayloadData>::default()],
},
),
(
"Sack",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::new(ChunkSelectiveAck {
cumulative_tsn_ack: 1000,
advertised_receiver_window_credit: 1500,
gap_ack_blocks: vec![GapAckBlock {
start: 100,
end: 200,
}],
..Default::default()
})],
},
),
(
"Reconfig",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::new(ChunkReconfig {
param_a: Some(Box::<ParamOutgoingResetRequest>::default()),
param_b: Some(Box::<ParamReconfigResponse>::default()),
})],
},
),
(
"ForwardTSN",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::new(ChunkForwardTsn {
new_cumulative_tsn: 100,
..Default::default()
})],
},
),
(
"Error",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::<ChunkError>::default()],
},
),
(
"Shutdown",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::<ChunkShutdown>::default()],
},
),
(
"ShutdownAck",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::new(ChunkShutdownAck)],
},
),
(
"ShutdownComplete",
Packet {
common_header: CommonHeader {
source_port: 1,
destination_port: 1,
verification_tag: 0,
},
chunks: vec![Box::new(ChunkShutdownComplete)],
},
),
];
let remote = SocketAddr::from_str("0.0.0.0:0").unwrap();
for (name, packet) in tests {
debug!("testing {}", name);
let config = Arc::new(TransportConfig::default());
let mut a = Association::new(None, config, 1400, 0, remote, None, Instant::now());
let packet = packet.marshal()?;
a.handle_event(AssociationEvent(AssociationEventInner::Datagram(
Transmit {
now: Instant::now(),
remote,
ecn: None,
local_ip: None,
payload: Payload::RawEncode(vec![packet]),
},
)));
a.close()?;
}
Ok(())
}
#[test]
fn test_old_rtx_on_regular_acks() -> Result<()> {
let si: u16 = 6;
let mut sbuf = vec![0u8; 1000];
for (i, b) in sbuf.iter_mut().enumerate() {
*b = (i & 0xff) as u8;
}
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::Normal, 0)?;
pair.latency = Duration::from_millis(500);
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
for i in 0..20u32 {
println!("sending packet {}", i);
sbuf[0..4].copy_from_slice(&i.to_be_bytes());
let n = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from(sbuf.clone()),
PayloadProtocolIdentifier::Binary,
)?;
assert_eq!(sbuf.len(), n, "unexpected length of received data");
pair.client.drive(pair.time, pair.server.addr);
if (5..10).contains(&i) {
pair.client.outbound.clear();
}
pair.drive_client();
pair.drive_server();
pair.time += Duration::from_millis(500);
}
pair.drive_client();
pair.drive_server();
let mut buf = vec![0u8; 3000];
for i in 0..20 {
{
let q = &pair
.server_conn_mut(server_ch)
.streams
.get(&si)
.unwrap()
.reassembly_queue;
println!("q.is_readable()={}", q.is_readable());
assert!(q.is_readable(), "should be readable at {}", i);
}
let chunks = pair.server_stream(server_ch, si)?.read_sctp()?.unwrap();
let (n, ppi) = (chunks.len(), chunks.ppi);
chunks.read(&mut buf)?;
assert_eq!(n, sbuf.len(), "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");
assert_eq!(
u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]),
i,
"unexpected received data"
);
}
close_association_pair(&mut pair, client_ch, server_ch, si);
Ok(())
}
#[test]
fn test_assoc_reset_duplicate_reconfig_request() -> Result<()> {
let si: u16 = 1;
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.client_stream(client_ch, si)?.stop()?;
pair.drive_client();
let captured_packets: Vec<_> = pair.server.inbound.iter().cloned().collect();
pair.drive();
assert!(
pair.server_stream(server_ch, si).is_err(),
"stream 1 should be removed after reset"
);
let _ = pair
.server_conn_mut(server_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)?;
assert!(
pair.server_stream(server_ch, si).is_ok(),
"new stream 1 should exist"
);
for packet in captured_packets {
pair.server.inbound.push_back(packet);
}
pair.drive();
assert!(
pair.server_stream(server_ch, si).is_ok(),
"new stream 1 should NOT be destroyed by duplicate reconfig request"
);
Ok(())
}
#[test]
fn test_assoc_reset_inprogress_completed_via_tsn_advance_then_retransmit() -> Result<()> {
let si: u16 = 1;
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
let _ = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from_static(b"payload"),
PayloadProtocolIdentifier::Binary,
)?;
pair.drive_client();
let withheld: Vec<_> = pair.server.inbound.drain(..).collect();
pair.client_stream(client_ch, si)?.stop()?;
pair.drive_client();
let (reconfig_packets, data_packets): (Vec<_>, Vec<_>) =
pair.server.inbound.drain(..).partition(|(_, _, raw)| {
let mut offset = 12usize;
while offset + 4 <= raw.len() {
if raw[offset] == 130 {
return true;
}
let chunk_len = u16::from_be_bytes([raw[offset + 2], raw[offset + 3]]) as usize;
if chunk_len < 4 {
break;
}
offset += (chunk_len + 3) & !3;
}
false
});
assert!(!reconfig_packets.is_empty(), "expected a RECONFIG packet");
for pkt in &reconfig_packets {
pair.server.inbound.push_back(pkt.clone());
}
pair.drive_server();
assert!(
pair.server_stream(server_ch, si).is_ok(),
"stream should survive InProgress reset"
);
for pkt in withheld {
pair.server.inbound.push_back(pkt);
}
for pkt in data_packets {
pair.server.inbound.push_back(pkt);
}
pair.drive();
assert!(
pair.server_stream(server_ch, si).is_err(),
"stream should be reset after TSN advance completes the InProgress request"
);
let _ = pair
.server_conn_mut(server_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)?;
assert!(
pair.server_stream(server_ch, si).is_ok(),
"new stream 1 should exist"
);
for pkt in reconfig_packets {
pair.server.inbound.push_back(pkt);
}
pair.drive();
assert!(
pair.server_stream(server_ch, si).is_ok(),
"new stream 1 should NOT be destroyed by retransmitted reconfig \
after InProgress completion via TSN advance"
);
Ok(())
}
#[test]
fn test_assoc_reset_inprogress_reconfig_retransmission() -> Result<()> {
let si: u16 = 1;
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
let _ = pair.client_stream(client_ch, si)?.write_sctp(
&Bytes::from_static(b"payload"),
PayloadProtocolIdentifier::Binary,
)?;
pair.drive_client();
let withheld: Vec<_> = pair.server.inbound.drain(..).collect();
pair.client_stream(client_ch, si)?.stop()?;
pair.drive_client();
let (reconfig_packets, data_packets): (Vec<_>, Vec<_>) =
pair.server.inbound.drain(..).partition(|(_, _, raw)| {
let mut offset = 12usize;
while offset + 4 <= raw.len() {
if raw[offset] == 130 {
return true;
}
let chunk_len = u16::from_be_bytes([raw[offset + 2], raw[offset + 3]]) as usize;
if chunk_len < 4 {
break;
}
offset += (chunk_len + 3) & !3; }
false
});
assert!(!reconfig_packets.is_empty(), "expected a RECONFIG packet");
for pkt in &reconfig_packets {
pair.server.inbound.push_back(pkt.clone());
}
pair.drive_server();
assert!(
pair.server_stream(server_ch, si).is_ok(),
"stream should survive InProgress reset"
);
for pkt in reconfig_packets {
pair.server.inbound.push_back(pkt);
}
for pkt in withheld {
pair.server.inbound.push_back(pkt);
}
for pkt in data_packets {
pair.server.inbound.push_back(pkt);
}
pair.drive();
assert!(
pair.server_stream(server_ch, si).is_err(),
"stream should be reset after InProgress retransmission completes"
);
Ok(())
}
#[test]
fn test_assoc_open_stream_rejects_pending_reset_id() -> Result<()> {
let si: u16 = 1;
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.server_stream(server_ch, si)?.stop()?;
pair.drive_server();
pair.drive_client();
assert!(
pair.client_stream(client_ch, si).is_err(),
"stream 1 should be removed from client"
);
match pair
.client_conn_mut(client_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)
{
Err(Error::ErrStreamResetPending) => {}
other => panic!("expected ErrStreamResetPending, got {:?}", other.err()),
}
pair.drive();
let _ = pair
.client_conn_mut(client_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)?;
assert!(
pair.client_stream(client_ch, si).is_ok(),
"stream 1 should be available after RE-CONFIG is acknowledged"
);
Ok(())
}
#[test]
fn test_assoc_reconfig_failure_clears_pending() -> Result<()> {
let si: u16 = 1;
let (mut pair, client_ch, server_ch) = create_association_pair(AckMode::NoDelay, 0)?;
establish_session_pair(&mut pair, client_ch, server_ch, si)?;
pair.server_stream(server_ch, si)?.stop()?;
pair.drive_server();
pair.drive_client();
pair.server.inbound.clear();
match pair
.client_conn_mut(client_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)
{
Err(Error::ErrStreamResetPending) => {}
Err(e) => panic!("expected ErrStreamResetPending, got Err({:?})", e),
Ok(_) => panic!("expected ErrStreamResetPending, got Ok"),
}
for _ in 0..10 {
pair.time += Duration::from_secs(61);
pair.drive_client();
pair.server.inbound.clear();
}
let _ = pair
.client_conn_mut(client_ch)
.open_stream(si, PayloadProtocolIdentifier::Binary)?;
assert!(
pair.client_stream(client_ch, si).is_ok(),
"stream 1 should be available after reconfig retransmission failure"
);
Ok(())
}
#[test]
fn test_snap_connect_established_and_transmit_uses_peer_verification_tag() {
let now = Instant::now();
let local_transport = TransportConfig::default();
let remote_transport = TransportConfig::default();
let local_init_bytes = generate_snap_token(&local_transport).expect("generate local init");
let remote_init_bytes = generate_snap_token(&remote_transport).expect("generate remote init");
let remote_init =
ChunkInit::unmarshal(&remote_init_bytes).expect("unmarshal remote INIT chunk");
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let client_config = ClientConfig::new().with_snap(local_init_bytes, remote_init_bytes);
let (_ch, mut assoc) = endpoint
.connect(client_config, remote_addr)
.expect("SNAP connect should succeed");
assert_eq!(assoc.state(), AssociationState::Established);
assert_matches!(assoc.poll(), Some(Event::Connected));
let mut stream = assoc
.open_stream(1, PayloadProtocolIdentifier::Binary)
.expect("open stream");
let msg = Bytes::from_static(b"hello");
stream
.write_sctp(&msg, PayloadProtocolIdentifier::Binary)
.expect("write_sctp");
let transmit = assoc
.poll_transmit(now)
.expect("expected at least one outbound datagram");
let Payload::RawEncode(datagrams) = transmit.payload else {
panic!("expected RawEncode transmit");
};
assert!(
!datagrams.is_empty(),
"expected at least one outbound packet"
);
let pkt = Packet::unmarshal(&datagrams[0]).expect("unmarshal outbound packet");
assert_eq!(
pkt.common_header.verification_tag, remote_init.initiate_tag,
"outbound packet should use peer initiate_tag as verification_tag"
);
}
#[test]
fn test_server_retransmitted_init_routes_to_existing_association() {
let now = Instant::now();
let mut endpoint = Endpoint::new(
Arc::new(EndpointConfig::default()),
Some(Arc::new(ServerConfig::default())),
);
let remote: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let init_tag: u32 = 0x1122_3344;
let init = ChunkInit {
is_ack: false,
initiate_tag: init_tag,
..Default::default()
};
let pkt = Packet {
common_header: CommonHeader {
source_port: 5000,
destination_port: 5000,
verification_tag: 0,
},
chunks: vec![Box::new(init)],
};
let bytes = pkt.marshal().expect("marshal INIT packet");
let (ch1, ev1) = endpoint
.handle(now, remote, None, None, bytes.clone())
.expect("first INIT should be handled");
assert!(
matches!(ev1, DatagramEvent::NewAssociation(_)),
"first INIT should create a new association"
);
let (ch2, ev2) = endpoint
.handle(now, remote, None, None, bytes)
.expect("retransmitted INIT should be handled");
assert_eq!(
ch2, ch1,
"retransmitted INIT should map to same association"
);
assert!(
matches!(ev2, DatagramEvent::AssociationEvent(_)),
"retransmitted INIT should be routed to existing association"
);
}
#[test]
fn test_snap_rejects_invalid_remote_bytes() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let client_config = ClientConfig::new().with_snap(local_init, Bytes::from_static(b"nope"));
let res = endpoint.connect(client_config, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Remote,
..
}))
),
"invalid remote bytes should fail with ParseFailed"
);
}
#[test]
fn test_snap_rejects_invalid_local_bytes() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let client_config = ClientConfig::new().with_snap(Bytes::from_static(b"nope"), remote_init);
let res = endpoint.connect(client_config, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Local,
..
}))
),
"invalid local bytes should fail with ParseFailed"
);
}
#[test]
fn test_snap_rejects_remote_init_ack() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let remote_init_ack = ChunkInit {
is_ack: true,
initiate_tag: 1,
..Default::default()
};
let remote_bytes = remote_init_ack.marshal().expect("marshal remote INIT-ACK");
let client_config = ClientConfig::new().with_snap(local_init, remote_bytes);
let res = endpoint.connect(client_config, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInitAck {
side: SnapSide::Remote
}))
),
"remote INIT-ACK should be rejected"
);
}
#[test]
fn test_snap_rejects_remote_init_with_zero_initiate_tag() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let remote_init = ChunkInit {
is_ack: false,
initiate_tag: 0,
..Default::default()
};
let remote_bytes = remote_init.marshal().expect("marshal remote INIT");
let client_config = ClientConfig::new().with_snap(local_init, remote_bytes);
let res = endpoint.connect(client_config, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ZeroInitiateTag {
side: SnapSide::Remote
}))
),
"remote INIT with initiate_tag=0 should be rejected"
);
}
#[test]
fn test_snap_partial_config_falls_back_to_handshake() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let client_config = ClientConfig {
local_sctp_init: Some(local_init),
..Default::default()
};
let (_ch, assoc) = endpoint
.connect(client_config, remote_addr)
.expect("partial SNAP config should fall back to handshake");
assert_eq!(
assoc.state(),
AssociationState::CookieWait,
"should use normal handshake when only local init provided"
);
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let client_config = ClientConfig {
remote_sctp_init: Some(remote_init),
..Default::default()
};
let (_ch, assoc) = endpoint
.connect(client_config, remote_addr)
.expect("partial SNAP config should fall back to handshake");
assert_eq!(
assoc.state(),
AssociationState::CookieWait,
"should use normal handshake when only remote init provided"
);
}
#[test]
fn test_snap_rejects_reused_local_init() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let remote1 = generate_snap_token(&TransportConfig::default()).expect("remote init 1");
let remote2 = generate_snap_token(&TransportConfig::default()).expect("remote init 2");
let cfg1 = ClientConfig::new().with_snap(local_init.clone(), remote1);
let res1 = endpoint.connect(cfg1, remote_addr);
assert!(res1.is_ok(), "first SNAP connect should succeed");
let cfg2 = ClientConfig::new().with_snap(local_init, remote2);
let res2 = endpoint.connect(cfg2, remote_addr);
assert!(
matches!(
res2,
Err(ConnectError::Snap(SnapError::AidCollision { .. }))
),
"second SNAP connect should fail due to AID collision"
);
}
#[test]
fn test_snap_rejects_local_init_ack() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let local_init_ack = ChunkInit {
is_ack: true,
initiate_tag: 1,
..Default::default()
};
let local_bytes = local_init_ack.marshal().expect("marshal local INIT-ACK");
let client_config = ClientConfig::new().with_snap(local_bytes, remote_init);
let res = endpoint.connect(client_config, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInitAck {
side: SnapSide::Local
}))
),
"local INIT-ACK should be rejected"
);
}
#[test]
fn test_snap_rejects_local_init_with_zero_initiate_tag() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let local_init = ChunkInit {
is_ack: false,
initiate_tag: 0,
..Default::default()
};
let local_bytes = local_init.marshal().expect("marshal local INIT");
let client_config = ClientConfig::new().with_snap(local_bytes, remote_init);
let res = endpoint.connect(client_config, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ZeroInitiateTag {
side: SnapSide::Local
}))
),
"local INIT with initiate_tag=0 should be rejected"
);
}
#[test]
fn test_snap_rejects_empty_bytes() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let valid = generate_snap_token(&TransportConfig::default()).expect("valid init");
let cfg = ClientConfig::new().with_snap(valid.clone(), Bytes::new());
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Remote,
..
}))
),
"empty remote bytes should fail parsing"
);
let cfg = ClientConfig::new().with_snap(Bytes::new(), valid);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Local,
..
}))
),
"empty local bytes should fail parsing"
);
}
#[test]
fn test_snap_rejects_truncated_bytes() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let valid = generate_snap_token(&TransportConfig::default()).expect("valid init");
let truncated = valid.slice(..10.min(valid.len()));
let cfg = ClientConfig::new().with_snap(valid.clone(), truncated.clone());
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Remote,
..
}))
),
"truncated remote bytes should fail parsing"
);
let cfg = ClientConfig::new().with_snap(truncated, valid);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Local,
..
}))
),
"truncated local bytes should fail parsing"
);
}
#[test]
fn test_snap_rejects_oversized_bytes() {
use crate::config::MAX_SNAP_INIT_BYTES;
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let valid = generate_snap_token(&TransportConfig::default()).expect("valid init");
let oversized = Bytes::from(vec![0u8; MAX_SNAP_INIT_BYTES + 1]);
let cfg = ClientConfig::new().with_snap(valid.clone(), oversized.clone());
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::OversizedInit {
side: SnapSide::Remote,
..
}))
),
"oversized remote bytes should be rejected"
);
let cfg = ClientConfig::new().with_snap(oversized, valid);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::OversizedInit {
side: SnapSide::Local,
..
}))
),
"oversized local bytes should be rejected"
);
}
#[test]
fn test_snap_rejects_remote_init_zero_inbound_streams() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let bad_remote = ChunkInit {
is_ack: false,
initiate_tag: 0xDEAD_BEEF,
initial_tsn: 1,
num_outbound_streams: 1,
num_inbound_streams: 0, advertised_receiver_window_credit: 65535,
..Default::default()
};
let remote_bytes = bad_remote.marshal().expect("marshal bad remote INIT");
let cfg = ClientConfig::new().with_snap(local_init, remote_bytes);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Remote,
..
}))
),
"remote INIT with num_inbound_streams=0 should be rejected: got {:?}",
res
);
}
#[test]
fn test_snap_rejects_remote_init_zero_outbound_streams() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let bad_remote = ChunkInit {
is_ack: false,
initiate_tag: 0xDEAD_BEEF,
initial_tsn: 1,
num_outbound_streams: 0, num_inbound_streams: 1,
advertised_receiver_window_credit: 65535,
..Default::default()
};
let remote_bytes = bad_remote.marshal().expect("marshal bad remote INIT");
let cfg = ClientConfig::new().with_snap(local_init, remote_bytes);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Remote,
..
}))
),
"remote INIT with num_outbound_streams=0 should be rejected: got {:?}",
res
);
}
#[test]
fn test_snap_rejects_remote_init_small_arwnd() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let bad_remote = ChunkInit {
is_ack: false,
initiate_tag: 0xDEAD_BEEF,
initial_tsn: 1,
num_outbound_streams: 1,
num_inbound_streams: 1,
advertised_receiver_window_credit: 100, ..Default::default()
};
let remote_bytes = bad_remote.marshal().expect("marshal bad remote INIT");
let cfg = ClientConfig::new().with_snap(local_init, remote_bytes);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Remote,
..
}))
),
"remote INIT with a_rwnd < 1500 should be rejected: got {:?}",
res
);
}
#[test]
fn test_snap_rejects_local_init_zero_inbound_streams() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let bad_local = ChunkInit {
is_ack: false,
initiate_tag: 0xCAFE_BABE,
initial_tsn: 1,
num_outbound_streams: 1,
num_inbound_streams: 0,
advertised_receiver_window_credit: 65535,
..Default::default()
};
let local_bytes = bad_local.marshal().expect("marshal bad local INIT");
let cfg = ClientConfig::new().with_snap(local_bytes, remote_init);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Local,
..
}))
),
"local INIT with num_inbound_streams=0 should be rejected: got {:?}",
res
);
}
#[test]
fn test_snap_rejects_local_init_small_arwnd() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let bad_local = ChunkInit {
is_ack: false,
initiate_tag: 0xCAFE_BABE,
initial_tsn: 1,
num_outbound_streams: 1,
num_inbound_streams: 1,
advertised_receiver_window_credit: 0,
..Default::default()
};
let local_bytes = bad_local.marshal().expect("marshal bad local INIT");
let cfg = ClientConfig::new().with_snap(local_bytes, remote_init);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Local,
..
}))
),
"local INIT with a_rwnd=0 should be rejected: got {:?}",
res
);
}
#[test]
fn test_snap_generate_token_roundtrip_is_valid() {
let config = TransportConfig::default();
let bytes = generate_snap_token(&config).expect("generate token");
assert!(
bytes.len() <= crate::config::MAX_SNAP_INIT_BYTES,
"generated token should be within MAX_SNAP_INIT_BYTES"
);
let init = ChunkInit::unmarshal(&bytes).expect("unmarshal generated token");
assert!(!init.is_ack, "generated token should be INIT, not INIT-ACK");
assert_ne!(init.initiate_tag, 0, "initiate_tag must not be zero");
assert_ne!(init.initial_tsn, 0, "initial_tsn must not be zero");
assert_eq!(init.num_outbound_streams, u16::MAX);
assert_eq!(init.num_inbound_streams, u16::MAX);
assert_eq!(
init.advertised_receiver_window_credit,
config.max_receive_buffer_size()
);
init.check()
.expect("check() should pass on generated token");
}
#[test]
fn test_snap_generate_token_unique_per_call() {
let config = TransportConfig::default();
let t1 = generate_snap_token(&config).expect("token 1");
let t2 = generate_snap_token(&config).expect("token 2");
assert_ne!(t1, t2, "two generated tokens should differ (random values)");
let i1 = ChunkInit::unmarshal(&t1).expect("parse t1");
let i2 = ChunkInit::unmarshal(&t2).expect("parse t2");
assert_ne!(
i1.initiate_tag, i2.initiate_tag,
"initiate_tags should differ"
);
}
#[test]
fn test_snap_rejects_wrong_chunk_type_bytes() {
use crate::chunk::Chunk;
use crate::chunk::chunk_selective_ack::ChunkSelectiveAck;
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let valid = generate_snap_token(&TransportConfig::default()).expect("valid init");
let sack = ChunkSelectiveAck {
cumulative_tsn_ack: 1,
advertised_receiver_window_credit: 65535,
gap_ack_blocks: vec![],
duplicate_tsn: vec![],
};
let sack_bytes = sack.marshal().expect("marshal SACK");
let cfg = ClientConfig::new().with_snap(valid, sack_bytes);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Remote,
..
}))
),
"SACK bytes should fail to parse as INIT: got {:?}",
res
);
}
#[test]
fn test_snap_rejects_both_sides_empty() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let cfg = ClientConfig::new().with_snap(Bytes::new(), Bytes::new());
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(res, Err(ConnectError::Snap(SnapError::ParseFailed { .. }))),
"two empty byte slices should fail"
);
}
#[test]
fn test_snap_rejects_random_garbage_bytes() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let valid = generate_snap_token(&TransportConfig::default()).expect("valid init");
let garbage = Bytes::from(vec![0xAB; 50]);
let cfg = ClientConfig::new().with_snap(valid.clone(), garbage.clone());
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Remote,
..
}))
),
"garbage remote bytes should fail: got {:?}",
res
);
let cfg = ClientConfig::new().with_snap(garbage, valid);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(
res,
Err(ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Local,
..
}))
),
"garbage local bytes should fail: got {:?}",
res
);
}
#[test]
fn test_snap_identical_local_and_remote_init() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let token = generate_snap_token(&TransportConfig::default()).expect("token");
let cfg = ClientConfig::new().with_snap(token.clone(), token);
let res = endpoint.connect(cfg, remote_addr);
assert!(
matches!(res, Err(ConnectError::Snap(SnapError::AidCollision { .. }))),
"same token for both sides should collide: got {:?}",
res
);
}
#[test]
fn test_snap_end_to_end_bidirectional_data() {
let endpoint_config = Arc::new(EndpointConfig::default());
let transport = TransportConfig::default();
let init_a = generate_snap_token(&transport).expect("init A");
let init_b = generate_snap_token(&transport).expect("init B");
let addr_a: SocketAddr = SocketAddr::new(
Ipv6Addr::LOCALHOST.into(),
CLIENT_PORTS.lock().unwrap().next().unwrap(),
);
let addr_b: SocketAddr = SocketAddr::new(
Ipv6Addr::LOCALHOST.into(),
CLIENT_PORTS.lock().unwrap().next().unwrap(),
);
let mut ep_a = Endpoint::new(endpoint_config.clone(), None);
let mut ep_b = Endpoint::new(endpoint_config, None);
let cfg_a = ClientConfig::new().with_snap(init_a.clone(), init_b.clone());
let (ch_a, mut assoc_a) = ep_a.connect(cfg_a, addr_b).expect("SNAP connect A");
let cfg_b = ClientConfig::new().with_snap(init_b, init_a);
let (ch_b, mut assoc_b) = ep_b.connect(cfg_b, addr_a).expect("SNAP connect B");
assert_eq!(assoc_a.state(), AssociationState::Established);
assert_eq!(assoc_b.state(), AssociationState::Established);
assert_matches!(assoc_a.poll(), Some(Event::Connected));
assert_matches!(assoc_b.poll(), Some(Event::Connected));
let now = Instant::now();
let si = 1u16;
let msg_a = Bytes::from_static(b"hello from A");
let mut stream_a = assoc_a
.open_stream(si, PayloadProtocolIdentifier::Binary)
.expect("open stream A");
stream_a
.write_sctp(&msg_a, PayloadProtocolIdentifier::Binary)
.expect("write A");
while let Some(transmit) = assoc_a.poll_transmit(now) {
if let Payload::RawEncode(datagrams) = transmit.payload {
for dgram in datagrams {
if let Some((handle, event)) = ep_b.handle(now, addr_a, None, None, dgram) {
assert_eq!(handle, ch_b);
if let DatagramEvent::AssociationEvent(ev) = event {
assoc_b.handle_event(ev);
}
}
}
}
}
let accepted = assoc_b.accept_stream().expect("accept stream on B");
assert_eq!(accepted.stream_identifier, si);
let mut buf = vec![0u8; 64];
let chunks = assoc_b
.stream(si)
.expect("get stream B")
.read_sctp()
.expect("read B")
.expect("expected data");
let n = chunks.read(&mut buf).expect("read payload");
assert_eq!(&buf[..n], b"hello from A");
let msg_b = Bytes::from_static(b"hello from B");
assoc_b
.stream(si)
.expect("get stream B for write")
.write_sctp(&msg_b, PayloadProtocolIdentifier::Binary)
.expect("write B");
while let Some(transmit) = assoc_b.poll_transmit(now) {
if let Payload::RawEncode(datagrams) = transmit.payload {
for dgram in datagrams {
if let Some((handle, event)) = ep_a.handle(now, addr_b, None, None, dgram) {
assert_eq!(handle, ch_a);
if let DatagramEvent::AssociationEvent(ev) = event {
assoc_a.handle_event(ev);
}
}
}
}
}
let chunks = assoc_a
.stream(si)
.expect("get stream A")
.read_sctp()
.expect("read A")
.expect("expected data from B");
let n = chunks.read(&mut buf).expect("read payload from B");
assert_eq!(&buf[..n], b"hello from B");
}
#[test]
fn test_snap_association_drains_cleanly() {
let mut endpoint = Endpoint::new(Arc::new(EndpointConfig::default()), None);
let remote_addr: SocketAddr = "127.0.0.1:5000".parse().unwrap();
let local_init = generate_snap_token(&TransportConfig::default()).expect("local init");
let remote_init = generate_snap_token(&TransportConfig::default()).expect("remote init");
let local_tag = ChunkInit::unmarshal(&local_init)
.expect("parse local")
.initiate_tag;
let remote_tag = ChunkInit::unmarshal(&remote_init)
.expect("parse remote")
.initiate_tag;
let cfg = ClientConfig::new().with_snap(local_init, remote_init);
let (ch, _assoc) = endpoint.connect(cfg, remote_addr).expect("SNAP connect");
assert!(
endpoint.association_ids.contains_key(&local_tag),
"local_tag should be in association_ids"
);
assert!(
endpoint.association_ids_init.contains_key(&remote_tag),
"remote_tag should be in association_ids_init"
);
let drain_event = EndpointEvent(EndpointEventInner::Drained);
endpoint.handle_event(ch, drain_event);
assert!(
!endpoint.association_ids.contains_key(&local_tag),
"local_tag should be removed from association_ids after drain"
);
assert!(
!endpoint.association_ids_init.contains_key(&remote_tag),
"remote_tag should be removed from association_ids_init after drain"
);
}