use std::{
collections::VecDeque,
io::{Read, Seek, Write},
};
use crate::{
compress::{DecompressorConfig, PushDecompressor},
error::{GeneralError, ProtocolError, ProtocolErrorKind},
header::WarcHeader,
io::LogicalPosition,
};
const BUFFER_LENGTH: usize = crate::io::IO_BUFFER_LENGTH;
const MAX_HEADER_LENGTH: usize = 32768;
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct DecoderConfig {
pub decompressor: DecompressorConfig,
}
#[derive(Debug)]
pub struct DecStateHeader;
#[derive(Debug, Default)]
pub struct DecStateBlock {
is_end: bool,
}
#[derive(Debug)]
pub struct Decoder<S, R: Read> {
state: S,
input: R,
push_decoder: PushDecoder,
logical_position: u64,
buf: Vec<u8>,
}
impl<S, R: Read> Decoder<S, R> {
pub fn get_ref(&self) -> &R {
&self.input
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.input
}
pub fn record_boundary_position(&self) -> u64 {
self.push_decoder.record_boundary_position()
}
fn read_into_push_decoder(&mut self) -> std::io::Result<usize> {
tracing::trace!("read into push decoder");
self.buf.resize(BUFFER_LENGTH, 0);
let read_length = self.input.read(&mut self.buf)?;
self.buf.truncate(read_length);
self.logical_position += read_length as u64;
self.push_decoder.write_all(&self.buf)?;
if read_length == 0 {
self.push_decoder.write_eof();
}
tracing::trace!(read_length, "read into push decoder");
Ok(read_length)
}
fn read_nonzero_into_push_decoder(&mut self) -> std::io::Result<()> {
let read_length = self.read_into_push_decoder()?;
if read_length == 0 {
Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))
} else {
Ok(())
}
}
pub fn reset(mut self) -> std::io::Result<Decoder<DecStateHeader, R>> {
self.push_decoder.reset()?;
Ok(Decoder {
state: DecStateHeader,
input: self.input,
push_decoder: self.push_decoder,
logical_position: self.logical_position,
buf: self.buf,
})
}
}
impl<R: Read> Decoder<DecStateHeader, R> {
pub fn new(input: R, config: DecoderConfig) -> std::io::Result<Self> {
let push_decoder = PushDecoder::new(config)?;
Ok(Self {
state: DecStateHeader,
input,
push_decoder,
logical_position: 0,
buf: Vec::with_capacity(BUFFER_LENGTH),
})
}
pub fn into_inner(self) -> R {
self.input
}
pub fn has_record_at_time_compression_fault(&self) -> bool {
self.push_decoder.has_record_at_time_compression_fault()
}
pub fn has_next_record(&mut self) -> std::io::Result<bool> {
if self.push_decoder.is_ready() {
self.read_into_push_decoder()?;
}
Ok(!self.push_decoder.is_ready())
}
pub fn read_header(mut self) -> Result<(WarcHeader, Decoder<DecStateBlock, R>), GeneralError> {
loop {
match self.push_decoder.get_event()? {
PushDecoderEvent::Ready | PushDecoderEvent::WantData => {
self.read_nonzero_into_push_decoder()?;
continue;
}
PushDecoderEvent::WantDataOrEof => unreachable!(),
PushDecoderEvent::Continue => continue,
PushDecoderEvent::Header { header } => {
return Ok((
header,
Decoder {
state: DecStateBlock::default(),
input: self.input,
push_decoder: self.push_decoder,
buf: self.buf,
logical_position: self.logical_position,
},
));
}
PushDecoderEvent::BlockData { data: _ } => unreachable!(),
PushDecoderEvent::EndRecord => unreachable!(),
}
}
}
}
impl<R: Read + Seek> Decoder<DecStateHeader, R> {
pub fn prepare_for_seek(&mut self) -> Result<(), GeneralError> {
if self
.push_decoder
.config
.decompressor
.format
.supports_concatenation()
{
loop {
self.read_into_push_decoder()?;
match self.push_decoder.get_event()? {
PushDecoderEvent::Ready
| PushDecoderEvent::WantData
| PushDecoderEvent::WantDataOrEof
| PushDecoderEvent::Continue => {}
PushDecoderEvent::Header { .. }
| PushDecoderEvent::BlockData { .. }
| PushDecoderEvent::EndRecord => break,
}
}
self.input.seek(std::io::SeekFrom::Start(0))?;
self.push_decoder.reset()?;
}
Ok(())
}
}
impl<R: Read> Decoder<DecStateBlock, R> {
fn read_block_impl(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.state.is_end {
return Ok(0);
}
if buf.is_empty() {
return Ok(0);
}
self.push_decoder.set_max_buffer_len(buf.len());
loop {
match self
.push_decoder
.get_event()
.map_err(std::io::Error::other)?
{
PushDecoderEvent::Ready => unreachable!(),
PushDecoderEvent::WantData => {
self.read_nonzero_into_push_decoder()?;
continue;
}
PushDecoderEvent::WantDataOrEof => {
self.read_into_push_decoder()?;
continue;
}
PushDecoderEvent::Continue => continue,
PushDecoderEvent::Header { header: _ } => unreachable!(),
PushDecoderEvent::BlockData { data } => {
debug_assert!(data.len() <= buf.len());
let buf_upper = buf.len().min(data.len());
tracing::trace!(read_length = buf_upper, "read block");
buf[0..buf_upper].copy_from_slice(&data[0..buf_upper]);
return Ok(buf_upper);
}
PushDecoderEvent::EndRecord => {
self.state.is_end = true;
return Ok(0);
}
}
}
}
pub fn finish_block(mut self) -> Result<Decoder<DecStateHeader, R>, GeneralError> {
tracing::trace!("finish block");
self.read_remaining_block()?;
Ok(Decoder {
state: DecStateHeader,
input: self.input,
push_decoder: self.push_decoder,
logical_position: self.logical_position,
buf: self.buf,
})
}
fn read_remaining_block(&mut self) -> Result<(), GeneralError> {
tracing::trace!("read remaining block");
self.push_decoder.set_max_buffer_len(BUFFER_LENGTH);
while !self.state.is_end {
match self.push_decoder.get_event()? {
PushDecoderEvent::Ready => unreachable!(),
PushDecoderEvent::WantData => {
self.read_nonzero_into_push_decoder()?;
continue;
}
PushDecoderEvent::WantDataOrEof => {
self.read_into_push_decoder()?;
continue;
}
PushDecoderEvent::Continue => continue,
PushDecoderEvent::Header { header: _ } => unreachable!(),
PushDecoderEvent::BlockData { data: _ } => continue,
PushDecoderEvent::EndRecord => self.state.is_end = true,
}
}
Ok(())
}
}
impl<R: Read> Read for Decoder<DecStateBlock, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_block_impl(buf)
}
}
impl<R: Read, S> LogicalPosition for Decoder<S, R> {
fn logical_position(&self) -> u64 {
self.logical_position
}
}
#[derive(Debug)]
pub enum PushDecoderEvent<'a> {
Ready,
WantData,
WantDataOrEof,
Continue,
Header { header: WarcHeader },
BlockData { data: &'a [u8] },
EndRecord,
}
impl<'a> PushDecoderEvent<'a> {
pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready)
}
pub fn is_want_data(&self) -> bool {
matches!(self, Self::WantData)
}
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue)
}
pub fn is_header(&self) -> bool {
matches!(self, Self::Header { .. })
}
pub fn is_block_data(&self) -> bool {
matches!(self, Self::BlockData { .. })
}
pub fn as_header(&self) -> Option<&WarcHeader> {
if let Self::Header { header } = self {
Some(header)
} else {
None
}
}
pub fn as_block_data(&self) -> Option<&'a [u8]> {
if let Self::BlockData { data } = self {
Some(data)
} else {
None
}
}
#[must_use]
pub fn is_end_record(&self) -> bool {
matches!(self, Self::EndRecord)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PushDecoderState {
PendingHeader,
Header,
Block,
RecordBoundary,
EndOfSegment,
}
#[derive(Debug)]
pub struct PushDecoder {
config: DecoderConfig,
state: PushDecoderState,
decompressor: PushDecompressor<VecDeque<u8>>,
decompressor_eof: bool,
input_eof: bool,
unused_input_buf: VecDeque<u8>,
bytes_consumed: u64,
record_boundary_position: u64,
block_length: u64,
block_current_position: u64,
buf_output_max_len: usize,
buf_output_reference_len: usize,
has_rat_comp_fault: bool,
}
impl PushDecoder {
pub fn new(config: DecoderConfig) -> std::io::Result<Self> {
let decompressor =
PushDecompressor::with_config(VecDeque::new(), config.decompressor.clone())?;
Ok(Self {
config,
state: PushDecoderState::PendingHeader,
decompressor,
decompressor_eof: false,
input_eof: false,
unused_input_buf: VecDeque::with_capacity(BUFFER_LENGTH),
bytes_consumed: 0,
record_boundary_position: 0,
block_length: 0,
block_current_position: 0,
buf_output_max_len: BUFFER_LENGTH,
buf_output_reference_len: 0,
has_rat_comp_fault: false,
})
}
pub fn record_boundary_position(&self) -> u64 {
self.record_boundary_position
}
pub fn has_next_record(&self) -> bool {
!self.unused_input_buf.is_empty()
}
pub fn max_buffer_len(&self) -> usize {
self.buf_output_max_len
}
pub fn set_max_buffer_len(&mut self, value: usize) {
if value != 0 {
self.buf_output_max_len = value;
} else {
self.buf_output_max_len = BUFFER_LENGTH;
}
}
pub fn has_record_at_time_compression_fault(&self) -> bool {
self.has_rat_comp_fault
}
pub fn is_ready(&self) -> bool {
matches!(self.state, PushDecoderState::PendingHeader)
}
pub fn get_event(&mut self) -> Result<PushDecoderEvent, GeneralError> {
self.decompressor
.get_mut()
.drain(0..self.buf_output_reference_len);
self.buf_output_reference_len = 0;
match self.state {
PushDecoderState::PendingHeader => Ok(PushDecoderEvent::Ready),
PushDecoderState::Header => self.process_header(),
PushDecoderState::Block => self.process_block(),
PushDecoderState::RecordBoundary => self.process_record_boundary(),
PushDecoderState::EndOfSegment => self.process_end_of_segment(),
}
}
pub fn reset(&mut self) -> std::io::Result<()> {
self.state = PushDecoderState::PendingHeader;
self.decompressor.get_mut().clear();
self.unused_input_buf.clear();
self.decompressor.start_next_segment()?;
Ok(())
}
fn process_header(&mut self) -> Result<PushDecoderEvent, GeneralError> {
let buf = self.decompressor.get_mut().make_contiguous();
if let Some(index) = crate::parse::scan_header_deliminator(buf) {
let header = self.process_decodable_header(index)?;
return Ok(PushDecoderEvent::Header { header });
}
self.check_max_header_length()?;
Ok(PushDecoderEvent::WantData)
}
fn process_decodable_header(&mut self, index: usize) -> Result<WarcHeader, GeneralError> {
let (buf, _slice1) = self.decompressor.get_ref().as_slices();
let header_bytes = &buf[0..index];
let header = WarcHeader::parse(header_bytes)?;
let length = header.content_length()?;
let record_id = header.fields.get("WARC-Record-ID");
let warc_type = header.fields.get("WARC-Type");
self.decompressor.get_mut().drain(0..index);
tracing::trace!(
record_id,
warc_type,
content_length = length,
"process decodable header"
);
self.block_current_position = 0;
self.block_length = length;
tracing::trace!("Header -> Block");
self.state = PushDecoderState::Block;
Ok(header)
}
fn check_max_header_length(&self) -> Result<(), ProtocolError> {
tracing::trace!("check max header length");
if self.decompressor.get_ref().len() > MAX_HEADER_LENGTH {
Err(ProtocolError::new(ProtocolErrorKind::HeaderTooBig))
} else {
Ok(())
}
}
fn process_block(&mut self) -> Result<PushDecoderEvent, GeneralError> {
tracing::trace!(
self.block_length,
self.block_current_position,
"process block"
);
debug_assert!(self.block_length >= self.block_current_position);
let remaining_bytes = self.block_length - self.block_current_position;
if remaining_bytes == 0 {
tracing::trace!("Block -> RecordBoundary");
self.state = PushDecoderState::RecordBoundary;
Ok(PushDecoderEvent::Continue)
} else if self.decompressor.get_ref().is_empty() {
Ok(PushDecoderEvent::WantData)
} else {
let (slice0, _slice1) = self.decompressor.get_ref().as_slices();
let consume_len = self.buf_output_max_len.min(slice0.len());
let consume_len = consume_len.min(remaining_bytes.try_into().unwrap_or(usize::MAX));
self.block_current_position += consume_len as u64;
self.buf_output_reference_len = consume_len;
tracing::trace!(consume_len, "process block");
Ok(PushDecoderEvent::BlockData {
data: &slice0[0..consume_len],
})
}
}
fn process_record_boundary(&mut self) -> Result<PushDecoderEvent, GeneralError> {
tracing::trace!(
len = self.decompressor.get_ref().len(),
"process record boundary"
);
if self.decompressor.get_ref().len() >= 4 {
let mut buf = [0u8; 4];
let mut iter = self.decompressor.get_ref().range(0..4).copied();
buf[0] = iter.next().unwrap();
buf[1] = iter.next().unwrap();
buf[2] = iter.next().unwrap();
buf[3] = iter.next().unwrap();
if !buf.starts_with(b"\r\n\r\n") {
Err(ProtocolError::new(ProtocolErrorKind::InvalidRecordBoundary).into())
} else {
self.decompressor.get_mut().drain(0..4);
self.state = PushDecoderState::EndOfSegment;
Ok(PushDecoderEvent::Continue)
}
} else {
Ok(PushDecoderEvent::WantData)
}
}
fn process_end_of_segment(&mut self) -> Result<PushDecoderEvent, GeneralError> {
tracing::trace!(self.decompressor_eof, "process end of segment");
if self.config.decompressor.format.supports_concatenation()
&& self.decompressor.get_ref().is_empty()
&& !self.decompressor_eof
&& !self.input_eof
{
Ok(PushDecoderEvent::WantDataOrEof)
} else {
self.reset_for_next_record()?;
Ok(PushDecoderEvent::EndRecord)
}
}
fn reset_for_next_record(&mut self) -> Result<(), GeneralError> {
tracing::trace!(
remain_decomp_len = self.decompressor.get_ref().len(),
"reset for next record"
);
if self.config.decompressor.format.supports_concatenation()
&& self.decompressor.get_ref().is_empty()
{
self.decompressor.start_next_segment()?;
} else if self.config.decompressor.format.supports_concatenation()
&& !self.has_rat_comp_fault
{
tracing::warn!("file is not using Record-at-time compression");
self.has_rat_comp_fault = true;
}
self.record_boundary_position = self.bytes_consumed;
self.decompressor_eof = false;
self.input_eof = false;
self.consume_unused_input()?;
if self.decompressor.get_ref().is_empty() {
tracing::trace!("RecordBoundary -> PendingHeader");
self.state = PushDecoderState::PendingHeader;
} else {
tracing::trace!("RecordBoundary -> Header");
self.state = PushDecoderState::Header;
}
Ok(())
}
fn consume_unused_input(&mut self) -> Result<(), GeneralError> {
tracing::trace!(len = self.unused_input_buf.len(), "consume unused input");
while !self.unused_input_buf.is_empty() {
let (slice0, _slice1) = self.unused_input_buf.as_slices();
let write_len = self.decompressor.write(slice0)?;
tracing::trace!(write_len, "consume unused input");
if write_len == 0 {
break;
}
self.bytes_consumed += write_len as u64;
self.unused_input_buf.drain(..write_len);
}
Ok(())
}
pub fn write_eof(&mut self) {
tracing::trace!("push decoder got write eof");
self.input_eof = true;
}
}
impl Write for PushDecoder {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
if self.state == PushDecoderState::PendingHeader {
tracing::trace!("PendingHeader -> Header");
self.state = PushDecoderState::Header;
}
let write_len = self.decompressor.write(buf)?;
tracing::trace!(buf_len = buf.len(), write_len, "push decoder write");
if write_len != 0 {
self.bytes_consumed += write_len as u64;
Ok(write_len)
} else {
self.decompressor_eof = true;
self.unused_input_buf.write_all(buf)?;
Ok(buf.len())
}
}
fn flush(&mut self) -> std::io::Result<()> {
self.decompressor.flush()
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
#[tracing_test::traced_test]
#[test]
fn test_reader() {
let data = b"WARC/1.1\r\n\
Content-Length: 12\r\n\
\r\n\
Hello world!\
\r\n\r\n\
WARC/1.1\r\n\
Content-Length: 0\r\n\
\r\n\
\r\n\r\n";
let reader = Decoder::new(Cursor::new(data), DecoderConfig::default()).unwrap();
let (_header, mut reader) = reader.read_header().unwrap();
let mut block = Vec::new();
reader.read_to_end(&mut block).unwrap();
let mut reader = reader.finish_block().unwrap();
assert!(reader.has_next_record().unwrap());
let (_header, mut reader) = reader.read_header().unwrap();
let mut block = Vec::new();
reader.read_to_end(&mut block).unwrap();
let mut reader = reader.finish_block().unwrap();
assert!(!reader.has_next_record().unwrap());
reader.into_inner();
}
#[tracing_test::traced_test]
#[test]
fn test_push_reader() {
let _data = b"WARC/1.1\r\n\
Content-Length: 12\r\n\
\r\n\
Hello world!\
\r\n\r\n\
WARC/1.1\r\n\
Content-Length: 0\r\n\
\r\n\
\r\n\r\n";
let mut decoder = PushDecoder::new(DecoderConfig::default()).unwrap();
let event = decoder.get_event().unwrap();
assert!(event.is_ready());
decoder.write_all(b"WARC/1.1\r\n").unwrap();
let event = decoder.get_event().unwrap();
assert!(event.is_want_data());
decoder.write_all(b"Content-Length: 12\r\n").unwrap();
decoder.write_all(b"\r\n").unwrap();
decoder.write_all(b"Hello ").unwrap();
let event = decoder.get_event().unwrap();
assert!(event.is_header());
let event = decoder.get_event().unwrap();
assert!(event.is_block_data());
assert_eq!(event.as_block_data().unwrap(), b"Hello ");
let event = decoder.get_event().unwrap();
assert!(event.is_want_data());
decoder.write_all(b"world!\r\n").unwrap();
let event = decoder.get_event().unwrap();
assert!(event.is_block_data());
assert_eq!(event.as_block_data().unwrap(), b"world!");
let event = decoder.get_event().unwrap();
assert!(event.is_continue());
let event = decoder.get_event().unwrap();
assert!(event.is_want_data());
decoder.write_all(b"\r\n").unwrap();
decoder.write_all(b"WARC/1.1\r\n").unwrap();
let event = decoder.get_event().unwrap();
assert!(event.is_continue());
let event = decoder.get_event().unwrap();
assert!(event.is_end_record());
let event = decoder.get_event().unwrap();
assert!(event.is_want_data());
decoder
.write_all(
b"Content-Length: 0\r\n\
\r\n\
\r\n\r\n",
)
.unwrap();
decoder.write_eof();
let event = decoder.get_event().unwrap();
assert!(event.is_header());
let event = decoder.get_event().unwrap();
assert!(event.is_continue());
let event = decoder.get_event().unwrap();
assert!(event.is_continue());
let event = decoder.get_event().unwrap();
assert!(event.is_end_record());
let event = decoder.get_event().unwrap();
assert!(event.is_ready());
}
}