use crate::association::Association;
use crate::association::state::AssociationState;
use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
use crate::queue::reassembly_queue::{Chunks, ReassemblyQueue};
use crate::{ErrorCauseCode, Event, Side};
use shared::error::{Error, Result};
use crate::util::{ByteSlice, BytesArray, BytesSource};
use bytes::Bytes;
use log::{debug, error, trace};
use std::fmt;
pub type StreamId = u16;
#[derive(Debug, PartialEq, Eq)]
pub enum StreamEvent {
Opened { id: StreamId },
Readable {
id: StreamId,
},
Writable {
id: StreamId,
},
Finished {
id: StreamId,
},
Stopped {
id: StreamId,
error_code: ErrorCauseCode,
},
Available,
BufferedAmountLow {
id: StreamId,
},
BufferedAmountHigh {
id: StreamId,
},
}
#[derive(Default, Debug, Copy, Clone, PartialEq)]
pub enum ReliabilityType {
#[default]
Reliable = 0,
Rexmit = 1,
Timed = 2,
}
impl fmt::Display for ReliabilityType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match *self {
ReliabilityType::Reliable => "Reliable",
ReliabilityType::Rexmit => "Rexmit",
ReliabilityType::Timed => "Timed",
};
write!(f, "{}", s)
}
}
impl From<u8> for ReliabilityType {
fn from(v: u8) -> ReliabilityType {
match v {
1 => ReliabilityType::Rexmit,
2 => ReliabilityType::Timed,
_ => ReliabilityType::Reliable,
}
}
}
pub struct Stream<'a> {
pub(crate) stream_identifier: StreamId,
pub(crate) association: &'a mut Association,
}
impl Stream<'_> {
pub fn read(&mut self) -> Result<Option<Chunks>> {
self.read_sctp()
}
pub fn read_sctp(&mut self) -> Result<Option<Chunks>> {
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier)
&& (s.state == RecvSendState::ReadWritable || s.state == RecvSendState::Readable)
{
Ok(s.reassembly_queue.read())
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn write_sctp(&mut self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
self.write_source(&mut ByteSlice::from_slice(p), ppi)
}
pub fn write(&mut self, data: &[u8]) -> Result<usize> {
self.write_with_ppi(data, self.get_default_payload_type()?)
}
pub fn write_with_ppi(&mut self, data: &[u8], ppi: PayloadProtocolIdentifier) -> Result<usize> {
self.write_source(&mut ByteSlice::from_slice(data), ppi)
}
pub fn write_chunk(&mut self, p: &Bytes) -> Result<usize> {
self.write_source(
&mut ByteSlice::from_slice(p),
self.get_default_payload_type()?,
)
}
pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<usize> {
self.write_source(
&mut BytesArray::from_chunks(data),
self.get_default_payload_type()?,
)
}
fn write_source<B: BytesSource>(
&mut self,
source: &mut B,
ppi: PayloadProtocolIdentifier,
) -> Result<usize> {
if !self.is_writable() {
return Err(Error::ErrStreamClosed);
}
if source.remaining() > self.association.max_message_size() as usize {
return Err(Error::ErrOutboundPacketTooLarge);
}
let state: AssociationState = self.association.state();
match state {
AssociationState::ShutdownSent
| AssociationState::ShutdownAckSent
| AssociationState::ShutdownPending
| AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
_ => {}
};
let (p, _) = source.pop_chunk(self.association.max_message_size() as usize);
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
let (is_buffered_amount_high, chunks) = s.packetize(&p, ppi);
if is_buffered_amount_high {
trace!("StreamEvent::BufferedAmountHigh");
self.association
.events
.push_back(Event::Stream(StreamEvent::BufferedAmountHigh {
id: self.stream_identifier,
}))
}
self.association.send_payload_data(chunks)?;
Ok(p.len())
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn is_readable(&self) -> bool {
if let Some(s) = self.association.streams.get(&self.stream_identifier) {
s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable
} else {
false
}
}
pub fn is_writable(&self) -> bool {
if let Some(s) = self.association.streams.get(&self.stream_identifier) {
s.state == RecvSendState::Writable || s.state == RecvSendState::ReadWritable
} else {
false
}
}
pub fn stop(&mut self) -> Result<()> {
let mut reset = false;
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
if s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable {
reset = true;
}
s.state = ((s.state as u8) & 0x2).into();
}
if reset {
self.association
.send_reset_request(self.stream_identifier)?;
}
Ok(())
}
pub fn finish(&mut self) -> Result<()> {
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
s.state = ((s.state as u8) & 0x1).into();
}
Ok(())
}
pub fn close(&mut self) -> Result<()> {
self.finish()?;
self.stop()
}
pub fn stream_identifier(&self) -> StreamId {
self.stream_identifier
}
pub fn set_default_payload_type(
&mut self,
default_payload_type: PayloadProtocolIdentifier,
) -> Result<()> {
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
s.default_payload_type = default_payload_type;
Ok(())
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn get_default_payload_type(&self) -> Result<PayloadProtocolIdentifier> {
if let Some(s) = self.association.streams.get(&self.stream_identifier) {
Ok(s.default_payload_type)
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn set_reliability_params(
&mut self,
unordered: bool,
rel_type: ReliabilityType,
rel_val: u32,
) -> Result<()> {
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
debug!(
"[{}] reliability params: ordered={} type={} value={}",
s.side, !unordered, rel_type, rel_val
);
s.unordered = unordered;
s.reliability_type = rel_type;
s.reliability_value = rel_val;
Ok(())
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn buffered_amount(&self) -> Result<usize> {
if let Some(s) = self.association.streams.get(&self.stream_identifier) {
Ok(s.buffered_amount)
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn buffered_amount_low_threshold(&self) -> Result<usize> {
if let Some(s) = self.association.streams.get(&self.stream_identifier) {
Ok(s.buffered_amount_low)
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn set_buffered_amount_low_threshold(&mut self, th: usize) -> Result<()> {
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
s.buffered_amount_low = th;
Ok(())
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn buffered_amount_high_threshold(&self) -> Result<usize> {
if let Some(s) = self.association.streams.get(&self.stream_identifier) {
Ok(s.buffered_amount_high)
} else {
Err(Error::ErrStreamClosed)
}
}
pub fn set_buffered_amount_high_threshold(&mut self, th: usize) -> Result<()> {
if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
s.buffered_amount_high = th;
Ok(())
} else {
Err(Error::ErrStreamClosed)
}
}
}
#[derive(Default, Debug, Copy, Clone, Eq, PartialEq)]
pub enum RecvSendState {
#[default]
Closed = 0,
Readable = 1,
Writable = 2,
ReadWritable = 3,
}
impl From<u8> for RecvSendState {
fn from(v: u8) -> Self {
match v {
1 => RecvSendState::Readable,
2 => RecvSendState::Writable,
3 => RecvSendState::ReadWritable,
_ => RecvSendState::Closed,
}
}
}
#[derive(Default, Debug)]
pub struct StreamState {
pub(crate) side: Side,
pub(crate) max_payload_size: u32,
pub(crate) stream_identifier: StreamId,
pub(crate) default_payload_type: PayloadProtocolIdentifier,
pub(crate) reassembly_queue: ReassemblyQueue,
pub(crate) sequence_number: u16,
pub(crate) state: RecvSendState,
pub(crate) unordered: bool,
pub(crate) reliability_type: ReliabilityType,
pub(crate) reliability_value: u32,
pub(crate) buffered_amount: usize,
pub(crate) buffered_amount_low: usize,
pub(crate) buffered_amount_high: usize,
}
impl StreamState {
pub(crate) fn new(
side: Side,
stream_identifier: StreamId,
max_payload_size: u32,
default_payload_type: PayloadProtocolIdentifier,
) -> Self {
StreamState {
side,
stream_identifier,
max_payload_size,
default_payload_type,
reassembly_queue: ReassemblyQueue::new(stream_identifier),
sequence_number: 0,
state: RecvSendState::ReadWritable,
unordered: false,
reliability_type: ReliabilityType::Reliable,
reliability_value: 0,
buffered_amount: 0,
buffered_amount_low: 0,
buffered_amount_high: u32::MAX as usize,
}
}
pub(crate) fn handle_data(&mut self, pd: &ChunkPayloadData) -> bool {
self.reassembly_queue.push(pd.clone())
}
pub(crate) fn handle_forward_tsn_for_ordered(&mut self, ssn: u16) {
if self.unordered {
return; }
self.reassembly_queue.forward_tsn_for_ordered(ssn);
}
pub(crate) fn handle_forward_tsn_for_unordered(&mut self, new_cumulative_tsn: u32) {
if !self.unordered {
return; }
self.reassembly_queue
.forward_tsn_for_unordered(new_cumulative_tsn);
}
fn packetize(
&mut self,
raw: &Bytes,
ppi: PayloadProtocolIdentifier,
) -> (bool, Vec<ChunkPayloadData>) {
let mut i = 0;
let mut remaining = raw.len();
let unordered = ppi != PayloadProtocolIdentifier::Dcep && self.unordered;
let mut chunks = vec![];
let head_abandoned = false;
let head_all_inflight = false;
while remaining != 0 {
let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining);
let user_data = raw.slice(i..i + fragment_size);
let chunk = ChunkPayloadData {
stream_identifier: self.stream_identifier,
user_data,
unordered,
beginning_fragment: i == 0,
ending_fragment: remaining - fragment_size == 0,
immediate_sack: false,
payload_type: ppi,
stream_sequence_number: self.sequence_number,
abandoned: head_abandoned, all_inflight: head_all_inflight, ..Default::default()
};
chunks.push(chunk);
remaining -= fragment_size;
i += fragment_size;
}
if !unordered {
self.sequence_number = self.sequence_number.wrapping_add(1);
}
let old_amount = self.buffered_amount;
let n_bytes_added = raw.len();
self.buffered_amount += raw.len();
let new_amount = self.buffered_amount;
trace!(
"[{}] new_amount = {}, old_amount = {}, buffered_amount_high = {}, n_bytes_added = {}",
self.side, new_amount, old_amount, self.buffered_amount_high, n_bytes_added,
);
let is_buffered_amount_high =
old_amount < self.buffered_amount_high && new_amount >= self.buffered_amount_high;
(is_buffered_amount_high, chunks)
}
pub(crate) fn on_buffer_released(&mut self, n_bytes_released: i64) -> bool {
if n_bytes_released <= 0 {
return false;
}
let old_amount = self.buffered_amount;
let new_amount = if old_amount < n_bytes_released as usize {
self.buffered_amount = 0;
error!(
"[{}] released buffer size {} should be <= {}",
self.side, n_bytes_released, 0,
);
0
} else {
self.buffered_amount -= n_bytes_released as usize;
old_amount - n_bytes_released as usize
};
trace!(
"[{}] new_amount = {}, old_amount = {}, buffered_amount_low = {}, n_bytes_released = {}",
self.side, new_amount, old_amount, self.buffered_amount_low, n_bytes_released,
);
old_amount > self.buffered_amount_low && new_amount <= self.buffered_amount_low
}
pub(crate) fn get_num_bytes_in_reassembly_queue(&self) -> usize {
self.reassembly_queue.get_num_bytes()
}
}