use eyre::{eyre, Result};
use crate::{
codec::{
self,
decode::{DecodeState, DecodeStats, EnvelopeDetails},
format::NetworkEnvelope,
},
frame::{
buffers::{ReadBuffer, COMPRESSED_DATA_BUFFER_CAPACITY, DECOMPRESSED_DATA_BUFFER_CAPACITY},
lz4::{DecompressState, DecompressStats, LZ4Buffer},
},
};
pub(crate) enum FramedRead {
Lz4(LZ4FramedRead),
None(NoneFramedRead),
}
impl FramedRead {
pub(crate) fn lz4() -> Self {
FramedRead::Lz4(LZ4FramedRead::new())
}
pub(crate) fn none() -> Self {
FramedRead::None(NoneFramedRead::new())
}
}
pub(crate) enum FramedReadState<'a> {
NeedMoreData { buffer: &'a mut [u8] },
EnvelopeSkipped(EnvelopeDetails),
Done { decoded: NetworkEnvelope },
}
#[derive(Default)]
pub(crate) struct FramedReadStats {
pub(crate) decompress_stats: DecompressStats,
pub(crate) decode_stats: DecodeStats,
}
pub(crate) trait FramedReadStrategy {
fn read(&mut self) -> Result<FramedReadState<'_>>;
fn mark_filled(&mut self, count: usize);
fn take_stats(&mut self) -> FramedReadStats;
}
impl FramedReadStrategy for FramedRead {
fn read(&mut self) -> Result<FramedReadState<'_>> {
match self {
FramedRead::Lz4(lz4) => lz4.read(),
FramedRead::None(none) => none.read(),
}
}
fn mark_filled(&mut self, count: usize) {
match self {
FramedRead::Lz4(lz4) => lz4.mark_filled(count),
FramedRead::None(none) => none.mark_filled(count),
}
}
fn take_stats(&mut self) -> FramedReadStats {
match self {
FramedRead::Lz4(lz4) => lz4.take_stats(),
FramedRead::None(none) => none.take_stats(),
}
}
}
pub(crate) struct LZ4FramedRead {
compressed_buffer: ReadBuffer,
decompressed_buffer: LZ4Buffer,
stats: FramedReadStats,
position: usize,
}
impl LZ4FramedRead {
pub(crate) fn new() -> Self {
Self {
compressed_buffer: ReadBuffer::with_capacity(COMPRESSED_DATA_BUFFER_CAPACITY),
decompressed_buffer: LZ4Buffer::with_capacity(DECOMPRESSED_DATA_BUFFER_CAPACITY),
stats: Default::default(),
position: 0,
}
}
}
impl FramedReadStrategy for LZ4FramedRead {
fn read(&mut self) -> Result<FramedReadState<'_>> {
'decompression: loop {
if self.position == self.decompressed_buffer.len() {
self.position = 0;
let lz4_state = self.decompressed_buffer.decompress_frame(
self.compressed_buffer.filled_slice(),
&mut self.stats.decompress_stats,
)?;
match lz4_state {
DecompressState::NeedMoreData {
total_length_estimate,
} => {
debug_assert!(total_length_estimate > self.compressed_buffer.filled_len());
return Ok(FramedReadState::NeedMoreData {
buffer: self
.compressed_buffer
.extend_to_contain(total_length_estimate),
});
}
DecompressState::Done { compressed_size } => {
self.compressed_buffer.consume_filled(compressed_size);
}
}
}
'decoding: loop {
let envelope_buffer = &self.decompressed_buffer.filled_slice()[self.position..];
let codec_state =
codec::decode::decode(envelope_buffer, &mut self.stats.decode_stats)?;
match codec_state {
DecodeState::NeedMoreData { .. } => {
if self.position == self.decompressed_buffer.len() {
continue 'decompression;
} else {
return Err(eyre!(
"lz4 decompressed data contains truncated envelopes"
));
}
}
DecodeState::Skipped {
bytes_consumed,
details,
} => {
self.position += bytes_consumed;
if let Some(details) = details {
return Ok(FramedReadState::EnvelopeSkipped(details));
} else {
continue 'decoding;
}
}
DecodeState::Done {
bytes_consumed,
decoded,
} => {
self.position += bytes_consumed;
return Ok(FramedReadState::Done { decoded });
}
}
}
}
}
fn mark_filled(&mut self, count: usize) {
self.compressed_buffer.extend_filled(count);
}
fn take_stats(&mut self) -> FramedReadStats {
std::mem::take(&mut self.stats)
}
}
pub(crate) struct NoneFramedRead {
buffer: ReadBuffer,
stats: FramedReadStats,
}
impl NoneFramedRead {
pub(crate) fn new() -> Self {
Self {
buffer: ReadBuffer::with_capacity(DECOMPRESSED_DATA_BUFFER_CAPACITY),
stats: Default::default(),
}
}
}
impl FramedReadStrategy for NoneFramedRead {
fn read(&mut self) -> Result<FramedReadState<'_>> {
loop {
let codec_state =
codec::decode::decode(self.buffer.filled_slice(), &mut self.stats.decode_stats)?;
match codec_state {
DecodeState::NeedMoreData {
total_length_estimate,
} => {
debug_assert!(total_length_estimate > self.buffer.filled_len());
break Ok(FramedReadState::NeedMoreData {
buffer: self.buffer.extend_to_contain(total_length_estimate),
});
}
DecodeState::Skipped {
bytes_consumed,
details,
} => {
self.stats.decompress_stats.total_uncompressed_bytes += bytes_consumed as u64;
self.buffer.consume_filled(bytes_consumed);
if let Some(details) = details {
return Ok(FramedReadState::EnvelopeSkipped(details));
} else {
continue;
}
}
DecodeState::Done {
bytes_consumed,
decoded,
} => {
self.stats.decompress_stats.total_uncompressed_bytes += bytes_consumed as u64;
self.buffer.consume_filled(bytes_consumed);
break Ok(FramedReadState::Done { decoded });
}
}
}
}
fn mark_filled(&mut self, count: usize) {
self.buffer.extend_filled(count);
}
fn take_stats(&mut self) -> FramedReadStats {
std::mem::take(&mut self.stats)
}
}