use alloc::vec;
use alloc::vec::Vec;
use core::ops::Range;
use stun_proto::agent::Transmit;
use stun_types::message::{Message, MessageHeader};
use tracing::{debug, trace};
use crate::channel::ChannelData;
#[derive(Debug)]
pub enum IncomingTcp<T: AsRef<[u8]> + core::fmt::Debug> {
CompleteMessage(Transmit<T>, Range<usize>),
CompleteChannel(Transmit<T>, Range<usize>),
StoredMessage(Vec<u8>, Transmit<T>),
StoredChannel(Vec<u8>, Transmit<T>),
}
impl<T: AsRef<[u8]> + core::fmt::Debug> IncomingTcp<T> {
pub fn data(&self) -> &[u8] {
match self {
Self::CompleteMessage(transmit, range) => {
&transmit.data.as_ref()[range.start..range.end]
}
Self::CompleteChannel(transmit, range) => {
&transmit.data.as_ref()[range.start..range.end]
}
Self::StoredMessage(data, _transmit) => data,
Self::StoredChannel(data, _transmit) => data,
}
}
pub fn message(&self) -> Option<Message<'_>> {
if !matches!(
self,
Self::CompleteMessage(_, _) | Self::StoredMessage(_, _)
) {
return None;
}
Message::from_bytes(self.data()).ok()
}
pub fn channel(&self) -> Option<ChannelData<'_>> {
if !matches!(
self,
Self::CompleteChannel(_, _) | Self::StoredChannel(_, _)
) {
return None;
}
ChannelData::parse(self.data()).ok()
}
}
impl<T: AsRef<[u8]> + core::fmt::Debug> AsRef<[u8]> for IncomingTcp<T> {
fn as_ref(&self) -> &[u8] {
self.data()
}
}
#[derive(Debug)]
pub enum StoredTcp {
Message(Vec<u8>),
Channel(Vec<u8>),
}
impl StoredTcp {
pub fn data(&self) -> &[u8] {
match self {
Self::Message(data) => data,
Self::Channel(data) => data,
}
}
fn into_incoming<T: AsRef<[u8]> + core::fmt::Debug>(
self,
transmit: Transmit<T>,
) -> IncomingTcp<T> {
match self {
Self::Message(msg) => IncomingTcp::StoredMessage(msg, transmit),
Self::Channel(channel) => IncomingTcp::StoredChannel(channel, transmit),
}
}
}
impl AsRef<[u8]> for StoredTcp {
fn as_ref(&self) -> &[u8] {
self.data()
}
}
#[derive(Debug, Default)]
pub struct TurnTcpBuffer {
tcp_buffer: Vec<u8>,
}
impl TurnTcpBuffer {
pub fn new() -> Self {
Self { tcp_buffer: vec![] }
}
#[tracing::instrument(
level = "trace",
skip(self, transmit),
fields(
transmit.data_len = transmit.data.as_ref().len(),
from = ?transmit.from
)
)]
pub fn incoming_tcp<T: AsRef<[u8]> + core::fmt::Debug>(
&mut self,
transmit: Transmit<T>,
) -> Option<IncomingTcp<T>> {
if self.tcp_buffer.is_empty() {
let data = transmit.data.as_ref();
trace!("Trying to parse incoming data as a complete message/channel");
let Ok(hdr) = MessageHeader::from_bytes(data) else {
let Ok(channel) = ChannelData::parse(data) else {
self.tcp_buffer.extend_from_slice(data);
return None;
};
let channel_len = 4 + channel.data().len();
debug!(
channel.id = channel.id(),
channel.len = channel_len - 4,
"Incoming data contains a channel",
);
if channel_len < data.len() {
self.tcp_buffer.extend_from_slice(&data[channel_len..]);
}
return Some(IncomingTcp::CompleteChannel(transmit, 0..channel_len));
};
let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
debug!(
msg.transaction = %hdr.transaction_id(),
msg.len = msg_len,
"Incoming data contains a message",
);
if data.len() < msg_len {
self.tcp_buffer.extend_from_slice(data);
return None;
}
if msg_len < data.len() {
self.tcp_buffer.extend_from_slice(&data[msg_len..]);
}
return Some(IncomingTcp::CompleteMessage(transmit, 0..msg_len));
}
self.tcp_buffer.extend_from_slice(transmit.data.as_ref());
self.poll_recv().map(|recv| recv.into_incoming(transmit))
}
#[tracing::instrument(
level = "trace",
skip(self),
fields(
buffered_len = self.tcp_buffer.len(),
)
)]
pub fn poll_recv(&mut self) -> Option<StoredTcp> {
let Ok(hdr) = MessageHeader::from_bytes(&self.tcp_buffer) else {
let Ok((id, channel_data_len)) = ChannelData::parse_header(&self.tcp_buffer) else {
trace!(
buffered.len = self.tcp_buffer.len(),
"cannot parse stored data"
);
return None;
};
let channel_len = 4 + channel_data_len;
if self.tcp_buffer.len() < channel_len {
trace!(
buffered.len = self.tcp_buffer.len(),
required = channel_len,
"need more bytes to complete channel data"
);
return None;
}
let (data, remaining) = self.tcp_buffer.split_at(channel_len);
let data_binding = data.to_vec();
debug!(
channel.id = id,
channel.len = channel_data_len,
remaining = remaining.len(),
"buffered data contains a channel",
);
self.tcp_buffer = remaining.to_vec();
return Some(StoredTcp::Channel(data_binding));
};
let msg_len = MessageHeader::LENGTH + hdr.data_length() as usize;
if self.tcp_buffer.len() < msg_len {
trace!(
buffered.len = self.tcp_buffer.len(),
required = msg_len,
"need more bytes to complete STUN message"
);
return None;
}
let (data, remaining) = self.tcp_buffer.split_at(msg_len);
let data_binding = data.to_vec();
debug!(
msg.transaction = %hdr.transaction_id(),
msg.len = msg_len,
remaining = remaining.len(),
"stored data contains a message",
);
self.tcp_buffer = remaining.to_vec();
Some(StoredTcp::Message(data_binding))
}
pub fn into_inner(self) -> Vec<u8> {
self.tcp_buffer
}
pub fn len(&self) -> usize {
self.tcp_buffer.len()
}
pub fn is_empty(&self) -> bool {
self.tcp_buffer.is_empty()
}
}
#[cfg(test)]
mod tests {
use core::net::SocketAddr;
use stun_types::{
attribute::Software,
message::{Message, MessageWriteVec},
prelude::{MessageWrite, MessageWriteExt},
TransportType,
};
use tracing::info;
use crate::message::ALLOCATE;
use super::*;
fn generate_addresses() -> (SocketAddr, SocketAddr) {
(
"192.168.0.1:1000".parse().unwrap(),
"10.0.0.2:2000".parse().unwrap(),
)
}
fn generate_message() -> Vec<u8> {
let mut msg = Message::builder_request(ALLOCATE, MessageWriteVec::new());
msg.add_attribute(&Software::new("turn-types").unwrap())
.unwrap();
msg.add_fingerprint().unwrap();
msg.finish()
}
fn generate_message_in_channel() -> Vec<u8> {
let msg = generate_message();
let channel = ChannelData::new(0x4000, &msg);
let mut out = vec![0; msg.len() + 4];
channel.write_into_unchecked(&mut out);
out
}
#[test]
fn test_incoming_tcp_complete_message() {
let _init = crate::tests::test_init_log();
let (local_addr, remote_addr) = generate_addresses();
let mut tcp = TurnTcpBuffer::new();
let msg = generate_message();
let ret = tcp
.incoming_tcp(Transmit::new(
msg.clone(),
TransportType::Tcp,
remote_addr,
local_addr,
))
.unwrap();
assert!(matches!(ret, IncomingTcp::CompleteMessage(_, _)));
assert_eq!(ret.data(), &msg);
assert_eq!(ret.as_ref(), &msg);
assert!(ret.message().is_some());
assert!(tcp.is_empty());
assert_eq!(tcp.len(), 0);
assert!(tcp.into_inner().is_empty());
}
#[test]
fn test_incoming_tcp_complete_message_in_channel() {
let _init = crate::tests::test_init_log();
let (local_addr, remote_addr) = generate_addresses();
let mut tcp = TurnTcpBuffer::new();
let msg = generate_message_in_channel();
let ret = tcp
.incoming_tcp(Transmit::new(
msg.clone(),
TransportType::Tcp,
remote_addr,
local_addr,
))
.unwrap();
assert!(matches!(ret, IncomingTcp::CompleteChannel(_, _)));
assert_eq!(ret.data(), &msg);
assert_eq!(ret.as_ref(), &msg);
assert!(ret.channel().is_some());
assert!(tcp.is_empty());
assert_eq!(tcp.len(), 0);
assert!(tcp.into_inner().is_empty());
}
#[test]
fn test_incoming_tcp_partial_message() {
let _init = crate::tests::test_init_log();
let (local_addr, remote_addr) = generate_addresses();
let mut tcp = TurnTcpBuffer::new();
let msg = generate_message();
info!("message: {msg:x?}");
for i in 1..msg.len() {
let ret = tcp.incoming_tcp(Transmit::new(
&msg[i - 1..i],
TransportType::Tcp,
remote_addr,
local_addr,
));
assert!(ret.is_none());
let data = tcp.into_inner();
assert_eq!(&data, &msg[..i]);
tcp = TurnTcpBuffer::new();
let ret = tcp.incoming_tcp(Transmit::new(
&data,
TransportType::Tcp,
remote_addr,
local_addr,
));
assert!(ret.is_none());
assert!(!tcp.is_empty());
assert_eq!(tcp.len(), i);
}
let ret = tcp
.incoming_tcp(Transmit::new(
&msg[msg.len() - 1..],
TransportType::Tcp,
remote_addr,
local_addr,
))
.unwrap();
assert_eq!(ret.data(), &msg);
assert_eq!(ret.as_ref(), &msg);
assert!(ret.message().is_some());
let IncomingTcp::StoredMessage(produced, _) = ret else {
unreachable!();
};
assert_eq!(produced, msg);
assert!(tcp.is_empty());
assert_eq!(tcp.len(), 0);
assert!(tcp.into_inner().is_empty());
}
#[test]
fn test_incoming_tcp_partial_channel() {
let _init = crate::tests::test_init_log();
let (local_addr, remote_addr) = generate_addresses();
let mut tcp = TurnTcpBuffer::new();
let channel = generate_message_in_channel();
info!("message: {channel:x?}");
for i in 1..channel.len() {
let ret = tcp.incoming_tcp(Transmit::new(
&channel[i - 1..i],
TransportType::Tcp,
remote_addr,
local_addr,
));
assert!(ret.is_none());
let data = tcp.into_inner();
assert_eq!(&data, &channel[..i]);
tcp = TurnTcpBuffer::new();
let ret = tcp.incoming_tcp(Transmit::new(
&data,
TransportType::Tcp,
remote_addr,
local_addr,
));
assert!(ret.is_none());
assert!(!tcp.is_empty());
assert_eq!(tcp.len(), i);
}
let ret = tcp
.incoming_tcp(Transmit::new(
&channel[channel.len() - 1..],
TransportType::Tcp,
remote_addr,
local_addr,
))
.unwrap();
assert_eq!(ret.data(), &channel);
assert_eq!(ret.as_ref(), &channel);
assert!(ret.channel().is_some());
let IncomingTcp::StoredChannel(produced, _) = ret else {
unreachable!()
};
assert_eq!(produced, channel);
assert!(tcp.into_inner().is_empty());
}
#[test]
fn test_incoming_tcp_message_and_channel() {
let _init = crate::tests::test_init_log();
let (local_addr, remote_addr) = generate_addresses();
let mut tcp = TurnTcpBuffer::new();
let msg = generate_message();
let channel = generate_message_in_channel();
let mut input = msg.clone();
input.extend_from_slice(&channel);
let ret = tcp
.incoming_tcp(Transmit::new(
input.clone(),
TransportType::Tcp,
remote_addr,
local_addr,
))
.unwrap();
assert_eq!(ret.data(), &msg);
assert_eq!(ret.as_ref(), &msg);
assert!(ret.message().is_some());
let IncomingTcp::CompleteMessage(transmit, msg_range) = ret else {
unreachable!();
};
assert_eq!(msg_range, 0..msg.len());
assert_eq!(transmit.data, input);
let ret = tcp.poll_recv().unwrap();
assert_eq!(ret.data(), &channel);
assert_eq!(ret.as_ref(), &channel);
let StoredTcp::Channel(produced) = ret else {
unreachable!()
};
assert_eq!(produced, channel);
}
#[test]
fn test_incoming_tcp_channel_and_message() {
let _init = crate::tests::test_init_log();
let (local_addr, remote_addr) = generate_addresses();
let mut tcp = TurnTcpBuffer::new();
let msg = generate_message();
let channel = generate_message_in_channel();
let mut input = channel.clone();
input.extend_from_slice(&msg);
let ret = tcp
.incoming_tcp(Transmit::new(
input.clone(),
TransportType::Tcp,
remote_addr,
local_addr,
))
.unwrap();
assert_eq!(ret.data(), &channel);
assert_eq!(ret.as_ref(), &channel);
assert!(ret.channel().is_some());
let IncomingTcp::CompleteChannel(transmit, channel_range) = ret else {
unreachable!()
};
assert_eq!(channel_range, 0..channel.len());
assert_eq!(transmit.data, input);
let ret = tcp.poll_recv().unwrap();
assert_eq!(ret.data(), &msg);
assert_eq!(ret.as_ref(), &msg);
let StoredTcp::Message(produced) = ret else {
unreachable!()
};
assert_eq!(produced, msg);
}
}