use crate::error::Error;
pub struct ScanBuffer {
words: Vec<u32>,
start_positions: Vec<u32>,
}
impl ScanBuffer {
pub fn new() -> Self {
Self {
words: Vec::new(),
start_positions: Vec::new(),
}
}
pub fn process(
&mut self,
scan_data: &[u8],
expected_restart_intervals: u32,
) -> crate::Result<()> {
let out_bytes = scan_data.len() + scan_data.len() / 3;
self.words.resize((out_bytes + 3) / 4, 0);
let start_pos_buffer_length = (expected_restart_intervals as usize).next_power_of_two();
let start_pos_index_mask = start_pos_buffer_length - 1;
self.start_positions.resize(start_pos_buffer_length, 0);
assert!(self.start_positions.len() > start_pos_index_mask);
let out: &mut [u8] = bytemuck::cast_slice_mut(&mut self.words);
let res = preprocess_scalar(out, &mut self.start_positions, scan_data);
self.words.truncate((res.bytes_out + 3) / 4);
self.start_positions.truncate(res.ri);
if res.ri != expected_restart_intervals as usize {
return Err(Error::from(format!(
"restart interval count mismatch: counted {}, expected {}",
res.ri, expected_restart_intervals
)));
}
Ok(())
}
pub fn processed_scan_data(&self) -> &[u8] {
bytemuck::cast_slice(&self.words)
}
pub fn start_positions(&self) -> &[u8] {
bytemuck::cast_slice(&self.start_positions)
}
}
struct PreprocessResult {
ri: usize,
bytes_out: usize,
}
#[inline(never)]
fn preprocess_scalar(
out: &mut [u8],
start_positions: &mut [u32],
scan_data: &[u8],
) -> PreprocessResult {
assert!(start_positions.len().is_power_of_two());
let start_pos_mask = start_positions.len() - 1;
let mut ri = 1; let mut write_ptr = 0;
let mut bytes = scan_data.iter().copied();
loop {
match bytes.next() {
Some(0xff) => match bytes.next() {
Some(0x00) => {
out[write_ptr] = 0xff;
write_ptr += 1;
}
Some(_) => {
write_ptr = (write_ptr + 0b11) & !0b11;
start_positions[ri & start_pos_mask] = (write_ptr / 4) as u32;
ri += 1;
}
None => break,
},
Some(byte) => {
out[write_ptr] = byte;
write_ptr += 1;
}
None => break,
}
}
PreprocessResult {
ri,
bytes_out: write_ptr,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn check(scan_data: &[u8], output: &[u8], start_positions: &[u32]) {
let mut buf = ScanBuffer::new();
buf.process(scan_data, start_positions.len() as u32)
.unwrap();
let bytes: &[u8] = bytemuck::cast_slice(&buf.words);
assert_eq!(output, bytes);
assert_eq!(buf.start_positions, start_positions);
}
fn check_err(scan_data: &[u8], start_positions: &[u32]) -> Error {
let mut buf = ScanBuffer::new();
buf.process(scan_data, start_positions.len() as u32)
.unwrap_err()
}
#[test]
fn process_scan_data() {
check(&[0x12, 0x34, 0x56, 0x78], &[0x12, 0x34, 0x56, 0x78], &[0]);
check(&[0xFF, 0xD0, 0xFF, 0xD0], &[], &[0, 0, 0]);
let scan = &[0xFF, 0x00, 0x44, 0x55, 0xFF, 0xD0, 0x34];
check(scan, &[0xFF, 0x44, 0x55, 0x00, 0x34, 0, 0, 0], &[0, 1]);
}
#[test]
fn test_expanding_output() {
check(
&[0x11, 0xFF, 0xD0, 0x11, 0xFF, 0xD0, 0x11],
&[0x11, 0, 0, 0, 0x11, 0, 0, 0, 0x11, 0, 0, 0],
&[0, 1, 2],
);
}
#[test]
fn test_too_many_rst_markers() {
let err = check_err(&[0x11, 0xFF, 0xD0, 0x11, 0xFF, 0xD0, 0x11], &[0]);
assert_eq!(
err.to_string(),
"restart interval count mismatch: counted 3, expected 1"
);
}
}