use crate::error::{Error, MultipartError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MultipartHeader {
pub total: u16,
pub offset: u16,
pub fragment: u16,
}
impl MultipartHeader {
pub const WIRE_LEN: usize = 6;
pub const fn encode(self) -> [u8; Self::WIRE_LEN] {
let t = self.total.to_le_bytes();
let o = self.offset.to_le_bytes();
let f = self.fragment.to_le_bytes();
[t[0], t[1], o[0], o[1], f[0], f[1]]
}
pub fn decode(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() < Self::WIRE_LEN {
return Err(Error::Truncated {
have: bytes.len(),
need: Self::WIRE_LEN,
});
}
Ok(Self {
total: u16::from_le_bytes([bytes[0], bytes[1]]),
offset: u16::from_le_bytes([bytes[2], bytes[3]]),
fragment: u16::from_le_bytes([bytes[4], bytes[5]]),
})
}
pub const fn is_abort(self) -> bool {
self.fragment == 0 && self.offset >= self.total
}
}
#[cfg(feature = "alloc")]
mod alloc_impls {
use super::*;
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct MultipartTx<'a> {
body: &'a [u8],
offset: u16,
fragment_size: u16,
}
impl<'a> MultipartTx<'a> {
pub fn new(body: &'a [u8], fragment_size: u16) -> Result<Self, Error> {
if fragment_size == 0 {
return Err(Error::Multipart(MultipartError::BadFragmentSize));
}
if body.len() > u16::MAX as usize {
return Err(Error::BufferOverflow {
need: body.len(),
have: u16::MAX as usize,
});
}
Ok(Self {
body,
offset: 0,
fragment_size,
})
}
pub fn total(&self) -> u16 {
self.body.len() as u16
}
pub fn remaining(&self) -> u16 {
self.total().saturating_sub(self.offset)
}
}
impl<'a> Iterator for MultipartTx<'a> {
type Item = (MultipartHeader, &'a [u8]);
fn next(&mut self) -> Option<Self::Item> {
if self.offset as usize >= self.body.len() {
return None;
}
let take =
(self.body.len() - self.offset as usize).min(self.fragment_size as usize) as u16;
let frag = &self.body[self.offset as usize..self.offset as usize + take as usize];
let header = MultipartHeader {
total: self.body.len() as u16,
offset: self.offset,
fragment: take,
};
self.offset += take;
Some((header, frag))
}
}
#[derive(Debug, Clone, Default)]
pub struct MultipartRx {
buf: Vec<u8>,
total: Option<u16>,
}
impl MultipartRx {
pub fn new() -> Self {
Self::default()
}
pub fn reset(&mut self) {
self.buf.clear();
self.total = None;
}
pub fn push(
&mut self,
header: MultipartHeader,
data: &[u8],
) -> Result<Option<Vec<u8>>, Error> {
if header.is_abort() {
self.reset();
return Err(Error::Multipart(MultipartError::Aborted));
}
if header.fragment as usize != data.len() {
return Err(Error::Multipart(MultipartError::BadFragmentSize));
}
match self.total {
None => {
if header.offset != 0 {
return Err(Error::Multipart(MultipartError::UnexpectedFirstOffset(
header.offset,
)));
}
self.total = Some(header.total);
self.buf.clear();
self.buf.reserve(header.total as usize);
}
Some(t) if t != header.total => {
return Err(Error::Multipart(MultipartError::InconsistentTotal {
first: t,
now: header.total,
}));
}
_ => {
if header.offset as usize != self.buf.len() {
return Err(Error::Multipart(MultipartError::OutOfOrderOffset {
expected: self.buf.len() as u16,
got: header.offset,
}));
}
}
}
if (header.offset as usize) + data.len() > header.total as usize {
return Err(Error::Multipart(MultipartError::OverflowsTotal));
}
self.buf.extend_from_slice(data);
if self.buf.len() == header.total as usize {
let total = header.total;
let mut out = core::mem::take(&mut self.buf);
self.total = None;
out.truncate(total as usize);
Ok(Some(out))
} else {
Ok(None)
}
}
pub fn progress(&self) -> usize {
self.buf.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_then_assemble() {
let body: Vec<u8> = (0u16..1024).flat_map(|n| n.to_le_bytes()).collect();
let tx = MultipartTx::new(&body, 200).unwrap();
let mut rx = MultipartRx::new();
let mut out = None;
for (h, frag) in tx {
if let Some(v) = rx.push(h, frag).unwrap() {
out = Some(v);
}
}
assert_eq!(out.unwrap(), body);
}
#[test]
fn reject_out_of_order() {
let mut rx = MultipartRx::new();
let h0 = MultipartHeader {
total: 8,
offset: 0,
fragment: 4,
};
rx.push(h0, &[0u8; 4]).unwrap();
let bogus = MultipartHeader {
total: 8,
offset: 6,
fragment: 2,
};
assert!(rx.push(bogus, &[0u8; 2]).is_err());
}
#[test]
fn reject_inconsistent_total() {
let mut rx = MultipartRx::new();
let h0 = MultipartHeader {
total: 8,
offset: 0,
fragment: 4,
};
rx.push(h0, &[0u8; 4]).unwrap();
let bogus = MultipartHeader {
total: 12,
offset: 4,
fragment: 4,
};
assert!(rx.push(bogus, &[0u8; 4]).is_err());
}
#[test]
fn abort_signal() {
let mut rx = MultipartRx::new();
let h0 = MultipartHeader {
total: 8,
offset: 0,
fragment: 4,
};
rx.push(h0, &[0u8; 4]).unwrap();
let abort = MultipartHeader {
total: 8,
offset: 8,
fragment: 0,
};
assert!(matches!(
rx.push(abort, &[]),
Err(Error::Multipart(MultipartError::Aborted))
));
}
#[test]
fn header_roundtrip() {
let h = MultipartHeader {
total: 0x1234,
offset: 0x0050,
fragment: 0x0010,
};
let bytes = h.encode();
assert_eq!(MultipartHeader::decode(&bytes).unwrap(), h);
}
}
}
#[cfg(feature = "alloc")]
pub use alloc_impls::{MultipartRx, MultipartTx};