use heapless::Vec;
use super::ExtendedHeader;
pub use super::ExtendedHeader as ChunkExtendedHeader;
use crate::protocol_layer::message::ParseError;
use crate::protocol_layer::message::header::{ExtendedMessageType, Header};
pub const MAX_EXTENDED_MSG_CHUNK_LEN: usize = 26;
pub const MAX_EXTENDED_MSG_LEN: usize = 260;
pub const MAX_CHUNKS: usize = MAX_EXTENDED_MSG_LEN / MAX_EXTENDED_MSG_CHUNK_LEN;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ChunkInfo {
pub chunk_number: u8,
pub total_data_size: u16,
pub request_chunk: bool,
pub message_type: ExtendedMessageType,
pub header: Header,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ChunkResult<T> {
Complete(T),
NeedMoreChunks(u8),
ChunkRequested(u8),
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ChunkedMessageAssembler {
buffer: Vec<u8, MAX_EXTENDED_MSG_LEN>,
expected_size: u16,
received_bytes: usize,
message_type: Option<ExtendedMessageType>,
header_template: Option<Header>,
next_chunk: u8,
in_progress: bool,
}
impl Default for ChunkedMessageAssembler {
fn default() -> Self {
Self::new()
}
}
impl ChunkedMessageAssembler {
pub const fn new() -> Self {
Self {
buffer: Vec::new(),
expected_size: 0,
received_bytes: 0,
message_type: None,
header_template: None,
next_chunk: 0,
in_progress: false,
}
}
pub fn reset(&mut self) {
*self = Self::new();
}
pub fn new_from_chunk(
header: Header,
ext_header: ExtendedHeader,
chunk_data: &[u8],
) -> Result<(Self, ChunkResult<Vec<u8, MAX_EXTENDED_MSG_LEN>>), ParseError> {
let mut assembler = Self::new();
let result = assembler.process_chunk(header, ext_header, chunk_data)?;
Ok((assembler, result))
}
pub fn is_in_progress(&self) -> bool {
self.in_progress
}
pub fn message_type(&self) -> Option<ExtendedMessageType> {
self.message_type
}
pub fn process_chunk(
&mut self,
header: Header,
ext_header: ExtendedHeader,
chunk_data: &[u8],
) -> Result<ChunkResult<Vec<u8, MAX_EXTENDED_MSG_LEN>>, ParseError> {
let chunk_number = ext_header.chunk_number();
let data_size = ext_header.data_size();
let request_chunk = ext_header.request_chunk();
if request_chunk {
return Ok(ChunkResult::ChunkRequested(chunk_number));
}
if chunk_number == 0 {
if self.in_progress {
return Err(ParseError::ParserReuse);
}
self.expected_size = data_size;
self.message_type = Some(header.message_type_raw().into());
self.header_template = Some(header);
self.in_progress = true;
self.next_chunk = 0;
} else if !self.in_progress {
return Err(ParseError::Other("Received non-zero chunk without chunk 0"));
} else if chunk_number != self.next_chunk {
return Err(ParseError::Other("Unexpected chunk number"));
}
if chunk_data.len() > MAX_EXTENDED_MSG_CHUNK_LEN {
return Err(ParseError::ChunkOverflow(chunk_data.len(), MAX_EXTENDED_MSG_CHUNK_LEN));
}
if self.buffer.extend_from_slice(chunk_data).is_err() {
return Err(ParseError::Other("Chunk buffer overflow"));
}
self.received_bytes += chunk_data.len();
self.next_chunk = chunk_number + 1;
if self.received_bytes >= self.expected_size as usize {
self.in_progress = false;
let final_size = core::cmp::min(self.buffer.len(), self.expected_size as usize);
self.buffer.truncate(final_size);
Ok(ChunkResult::Complete(self.buffer.clone()))
} else {
Ok(ChunkResult::NeedMoreChunks(self.next_chunk))
}
}
pub fn build_chunk_request_header(chunk_number: u8) -> ExtendedHeader {
ExtendedHeader::new(0)
.with_chunked(true)
.with_request_chunk(true)
.with_chunk_number(chunk_number)
}
pub fn buffer(&self) -> &[u8] {
&self.buffer
}
pub fn received_bytes(&self) -> usize {
self.received_bytes
}
pub fn expected_size(&self) -> u16 {
self.expected_size
}
}
pub struct ChunkedMessageSender<'a> {
data: &'a [u8],
current_chunk: u8,
total_chunks: u8,
}
impl<'a> ChunkedMessageSender<'a> {
pub fn new(data: &'a [u8]) -> Self {
let total_chunks = if data.is_empty() {
1
} else {
data.len().div_ceil(MAX_EXTENDED_MSG_CHUNK_LEN) as u8
};
Self {
data,
current_chunk: 0,
total_chunks,
}
}
pub fn is_complete(&self) -> bool {
self.current_chunk >= self.total_chunks
}
pub fn current_chunk(&self) -> u8 {
self.current_chunk
}
pub fn total_chunks(&self) -> u8 {
self.total_chunks
}
pub fn data_size(&self) -> u16 {
self.data.len() as u16
}
pub fn get_chunk(&self, chunk_number: u8) -> Option<(ExtendedHeader, &[u8])> {
if chunk_number >= self.total_chunks {
return None;
}
let start = chunk_number as usize * MAX_EXTENDED_MSG_CHUNK_LEN;
let end = core::cmp::min(start + MAX_EXTENDED_MSG_CHUNK_LEN, self.data.len());
let chunk_data = &self.data[start..end];
let ext_header = ExtendedHeader::new(self.data.len() as u16)
.with_chunked(true)
.with_chunk_number(chunk_number);
Some((ext_header, chunk_data))
}
pub fn reset(&mut self) {
self.current_chunk = 0;
}
}
impl<'a> Iterator for ChunkedMessageSender<'a> {
type Item = (ExtendedHeader, &'a [u8]);
fn next(&mut self) -> Option<Self::Item> {
if self.is_complete() {
return None;
}
let start = self.current_chunk as usize * MAX_EXTENDED_MSG_CHUNK_LEN;
let end = core::cmp::min(start + MAX_EXTENDED_MSG_CHUNK_LEN, self.data.len());
let chunk_data = &self.data[start..end];
let ext_header = ExtendedHeader::new(self.data.len() as u16)
.with_chunked(true)
.with_chunk_number(self.current_chunk);
self.current_chunk += 1;
Some((ext_header, chunk_data))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = (self.total_chunks - self.current_chunk) as usize;
(remaining, Some(remaining))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunked_sender_single_chunk() {
let data = [1u8, 2, 3, 4, 5];
let mut sender = ChunkedMessageSender::new(&data);
assert_eq!(sender.total_chunks(), 1);
assert!(!sender.is_complete());
let (ext_hdr, chunk) = sender.next().unwrap();
assert_eq!(chunk, &data);
assert_eq!(ext_hdr.data_size(), 5);
assert_eq!(ext_hdr.chunk_number(), 0);
assert!(ext_hdr.chunked());
assert!(sender.is_complete());
assert!(sender.next().is_none());
}
#[test]
fn test_chunked_sender_multiple_chunks() {
let data = [0u8; 30];
let mut sender = ChunkedMessageSender::new(&data);
assert_eq!(sender.total_chunks(), 2);
let (ext_hdr, chunk) = sender.next().unwrap();
assert_eq!(chunk.len(), 26);
assert_eq!(ext_hdr.chunk_number(), 0);
let (ext_hdr, chunk) = sender.next().unwrap();
assert_eq!(chunk.len(), 4);
assert_eq!(ext_hdr.chunk_number(), 1);
assert!(sender.is_complete());
}
#[test]
fn test_assembler_single_chunk() {
let mut assembler = ChunkedMessageAssembler::new();
let header = Header(0x1000); let ext_header = ExtendedHeader::new(5).with_chunked(true).with_chunk_number(0);
let data = [1u8, 2, 3, 4, 5];
match assembler.process_chunk(header, ext_header, &data).unwrap() {
ChunkResult::Complete(buf) => {
assert_eq!(&buf[..], &data);
}
_ => panic!("Expected complete"),
}
}
#[test]
fn test_assembler_parser_reuse_error() {
let mut assembler = ChunkedMessageAssembler::new();
let header = Header(0x1000);
let ext_header = ExtendedHeader::new(30).with_chunked(true).with_chunk_number(0);
let data = [1u8; 26];
match assembler.process_chunk(header, ext_header, &data).unwrap() {
ChunkResult::NeedMoreChunks(next) => assert_eq!(next, 1),
_ => panic!("Expected NeedMoreChunks"),
}
let result = assembler.process_chunk(header, ext_header, &data);
assert!(matches!(result, Err(ParseError::ParserReuse)));
}
#[test]
fn test_new_from_chunk() {
let header = Header(0x1000);
let ext_header = ExtendedHeader::new(5).with_chunked(true).with_chunk_number(0);
let data = [1u8, 2, 3, 4, 5];
let (assembler, result) = ChunkedMessageAssembler::new_from_chunk(header, ext_header, &data).unwrap();
match result {
ChunkResult::Complete(buf) => assert_eq!(&buf[..], &data),
_ => panic!("Expected Complete"),
}
assert!(!assembler.is_in_progress());
}
#[test]
fn test_new_from_chunk_multi_chunk() {
let header = Header(0x1000);
let ext_header = ExtendedHeader::new(30).with_chunked(true).with_chunk_number(0);
let chunk_0 = [0u8; 26];
let (mut assembler, result) = ChunkedMessageAssembler::new_from_chunk(header, ext_header, &chunk_0).unwrap();
match result {
ChunkResult::NeedMoreChunks(next) => assert_eq!(next, 1),
_ => panic!("Expected NeedMoreChunks"),
}
assert!(assembler.is_in_progress());
let ext_header_1 = ExtendedHeader::new(30).with_chunked(true).with_chunk_number(1);
let chunk_1 = [0u8; 4];
match assembler.process_chunk(header, ext_header_1, &chunk_1).unwrap() {
ChunkResult::Complete(_) => {}
_ => panic!("Expected Complete"),
}
assert!(!assembler.is_in_progress());
}
#[test]
fn test_chunk_overflow_error() {
let mut assembler = ChunkedMessageAssembler::new();
let header = Header(0x1000);
let ext_header = ExtendedHeader::new(30).with_chunked(true).with_chunk_number(0);
let oversized_chunk = [0u8; 27];
let result = assembler.process_chunk(header, ext_header, &oversized_chunk);
assert!(matches!(
result,
Err(ParseError::ChunkOverflow(27, MAX_EXTENDED_MSG_CHUNK_LEN))
));
}
#[test]
fn test_chunked_sender_as_iterator() {
let data = [0u8; 30];
let mut sender = ChunkedMessageSender::new(&data);
let (ext_hdr0, chunk0) = sender.next().unwrap();
assert_eq!(ext_hdr0.chunk_number(), 0);
assert_eq!(chunk0.len(), 26);
let (ext_hdr1, chunk1) = sender.next().unwrap();
assert_eq!(ext_hdr1.chunk_number(), 1);
assert_eq!(chunk1.len(), 4);
assert!(sender.next().is_none());
}
#[test]
fn test_chunked_sender_for_loop() {
let data = [1u8, 2, 3, 4, 5];
let sender = ChunkedMessageSender::new(&data);
let mut count = 0;
for (ext_hdr, chunk) in sender {
assert_eq!(ext_hdr.chunk_number(), count);
assert_eq!(chunk, &data);
count += 1;
}
assert_eq!(count, 1);
}
}