#[cfg(test)]
mod data_channel_test;
use crate::message::{
message_channel_ack::*, message_channel_close::*, message_channel_open::*,
message_channel_threshold::*, *,
};
use bytes::{Buf, BytesMut};
use log::debug;
use sctp::{PayloadProtocolIdentifier, ReliabilityType};
use shared::error::{Error, Result};
use shared::marshal::*;
use std::collections::VecDeque;
const RECEIVE_MTU: usize = 8192;
#[derive(Eq, PartialEq, Default, Clone, Debug)]
pub struct DataChannelConfig {
pub channel_type: ChannelType,
pub negotiated: bool,
pub priority: u16,
pub reliability_parameter: u32,
pub label: String,
pub protocol: String,
}
#[derive(Debug, Default, Clone)]
pub struct DataChannelMessage {
pub association_handle: usize,
pub stream_id: u16,
pub ppi: PayloadProtocolIdentifier,
pub payload: BytesMut,
}
#[derive(Debug, Default, Clone)]
pub struct DataChannel {
config: DataChannelConfig,
association_handle: usize,
stream_id: u16,
read_outs: VecDeque<DataChannelMessage>,
write_outs: VecDeque<DataChannelMessage>,
messages_sent: usize,
messages_received: usize,
bytes_sent: usize,
bytes_received: usize,
}
impl DataChannel {
fn new(config: DataChannelConfig, association_handle: usize, stream_id: u16) -> Self {
Self {
config,
association_handle,
stream_id,
read_outs: VecDeque::new(),
write_outs: VecDeque::new(),
..Default::default()
}
}
pub fn dial(
config: DataChannelConfig,
association_handle: usize,
stream_id: u16,
) -> Result<Self> {
let mut data_channel = DataChannel::new(config.clone(), association_handle, stream_id);
if !config.negotiated {
let msg = Message::DataChannelOpen(DataChannelOpen {
channel_type: config.channel_type,
priority: config.priority,
reliability_parameter: config.reliability_parameter,
label: config.label.bytes().collect(),
protocol: config.protocol.bytes().collect(),
})
.marshal()?;
data_channel.write_outs.push_back(DataChannelMessage {
association_handle,
stream_id,
ppi: PayloadProtocolIdentifier::Dcep,
payload: msg,
});
}
Ok(data_channel)
}
pub fn accept(
mut config: DataChannelConfig,
association_handle: usize,
stream_id: u16,
ppi: PayloadProtocolIdentifier,
buf: &[u8],
) -> Result<Self> {
if ppi != PayloadProtocolIdentifier::Dcep {
return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8));
}
let mut read_buf = buf;
let msg = Message::unmarshal(&mut read_buf)?;
if let Message::DataChannelOpen(dco) = msg {
config.channel_type = dco.channel_type;
config.priority = dco.priority;
config.reliability_parameter = dco.reliability_parameter;
config.label = String::from_utf8(dco.label)?;
config.protocol = String::from_utf8(dco.protocol)?;
} else {
return Err(Error::InvalidMessageType(msg.message_type() as u8));
};
let mut data_channel = DataChannel::new(config, association_handle, stream_id);
data_channel.write_data_channel_ack()?;
Ok(data_channel)
}
pub fn messages_sent(&self) -> usize {
self.messages_sent
}
pub fn messages_received(&self) -> usize {
self.messages_received
}
pub fn bytes_sent(&self) -> usize {
self.bytes_sent
}
pub fn bytes_received(&self) -> usize {
self.bytes_received
}
pub fn association_handle(&self) -> usize {
self.association_handle
}
pub fn stream_identifier(&self) -> u16 {
self.stream_id
}
pub fn config(&self) -> &DataChannelConfig {
&self.config
}
fn handle_dcep<B>(&mut self, data: &mut B) -> Result<()>
where
B: Buf,
{
let msg = Message::unmarshal(data)?;
match msg {
Message::DataChannelOpen(_) => {
debug!("Received DATA_CHANNEL_OPEN");
self.write_data_channel_ack()?;
}
Message::DataChannelAck(_) => {
debug!("Received DATA_CHANNEL_ACK");
}
_ => {
return Err(Error::InvalidMessageType(msg.message_type() as u8));
}
};
Ok(())
}
fn write_data_channel_ack(&mut self) -> Result<()> {
let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
self.write_outs.push_back(DataChannelMessage {
association_handle: self.association_handle,
stream_id: self.stream_id,
ppi: PayloadProtocolIdentifier::Dcep,
payload: ack,
});
Ok(())
}
fn write_data_channel_close(&mut self) -> Result<()> {
let close = Message::DataChannelClose(DataChannelClose {}).marshal()?;
self.write_outs.push_back(DataChannelMessage {
association_handle: self.association_handle,
stream_id: self.stream_id,
ppi: PayloadProtocolIdentifier::Dcep,
payload: close,
});
Ok(())
}
fn write_data_channel_high_threshold(&mut self, threshold: u32) -> Result<()> {
let low_threshold =
Message::DataChannelThreshold(DataChannelThreshold::High(threshold)).marshal()?;
self.write_outs.push_back(DataChannelMessage {
association_handle: self.association_handle,
stream_id: self.stream_id,
ppi: PayloadProtocolIdentifier::Dcep,
payload: low_threshold,
});
Ok(())
}
fn write_data_channel_low_threshold(&mut self, threshold: u32) -> Result<()> {
let low_threshold =
Message::DataChannelThreshold(DataChannelThreshold::Low(threshold)).marshal()?;
self.write_outs.push_back(DataChannelMessage {
association_handle: self.association_handle,
stream_id: self.stream_id,
ppi: PayloadProtocolIdentifier::Dcep,
payload: low_threshold,
});
Ok(())
}
pub fn set_buffered_amount_high_threshold(&mut self, threshold: u32) -> Result<()> {
self.write_data_channel_high_threshold(threshold)
}
pub fn set_buffered_amount_low_threshold(&mut self, threshold: u32) -> Result<()> {
self.write_data_channel_low_threshold(threshold)
}
pub fn get_reliability_params(channel_type: ChannelType) -> (bool, ReliabilityType) {
match channel_type {
ChannelType::Reliable => (false, ReliabilityType::Reliable),
ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
}
}
pub fn get_channel_type_and_reliability_parameter(
ordered: bool,
max_retransmits: Option<u16>,
max_packet_life_time: Option<u16>,
) -> (ChannelType, u32) {
let channel_type;
let reliability_parameter;
match (max_retransmits, max_packet_life_time) {
(None, None) => {
reliability_parameter = 0u32;
if ordered {
channel_type = ChannelType::Reliable;
} else {
channel_type = ChannelType::ReliableUnordered;
}
}
(Some(max_retransmits), _) => {
reliability_parameter = max_retransmits as u32;
if ordered {
channel_type = ChannelType::PartialReliableRexmit;
} else {
channel_type = ChannelType::PartialReliableRexmitUnordered;
}
}
(None, Some(max_packet_lifetime)) => {
reliability_parameter = max_packet_lifetime as u32;
if ordered {
channel_type = ChannelType::PartialReliableTimed;
} else {
channel_type = ChannelType::PartialReliableTimedUnordered;
}
}
}
(channel_type, reliability_parameter)
}
pub fn get_data_channel_message(is_string: bool, data: BytesMut) -> DataChannelMessage {
let ppi = match (is_string, data.len()) {
(false, 0) => PayloadProtocolIdentifier::BinaryEmpty,
(false, _) => PayloadProtocolIdentifier::Binary,
(true, 0) => PayloadProtocolIdentifier::StringEmpty,
(true, _) => PayloadProtocolIdentifier::String,
};
if data.is_empty() {
DataChannelMessage {
ppi,
payload: BytesMut::from(&[0][..]),
..Default::default()
}
} else {
DataChannelMessage {
ppi,
payload: data,
..Default::default()
}
}
}
}
impl sansio::Protocol<DataChannelMessage, DataChannelMessage, ()> for DataChannel {
type Rout = DataChannelMessage;
type Wout = DataChannelMessage;
type Eout = ();
type Error = Error;
type Time = ();
fn handle_read(&mut self, msg: DataChannelMessage) -> Result<()> {
self.messages_received += 1;
self.bytes_received += msg.payload.len();
if msg.ppi == PayloadProtocolIdentifier::Dcep {
let mut data_buf = &msg.payload[..];
self.handle_dcep(&mut data_buf)
} else {
self.read_outs.push_back(msg);
Ok(())
}
}
fn poll_read(&mut self) -> Option<DataChannelMessage> {
self.read_outs.pop_front()
}
fn handle_write(&mut self, mut msg: DataChannelMessage) -> Result<()> {
self.messages_sent += 1;
self.bytes_sent += msg.payload.len();
msg.association_handle = self.association_handle;
msg.stream_id = self.stream_id;
self.write_outs.push_back(msg);
Ok(())
}
fn poll_write(&mut self) -> Option<DataChannelMessage> {
self.write_outs.pop_front()
}
fn close(&mut self) -> Result<()> {
self.write_data_channel_close()
}
}