use core::cmp;
use crate::{
channel::Channel,
command::Command,
error::{DefragmentationError, FragmentationError},
packet::{
ContinuationPacket, InitializationPacket, Packet, PacketType, CONTINUATION_HEADER_SIZE,
INITIALIZATION_HEADER_SIZE,
},
};
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct Message<T: AsRef<[u8]>> {
pub channel: Channel,
pub command: Command,
pub data: T,
}
impl<T: AsRef<[u8]>> Message<T> {
pub fn fragments(&self, packet_size: usize) -> Result<Fragments<'_>, FragmentationError> {
Fragments::new(self, packet_size)
}
}
impl<T: AsRef<[u8]> + Default + Extend<u8>> Message<T> {
pub fn from_fragments<S: AsRef<[u8]>>(
packet: InitializationPacket<S>,
) -> DefragmentedMessage<T> {
packet.into()
}
pub fn try_from_fragments<S: AsRef<[u8]>>(
packet: Packet<S>,
) -> Result<DefragmentedMessage<T>, DefragmentationError> {
if let Packet::Initialization(packet) = packet {
Ok(Self::from_fragments(packet))
} else {
Err(DefragmentationError::InvalidPacketType {
expected: PacketType::Initialization,
actual: packet.packet_type(),
})
}
}
}
#[derive(Clone, Debug)]
pub struct Fragments<'a> {
channel: Channel,
init: Option<InitializationPacket<&'a [u8]>>,
chunks: core::iter::Enumerate<core::slice::Chunks<'a, u8>>,
}
impl<'a> Fragments<'a> {
fn new<T: AsRef<[u8]>>(
message: &'a Message<T>,
packet_size: usize,
) -> Result<Self, FragmentationError> {
if INITIALIZATION_HEADER_SIZE >= packet_size || CONTINUATION_HEADER_SIZE >= packet_size {
return Err(FragmentationError::PacketSizeTooSmall);
}
let data_size_init = packet_size - INITIALIZATION_HEADER_SIZE;
let data_size_cont = packet_size - CONTINUATION_HEADER_SIZE;
if message.data.as_ref().len() > data_size_init + 128 * data_size_cont {
return Err(FragmentationError::DataTooLong);
}
let n = cmp::min(data_size_init, message.data.as_ref().len());
let (data_init, data_cont) = message.data.as_ref().split_at(n);
let init = InitializationPacket {
channel: message.channel,
command: message.command,
length: message.data.as_ref().len() as u16,
data: data_init,
};
Ok(Self {
channel: message.channel,
init: Some(init),
chunks: data_cont.chunks(data_size_cont).enumerate(),
})
}
}
impl<'a> Iterator for Fragments<'a> {
type Item = Packet<&'a [u8]>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(init) = self.init.take() {
Some(Packet::Initialization(init))
} else if let Some((sequence, data)) = self.chunks.next() {
Some(Packet::Continuation(ContinuationPacket {
channel: self.channel,
sequence: sequence as u8,
data,
}))
} else {
None
}
}
}
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub enum DefragmentedMessage<T: AsRef<[u8]> + Default + Extend<u8>> {
Complete(Message<T>),
Partial(PartialMessage<T>),
}
impl<T: AsRef<[u8]> + Default + Extend<u8>, S: AsRef<[u8]>> From<InitializationPacket<S>>
for DefragmentedMessage<T>
{
fn from(packet: InitializationPacket<S>) -> Self {
PartialMessage::from(packet).into()
}
}
impl<T: AsRef<[u8]> + Default + Extend<u8>> From<PartialMessage<T>> for DefragmentedMessage<T> {
fn from(message: PartialMessage<T>) -> Self {
if message.length == message.data.as_ref().len() {
Self::Complete(Message {
channel: message.channel,
command: message.command,
data: message.data,
})
} else {
Self::Partial(message)
}
}
}
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct PartialMessage<T: AsRef<[u8]> + Default + Extend<u8>> {
channel: Channel,
command: Command,
length: usize,
data: T,
next_sequence: u8,
}
impl<T: AsRef<[u8]> + Default + Extend<u8>> PartialMessage<T> {
pub fn extend<S: AsRef<[u8]>>(
mut self,
packet: &ContinuationPacket<S>,
) -> Result<DefragmentedMessage<T>, DefragmentationError> {
if self.channel != packet.channel {
return Err(DefragmentationError::InvalidChannel {
expected: self.channel,
actual: packet.channel,
});
}
if self.next_sequence != packet.sequence {
return Err(DefragmentationError::InvalidSequence {
expected: self.next_sequence,
actual: packet.sequence,
});
}
self.extend_data(packet.data.as_ref());
self.next_sequence += 1;
Ok(self.into())
}
pub fn try_extend<S: AsRef<[u8]>>(
self,
packet: &Packet<S>,
) -> Result<DefragmentedMessage<T>, DefragmentationError> {
if let Packet::Continuation(packet) = packet {
self.extend(packet)
} else {
Err(DefragmentationError::InvalidPacketType {
expected: PacketType::Continuation,
actual: packet.packet_type(),
})
}
}
fn extend_data(&mut self, data: &[u8]) {
let n = cmp::min(data.len(), self.length - self.data.as_ref().len());
self.data.extend(data[..n].iter().cloned());
}
}
impl<T: AsRef<[u8]> + Default + Extend<u8>, S: AsRef<[u8]>> From<InitializationPacket<S>>
for PartialMessage<T>
{
fn from(packet: InitializationPacket<S>) -> Self {
let mut message = Self {
channel: packet.channel,
command: packet.command,
length: usize::from(packet.length),
data: T::default(),
next_sequence: 0,
};
message.extend_data(packet.data.as_ref());
message
}
}
#[cfg(test)]
mod test {
use quickcheck::Arbitrary;
use super::{DefragmentedMessage, Message};
impl Arbitrary for Message<Vec<u8>> {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
Self {
channel: Arbitrary::arbitrary(g),
command: Arbitrary::arbitrary(g),
data: Arbitrary::arbitrary(g),
}
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let channel = self.channel;
let command = self.command;
Box::new(self.data.shrink().map(move |data| Self {
channel,
command,
data,
}))
}
}
quickcheck::quickcheck! {
fn fragments(message: Message<Vec<u8>>) -> bool {
use std::convert::TryInto;
let mut d: Option<DefragmentedMessage<Vec<u8>>> = None;
for fragment in message.fragments(64).unwrap() {
if let Some(dm) = d.take() {
match dm {
DefragmentedMessage::Partial(p) => {
d = Some(p.extend(&fragment.try_into().unwrap()).unwrap());
},
DefragmentedMessage::Complete(_) => unreachable!(),
}
} else {
d = Some(Message::from_fragments(fragment.try_into().unwrap()));
}
}
match d.unwrap() {
DefragmentedMessage::Partial(_) => false,
DefragmentedMessage::Complete(m) => m == message,
}
}
}
}