use crate::messages::{
ApplicationMessage, DTLSMessageEvent, DataChannelEvent, DataChannelMessage,
DataChannelMessageParams, DataChannelMessageType, MessageEvent, TaggedMessageEvent,
};
use datachannel::message::{message_channel_ack::*, message_channel_open::*, message_type::*, *};
use log::{debug, error, warn};
use retty::channel::{Context, Handler};
use sctp::ReliabilityType;
use shared::error::Result;
use shared::marshal::*;
use std::collections::VecDeque;
#[derive(Default)]
pub struct DataChannelHandler {
transmits: VecDeque<TaggedMessageEvent>,
}
impl DataChannelHandler {
pub fn new() -> Self {
Self {
transmits: VecDeque::new(),
}
}
}
impl Handler for DataChannelHandler {
type Rin = TaggedMessageEvent;
type Rout = Self::Rin;
type Win = TaggedMessageEvent;
type Wout = Self::Win;
fn name(&self) -> &str {
"DataChannelHandler"
}
fn handle_read(
&mut self,
ctx: &Context<Self::Rin, Self::Rout, Self::Win, Self::Wout>,
msg: Self::Rin,
) {
if let MessageEvent::Dtls(DTLSMessageEvent::Sctp(message)) = msg.message {
debug!(
"recv SCTP DataChannelMessage from {:?}",
msg.transport.peer_addr
);
let try_read =
|| -> Result<(Option<ApplicationMessage>, Option<DataChannelMessage>)> {
if message.data_message_type == DataChannelMessageType::Control {
let mut buf = &message.payload[..];
if MessageType::unmarshal(&mut buf)? == MessageType::DataChannelOpen {
debug!("DataChannelOpen for association_handle {} and stream_id {} and data_message_type {:?}",
message.association_handle,
message.stream_id,
message.data_message_type);
let data_channel_open = DataChannelOpen::unmarshal(&mut buf)?;
let (unordered, reliability_type) =
get_reliability_params(data_channel_open.channel_type);
let payload = Message::DataChannelAck(DataChannelAck {}).marshal()?;
Ok((
Some(ApplicationMessage {
association_handle: message.association_handle,
stream_id: message.stream_id,
data_channel_event: DataChannelEvent::Open,
}),
Some(DataChannelMessage {
association_handle: message.association_handle,
stream_id: message.stream_id,
data_message_type: DataChannelMessageType::Control,
params: Some(DataChannelMessageParams {
unordered,
reliability_type,
reliability_parameter: data_channel_open
.reliability_parameter,
}),
payload,
}),
))
} else {
Ok((None, None))
}
} else {
Ok((
Some(ApplicationMessage {
association_handle: message.association_handle,
stream_id: message.stream_id,
data_channel_event: DataChannelEvent::Message(message.payload),
}),
None,
))
}
};
match try_read() {
Ok((inbound_message, outbound_message)) => {
if let Some(data_channel_message) = outbound_message {
debug!("send DataChannelAck message {:?}", msg.transport.peer_addr);
self.transmits.push_back(TaggedMessageEvent {
now: msg.now,
transport: msg.transport,
message: MessageEvent::Dtls(DTLSMessageEvent::Sctp(
data_channel_message,
)),
});
}
if let Some(application_message) = inbound_message {
debug!("recv application message {:?}", msg.transport.peer_addr);
ctx.fire_read(TaggedMessageEvent {
now: msg.now,
transport: msg.transport,
message: MessageEvent::Dtls(DTLSMessageEvent::DataChannel(
application_message,
)),
})
}
}
Err(err) => {
error!("try_read with error {}", err);
ctx.fire_exception(Box::new(err))
}
};
} else {
debug!("bypass DataChannel read {:?}", msg.transport.peer_addr);
ctx.fire_read(msg);
}
}
fn poll_write(
&mut self,
ctx: &Context<Self::Rin, Self::Rout, Self::Win, Self::Wout>,
) -> Option<Self::Wout> {
if let Some(msg) = ctx.fire_poll_write() {
if let MessageEvent::Dtls(DTLSMessageEvent::DataChannel(message)) = msg.message {
debug!("send application message {:?}", msg.transport.peer_addr);
if let DataChannelEvent::Message(payload) = message.data_channel_event {
self.transmits.push_back(TaggedMessageEvent {
now: msg.now,
transport: msg.transport,
message: MessageEvent::Dtls(DTLSMessageEvent::Sctp(DataChannelMessage {
association_handle: message.association_handle,
stream_id: message.stream_id,
data_message_type: DataChannelMessageType::Text,
params: None,
payload,
})),
});
} else {
warn!(
"drop unsupported DATACHANNEL message to {}",
msg.transport.peer_addr
);
}
} else {
debug!("bypass DataChannel write {:?}", msg.transport.peer_addr);
self.transmits.push_back(msg);
}
}
self.transmits.pop_front()
}
}
fn get_reliability_params(channel_type: ChannelType) -> (bool, ReliabilityType) {
let (unordered, reliability_type) = match channel_type {
ChannelType::Reliable => (false, ReliabilityType::Reliable),
ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
};
(unordered, reliability_type)
}