use crate::association::{
state::{AckMode, AckState, AssociationState},
stats::AssociationStats,
};
use crate::chunk::{
chunk_abort::ChunkAbort, chunk_cookie_ack::ChunkCookieAck, chunk_cookie_echo::ChunkCookieEcho,
chunk_error::ChunkError, chunk_forward_tsn::ChunkForwardTsn,
chunk_forward_tsn::ChunkForwardTsnStream, chunk_heartbeat::ChunkHeartbeat,
chunk_heartbeat_ack::ChunkHeartbeatAck, chunk_init::ChunkInit, chunk_init::ChunkInitAck,
chunk_payload_data::ChunkPayloadData, chunk_payload_data::PayloadProtocolIdentifier,
chunk_reconfig::ChunkReconfig, chunk_selective_ack::ChunkSelectiveAck,
chunk_shutdown::ChunkShutdown, chunk_shutdown_ack::ChunkShutdownAck,
chunk_shutdown_complete::ChunkShutdownComplete, chunk_type::CT_FORWARD_TSN, Chunk,
ErrorCauseUnrecognizedChunkType,
};
use crate::config::{ServerConfig, TransportConfig, COMMON_HEADER_SIZE, DATA_CHUNK_HEADER_SIZE};
use crate::error::{Error, Result};
use crate::packet::{CommonHeader, Packet};
use crate::param::{
param_heartbeat_info::ParamHeartbeatInfo,
param_outgoing_reset_request::ParamOutgoingResetRequest,
param_reconfig_response::{ParamReconfigResponse, ReconfigResult},
param_state_cookie::ParamStateCookie,
param_supported_extensions::ParamSupportedExtensions,
Param,
};
use crate::queue::{payload_queue::PayloadQueue, pending_queue::PendingQueue};
use crate::shared::{AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner};
use crate::util::{sna16lt, sna32gt, sna32gte, sna32lt, sna32lte};
use crate::{AssociationEvent, Payload, Side, Transmit};
use stream::{ReliabilityType, Stream, StreamEvent, StreamId, StreamState};
use timer::{RtoManager, Timer, TimerTable, ACK_INTERVAL};
use crate::association::stream::RecvSendState;
use bytes::Bytes;
use fxhash::FxHashMap;
use log::{debug, error, trace, warn};
use rand::random;
use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
pub(crate) mod state;
pub(crate) mod stats;
pub(crate) mod stream;
mod timer;
#[cfg(test)]
mod association_test;
#[derive(Debug, Error, Eq, Clone, PartialEq)]
pub enum AssociationError {
#[error("{0}")]
HandshakeFailed(#[from] Error),
#[error("transport error")]
TransportError,
#[error("aborted by peer")]
AssociationClosed,
#[error("closed by peer")]
ApplicationClosed,
#[error("reset by peer")]
Reset,
#[error("timed out")]
TimedOut,
#[error("closed")]
LocallyClosed,
}
#[derive(Debug)]
pub enum Event {
Connected,
AssociationLost {
reason: AssociationError,
},
Stream(StreamEvent),
DatagramReceived,
}
#[derive(Debug)]
pub struct Association {
side: Side,
state: AssociationState,
handshake_completed: bool,
max_message_size: u32,
inflight_queue_length: usize,
will_send_shutdown: bool,
bytes_received: usize,
bytes_sent: usize,
peer_verification_tag: u32,
my_verification_tag: u32,
my_next_tsn: u32,
peer_last_tsn: u32,
min_tsn2measure_rtt: u32,
will_send_forward_tsn: bool,
will_retransmit_fast: bool,
will_retransmit_reconfig: bool,
will_send_shutdown_ack: bool,
will_send_shutdown_complete: bool,
my_next_rsn: u32,
reconfigs: FxHashMap<u32, ChunkReconfig>,
reconfig_requests: FxHashMap<u32, ParamOutgoingResetRequest>,
remote_addr: SocketAddr,
local_ip: Option<IpAddr>,
source_port: u16,
destination_port: u16,
my_max_num_inbound_streams: u16,
my_max_num_outbound_streams: u16,
my_cookie: Option<ParamStateCookie>,
payload_queue: PayloadQueue,
inflight_queue: PayloadQueue,
pending_queue: PendingQueue,
control_queue: VecDeque<Packet>,
stream_queue: VecDeque<u16>,
pub(crate) mtu: u32,
max_payload_size: u32,
cumulative_tsn_ack_point: u32,
advanced_peer_tsn_ack_point: u32,
use_forward_tsn: bool,
pub(crate) rto_mgr: RtoManager,
timers: TimerTable,
max_receive_buffer_size: u32,
pub(crate) cwnd: u32,
rwnd: u32,
pub(crate) ssthresh: u32,
partial_bytes_acked: u32,
pub(crate) in_fast_recovery: bool,
fast_recover_exit_point: u32,
stored_init: Option<ChunkInit>,
stored_cookie_echo: Option<ChunkCookieEcho>,
pub(crate) streams: FxHashMap<StreamId, StreamState>,
events: VecDeque<Event>,
endpoint_events: VecDeque<EndpointEventInner>,
error: Option<AssociationError>,
delayed_ack_triggered: bool,
immediate_ack_triggered: bool,
pub(crate) stats: AssociationStats,
ack_state: AckState,
pub(crate) ack_mode: AckMode,
}
impl Default for Association {
fn default() -> Self {
Association {
side: Side::default(),
state: AssociationState::default(),
handshake_completed: false,
max_message_size: 0,
inflight_queue_length: 0,
will_send_shutdown: false,
bytes_received: 0,
bytes_sent: 0,
peer_verification_tag: 0,
my_verification_tag: 0,
my_next_tsn: 0,
peer_last_tsn: 0,
min_tsn2measure_rtt: 0,
will_send_forward_tsn: false,
will_retransmit_fast: false,
will_retransmit_reconfig: false,
will_send_shutdown_ack: false,
will_send_shutdown_complete: false,
my_next_rsn: 0,
reconfigs: FxHashMap::default(),
reconfig_requests: FxHashMap::default(),
remote_addr: SocketAddr::from_str("0.0.0.0:0").unwrap(),
local_ip: None,
source_port: 0,
destination_port: 0,
my_max_num_inbound_streams: 0,
my_max_num_outbound_streams: 0,
my_cookie: None,
payload_queue: PayloadQueue::default(),
inflight_queue: PayloadQueue::default(),
pending_queue: PendingQueue::default(),
control_queue: VecDeque::default(),
stream_queue: VecDeque::default(),
mtu: 0,
max_payload_size: 0,
cumulative_tsn_ack_point: 0,
advanced_peer_tsn_ack_point: 0,
use_forward_tsn: false,
rto_mgr: RtoManager::default(),
timers: TimerTable::default(),
max_receive_buffer_size: 0,
cwnd: 0,
rwnd: 0,
ssthresh: 0,
partial_bytes_acked: 0,
in_fast_recovery: false,
fast_recover_exit_point: 0,
stored_init: None,
stored_cookie_echo: None,
streams: FxHashMap::default(),
events: VecDeque::default(),
endpoint_events: VecDeque::default(),
error: None,
delayed_ack_triggered: false,
immediate_ack_triggered: false,
stats: AssociationStats::default(),
ack_state: AckState::default(),
ack_mode: AckMode::default(),
}
}
}
impl Association {
pub(crate) fn new(
server_config: Option<Arc<ServerConfig>>,
config: Arc<TransportConfig>,
max_payload_size: u32,
local_aid: AssociationId,
remote_addr: SocketAddr,
local_ip: Option<IpAddr>,
now: Instant,
) -> Self {
let side = if server_config.is_some() {
Side::Server
} else {
Side::Client
};
let mtu = max_payload_size + COMMON_HEADER_SIZE + DATA_CHUNK_HEADER_SIZE;
let cwnd = (2 * mtu).clamp(4380, 4 * mtu);
let mut tsn = random::<u32>();
if tsn == 0 {
tsn += 1;
}
let mut this = Association {
side,
handshake_completed: false,
max_receive_buffer_size: config.max_receive_buffer_size(),
max_message_size: config.max_message_size(),
my_max_num_outbound_streams: config.max_num_outbound_streams(),
my_max_num_inbound_streams: config.max_num_inbound_streams(),
max_payload_size,
rto_mgr: RtoManager::new(),
timers: TimerTable::new(),
mtu,
cwnd,
remote_addr,
local_ip,
my_verification_tag: local_aid,
my_next_tsn: tsn,
my_next_rsn: tsn,
min_tsn2measure_rtt: tsn,
cumulative_tsn_ack_point: tsn - 1,
advanced_peer_tsn_ack_point: tsn - 1,
error: None,
..Default::default()
};
if side.is_client() {
let mut init = ChunkInit {
initial_tsn: this.my_next_tsn,
num_outbound_streams: this.my_max_num_outbound_streams,
num_inbound_streams: this.my_max_num_inbound_streams,
initiate_tag: this.my_verification_tag,
advertised_receiver_window_credit: this.max_receive_buffer_size,
..Default::default()
};
init.set_supported_extensions();
this.set_state(AssociationState::CookieWait);
this.stored_init = Some(init);
let _ = this.send_init();
this.timers
.start(Timer::T1Init, now, this.rto_mgr.get_rto());
}
this
}
#[must_use]
pub fn poll(&mut self) -> Option<Event> {
if let Some(x) = self.events.pop_front() {
return Some(x);
}
if let Some(err) = self.error.take() {
return Some(Event::AssociationLost { reason: err });
}
None
}
#[must_use]
pub fn poll_endpoint_event(&mut self) -> Option<EndpointEvent> {
self.endpoint_events.pop_front().map(EndpointEvent)
}
#[must_use]
pub fn poll_timeout(&mut self) -> Option<Instant> {
self.timers.next_timeout()
}
#[must_use]
pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit> {
let (contents, _) = self.gather_outbound(now);
if contents.is_empty() {
None
} else {
trace!(
"[{}] sending {} bytes (total {} datagrams)",
self.side,
contents.iter().fold(0, |l, c| l + c.len()),
contents.len()
);
Some(Transmit {
now,
remote: self.remote_addr,
payload: Payload::RawEncode(contents),
ecn: None,
local_ip: self.local_ip,
})
}
}
pub fn handle_timeout(&mut self, now: Instant) {
for &timer in &Timer::VALUES {
let (expired, failure, n_rtos) = self.timers.is_expired(timer, now);
if !expired {
continue;
}
self.timers.set(timer, None);
if timer == Timer::Ack {
self.on_ack_timeout();
} else if failure {
self.on_retransmission_failure(timer);
} else {
self.on_retransmission_timeout(timer, n_rtos);
self.timers.start(timer, now, self.rto_mgr.get_rto());
}
}
}
pub fn handle_event(&mut self, event: AssociationEvent) {
match event.0 {
AssociationEventInner::Datagram(transmit) => {
if let Payload::PartialDecode(partial_decode) = transmit.payload {
debug!(
"[{}] recving {} bytes",
self.side,
COMMON_HEADER_SIZE as usize + partial_decode.remaining.len()
);
let pkt = match partial_decode.finish() {
Ok(p) => p,
Err(err) => {
warn!("[{}] unable to parse SCTP packet {}", self.side, err);
return;
}
};
if let Err(err) = self.handle_inbound(pkt, transmit.now) {
error!("handle_inbound got err: {}", err);
let _ = self.close();
}
} else {
trace!("discarding invalid partial_decode");
}
} }
}
pub fn stats(&self) -> AssociationStats {
self.stats
}
pub fn is_handshaking(&self) -> bool {
!self.handshake_completed
}
pub fn is_closed(&self) -> bool {
self.state == AssociationState::Closed
}
pub fn is_drained(&self) -> bool {
self.state.is_drained()
}
pub fn side(&self) -> Side {
self.side
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
pub fn rtt(&self) -> Duration {
Duration::from_millis(self.rto_mgr.get_rto())
}
pub fn local_ip(&self) -> Option<IpAddr> {
self.local_ip
}
pub fn shutdown(&mut self) -> Result<()> {
debug!("[{}] closing association..", self.side);
let state = self.state();
if state != AssociationState::Established {
return Err(Error::ErrShutdownNonEstablished);
}
self.set_state(AssociationState::ShutdownPending);
if self.inflight_queue_length == 0 {
self.will_send_shutdown = true;
self.awake_write_loop();
self.set_state(AssociationState::ShutdownSent);
}
self.endpoint_events.push_back(EndpointEventInner::Drained);
Ok(())
}
pub fn close(&mut self) -> Result<()> {
if self.state() != AssociationState::Closed {
self.set_state(AssociationState::Closed);
debug!("[{}] closing association..", self.side);
self.close_all_timers();
for si in self.streams.keys().cloned().collect::<Vec<u16>>() {
self.unregister_stream(si);
}
debug!("[{}] association closed", self.side);
debug!(
"[{}] stats nDATAs (in) : {}",
self.side,
self.stats.get_num_datas()
);
debug!(
"[{}] stats nSACKs (in) : {}",
self.side,
self.stats.get_num_sacks()
);
debug!(
"[{}] stats nT3Timeouts : {}",
self.side,
self.stats.get_num_t3timeouts()
);
debug!(
"[{}] stats nAckTimeouts: {}",
self.side,
self.stats.get_num_ack_timeouts()
);
debug!(
"[{}] stats nFastRetrans: {}",
self.side,
self.stats.get_num_fast_retrans()
);
}
Ok(())
}
pub fn open_stream(
&mut self,
stream_identifier: StreamId,
default_payload_type: PayloadProtocolIdentifier,
) -> Result<Stream<'_>> {
if self.streams.contains_key(&stream_identifier) {
return Err(Error::ErrStreamAlreadyExist);
}
if let Some(s) = self.create_stream(stream_identifier, false, default_payload_type) {
Ok(s)
} else {
Err(Error::ErrStreamCreateFailed)
}
}
pub fn accept_stream(&mut self) -> Option<Stream<'_>> {
self.stream_queue
.pop_front()
.map(move |stream_identifier| Stream {
stream_identifier,
association: self,
})
}
pub fn stream(&mut self, stream_identifier: StreamId) -> Result<Stream<'_>> {
if !self.streams.contains_key(&stream_identifier) {
Err(Error::ErrStreamNotExisted)
} else {
Ok(Stream {
stream_identifier,
association: self,
})
}
}
pub(crate) fn bytes_sent(&self) -> usize {
self.bytes_sent
}
pub(crate) fn bytes_received(&self) -> usize {
self.bytes_received
}
pub(crate) fn max_message_size(&self) -> u32 {
self.max_message_size
}
pub(crate) fn set_max_message_size(&mut self, max_message_size: u32) {
self.max_message_size = max_message_size;
}
fn unregister_stream(&mut self, stream_identifier: StreamId) {
if let Some(mut s) = self.streams.remove(&stream_identifier) {
debug!("[{}] unregister_stream {}", self.side, stream_identifier);
s.state = RecvSendState::Closed;
}
}
fn set_state(&mut self, new_state: AssociationState) {
if new_state != self.state {
debug!(
"[{}] state change: '{}' => '{}'",
self.side, self.state, new_state,
);
}
self.state = new_state;
}
pub(crate) fn state(&self) -> AssociationState {
self.state
}
fn send_init(&mut self) -> Result<()> {
if let Some(stored_init) = &self.stored_init {
debug!("[{}] sending INIT", self.side);
self.source_port = 5000; self.destination_port = 5000;
let outbound = Packet {
common_header: CommonHeader {
source_port: self.source_port,
destination_port: self.destination_port,
verification_tag: self.peer_verification_tag,
},
chunks: vec![Box::new(stored_init.clone())],
};
self.control_queue.push_back(outbound);
self.awake_write_loop();
Ok(())
} else {
Err(Error::ErrInitNotStoredToSend)
}
}
fn send_cookie_echo(&mut self) -> Result<()> {
if let Some(stored_cookie_echo) = &self.stored_cookie_echo {
debug!("[{}] sending COOKIE-ECHO", self.side);
let outbound = Packet {
common_header: CommonHeader {
source_port: self.source_port,
destination_port: self.destination_port,
verification_tag: self.peer_verification_tag,
},
chunks: vec![Box::new(stored_cookie_echo.clone())],
};
self.control_queue.push_back(outbound);
self.awake_write_loop();
Ok(())
} else {
Err(Error::ErrCookieEchoNotStoredToSend)
}
}
fn handle_inbound(&mut self, p: Packet, now: Instant) -> Result<()> {
if let Err(err) = p.check_packet() {
warn!("[{}] failed validating packet {}", self.side, err);
return Ok(());
}
self.handle_chunk_start();
for c in &p.chunks {
self.handle_chunk(&p, c, now)?;
}
self.handle_chunk_end(now);
Ok(())
}
fn handle_chunk_start(&mut self) {
self.delayed_ack_triggered = false;
self.immediate_ack_triggered = false;
}
fn handle_chunk_end(&mut self, now: Instant) {
if self.immediate_ack_triggered {
self.ack_state = AckState::Immediate;
self.timers.stop(Timer::Ack);
self.awake_write_loop();
} else if self.delayed_ack_triggered {
self.ack_state = AckState::Delay;
self.timers.start(Timer::Ack, now, ACK_INTERVAL);
}
}
#[allow(clippy::borrowed_box)]
fn handle_chunk(
&mut self,
p: &Packet,
chunk: &Box<dyn Chunk + Send + Sync>,
now: Instant,
) -> Result<()> {
chunk.check()?;
let chunk_any = chunk.as_any();
let packets = if let Some(c) = chunk_any.downcast_ref::<ChunkInit>() {
if c.is_ack {
self.handle_init_ack(p, c, now)?
} else {
self.handle_init(p, c)?
}
} else if let Some(c) = chunk_any.downcast_ref::<ChunkAbort>() {
let mut err_str = String::new();
for e in &c.error_causes {
err_str += &format!("({})", e);
}
return Err(Error::ErrAbortChunk(err_str));
} else if let Some(c) = chunk_any.downcast_ref::<ChunkError>() {
let mut err_str = String::new();
for e in &c.error_causes {
err_str += &format!("({})", e);
}
return Err(Error::ErrAbortChunk(err_str));
} else if let Some(c) = chunk_any.downcast_ref::<ChunkHeartbeat>() {
self.handle_heartbeat(c)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkCookieEcho>() {
self.handle_cookie_echo(c)?
} else if chunk_any.downcast_ref::<ChunkCookieAck>().is_some() {
self.handle_cookie_ack()?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkPayloadData>() {
self.handle_data(c)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkSelectiveAck>() {
self.handle_sack(c, now)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkReconfig>() {
self.handle_reconfig(c)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkForwardTsn>() {
self.handle_forward_tsn(c)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkShutdown>() {
self.handle_shutdown(c)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkShutdownAck>() {
self.handle_shutdown_ack(c)?
} else if let Some(c) = chunk_any.downcast_ref::<ChunkShutdownComplete>() {
self.handle_shutdown_complete(c)?
} else {
return Err(Error::ErrChunkTypeUnhandled);
};
if !packets.is_empty() {
let mut buf: VecDeque<_> = packets.into_iter().collect();
self.control_queue.append(&mut buf);
self.awake_write_loop();
}
Ok(())
}
fn handle_init(&mut self, p: &Packet, i: &ChunkInit) -> Result<Vec<Packet>> {
let state = self.state();
debug!("[{}] chunkInit received in state '{}'", self.side, state);
if state != AssociationState::Closed
&& state != AssociationState::CookieWait
&& state != AssociationState::CookieEchoed
{
return Err(Error::ErrHandleInitState);
}
self.my_max_num_inbound_streams =
std::cmp::min(i.num_inbound_streams, self.my_max_num_inbound_streams);
self.my_max_num_outbound_streams =
std::cmp::min(i.num_outbound_streams, self.my_max_num_outbound_streams);
self.peer_verification_tag = i.initiate_tag;
self.source_port = p.common_header.destination_port;
self.destination_port = p.common_header.source_port;
self.peer_last_tsn = if i.initial_tsn == 0 {
u32::MAX
} else {
i.initial_tsn - 1
};
for param in &i.params {
if let Some(v) = param.as_any().downcast_ref::<ParamSupportedExtensions>() {
for t in &v.chunk_types {
if *t == CT_FORWARD_TSN {
debug!("[{}] use ForwardTSN (on init)", self.side);
self.use_forward_tsn = true;
}
}
}
}
if !self.use_forward_tsn {
warn!("[{}] not using ForwardTSN (on init)", self.side);
}
let mut outbound = Packet {
common_header: CommonHeader {
verification_tag: self.peer_verification_tag,
source_port: self.source_port,
destination_port: self.destination_port,
},
chunks: vec![],
};
let mut init_ack = ChunkInit {
is_ack: true,
initial_tsn: self.my_next_tsn,
num_outbound_streams: self.my_max_num_outbound_streams,
num_inbound_streams: self.my_max_num_inbound_streams,
initiate_tag: self.my_verification_tag,
advertised_receiver_window_credit: self.max_receive_buffer_size,
..Default::default()
};
if self.my_cookie.is_none() {
self.my_cookie = Some(ParamStateCookie::new());
}
if let Some(my_cookie) = &self.my_cookie {
init_ack.params = vec![Box::new(my_cookie.clone())];
}
init_ack.set_supported_extensions();
outbound.chunks = vec![Box::new(init_ack)];
Ok(vec![outbound])
}
fn handle_init_ack(
&mut self,
p: &Packet,
i: &ChunkInitAck,
now: Instant,
) -> Result<Vec<Packet>> {
let state = self.state();
debug!("[{}] chunkInitAck received in state '{}'", self.side, state);
if state != AssociationState::CookieWait {
return Ok(vec![]);
}
self.my_max_num_inbound_streams =
std::cmp::min(i.num_inbound_streams, self.my_max_num_inbound_streams);
self.my_max_num_outbound_streams =
std::cmp::min(i.num_outbound_streams, self.my_max_num_outbound_streams);
self.peer_verification_tag = i.initiate_tag;
self.peer_last_tsn = if i.initial_tsn == 0 {
u32::MAX
} else {
i.initial_tsn - 1
};
if self.source_port != p.common_header.destination_port
|| self.destination_port != p.common_header.source_port
{
warn!("[{}] handle_init_ack: port mismatch", self.side);
return Ok(vec![]);
}
self.rwnd = i.advertised_receiver_window_credit;
debug!("[{}] initial rwnd={}", self.side, self.rwnd);
self.ssthresh = self.rwnd;
trace!(
"[{}] updated cwnd={} ssthresh={} inflight={} (INI)",
self.side,
self.cwnd,
self.ssthresh,
self.inflight_queue.get_num_bytes()
);
self.timers.stop(Timer::T1Init);
self.stored_init = None;
let mut cookie_param = None;
for param in &i.params {
if let Some(v) = param.as_any().downcast_ref::<ParamStateCookie>() {
cookie_param = Some(v);
} else if let Some(v) = param.as_any().downcast_ref::<ParamSupportedExtensions>() {
for t in &v.chunk_types {
if *t == CT_FORWARD_TSN {
debug!("[{}] use ForwardTSN (on initAck)", self.side);
self.use_forward_tsn = true;
}
}
}
}
if !self.use_forward_tsn {
warn!("[{}] not using ForwardTSN (on initAck)", self.side);
}
if let Some(v) = cookie_param {
self.stored_cookie_echo = Some(ChunkCookieEcho {
cookie: v.cookie.clone(),
});
self.send_cookie_echo()?;
self.timers
.start(Timer::T1Cookie, now, self.rto_mgr.get_rto());
self.set_state(AssociationState::CookieEchoed);
Ok(vec![])
} else {
Err(Error::ErrInitAckNoCookie)
}
}
fn handle_heartbeat(&self, c: &ChunkHeartbeat) -> Result<Vec<Packet>> {
trace!("[{}] chunkHeartbeat", self.side);
if let Some(p) = c.params.first() {
if let Some(hbi) = p.as_any().downcast_ref::<ParamHeartbeatInfo>() {
return Ok(vec![Packet {
common_header: CommonHeader {
verification_tag: self.peer_verification_tag,
source_port: self.source_port,
destination_port: self.destination_port,
},
chunks: vec![Box::new(ChunkHeartbeatAck {
params: vec![Box::new(ParamHeartbeatInfo {
heartbeat_information: hbi.heartbeat_information.clone(),
})],
})],
}]);
} else {
warn!(
"[{}] failed to handle Heartbeat, no ParamHeartbeatInfo",
self.side,
);
}
}
Ok(vec![])
}
fn handle_cookie_echo(&mut self, c: &ChunkCookieEcho) -> Result<Vec<Packet>> {
let state = self.state();
debug!("[{}] COOKIE-ECHO received in state '{}'", self.side, state);
if let Some(my_cookie) = &self.my_cookie {
match state {
AssociationState::Established => {
if my_cookie.cookie != c.cookie {
return Ok(vec![]);
}
}
AssociationState::Closed
| AssociationState::CookieWait
| AssociationState::CookieEchoed => {
if my_cookie.cookie != c.cookie {
return Ok(vec![]);
}
self.timers.stop(Timer::T1Init);
self.stored_init = None;
self.timers.stop(Timer::T1Cookie);
self.stored_cookie_echo = None;
self.events.push_back(Event::Connected);
self.set_state(AssociationState::Established);
self.handshake_completed = true;
}
_ => return Ok(vec![]),
};
} else {
debug!("[{}] COOKIE-ECHO received before initialization", self.side);
return Ok(vec![]);
}
Ok(vec![Packet {
common_header: CommonHeader {
verification_tag: self.peer_verification_tag,
source_port: self.source_port,
destination_port: self.destination_port,
},
chunks: vec![Box::new(ChunkCookieAck {})],
}])
}
fn handle_cookie_ack(&mut self) -> Result<Vec<Packet>> {
let state = self.state();
debug!("[{}] COOKIE-ACK received in state '{}'", self.side, state);
if state != AssociationState::CookieEchoed {
return Ok(vec![]);
}
self.timers.stop(Timer::T1Cookie);
self.stored_cookie_echo = None;
self.events.push_back(Event::Connected);
self.set_state(AssociationState::Established);
self.handshake_completed = true;
Ok(vec![])
}
fn handle_data(&mut self, d: &ChunkPayloadData) -> Result<Vec<Packet>> {
trace!(
"[{}] DATA: tsn={} immediateSack={} len={}",
self.side,
d.tsn,
d.immediate_sack,
d.user_data.len()
);
self.stats.inc_datas();
let can_push = self.payload_queue.can_push(d, self.peer_last_tsn);
let mut stream_handle_data = false;
if can_push {
if self.get_or_create_stream(d.stream_identifier).is_some() {
if self.get_my_receiver_window_credit() > 0 {
self.payload_queue.push(d.clone(), self.peer_last_tsn);
stream_handle_data = true;
} else {
if let Some(last_tsn) = self.payload_queue.get_last_tsn_received() {
if sna32lt(d.tsn, *last_tsn) {
debug!("[{}] receive buffer full, but accepted as this is a missing chunk with tsn={} ssn={}", self.side, d.tsn, d.stream_sequence_number);
self.payload_queue.push(d.clone(), self.peer_last_tsn);
stream_handle_data = true; }
} else {
debug!(
"[{}] receive buffer full. dropping DATA with tsn={} ssn={}",
self.side, d.tsn, d.stream_sequence_number
);
}
}
} else {
debug!("[{}] discard {}", self.side, d.stream_sequence_number);
return Ok(vec![]);
}
}
let immediate_sack = d.immediate_sack;
if stream_handle_data {
if let Some(s) = self.streams.get_mut(&d.stream_identifier) {
self.events.push_back(Event::DatagramReceived);
s.handle_data(d);
if s.reassembly_queue.is_readable() {
self.events.push_back(Event::Stream(StreamEvent::Readable {
id: d.stream_identifier,
}))
}
}
}
self.handle_peer_last_tsn_and_acknowledgement(immediate_sack)
}
fn handle_sack(&mut self, d: &ChunkSelectiveAck, now: Instant) -> Result<Vec<Packet>> {
trace!(
"[{}] {}, SACK: cumTSN={} a_rwnd={}",
self.side,
self.cumulative_tsn_ack_point,
d.cumulative_tsn_ack,
d.advertised_receiver_window_credit
);
let state = self.state();
if state != AssociationState::Established
&& state != AssociationState::ShutdownPending
&& state != AssociationState::ShutdownReceived
{
return Ok(vec![]);
}
self.stats.inc_sacks();
if sna32gt(self.cumulative_tsn_ack_point, d.cumulative_tsn_ack) {
debug!(
"[{}] SACK Cumulative ACK {} is older than ACK point {}",
self.side, d.cumulative_tsn_ack, self.cumulative_tsn_ack_point
);
return Ok(vec![]);
}
let (bytes_acked_per_stream, htna) = self.process_selective_ack(d, now)?;
let mut total_bytes_acked = 0;
for n_bytes_acked in bytes_acked_per_stream.values() {
total_bytes_acked += *n_bytes_acked;
}
let mut cum_tsn_ack_point_advanced = false;
if sna32lt(self.cumulative_tsn_ack_point, d.cumulative_tsn_ack) {
trace!(
"[{}] SACK: cumTSN advanced: {} -> {}",
self.side,
self.cumulative_tsn_ack_point,
d.cumulative_tsn_ack
);
self.cumulative_tsn_ack_point = d.cumulative_tsn_ack;
cum_tsn_ack_point_advanced = true;
self.on_cumulative_tsn_ack_point_advanced(total_bytes_acked, now);
}
for (si, n_bytes_acked) in &bytes_acked_per_stream {
if let Some(s) = self.streams.get_mut(si) {
if s.on_buffer_released(*n_bytes_acked) {
self.events
.push_back(Event::Stream(StreamEvent::BufferedAmountLow { id: *si }))
}
}
}
let bytes_outstanding = self.inflight_queue.get_num_bytes() as u32;
if bytes_outstanding >= d.advertised_receiver_window_credit {
self.rwnd = 0;
} else {
self.rwnd = d.advertised_receiver_window_credit - bytes_outstanding;
}
self.process_fast_retransmission(d.cumulative_tsn_ack, htna, cum_tsn_ack_point_advanced)?;
if self.use_forward_tsn {
if sna32lt(
self.advanced_peer_tsn_ack_point,
self.cumulative_tsn_ack_point,
) {
self.advanced_peer_tsn_ack_point = self.cumulative_tsn_ack_point
}
let mut i = self.advanced_peer_tsn_ack_point + 1;
while let Some(c) = self.inflight_queue.get(i) {
if !c.abandoned() {
break;
}
self.advanced_peer_tsn_ack_point = i;
i += 1;
}
if sna32gt(
self.advanced_peer_tsn_ack_point,
self.cumulative_tsn_ack_point,
) {
self.will_send_forward_tsn = true;
debug!(
"[{}] handleSack {}: sna32GT({}, {})",
self.side,
self.will_send_forward_tsn,
self.advanced_peer_tsn_ack_point,
self.cumulative_tsn_ack_point
);
}
self.awake_write_loop();
}
self.postprocess_sack(state, cum_tsn_ack_point_advanced, now);
Ok(vec![])
}
fn handle_reconfig(&mut self, c: &ChunkReconfig) -> Result<Vec<Packet>> {
trace!("[{}] handle_reconfig", self.side);
let mut pp = vec![];
if let Some(param_a) = &c.param_a {
if let Some(p) = self.handle_reconfig_param(param_a)? {
pp.push(p);
}
}
if let Some(param_b) = &c.param_b {
if let Some(p) = self.handle_reconfig_param(param_b)? {
pp.push(p);
}
}
Ok(pp)
}
fn handle_forward_tsn(&mut self, c: &ChunkForwardTsn) -> Result<Vec<Packet>> {
trace!("[{}] FwdTSN: {}", self.side, c.to_string());
if !self.use_forward_tsn {
warn!("[{}] received FwdTSN but not enabled", self.side);
let cerr = ChunkError {
error_causes: vec![ErrorCauseUnrecognizedChunkType::default()],
};
let outbound = Packet {
common_header: CommonHeader {
verification_tag: self.peer_verification_tag,
source_port: self.source_port,
destination_port: self.destination_port,
},
chunks: vec![Box::new(cerr)],
};
return Ok(vec![outbound]);
}
trace!(
"[{}] should send ack? newCumTSN={} peer_last_tsn={}",
self.side,
c.new_cumulative_tsn,
self.peer_last_tsn
);
if sna32lte(c.new_cumulative_tsn, self.peer_last_tsn) {
trace!("[{}] sending ack on Forward TSN", self.side);
self.ack_state = AckState::Immediate;
self.timers.stop(Timer::Ack);
self.awake_write_loop();
return Ok(vec![]);
}
while sna32lt(self.peer_last_tsn, c.new_cumulative_tsn) {
self.payload_queue.pop(self.peer_last_tsn + 1); self.peer_last_tsn += 1;
}
for forwarded in &c.streams {
if let Some(s) = self.streams.get_mut(&forwarded.identifier) {
s.handle_forward_tsn_for_ordered(forwarded.sequence);
}
}
for s in self.streams.values_mut() {
s.handle_forward_tsn_for_unordered(c.new_cumulative_tsn);
}
self.handle_peer_last_tsn_and_acknowledgement(false)
}
fn handle_shutdown(&mut self, _: &ChunkShutdown) -> Result<Vec<Packet>> {
let state = self.state();
if state == AssociationState::Established {
if !self.inflight_queue.is_empty() {
self.set_state(AssociationState::ShutdownReceived);
} else {
self.will_send_shutdown_ack = true;
self.set_state(AssociationState::ShutdownAckSent);
self.awake_write_loop();
}
} else if state == AssociationState::ShutdownSent {
self.will_send_shutdown_ack = true;
self.set_state(AssociationState::ShutdownAckSent);
self.awake_write_loop();
}
Ok(vec![])
}
fn handle_shutdown_ack(&mut self, _: &ChunkShutdownAck) -> Result<Vec<Packet>> {
let state = self.state();
if state == AssociationState::ShutdownSent || state == AssociationState::ShutdownAckSent {
self.timers.stop(Timer::T2Shutdown);
self.will_send_shutdown_complete = true;
self.awake_write_loop();
}
Ok(vec![])
}
fn handle_shutdown_complete(&mut self, _: &ChunkShutdownComplete) -> Result<Vec<Packet>> {
let state = self.state();
if state == AssociationState::ShutdownAckSent {
self.timers.stop(Timer::T2Shutdown);
self.close()?;
}
Ok(vec![])
}
fn handle_peer_last_tsn_and_acknowledgement(
&mut self,
sack_immediately: bool,
) -> Result<Vec<Packet>> {
let mut reply = vec![];
while self.payload_queue.pop(self.peer_last_tsn + 1).is_some() {
self.peer_last_tsn += 1;
let rst_reqs: Vec<ParamOutgoingResetRequest> =
self.reconfig_requests.values().cloned().collect();
for rst_req in rst_reqs {
let resp = self.reset_streams_if_any(&rst_req);
debug!("[{}] RESET RESPONSE: {}", self.side, resp);
reply.push(resp);
}
}
let has_packet_loss = !self.payload_queue.is_empty();
if has_packet_loss {
trace!(
"[{}] packetloss: {}",
self.side,
self.payload_queue
.get_gap_ack_blocks_string(self.peer_last_tsn)
);
}
if (self.ack_state != AckState::Immediate
&& !sack_immediately
&& !has_packet_loss
&& self.ack_mode == AckMode::Normal)
|| self.ack_mode == AckMode::AlwaysDelay
{
if self.ack_state == AckState::Idle {
self.delayed_ack_triggered = true;
} else {
self.immediate_ack_triggered = true;
}
} else {
self.immediate_ack_triggered = true;
}
Ok(reply)
}
#[allow(clippy::borrowed_box)]
fn handle_reconfig_param(
&mut self,
raw: &Box<dyn Param + Send + Sync>,
) -> Result<Option<Packet>> {
if let Some(p) = raw.as_any().downcast_ref::<ParamOutgoingResetRequest>() {
self.reconfig_requests
.insert(p.reconfig_request_sequence_number, p.clone());
Ok(Some(self.reset_streams_if_any(p)))
} else if let Some(p) = raw.as_any().downcast_ref::<ParamReconfigResponse>() {
self.reconfigs.remove(&p.reconfig_response_sequence_number);
if self.reconfigs.is_empty() {
self.timers.stop(Timer::Reconfig);
}
Ok(None)
} else {
Err(Error::ErrParameterType)
}
}
fn process_selective_ack(
&mut self,
d: &ChunkSelectiveAck,
now: Instant,
) -> Result<(HashMap<u16, i64>, u32)> {
let mut bytes_acked_per_stream = HashMap::new();
let mut i = self.cumulative_tsn_ack_point + 1;
while sna32lte(i, d.cumulative_tsn_ack) {
if let Some(c) = self.inflight_queue.pop(i) {
if !c.acked {
if i == self.cumulative_tsn_ack_point + 1 {
self.timers.stop(Timer::T3RTX);
}
let n_bytes_acked = c.user_data.len() as i64;
if let Some(amount) = bytes_acked_per_stream.get_mut(&c.stream_identifier) {
*amount += n_bytes_acked;
} else {
bytes_acked_per_stream.insert(c.stream_identifier, n_bytes_acked);
}
if c.nsent == 1 && sna32gte(c.tsn, self.min_tsn2measure_rtt) {
self.min_tsn2measure_rtt = self.my_next_tsn;
if let Some(since) = &c.since {
let rtt = now.duration_since(*since);
let srtt = self.rto_mgr.set_new_rtt(rtt.as_millis() as u64);
trace!(
"[{}] SACK: measured-rtt={} srtt={} new-rto={}",
self.side,
rtt.as_millis(),
srtt,
self.rto_mgr.get_rto()
);
} else {
error!("[{}] invalid c.since", self.side);
}
}
}
if self.in_fast_recovery && c.tsn == self.fast_recover_exit_point {
debug!("[{}] exit fast-recovery", self.side);
self.in_fast_recovery = false;
}
} else {
return Err(Error::ErrInflightQueueTsnPop);
}
i += 1;
}
let mut htna = d.cumulative_tsn_ack;
for g in &d.gap_ack_blocks {
for i in g.start..=g.end {
let tsn = d.cumulative_tsn_ack + i as u32;
let (is_existed, is_acked) = if let Some(c) = self.inflight_queue.get(tsn) {
(true, c.acked)
} else {
(false, false)
};
let n_bytes_acked = if is_existed && !is_acked {
self.inflight_queue.mark_as_acked(tsn) as i64
} else {
0
};
if let Some(c) = self.inflight_queue.get(tsn) {
if !is_acked {
if let Some(amount) = bytes_acked_per_stream.get_mut(&c.stream_identifier) {
*amount += n_bytes_acked;
} else {
bytes_acked_per_stream.insert(c.stream_identifier, n_bytes_acked);
}
trace!("[{}] tsn={} has been sacked", self.side, c.tsn);
if c.nsent == 1 {
self.min_tsn2measure_rtt = self.my_next_tsn;
if let Some(since) = &c.since {
let rtt = now.duration_since(*since);
let srtt = self.rto_mgr.set_new_rtt(rtt.as_millis() as u64);
trace!(
"[{}] SACK: measured-rtt={} srtt={} new-rto={}",
self.side,
rtt.as_millis(),
srtt,
self.rto_mgr.get_rto()
);
} else {
error!("[{}] invalid c.since", self.side);
}
}
if sna32lt(htna, tsn) {
htna = tsn;
}
}
} else {
return Err(Error::ErrTsnRequestNotExist);
}
}
}
Ok((bytes_acked_per_stream, htna))
}
fn on_cumulative_tsn_ack_point_advanced(&mut self, total_bytes_acked: i64, now: Instant) {
if self.inflight_queue.is_empty() {
trace!(
"[{}] SACK: no more packet in-flight (pending={})",
self.side,
self.pending_queue.len()
);
self.timers.stop(Timer::T3RTX);
} else {
trace!("[{}] T3-rtx timer start (pt2)", self.side);
self.timers.start(Timer::T3RTX, now, self.rto_mgr.get_rto());
}
if self.cwnd <= self.ssthresh {
if !self.in_fast_recovery && !self.pending_queue.is_empty() {
self.cwnd += std::cmp::min(total_bytes_acked as u32, self.cwnd); trace!(
"[{}] updated cwnd={} ssthresh={} acked={} (SS)",
self.side,
self.cwnd,
self.ssthresh,
total_bytes_acked
);
} else {
trace!(
"[{}] cwnd did not grow: cwnd={} ssthresh={} acked={} FR={} pending={}",
self.side,
self.cwnd,
self.ssthresh,
total_bytes_acked,
self.in_fast_recovery,
self.pending_queue.len()
);
}
} else {
self.partial_bytes_acked += total_bytes_acked as u32;
if self.partial_bytes_acked >= self.cwnd && !self.pending_queue.is_empty() {
self.partial_bytes_acked -= self.cwnd;
self.cwnd += self.mtu;
trace!(
"[{}] updated cwnd={} ssthresh={} acked={} (CA)",
self.side,
self.cwnd,
self.ssthresh,
total_bytes_acked
);
}
}
}
fn process_fast_retransmission(
&mut self,
cum_tsn_ack_point: u32,
htna: u32,
cum_tsn_ack_point_advanced: bool,
) -> Result<()> {
if !self.in_fast_recovery || cum_tsn_ack_point_advanced {
let max_tsn = if !self.in_fast_recovery {
htna
} else {
cum_tsn_ack_point + (self.inflight_queue.len() as u32) + 1
};
let mut tsn = cum_tsn_ack_point + 1;
while sna32lt(tsn, max_tsn) {
if let Some(c) = self.inflight_queue.get_mut(tsn) {
if !c.acked && !c.abandoned() && c.miss_indicator < 3 {
c.miss_indicator += 1;
if c.miss_indicator == 3 && !self.in_fast_recovery {
self.in_fast_recovery = true;
self.fast_recover_exit_point = htna;
self.ssthresh = std::cmp::max(self.cwnd / 2, 4 * self.mtu);
self.cwnd = self.ssthresh;
self.partial_bytes_acked = 0;
self.will_retransmit_fast = true;
trace!(
"[{}] updated cwnd={} ssthresh={} inflight={} (FR)",
self.side,
self.cwnd,
self.ssthresh,
self.inflight_queue.get_num_bytes()
);
}
}
} else {
return Err(Error::ErrTsnRequestNotExist);
}
tsn += 1;
}
}
if self.in_fast_recovery && cum_tsn_ack_point_advanced {
self.will_retransmit_fast = true;
}
Ok(())
}
fn postprocess_sack(
&mut self,
state: AssociationState,
mut should_awake_write_loop: bool,
now: Instant,
) {
if !self.inflight_queue.is_empty() {
trace!("[{}] T3-rtx timer start (pt3)", self.side);
self.timers.start(Timer::T3RTX, now, self.rto_mgr.get_rto());
} else if state == AssociationState::ShutdownPending {
should_awake_write_loop = true;
self.will_send_shutdown = true;
self.set_state(AssociationState::ShutdownSent);
} else if state == AssociationState::ShutdownReceived {
should_awake_write_loop = true;
self.will_send_shutdown_ack = true;
self.set_state(AssociationState::ShutdownAckSent);
}
if should_awake_write_loop {
self.awake_write_loop();
}
}
fn reset_streams_if_any(&mut self, p: &ParamOutgoingResetRequest) -> Packet {
let mut result = ReconfigResult::SuccessPerformed;
if sna32lte(p.sender_last_tsn, self.peer_last_tsn) {
debug!(
"[{}] resetStream(): senderLastTSN={} <= peer_last_tsn={}",
self.side, p.sender_last_tsn, self.peer_last_tsn
);
for id in &p.stream_identifiers {
if self.streams.contains_key(id) {
self.unregister_stream(*id);
}
}
self.reconfig_requests
.remove(&p.reconfig_request_sequence_number);
} else {
debug!(
"[{}] resetStream(): senderLastTSN={} > peer_last_tsn={}",
self.side, p.sender_last_tsn, self.peer_last_tsn
);
result = ReconfigResult::InProgress;
}
self.create_packet(vec![Box::new(ChunkReconfig {
param_a: Some(Box::new(ParamReconfigResponse {
reconfig_response_sequence_number: p.reconfig_request_sequence_number,
result,
})),
param_b: None,
})])
}
pub(crate) fn create_packet(&self, chunks: Vec<Box<dyn Chunk + Send + Sync>>) -> Packet {
Packet {
common_header: CommonHeader {
verification_tag: self.peer_verification_tag,
source_port: self.source_port,
destination_port: self.destination_port,
},
chunks,
}
}
fn create_stream(
&mut self,
stream_identifier: StreamId,
accept: bool,
default_payload_type: PayloadProtocolIdentifier,
) -> Option<Stream<'_>> {
let s = StreamState::new(
self.side,
stream_identifier,
self.max_payload_size,
default_payload_type,
);
if accept {
self.stream_queue.push_back(stream_identifier);
self.events.push_back(Event::Stream(StreamEvent::Opened));
}
self.streams.insert(stream_identifier, s);
Some(Stream {
stream_identifier,
association: self,
})
}
fn get_or_create_stream(&mut self, stream_identifier: StreamId) -> Option<Stream<'_>> {
if self.streams.contains_key(&stream_identifier) {
Some(Stream {
stream_identifier,
association: self,
})
} else {
self.create_stream(
stream_identifier,
true,
PayloadProtocolIdentifier::default(),
)
}
}
pub(crate) fn get_my_receiver_window_credit(&self) -> u32 {
let mut bytes_queued = 0;
for s in self.streams.values() {
bytes_queued += s.get_num_bytes_in_reassembly_queue() as u32;
}
if bytes_queued >= self.max_receive_buffer_size {
0
} else {
self.max_receive_buffer_size - bytes_queued
}
}
fn gather_outbound(&mut self, now: Instant) -> (Vec<Bytes>, bool) {
let mut raw_packets = vec![];
if !self.control_queue.is_empty() {
for p in self.control_queue.drain(..) {
if let Ok(raw) = p.marshal() {
raw_packets.push(raw);
} else {
warn!("[{}] failed to serialize a control packet", self.side);
continue;
}
}
}
let state = self.state();
match state {
AssociationState::Established => {
raw_packets = self.gather_data_packets_to_retransmit(raw_packets, now);
raw_packets = self.gather_outbound_data_and_reconfig_packets(raw_packets, now);
raw_packets = self.gather_outbound_fast_retransmission_packets(raw_packets, now);
raw_packets = self.gather_outbound_sack_packets(raw_packets);
raw_packets = self.gather_outbound_forward_tsn_packets(raw_packets);
(raw_packets, true)
}
AssociationState::ShutdownPending
| AssociationState::ShutdownSent
| AssociationState::ShutdownReceived => {
raw_packets = self.gather_data_packets_to_retransmit(raw_packets, now);
raw_packets = self.gather_outbound_fast_retransmission_packets(raw_packets, now);
raw_packets = self.gather_outbound_sack_packets(raw_packets);
self.gather_outbound_shutdown_packets(raw_packets, now)
}
AssociationState::ShutdownAckSent => {
self.gather_outbound_shutdown_packets(raw_packets, now)
}
_ => (raw_packets, true),
}
}
fn gather_data_packets_to_retransmit(
&mut self,
mut raw_packets: Vec<Bytes>,
now: Instant,
) -> Vec<Bytes> {
for p in &self.get_data_packets_to_retransmit(now) {
if let Ok(raw) = p.marshal() {
raw_packets.push(raw);
} else {
warn!(
"[{}] failed to serialize a DATA packet to be retransmitted",
self.side
);
}
}
raw_packets
}
fn gather_outbound_data_and_reconfig_packets(
&mut self,
mut raw_packets: Vec<Bytes>,
now: Instant,
) -> Vec<Bytes> {
let (chunks, sis_to_reset) = self.pop_pending_data_chunks_to_send(now);
if !chunks.is_empty() {
trace!("[{}] T3-rtx timer start (pt1)", self.side);
self.timers.start(Timer::T3RTX, now, self.rto_mgr.get_rto());
for p in &self.bundle_data_chunks_into_packets(chunks) {
if let Ok(raw) = p.marshal() {
raw_packets.push(raw);
} else {
warn!("[{}] failed to serialize a DATA packet", self.side);
}
}
}
if !sis_to_reset.is_empty() || self.will_retransmit_reconfig {
if self.will_retransmit_reconfig {
self.will_retransmit_reconfig = false;
debug!(
"[{}] retransmit {} RECONFIG chunk(s)",
self.side,
self.reconfigs.len()
);
for c in self.reconfigs.values() {
let p = self.create_packet(vec![Box::new(c.clone())]);
if let Ok(raw) = p.marshal() {
raw_packets.push(raw);
} else {
warn!(
"[{}] failed to serialize a RECONFIG packet to be retransmitted",
self.side,
);
}
}
}
if !sis_to_reset.is_empty() {
let rsn = self.generate_next_rsn();
let tsn = self.my_next_tsn - 1;
debug!(
"[{}] sending RECONFIG: rsn={} tsn={} streams={:?}",
self.side,
rsn,
self.my_next_tsn - 1,
sis_to_reset
);
let c = ChunkReconfig {
param_a: Some(Box::new(ParamOutgoingResetRequest {
reconfig_request_sequence_number: rsn,
sender_last_tsn: tsn,
stream_identifiers: sis_to_reset,
..Default::default()
})),
..Default::default()
};
self.reconfigs.insert(rsn, c.clone());
let p = self.create_packet(vec![Box::new(c)]);
if let Ok(raw) = p.marshal() {
raw_packets.push(raw);
} else {
warn!(
"[{}] failed to serialize a RECONFIG packet to be transmitted",
self.side
);
}
}
if !self.reconfigs.is_empty() {
self.timers
.start(Timer::Reconfig, now, self.rto_mgr.get_rto());
}
}
raw_packets
}
fn gather_outbound_fast_retransmission_packets(
&mut self,
mut raw_packets: Vec<Bytes>,
now: Instant,
) -> Vec<Bytes> {
if self.will_retransmit_fast {
self.will_retransmit_fast = false;
let mut to_fast_retrans: Vec<Box<dyn Chunk + Send + Sync>> = vec![];
let mut fast_retrans_size = COMMON_HEADER_SIZE;
let mut i = 0;
loop {
let tsn = self.cumulative_tsn_ack_point + i + 1;
if let Some(c) = self.inflight_queue.get_mut(tsn) {
if c.acked || c.abandoned() || c.nsent > 1 || c.miss_indicator < 3 {
i += 1;
continue;
}
let data_chunk_size = DATA_CHUNK_HEADER_SIZE + c.user_data.len() as u32;
if self.mtu < fast_retrans_size + data_chunk_size {
break;
}
fast_retrans_size += data_chunk_size;
self.stats.inc_fast_retrans();
c.nsent += 1;
} else {
break; }
if let Some(c) = self.inflight_queue.get_mut(tsn) {
Association::check_partial_reliability_status(
c,
now,
self.use_forward_tsn,
self.side,
&self.streams,
);
to_fast_retrans.push(Box::new(c.clone()));
trace!(
"[{}] fast-retransmit: tsn={} sent={} htna={}",
self.side,
c.tsn,
c.nsent,
self.fast_recover_exit_point
);
}
i += 1;
}
if !to_fast_retrans.is_empty() {
if let Ok(raw) = self.create_packet(to_fast_retrans).marshal() {
raw_packets.push(raw);
} else {
warn!(
"[{}] failed to serialize a DATA packet to be fast-retransmitted",
self.side
);
}
}
}
raw_packets
}
fn gather_outbound_sack_packets(&mut self, mut raw_packets: Vec<Bytes>) -> Vec<Bytes> {
if self.ack_state == AckState::Immediate {
self.ack_state = AckState::Idle;
let sack = self.create_selective_ack_chunk();
debug!("[{}] sending SACK: {}", self.side, sack);
if let Ok(raw) = self.create_packet(vec![Box::new(sack)]).marshal() {
raw_packets.push(raw);
} else {
warn!("[{}] failed to serialize a SACK packet", self.side);
}
}
raw_packets
}
fn gather_outbound_forward_tsn_packets(&mut self, mut raw_packets: Vec<Bytes>) -> Vec<Bytes> {
if self.will_send_forward_tsn {
self.will_send_forward_tsn = false;
if sna32gt(
self.advanced_peer_tsn_ack_point,
self.cumulative_tsn_ack_point,
) {
let fwd_tsn = self.create_forward_tsn();
if let Ok(raw) = self.create_packet(vec![Box::new(fwd_tsn)]).marshal() {
raw_packets.push(raw);
} else {
warn!("[{}] failed to serialize a Forward TSN packet", self.side);
}
}
}
raw_packets
}
fn gather_outbound_shutdown_packets(
&mut self,
mut raw_packets: Vec<Bytes>,
now: Instant,
) -> (Vec<Bytes>, bool) {
let mut ok = true;
if self.will_send_shutdown {
self.will_send_shutdown = false;
let shutdown = ChunkShutdown {
cumulative_tsn_ack: self.cumulative_tsn_ack_point,
};
if let Ok(raw) = self.create_packet(vec![Box::new(shutdown)]).marshal() {
self.timers
.start(Timer::T2Shutdown, now, self.rto_mgr.get_rto());
raw_packets.push(raw);
} else {
warn!("[{}] failed to serialize a Shutdown packet", self.side);
}
} else if self.will_send_shutdown_ack {
self.will_send_shutdown_ack = false;
let shutdown_ack = ChunkShutdownAck {};
if let Ok(raw) = self.create_packet(vec![Box::new(shutdown_ack)]).marshal() {
self.timers
.start(Timer::T2Shutdown, now, self.rto_mgr.get_rto());
raw_packets.push(raw);
} else {
warn!("[{}] failed to serialize a ShutdownAck packet", self.side);
}
} else if self.will_send_shutdown_complete {
self.will_send_shutdown_complete = false;
let shutdown_complete = ChunkShutdownComplete {};
if let Ok(raw) = self
.create_packet(vec![Box::new(shutdown_complete)])
.marshal()
{
raw_packets.push(raw);
ok = false;
} else {
warn!(
"[{}] failed to serialize a ShutdownComplete packet",
self.side
);
}
}
(raw_packets, ok)
}
fn get_data_packets_to_retransmit(&mut self, now: Instant) -> Vec<Packet> {
let awnd = std::cmp::min(self.cwnd, self.rwnd);
let mut chunks = vec![];
let mut bytes_to_send = 0;
let mut done = false;
let mut i = 0;
while !done {
let tsn = self.cumulative_tsn_ack_point + i + 1;
if let Some(c) = self.inflight_queue.get_mut(tsn) {
if !c.retransmit {
i += 1;
continue;
}
if i == 0 && self.rwnd < c.user_data.len() as u32 {
done = true;
} else if bytes_to_send + c.user_data.len() > awnd as usize {
break;
}
c.retransmit = false;
bytes_to_send += c.user_data.len();
c.nsent += 1;
} else {
break; }
if let Some(c) = self.inflight_queue.get_mut(tsn) {
Association::check_partial_reliability_status(
c,
now,
self.use_forward_tsn,
self.side,
&self.streams,
);
trace!(
"[{}] retransmitting tsn={} ssn={} sent={}",
self.side,
c.tsn,
c.stream_sequence_number,
c.nsent
);
chunks.push(c.clone());
}
i += 1;
}
self.bundle_data_chunks_into_packets(chunks)
}
fn pop_pending_data_chunks_to_send(
&mut self,
now: Instant,
) -> (Vec<ChunkPayloadData>, Vec<u16>) {
let mut chunks = vec![];
let mut sis_to_reset = vec![]; if !self.pending_queue.is_empty() {
while let Some(c) = self.pending_queue.peek() {
let (beginning_fragment, unordered, data_len, stream_identifier) = (
c.beginning_fragment,
c.unordered,
c.user_data.len(),
c.stream_identifier,
);
if data_len == 0 {
sis_to_reset.push(stream_identifier);
if self
.pending_queue
.pop(beginning_fragment, unordered)
.is_none()
{
error!("[{}] failed to pop from pending queue", self.side);
}
continue;
}
if self.inflight_queue.get_num_bytes() + data_len > self.cwnd as usize {
break; }
if data_len > self.rwnd as usize {
break; }
self.rwnd -= data_len as u32;
if let Some(chunk) = self.move_pending_data_chunk_to_inflight_queue(
beginning_fragment,
unordered,
now,
) {
chunks.push(chunk);
}
}
if chunks.is_empty() && self.inflight_queue.is_empty() {
if let Some(c) = self.pending_queue.peek() {
let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered);
if let Some(chunk) = self.move_pending_data_chunk_to_inflight_queue(
beginning_fragment,
unordered,
now,
) {
chunks.push(chunk);
}
}
}
}
(chunks, sis_to_reset)
}
fn bundle_data_chunks_into_packets(&self, chunks: Vec<ChunkPayloadData>) -> Vec<Packet> {
let mut packets = vec![];
let mut chunks_to_send = vec![];
let mut bytes_in_packet = COMMON_HEADER_SIZE;
for c in chunks {
if bytes_in_packet + c.user_data.len() as u32 > self.mtu {
packets.push(self.create_packet(chunks_to_send));
chunks_to_send = vec![];
bytes_in_packet = COMMON_HEADER_SIZE;
}
bytes_in_packet += DATA_CHUNK_HEADER_SIZE + c.user_data.len() as u32;
chunks_to_send.push(Box::new(c));
}
if !chunks_to_send.is_empty() {
packets.push(self.create_packet(chunks_to_send));
}
packets
}
fn generate_next_tsn(&mut self) -> u32 {
let tsn = self.my_next_tsn;
self.my_next_tsn += 1;
tsn
}
fn generate_next_rsn(&mut self) -> u32 {
let rsn = self.my_next_rsn;
self.my_next_rsn += 1;
rsn
}
fn check_partial_reliability_status(
c: &mut ChunkPayloadData,
now: Instant,
use_forward_tsn: bool,
side: Side,
streams: &FxHashMap<u16, StreamState>,
) {
if !use_forward_tsn {
return;
}
if c.payload_type == PayloadProtocolIdentifier::Dcep {
return;
}
if let Some(s) = streams.get(&c.stream_identifier) {
let reliability_type: ReliabilityType = s.reliability_type;
let reliability_value = s.reliability_value;
if reliability_type == ReliabilityType::Rexmit {
if c.nsent >= reliability_value {
c.set_abandoned(true);
trace!(
"[{}] marked as abandoned: tsn={} ppi={} (remix: {})",
side,
c.tsn,
c.payload_type,
c.nsent
);
}
} else if reliability_type == ReliabilityType::Timed {
if let Some(since) = &c.since {
let elapsed = now.duration_since(*since);
if elapsed.as_millis() as u32 >= reliability_value {
c.set_abandoned(true);
trace!(
"[{}] marked as abandoned: tsn={} ppi={} (timed: {:?})",
side,
c.tsn,
c.payload_type,
elapsed
);
}
} else {
error!("[{}] invalid c.since", side);
}
}
} else {
error!("[{}] stream {} not found)", side, c.stream_identifier);
}
}
fn create_selective_ack_chunk(&mut self) -> ChunkSelectiveAck {
ChunkSelectiveAck {
cumulative_tsn_ack: self.peer_last_tsn,
advertised_receiver_window_credit: self.get_my_receiver_window_credit(),
gap_ack_blocks: self.payload_queue.get_gap_ack_blocks(self.peer_last_tsn),
duplicate_tsn: self.payload_queue.pop_duplicates(),
}
}
fn create_forward_tsn(&self) -> ChunkForwardTsn {
let mut stream_map: HashMap<u16, u16> = HashMap::new(); let mut i = self.cumulative_tsn_ack_point + 1;
while sna32lte(i, self.advanced_peer_tsn_ack_point) {
if let Some(c) = self.inflight_queue.get(i) {
if let Some(ssn) = stream_map.get(&c.stream_identifier) {
if sna16lt(*ssn, c.stream_sequence_number) {
stream_map.insert(c.stream_identifier, c.stream_sequence_number);
}
} else {
stream_map.insert(c.stream_identifier, c.stream_sequence_number);
}
} else {
break;
}
i += 1;
}
let mut fwd_tsn = ChunkForwardTsn {
new_cumulative_tsn: self.advanced_peer_tsn_ack_point,
streams: vec![],
};
let mut stream_str = String::new();
for (si, ssn) in &stream_map {
stream_str += format!("(si={} ssn={})", si, ssn).as_str();
fwd_tsn.streams.push(ChunkForwardTsnStream {
identifier: *si,
sequence: *ssn,
});
}
trace!(
"[{}] building fwd_tsn: newCumulativeTSN={} cumTSN={} - {}",
self.side,
fwd_tsn.new_cumulative_tsn,
self.cumulative_tsn_ack_point,
stream_str
);
fwd_tsn
}
fn move_pending_data_chunk_to_inflight_queue(
&mut self,
beginning_fragment: bool,
unordered: bool,
now: Instant,
) -> Option<ChunkPayloadData> {
if let Some(mut c) = self.pending_queue.pop(beginning_fragment, unordered) {
if c.ending_fragment {
c.set_all_inflight();
}
c.tsn = self.generate_next_tsn();
c.since = Some(now); c.nsent = 1;
Association::check_partial_reliability_status(
&mut c,
now,
self.use_forward_tsn,
self.side,
&self.streams,
);
trace!(
"[{}] sending ppi={} tsn={} ssn={} sent={} len={} ({},{})",
self.side,
c.payload_type as u32,
c.tsn,
c.stream_sequence_number,
c.nsent,
c.user_data.len(),
c.beginning_fragment,
c.ending_fragment
);
self.inflight_queue.push_no_check(c.clone());
Some(c)
} else {
error!("[{}] failed to pop from pending queue", self.side);
None
}
}
pub(crate) fn send_reset_request(&mut self, stream_identifier: StreamId) -> Result<()> {
let state = self.state();
if state != AssociationState::Established {
return Err(Error::ErrResetPacketInStateNotExist);
}
let c = ChunkPayloadData {
stream_identifier,
beginning_fragment: true,
ending_fragment: true,
user_data: Bytes::new(),
..Default::default()
};
self.pending_queue.push(c);
self.awake_write_loop();
Ok(())
}
pub(crate) fn send_payload_data(&mut self, chunks: Vec<ChunkPayloadData>) -> Result<()> {
let state = self.state();
if state != AssociationState::Established {
return Err(Error::ErrPayloadDataStateNotExist);
}
for c in chunks {
self.pending_queue.push(c);
}
self.awake_write_loop();
Ok(())
}
pub(crate) fn buffered_amount(&self) -> usize {
self.pending_queue.get_num_bytes() + self.inflight_queue.get_num_bytes()
}
fn awake_write_loop(&self) {
}
fn close_all_timers(&mut self) {
for timer in Timer::VALUES {
self.timers.stop(timer);
}
}
fn on_ack_timeout(&mut self) {
trace!(
"[{}] ack timed out (ack_state: {})",
self.side,
self.ack_state
);
self.stats.inc_ack_timeouts();
self.ack_state = AckState::Immediate;
self.awake_write_loop();
}
fn on_retransmission_timeout(&mut self, timer_id: Timer, n_rtos: usize) {
match timer_id {
Timer::T1Init => {
if let Err(err) = self.send_init() {
debug!(
"[{}] failed to retransmit init (n_rtos={}): {:?}",
self.side, n_rtos, err
);
}
}
Timer::T1Cookie => {
if let Err(err) = self.send_cookie_echo() {
debug!(
"[{}] failed to retransmit cookie-echo (n_rtos={}): {:?}",
self.side, n_rtos, err
);
}
}
Timer::T2Shutdown => {
debug!(
"[{}] retransmission of shutdown timeout (n_rtos={})",
self.side, n_rtos
);
let state = self.state();
match state {
AssociationState::ShutdownSent => {
self.will_send_shutdown = true;
self.awake_write_loop();
}
AssociationState::ShutdownAckSent => {
self.will_send_shutdown_ack = true;
self.awake_write_loop();
}
_ => {}
}
}
Timer::T3RTX => {
self.stats.inc_t3timeouts();
self.ssthresh = std::cmp::max(self.cwnd / 2, 4 * self.mtu);
self.cwnd = self.mtu;
trace!(
"[{}] updated cwnd={} ssthresh={} inflight={} (RTO)",
self.side,
self.cwnd,
self.ssthresh,
self.inflight_queue.get_num_bytes()
);
if self.use_forward_tsn {
let mut i = self.advanced_peer_tsn_ack_point + 1;
while let Some(c) = self.inflight_queue.get(i) {
if !c.abandoned() {
break;
}
self.advanced_peer_tsn_ack_point = i;
i += 1;
}
if sna32gt(
self.advanced_peer_tsn_ack_point,
self.cumulative_tsn_ack_point,
) {
self.will_send_forward_tsn = true;
debug!(
"[{}] on_retransmission_timeout {}: sna32GT({}, {})",
self.side,
self.will_send_forward_tsn,
self.advanced_peer_tsn_ack_point,
self.cumulative_tsn_ack_point
);
}
}
debug!(
"[{}] T3-rtx timed out: n_rtos={} cwnd={} ssthresh={}",
self.side, n_rtos, self.cwnd, self.ssthresh
);
self.inflight_queue.mark_all_to_retrasmit();
self.awake_write_loop();
}
Timer::Reconfig => {
self.will_retransmit_reconfig = true;
self.awake_write_loop();
}
_ => {}
}
}
fn on_retransmission_failure(&mut self, id: Timer) {
match id {
Timer::T1Init => {
error!("[{}] retransmission failure: T1-init", self.side);
self.error = Some(AssociationError::HandshakeFailed(
Error::ErrHandshakeInitAck,
));
}
Timer::T1Cookie => {
error!("[{}] retransmission failure: T1-cookie", self.side);
self.error = Some(AssociationError::HandshakeFailed(
Error::ErrHandshakeCookieEcho,
));
}
Timer::T2Shutdown => {
error!("[{}] retransmission failure: T2-shutdown", self.side);
}
Timer::T3RTX => {
error!("[{}] retransmission failure: T3-rtx (DATA)", self.side);
}
_ => {}
}
}
#[cfg(test)]
pub(crate) fn is_idle(&self) -> bool {
Timer::VALUES
.iter()
.filter_map(|&t| Some((t, self.timers.get(t)?)))
.min_by_key(|&(_, time)| time)
.is_none()
}
}