use crate::message::ReaderOptions;
use crate::message::ReaderSegments;
use crate::private::units::BYTES_PER_WORD;
use crate::{Error, ErrorKind, Result};
use super::SEGMENTS_COUNT_LIMIT;
const U32_LEN_IN_BYTES: usize = core::mem::size_of::<u32>();
pub struct NoAllocSegmentTableInfo {
pub segments_count: usize,
pub segment_table_length_bytes: usize,
pub total_segments_length_bytes: usize,
}
fn read_segment_table(slice: &[u8], options: ReaderOptions) -> Result<NoAllocSegmentTableInfo> {
let mut remaining = slice;
verify_alignment(remaining.as_ptr())?;
let segments_count = u32_to_segments_count(read_u32_le(&mut remaining)?)?;
if segments_count >= SEGMENTS_COUNT_LIMIT {
return Err(Error::from_kind(ErrorKind::InvalidNumberOfSegments(
segments_count,
)));
}
let mut total_segments_length_bytes = 0_usize;
for _ in 0..segments_count {
let segment_length_in_bytes = u32_to_segment_length_bytes(read_u32_le(&mut remaining)?)?;
total_segments_length_bytes = total_segments_length_bytes
.checked_add(segment_length_in_bytes)
.ok_or_else(|| Error::from_kind(ErrorKind::MessageSizeOverflow))?;
}
if let Some(limit) = options.traversal_limit_in_words {
let total_segments_length_words = total_segments_length_bytes / 8;
if total_segments_length_words > limit {
return Err(Error::from_kind(ErrorKind::MessageTooLarge(
total_segments_length_words,
)));
}
}
if segments_count % 2 == 0 {
let _padding = read_u32_le(&mut remaining)?;
}
let expected_data_offset = calculate_data_offset(segments_count)
.ok_or_else(|| Error::from_kind(ErrorKind::MessageSizeOverflow))?;
let consumed_bytes = slice.len() - remaining.len();
assert_eq!(
expected_data_offset, consumed_bytes,
"Expected header size and actual header size must match, otherwise we have a bug in this code"
);
if remaining.len() < total_segments_length_bytes {
return Err(Error::from_kind(ErrorKind::MessageEndsPrematurely(
total_segments_length_bytes / BYTES_PER_WORD,
remaining.len() / BYTES_PER_WORD,
)));
}
Ok(NoAllocSegmentTableInfo {
segments_count,
segment_table_length_bytes: expected_data_offset,
total_segments_length_bytes,
})
}
pub type NoAllocSliceSegments<'b> = NoAllocBufferSegments<&'b [u8]>;
enum NoAllocBufferSegmentType {
SingleSegment(usize),
MultipleSegments,
}
pub struct NoAllocBufferSegments<T> {
buffer: T,
segment_type: NoAllocBufferSegmentType,
}
impl<T> NoAllocBufferSegments<T> {
pub fn from_segment_table_info(buffer: T, info: NoAllocSegmentTableInfo) -> Self {
if info.segments_count == 1 {
Self {
buffer,
segment_type: NoAllocBufferSegmentType::SingleSegment(
info.total_segments_length_bytes,
),
}
} else {
Self {
buffer,
segment_type: NoAllocBufferSegmentType::MultipleSegments,
}
}
}
}
impl<'b> NoAllocBufferSegments<&'b [u8]> {
pub fn from_slice(slice: &mut &'b [u8], options: ReaderOptions) -> Result<Self> {
let segment_table_info = read_segment_table(slice, options)?;
let message_length = segment_table_info.segment_table_length_bytes
+ segment_table_info.total_segments_length_bytes;
let message = &slice[..message_length];
*slice = &slice[message_length..];
Ok(Self::from_segment_table_info(message, segment_table_info))
}
}
impl<T: AsRef<[u8]>> NoAllocBufferSegments<T> {
pub fn from_buffer(buffer: T, options: ReaderOptions) -> Result<Self> {
let segment_table_info = read_segment_table(buffer.as_ref(), options)?;
Ok(Self::from_segment_table_info(buffer, segment_table_info))
}
}
impl<T: AsRef<[u8]>> ReaderSegments for NoAllocBufferSegments<T> {
fn get_segment(&self, idx: u32) -> Option<&[u8]> {
let idx: usize = idx.try_into().unwrap();
match self.segment_type {
NoAllocBufferSegmentType::SingleSegment(length_bytes) => {
if idx == 0 {
Some(&self.buffer.as_ref()[8..8 + length_bytes])
} else {
None
}
}
NoAllocBufferSegmentType::MultipleSegments => {
let mut buf = self.buffer.as_ref();
let segments_count = u32_to_segments_count(read_u32_le(&mut buf).unwrap()).unwrap();
if idx >= segments_count {
return None;
}
let mut segment_offset = calculate_data_offset(segments_count).unwrap();
for _ in 0..idx {
segment_offset = segment_offset
.checked_add(
u32_to_segment_length_bytes(read_u32_le(&mut buf).unwrap()).unwrap(),
)
.unwrap();
}
let segment_length =
u32_to_segment_length_bytes(read_u32_le(&mut buf).unwrap()).unwrap();
Some(&self.buffer.as_ref()[segment_offset..(segment_offset + segment_length)])
}
}
}
fn len(&self) -> usize {
match self.segment_type {
NoAllocBufferSegmentType::SingleSegment { .. } => 1,
NoAllocBufferSegmentType::MultipleSegments => {
u32_to_segments_count(read_u32_le(&mut self.buffer.as_ref()).unwrap()).unwrap()
}
}
}
}
fn verify_alignment(ptr: *const u8) -> Result<()> {
if cfg!(feature = "unaligned") {
return Ok(());
}
if ptr.align_offset(BYTES_PER_WORD) == 0 {
Ok(())
} else {
Err(Error::from_kind(
ErrorKind::MessageNotAlignedBy8BytesBoundary,
))
}
}
fn read_u32_le(slice: &mut &[u8]) -> Result<u32> {
if slice.len() < U32_LEN_IN_BYTES {
return Err(Error::from_kind(ErrorKind::MessageEndsPrematurely(
U32_LEN_IN_BYTES,
slice.len(),
)));
}
let u32_buf: [u8; U32_LEN_IN_BYTES] = slice[..U32_LEN_IN_BYTES].try_into().unwrap();
*slice = &slice[U32_LEN_IN_BYTES..];
Ok(u32::from_le_bytes(u32_buf))
}
fn u32_to_segments_count(val: u32) -> Result<usize> {
let result: Option<usize> = val.try_into().ok();
let result = result.and_then(|v: usize| v.checked_add(1));
result.ok_or_else(|| Error::from_kind(ErrorKind::FourByteLengthTooBigForUSize))
}
fn u32_to_segment_length_bytes(val: u32) -> Result<usize> {
let length_in_words: Option<usize> = val.try_into().ok();
let length_in_bytes = length_in_words.and_then(|l| l.checked_mul(BYTES_PER_WORD));
length_in_bytes.ok_or_else(|| Error::from_kind(ErrorKind::FourByteSegmentLengthTooBigForUSize))
}
fn calculate_data_offset(segments_count: usize) -> Option<usize> {
if segments_count == 0 {
return None;
}
let mut data_offset = 0_usize;
{
let segments_count_len = U32_LEN_IN_BYTES;
data_offset = data_offset.checked_add(segments_count_len)?;
}
{
let segments_lengt_len = segments_count.checked_mul(U32_LEN_IN_BYTES)?;
data_offset = data_offset.checked_add(segments_lengt_len)?;
}
let padding_len = match data_offset % BYTES_PER_WORD {
0 => 0,
4 => 4,
_ => unreachable!(
"Mis-alignment by anything other than 4 should be impossible, this is a bug"
),
};
data_offset = data_offset.checked_add(padding_len)?;
assert_eq!(
data_offset % BYTES_PER_WORD,
0,
"data_offset after adding panic must be aligned by 8. \
If it's not, it's a bug"
);
Some(data_offset)
}
#[cfg(test)]
mod tests {
#[cfg(feature = "alloc")]
use quickcheck::{quickcheck, TestResult};
use super::calculate_data_offset;
#[cfg(feature = "alloc")]
use crate::{
message::{ReaderOptions, ReaderSegments},
serialize, word, Word,
};
#[cfg(feature = "alloc")]
use crate::OutputSegments;
use super::{
read_u32_le, u32_to_segment_length_bytes, u32_to_segments_count, verify_alignment,
};
#[cfg(feature = "alloc")]
use super::{NoAllocBufferSegmentType, NoAllocBufferSegments, NoAllocSliceSegments};
#[repr(align(8))]
struct Aligned([u8; 8]);
#[cfg(feature = "unaligned")]
#[test]
fn test_verify_alignment_unaligned_mode() {
assert_eq!(core::mem::size_of::<Aligned>(), 8);
let aligned = Aligned([0; 8]);
for idx in 0..8 {
verify_alignment(unsafe { aligned.0.as_ptr().add(idx) }).unwrap();
}
}
#[cfg(not(feature = "unaligned"))]
#[test]
fn test_verify_alignment() {
assert_eq!(core::mem::size_of::<Aligned>(), 8);
let aligned = Aligned([0; 8]);
verify_alignment(aligned.0.as_ptr()).unwrap();
for idx in 1..8 {
verify_alignment(unsafe { aligned.0.as_ptr().add(idx) }).unwrap_err();
}
}
#[test]
fn test_read_u32_le() {
let buffer = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let mut buffer_remaining = &buffer[..];
assert_eq!(read_u32_le(&mut buffer_remaining).unwrap(), 0x04030201);
assert_eq!(buffer_remaining, &buffer[4..]);
}
#[test]
fn test_read_u32_le_truncated() {
let buffer = [0x01, 0x02, 0x03];
let mut buffer_remaining = &buffer[..];
read_u32_le(&mut buffer_remaining).unwrap_err();
assert_eq!(buffer_remaining, &buffer[..]);
}
#[test]
fn test_u32_to_segments_count() {
assert_eq!(u32_to_segments_count(0).unwrap(), 1);
assert_eq!(u32_to_segments_count(10).unwrap(), 11);
}
#[test]
fn test_u32_to_segment_length_bytes() {
assert_eq!(u32_to_segment_length_bytes(0).unwrap(), 0);
assert_eq!(u32_to_segment_length_bytes(123).unwrap(), 123 * 8);
}
#[test]
fn test_calculate_data_offset_no_padding() {
assert_eq!(calculate_data_offset(0), None);
assert_eq!(calculate_data_offset(1), Some(8));
assert_eq!(calculate_data_offset(2), Some(16));
assert_eq!(calculate_data_offset(3), Some(16));
assert_eq!(calculate_data_offset(100), Some(408));
assert_eq!(calculate_data_offset(101), Some(408));
}
#[cfg(feature = "alloc")]
quickcheck! {
#[cfg_attr(miri, ignore)] fn test_no_alloc_buffer_segments_single_segment_optimization(
segment_0 : alloc::vec::Vec<Word>) -> TestResult
{
let words = &segment_0[..];
let bytes = Word::words_to_bytes(words);
let output_segments = OutputSegments::SingleSegment([bytes]);
let mut msg = vec![];
serialize::write_message_segments(&mut msg, &output_segments).unwrap();
let no_alloc_segments =
NoAllocSliceSegments::from_slice(&mut msg.as_slice(), ReaderOptions::new()).unwrap();
assert!(matches!(
no_alloc_segments,
NoAllocBufferSegments { buffer: _,
segment_type : NoAllocBufferSegmentType::SingleSegment { .. },
}
));
assert_eq!(no_alloc_segments.len(), 1);
assert_eq!(no_alloc_segments.get_segment(0), Some(bytes));
assert_eq!(no_alloc_segments.get_segment(1), None);
TestResult::from_bool(true)
}
#[cfg_attr(miri, ignore)] fn test_no_alloc_buffer_segments_multiple_segments(segments_vec: alloc::vec::Vec<alloc::vec::Vec<Word>>) -> TestResult {
if segments_vec.is_empty() { return TestResult::discard() };
let segments: alloc::vec::Vec<_> = segments_vec.iter().map(|s|
Word::words_to_bytes(s.as_slice())).collect();
let output_segments = OutputSegments::MultiSegment(segments.clone());
let mut msg = vec![];
serialize::write_message_segments(&mut msg, &output_segments).unwrap();
let no_alloc_segments =
NoAllocSliceSegments::from_slice(&mut msg.as_slice(), ReaderOptions::new()).unwrap();
assert_eq!(no_alloc_segments.len(), segments.len());
for (i, segment) in segments.iter().enumerate() {
assert_eq!(no_alloc_segments.get_segment(i as u32), Some(*segment));
}
assert_eq!(
no_alloc_segments.get_segment(no_alloc_segments.len() as u32),
None
);
TestResult::from_bool(true)
}
}
#[cfg(feature = "alloc")]
#[test]
fn test_no_alloc_buffer_segments_message_postfix() {
let output_segments = OutputSegments::SingleSegment([&[1, 2, 3, 4, 5, 6, 7, 8]]);
let mut buf = Word::allocate_zeroed_vec(2);
serialize::write_message_segments(Word::words_to_bytes_mut(&mut buf), &output_segments)
.unwrap();
buf.push(word(11, 12, 13, 14, 15, 16, 0, 0));
let remaining = &mut Word::words_to_bytes(&buf);
NoAllocSliceSegments::from_slice(remaining, ReaderOptions::new()).unwrap();
assert_eq!(*remaining, &[11, 12, 13, 14, 15, 16, 0, 0]);
}
#[cfg(feature = "alloc")]
#[test]
fn test_no_alloc_buffer_segments_message_invalid() {
let mut buf = vec![];
buf.extend([0, 2, 0, 0]); buf.extend([0; 513 * 8]);
assert!(NoAllocSliceSegments::from_slice(&mut &buf[..], ReaderOptions::new()).is_err());
buf.clear();
buf.extend([0, 0, 0, 0]); assert!(NoAllocSliceSegments::from_slice(&mut &buf[..], ReaderOptions::new()).is_err());
buf.clear();
buf.extend([0, 0, 0, 0]); buf.extend([0; 3]);
assert!(NoAllocSliceSegments::from_slice(&mut &buf[..], ReaderOptions::new()).is_err());
buf.clear();
buf.extend([255, 255, 255, 255]); assert!(NoAllocSliceSegments::from_slice(&mut &buf[..], ReaderOptions::new()).is_err());
buf.clear();
}
#[cfg(feature = "alloc")]
quickcheck! {
#[cfg_attr(miri, ignore)] fn test_no_alloc_buffer_segments_message_truncated(segments_vec: alloc::vec::Vec<alloc::vec::Vec<Word>>) -> TestResult {
if segments_vec.is_empty() { return TestResult::discard() }
let segments: alloc::vec::Vec<_> = segments_vec.iter()
.map(|s| Word::words_to_bytes(s.as_slice())).collect();
let output_segments = OutputSegments::MultiSegment(segments.clone());
let mut msg = vec![];
serialize::write_message_segments(&mut msg, &output_segments).unwrap();
msg.pop().unwrap();
let no_alloc_segments =
NoAllocSliceSegments::from_slice(&mut msg.as_slice(), ReaderOptions::new());
assert!(no_alloc_segments.is_err());
TestResult::from_bool(true)
}
#[cfg_attr(miri, ignore)] fn test_no_alloc_buffer_segments_message_options_limit(
segments_vec: alloc::vec::Vec<alloc::vec::Vec<Word>>) -> TestResult
{
let mut word_count = 0;
let segments: alloc::vec::Vec<_> = segments_vec.iter()
.map(|s| {
let ws = Word::words_to_bytes(s.as_slice());
word_count += s.len();
ws
}).collect();
if word_count == 0 { return TestResult::discard() };
let output_segments = OutputSegments::MultiSegment(segments.clone());
let mut msg = vec![];
serialize::write_message_segments(&mut msg, &output_segments).unwrap();
let mut options = ReaderOptions::new();
options.traversal_limit_in_words(Some(word_count));
let _no_alloc_segments =
NoAllocSliceSegments::from_slice(&mut msg.as_slice(), options).unwrap();
let mut options = ReaderOptions::new();
options.traversal_limit_in_words(Some(word_count - 1));
let no_alloc_segments = NoAllocSliceSegments::from_slice(&mut msg.as_slice(), options);
assert!(no_alloc_segments.is_err());
TestResult::from_bool(true)
}
#[cfg_attr(miri, ignore)] fn test_no_alloc_buffer_segments_bad_alignment(segment_0: alloc::vec::Vec<Word>) -> TestResult {
if segment_0.is_empty() { return TestResult::discard(); }
let output_segments = OutputSegments::SingleSegment([Word::words_to_bytes(&segment_0)]);
let mut msg = vec![];
serialize::write_message_segments(&mut msg, &output_segments).unwrap();
msg.insert(0_usize, 0_u8);
let no_alloc_segments = NoAllocSliceSegments::from_slice(&mut &msg[1..], ReaderOptions::new());
if cfg!(feature = "unaligned") {
no_alloc_segments.unwrap();
} else {
assert!(no_alloc_segments.is_err());
}
TestResult::from_bool(true)
}
}
}