use crate::protobufs;
use log::{debug, error, trace};
use prost::Message;
use thiserror::Error;
use tokio::sync::mpsc::UnboundedSender;
use super::wrappers::encoded_data::IncomingStreamData;
#[derive(Clone, Debug)]
pub struct StreamBuffer {
buffer: Vec<u8>,
decoded_packet_tx: UnboundedSender<protobufs::FromRadio>,
}
#[derive(Error, Debug, Clone)]
pub enum StreamBufferError {
#[error("Could not find header sequence [0x94, 0xc3] in buffer")]
MissingHeaderBytes,
#[error("Incorrect framing byte: got {found_framing_byte}, expected 0xc3")]
IncorrectFramingByte { found_framing_byte: u8 },
#[error("Buffer data is shorter than packet header size: buffer contains {buffer_size} bytes, expected at least {packet_size} bytes")]
IncompletePacket {
buffer_size: usize,
packet_size: usize,
},
#[error("Buffer does not contain a value at MSB buffer index of {msb_index}")]
MissingMSB { msb_index: usize },
#[error("Buffer does not contain a value at LSB buffer index of {lsb_index}")]
MissingLSB { lsb_index: usize },
#[error("Detected malformed packet, packet buffer contains a framing byte at index {next_packet_start_idx}")]
MalformedPacket { next_packet_start_idx: usize },
#[error(transparent)]
DecodeFailure(#[from] prost::DecodeError),
}
const PACKET_HEADER_SIZE: usize = 4;
impl StreamBuffer {
pub fn new(decoded_packet_tx: UnboundedSender<protobufs::FromRadio>) -> Self {
StreamBuffer {
buffer: vec![],
decoded_packet_tx,
}
}
pub fn process_incoming_bytes(&mut self, message: IncomingStreamData) {
self.buffer.extend_from_slice(message.as_ref());
while !self.buffer.is_empty() {
let decoded_packet = match self.process_packet_buffer() {
Ok(packet) => packet,
Err(err) => match err {
StreamBufferError::MissingHeaderBytes => {
error!("Could not find header sequence [0x94, 0xc3], purging buffer and waiting for more data");
break; }
StreamBufferError::IncorrectFramingByte { found_framing_byte } => {
error!(
"Byte {found_framing_byte} not equal to 0xc3, waiting for more data"
);
break; }
StreamBufferError::IncompletePacket {
buffer_size,
packet_size,
} => {
error!(
"Incomplete packet data, expected {packet_size} bytes, found {buffer_size} bytes"
);
break; }
StreamBufferError::MissingMSB { msb_index } => {
error!("Could not find MSB at index {msb_index}, waiting for more data");
break; }
StreamBufferError::MissingLSB { lsb_index } => {
error!("Could not find LSB at index {lsb_index}, waiting for more data");
break; }
StreamBufferError::MalformedPacket {
next_packet_start_idx,
} => {
error!(
"Detected malformed packet with next packet starting at index {next_packet_start_idx}, purged malformed packet"
);
continue; }
StreamBufferError::DecodeFailure { .. } => {
error!("Failed to decode chunk from packet, this does not affect the next iteration");
continue; }
},
};
trace!("Successfully decoded packet");
match self.decoded_packet_tx.send(decoded_packet) {
Ok(_) => {
trace!("Successfully sent decoded packet");
continue;
}
Err(e) => {
error!("Failed to send decoded packet: {e}");
break;
}
};
}
trace!(
"Processing complete, buffer contains {} bytes",
self.buffer.len()
);
}
fn process_packet_buffer(&mut self) -> Result<protobufs::FromRadio, StreamBufferError> {
trace!(
"Packet buffer with length {:?}: {:?}",
self.buffer.len(),
self.buffer
);
if self.buffer.len() < PACKET_HEADER_SIZE {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: PACKET_HEADER_SIZE,
});
}
self.shift_buffer_to_first_valid_header()?;
let incoming_packet_data_size = self.get_data_size_from_header()?;
self.validate_packet_in_buffer(incoming_packet_data_size)?;
let packet_data = self.extract_packet_from_buffer(incoming_packet_data_size)?;
let decoded_packet = protobufs::FromRadio::decode(packet_data.as_slice())?;
Ok(decoded_packet)
}
fn shift_buffer_to_first_valid_header(&mut self) -> Result<(), StreamBufferError> {
let framing_index = Self::find_framing_index(&self.buffer).ok_or_else(|| {
self.buffer.clear(); StreamBufferError::MissingHeaderBytes
})?;
if framing_index != 0 {
debug!("Found framing byte at index {framing_index}, shifting buffer");
self.buffer.drain(0..framing_index);
trace!("Buffer after shifting: {:?}", self.buffer);
}
Ok(())
}
fn find_framing_index(buffer: &[u8]) -> Option<usize> {
if buffer.len() < 2 {
return None;
}
buffer.windows(2).position(|b| b == [0x94, 0xc3])
}
fn get_data_size_from_header(&mut self) -> Result<usize, StreamBufferError> {
let found_framing_byte = match self.buffer.get(1) {
Some(val) => val.to_owned(),
None => {
debug!("Could not find framing byte, waiting for more data");
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: PACKET_HEADER_SIZE,
});
}
};
if found_framing_byte != 0xc3 {
return Err(StreamBufferError::IncorrectFramingByte { found_framing_byte });
}
let msb_index: usize = 2;
let msb = match self.buffer.get(msb_index) {
Some(val) => val,
None => {
return Err(StreamBufferError::MissingMSB { msb_index });
}
};
let lsb_index: usize = 3;
let lsb = match self.buffer.get(lsb_index) {
Some(val) => val,
None => {
return Err(StreamBufferError::MissingLSB { lsb_index });
}
};
let incoming_packet_data_size: usize = usize::from(u16::from_le_bytes([*lsb, *msb]));
Ok(incoming_packet_data_size)
}
fn validate_packet_in_buffer(
&mut self,
packet_data_size: usize,
) -> Result<(), StreamBufferError> {
if self.buffer.len() < PACKET_HEADER_SIZE + packet_data_size {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: PACKET_HEADER_SIZE + packet_data_size,
});
}
let packet_data_start_index = PACKET_HEADER_SIZE;
let mut packet_data_end_index = packet_data_start_index + packet_data_size;
if self.buffer.len() > packet_data_end_index {
packet_data_end_index += 1;
}
let packet_buffer = self.buffer[packet_data_start_index..packet_data_end_index].to_vec();
let next_packet_start_index = StreamBuffer::find_framing_index(&packet_buffer)
.map(|idx| idx + packet_data_start_index);
if let Some(next_packet_start_idx) = next_packet_start_index {
self.buffer.drain(..next_packet_start_idx);
return Err(StreamBufferError::MalformedPacket {
next_packet_start_idx,
});
}
Ok(())
}
fn extract_packet_from_buffer(
&mut self,
packet_data_size: usize,
) -> Result<Vec<u8>, StreamBufferError> {
if self.buffer.len() < packet_data_size {
return Err(StreamBufferError::IncompletePacket {
buffer_size: self.buffer.len(),
packet_size: PACKET_HEADER_SIZE + packet_data_size,
});
}
let packet_start_index = 0;
let packet_end_index = PACKET_HEADER_SIZE + packet_data_size;
let mut packet_data_with_header: Vec<u8> = self
.buffer
.drain(packet_start_index..packet_end_index)
.collect();
let packet_data: Vec<u8> = packet_data_with_header
.drain(PACKET_HEADER_SIZE..)
.collect();
Ok(packet_data)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::{protobufs, utils_internal::format_data_packet};
use futures_util::FutureExt;
use prost::Message;
use tokio::sync::mpsc::unbounded_channel;
use super::*;
async fn timeout_test<F, T>(future: F, timeout: impl Into<Option<Duration>>) -> T
where
F: FutureExt<Output = T> + Send,
{
let timeout_opt: Option<Duration> = timeout.into();
let timeout_duration = timeout_opt.unwrap_or(Duration::from_millis(100));
tokio::time::timeout(timeout_duration, future)
.await
.expect("Future timed out")
}
fn mock_encoded_from_radio_packet(
payload_variant: protobufs::from_radio::PayloadVariant,
id: impl Into<Option<u32>>,
) -> (protobufs::FromRadio, Vec<u8>) {
let packet_id = id.into().unwrap_or(rand::random());
let packet = protobufs::FromRadio {
id: packet_id,
payload_variant: Some(payload_variant),
};
(packet.clone(), packet.encode_to_vec())
}
#[tokio::test]
async fn process_single_complete_packet() {
let payload_variant_1 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_1, packet_data_1) = mock_encoded_from_radio_packet(payload_variant_1, None);
let encoded_packet_1 = format_data_packet(packet_data_1.into()).unwrap();
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_packet_1.data().into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_1));
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn handle_incomplete_packet_at_end() {
let payload_variant_1 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let payload_variant_2 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_1, packet_data_1) = mock_encoded_from_radio_packet(payload_variant_1, None);
let (_packet_2, packet_data_2) = mock_encoded_from_radio_packet(payload_variant_2, None);
let encoded_packet_1 = format_data_packet(packet_data_1.into()).unwrap();
let encoded_packet_2 = format_data_packet(packet_data_2.into()).unwrap();
let incomplete_encoded_packet_2 = encoded_packet_2
.data_vec()
.into_iter()
.take(6)
.collect::<Vec<u8>>();
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_packet_1.data().into());
buffer.process_incoming_bytes(incomplete_encoded_packet_2.clone().into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_1));
assert_eq!(buffer.buffer.len(), 6);
assert_eq!(buffer.buffer, incomplete_encoded_packet_2);
}
#[tokio::test]
async fn process_multiple_complete_packets() {
let payload_variant_1 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let payload_variant_2 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_1, packet_data_1) = mock_encoded_from_radio_packet(payload_variant_1, None);
let (packet_2, packet_data_2) = mock_encoded_from_radio_packet(payload_variant_2, None);
let encoded_packet_1 = format_data_packet(packet_data_1.into()).unwrap();
let encoded_packet_2 = format_data_packet(packet_data_2.into()).unwrap();
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_packet_1.data().into());
buffer.process_incoming_bytes(encoded_packet_2.data().into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_1));
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_2));
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn handle_malformed_packet_amid_valid_packets() {
let payload_variant_1 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let payload_variant_2 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let payload_variant_3 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_1, packet_data_1) = mock_encoded_from_radio_packet(payload_variant_1, None);
let (_packet_2, packet_data_2) = mock_encoded_from_radio_packet(payload_variant_2, None);
let (packet_3, packet_data_3) = mock_encoded_from_radio_packet(payload_variant_3, None);
let encoded_packet_1 = format_data_packet(packet_data_1.into()).unwrap();
let encoded_packet_2 = format_data_packet(packet_data_2.into()).unwrap();
let encoded_packet_3 = format_data_packet(packet_data_3.into()).unwrap();
let malformed_encoded_packet_2 = encoded_packet_2
.data_vec()
.into_iter()
.take(6)
.collect::<Vec<u8>>();
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_packet_1.data().into());
buffer.process_incoming_bytes(malformed_encoded_packet_2.clone().into());
buffer.process_incoming_bytes(encoded_packet_3.data().into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_1));
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_3));
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn handle_buffer_ending_with_false_start() {
let payload_variant_1 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_1, packet_data_1) = mock_encoded_from_radio_packet(payload_variant_1, None);
let encoded_packet_1 = format_data_packet(packet_data_1.into()).unwrap();
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_packet_1.data().into());
buffer.process_incoming_bytes(vec![0x94].into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_1));
assert_eq!(buffer.buffer, vec![0x94]);
}
#[tokio::test]
async fn clear_buffer_on_invalid_packet_start() {
let malformed_packet_1 = vec![0x94, 0x00, 0x94, 0x94, 0x00];
let (mock_tx, mut _mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(malformed_packet_1.into());
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn process_after_repeated_false_starts() {
let payload_variant_2 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_2, packet_data_2) = mock_encoded_from_radio_packet(payload_variant_2, None);
let encoded_packet_2 = format_data_packet(packet_data_2.into()).unwrap();
let malformed_packet_1 = vec![0x94, 0x00, 0x94, 0x94, 0x00];
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(malformed_packet_1.into());
buffer.process_incoming_bytes(encoded_packet_2.data().into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_2));
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn process_large_packet_spanning_multiple_chunks() {
let payload_variant_1 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_1, packet_data_1) = mock_encoded_from_radio_packet(payload_variant_1, None);
let encoded_packet_1 = format_data_packet(packet_data_1.into()).unwrap();
let encoded_packet_1_chunk_1 = encoded_packet_1
.clone()
.data_vec()
.into_iter()
.take(6)
.collect::<Vec<u8>>();
let encoded_packet_1_chunk_2 = encoded_packet_1
.data_vec()
.into_iter()
.skip(6)
.collect::<Vec<u8>>();
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_packet_1_chunk_1.into());
buffer.process_incoming_bytes(encoded_packet_1_chunk_2.into());
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_1));
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn process_packet_with_zero_length() {
let payload_variant_2 =
protobufs::from_radio::PayloadVariant::MyInfo(protobufs::MyNodeInfo::default());
let (packet_2, packet_data_2) = mock_encoded_from_radio_packet(payload_variant_2, None);
let encoded_packet_2 = format_data_packet(packet_data_2.into()).unwrap();
let encoded_zero_length_packet = vec![0x94, 0xc3, 0x00, 0x00];
let (mock_tx, mut mock_rx) = unbounded_channel::<protobufs::FromRadio>();
let mut buffer = StreamBuffer::new(mock_tx);
buffer.process_incoming_bytes(encoded_zero_length_packet.into());
buffer.process_incoming_bytes(encoded_packet_2.data().into());
let empty_packet = protobufs::FromRadio {
id: 0,
payload_variant: None,
};
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(empty_packet));
assert_eq!(timeout_test(mock_rx.recv(), None).await, Some(packet_2));
assert_eq!(buffer.buffer.len(), 0);
}
#[tokio::test]
async fn detect_malformed_packets_with_internal_header_sequence() {}
}