use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{Algorithm, RawDecoder, RawEncoder, RawProgress};
#[derive(Debug, Clone, Copy, Default)]
pub struct Snappy;
impl Algorithm for Snappy {
const NAME: &'static str = "snappy";
type Encoder = Encoder;
type Decoder = Decoder;
type EncoderConfig = ();
type DecoderConfig = ();
fn encoder_with(_: ()) -> Encoder {
Encoder::new()
}
fn decoder_with(_: ()) -> Decoder {
Decoder::new()
}
}
fn write_varint_u32(value: u32, out: &mut Vec<u8>) {
let mut v = value;
while v >= 0x80 {
out.push(((v & 0x7F) as u8) | 0x80);
v >>= 7;
}
out.push(v as u8);
}
fn read_varint_u32(buf: &[u8]) -> Result<(u32, usize), Error> {
let mut result: u64 = 0;
let mut shift: u32 = 0;
for (i, &b) in buf.iter().enumerate() {
if i == 5 {
return Err(Error::Corrupt);
}
result |= ((b & 0x7F) as u64) << shift;
if b & 0x80 == 0 {
if result > u32::MAX as u64 {
return Err(Error::Corrupt);
}
return Ok((result as u32, i + 1));
}
shift += 7;
}
Err(Error::UnexpectedEnd)
}
#[derive(Debug, Default)]
pub struct Encoder {
input: Vec<u8>,
output: Vec<u8>,
out_pos: usize,
compressed: bool,
}
impl Encoder {
pub const fn new() -> Self {
Self {
input: Vec::new(),
output: Vec::new(),
out_pos: 0,
compressed: false,
}
}
}
impl RawEncoder for Encoder {
fn raw_encode(&mut self, input: &[u8], _output: &mut [u8]) -> Result<RawProgress, Error> {
self.input.extend_from_slice(input);
Ok(RawProgress {
consumed: input.len(),
written: 0,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if !self.compressed {
compress_block(&self.input, &mut self.output);
self.compressed = true;
}
let remaining = self.output.len() - self.out_pos;
let n = remaining.min(output.len());
output[..n].copy_from_slice(&self.output[self.out_pos..self.out_pos + n]);
self.out_pos += n;
let done = self.out_pos == self.output.len();
Ok(RawProgress {
consumed: 0,
written: n,
done,
})
}
fn raw_reset(&mut self) {
self.input.clear();
self.output.clear();
self.out_pos = 0;
self.compressed = false;
}
}
fn compress_block(input: &[u8], out: &mut Vec<u8>) {
out.clear();
write_varint_u32(input.len() as u32, out);
if input.is_empty() {
return;
}
if input.len() < 4 {
emit_literal(input, out);
return;
}
const HASH_BITS: u32 = 14;
const HASH_SIZE: usize = 1 << HASH_BITS;
const NIL: u32 = u32::MAX;
let mut table = alloc::vec![NIL; HASH_SIZE];
let input_end = input.len();
let match_limit = input_end.saturating_sub(4);
let mut next_emit = 0usize; let mut ip = 0usize;
let hash = |bytes: &[u8], pos: usize| -> usize {
let v = (bytes[pos] as u32)
| ((bytes[pos + 1] as u32) << 8)
| ((bytes[pos + 2] as u32) << 16)
| ((bytes[pos + 3] as u32) << 24);
((v.wrapping_mul(0x1E35A7BD)) >> (32 - HASH_BITS)) as usize
};
while ip < match_limit {
let h = hash(input, ip);
let candidate = table[h] as usize;
table[h] = ip as u32;
let four_match = (table[h] != NIL)
&& (candidate < ip)
&& candidate + 3 < input_end
&& input[candidate] == input[ip]
&& input[candidate + 1] == input[ip + 1]
&& input[candidate + 2] == input[ip + 2]
&& input[candidate + 3] == input[ip + 3];
if !four_match {
ip += 1;
continue;
}
if next_emit < ip {
emit_literal(&input[next_emit..ip], out);
}
let mut m_end = ip + 4;
let mut c_end = candidate + 4;
while m_end < input_end && input[m_end] == input[c_end] {
m_end += 1;
c_end += 1;
}
let mut match_len = m_end - ip;
let offset = (ip - candidate) as u32;
let mut emitted = 0usize;
while match_len > 0 {
let take = if match_len <= 64 {
match_len
} else if match_len < 68 {
match_len - 4
} else {
64
};
emit_copy(offset, take as u32, out);
let pos_after = ip + emitted + take;
if pos_after + 3 < input_end {
let h2 = hash(input, pos_after - 1);
table[h2] = (pos_after - 1) as u32;
}
emitted += take;
match_len -= take;
}
ip += emitted;
next_emit = ip;
if ip + 3 < input_end {
let h2 = hash(input, ip - 1);
table[h2] = (ip - 1) as u32;
}
}
if next_emit < input_end {
emit_literal(&input[next_emit..], out);
}
}
fn emit_literal(data: &[u8], out: &mut Vec<u8>) {
debug_assert!(!data.is_empty());
let n = data.len();
let n_minus_1 = (n - 1) as u32;
if n_minus_1 < 60 {
out.push((n_minus_1 as u8) << 2);
} else if n_minus_1 < 1 << 8 {
out.push(60 << 2);
out.push(n_minus_1 as u8);
} else if n_minus_1 < 1 << 16 {
out.push(61 << 2);
out.push(n_minus_1 as u8);
out.push((n_minus_1 >> 8) as u8);
} else if n_minus_1 < 1 << 24 {
out.push(62 << 2);
out.push(n_minus_1 as u8);
out.push((n_minus_1 >> 8) as u8);
out.push((n_minus_1 >> 16) as u8);
} else {
out.push(63 << 2);
out.push(n_minus_1 as u8);
out.push((n_minus_1 >> 8) as u8);
out.push((n_minus_1 >> 16) as u8);
out.push((n_minus_1 >> 24) as u8);
}
out.extend_from_slice(data);
}
fn emit_copy(offset: u32, length: u32, out: &mut Vec<u8>) {
debug_assert!((4..=64).contains(&length));
debug_assert!(offset >= 1);
if length <= 11 && offset < (1 << 11) {
let len_bits = ((length - 4) as u8) << 2;
let off_hi = ((offset >> 8) as u8) << 5;
out.push(0b01 | len_bits | off_hi);
out.push(offset as u8);
} else if offset < (1 << 16) {
let len_bits = ((length - 1) as u8) << 2;
out.push(0b10 | len_bits);
out.push(offset as u8);
out.push((offset >> 8) as u8);
} else {
let len_bits = ((length - 1) as u8) << 2;
out.push(0b11 | len_bits);
out.push(offset as u8);
out.push((offset >> 8) as u8);
out.push((offset >> 16) as u8);
out.push((offset >> 24) as u8);
}
}
#[derive(Debug, Default)]
pub struct Decoder {
input: Vec<u8>,
output: Vec<u8>,
out_pos: usize,
decompressed: bool,
}
impl Decoder {
pub const fn new() -> Self {
Self {
input: Vec::new(),
output: Vec::new(),
out_pos: 0,
decompressed: false,
}
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], _output: &mut [u8]) -> Result<RawProgress, Error> {
self.input.extend_from_slice(input);
Ok(RawProgress {
consumed: input.len(),
written: 0,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if !self.decompressed {
decompress_block(&self.input, &mut self.output)?;
self.decompressed = true;
}
let remaining = self.output.len() - self.out_pos;
let n = remaining.min(output.len());
output[..n].copy_from_slice(&self.output[self.out_pos..self.out_pos + n]);
self.out_pos += n;
let done = self.out_pos == self.output.len();
Ok(RawProgress {
consumed: 0,
written: n,
done,
})
}
fn raw_reset(&mut self) {
self.input.clear();
self.output.clear();
self.out_pos = 0;
self.decompressed = false;
}
}
fn decompress_block(input: &[u8], out: &mut Vec<u8>) -> Result<(), Error> {
out.clear();
if input.is_empty() {
return Err(Error::UnexpectedEnd);
}
let (uncompressed_len, vi_len) = read_varint_u32(input)?;
let uncompressed_len = uncompressed_len as usize;
out.reserve(uncompressed_len);
let mut ip = vi_len;
let input_end = input.len();
while ip < input_end {
let tag = input[ip];
ip += 1;
match tag & 0b11 {
0b00 => {
let upper = (tag >> 2) as u32;
let length = if upper < 60 {
upper + 1
} else {
let extra = (upper - 59) as usize; if ip + extra > input_end {
return Err(Error::UnexpectedEnd);
}
let mut len_minus_1: u32 = 0;
for i in 0..extra {
len_minus_1 |= (input[ip + i] as u32) << (8 * i);
}
ip += extra;
len_minus_1.wrapping_add(1)
};
let length = length as usize;
if ip + length > input_end {
return Err(Error::UnexpectedEnd);
}
if out.len() + length > uncompressed_len {
return Err(Error::Corrupt);
}
out.extend_from_slice(&input[ip..ip + length]);
ip += length;
}
0b01 => {
if ip >= input_end {
return Err(Error::UnexpectedEnd);
}
let length = (((tag >> 2) & 0x07) as usize) + 4;
let off_hi = ((tag >> 5) & 0x07) as usize;
let offset = (off_hi << 8) | (input[ip] as usize);
ip += 1;
copy_from_back(out, offset, length, uncompressed_len)?;
}
0b10 => {
if ip + 2 > input_end {
return Err(Error::UnexpectedEnd);
}
let length = (((tag >> 2) & 0x3F) as usize) + 1;
let offset = (input[ip] as usize) | ((input[ip + 1] as usize) << 8);
ip += 2;
copy_from_back(out, offset, length, uncompressed_len)?;
}
0b11 => {
if ip + 4 > input_end {
return Err(Error::UnexpectedEnd);
}
let length = (((tag >> 2) & 0x3F) as usize) + 1;
let offset = (input[ip] as usize)
| ((input[ip + 1] as usize) << 8)
| ((input[ip + 2] as usize) << 16)
| ((input[ip + 3] as usize) << 24);
ip += 4;
copy_from_back(out, offset, length, uncompressed_len)?;
}
_ => unreachable!(),
}
}
if out.len() != uncompressed_len {
return Err(Error::Corrupt);
}
Ok(())
}
fn copy_from_back(
out: &mut Vec<u8>,
offset: usize,
length: usize,
uncompressed_len: usize,
) -> Result<(), Error> {
if offset == 0 || offset > out.len() {
return Err(Error::InvalidDistance);
}
if out.len() + length > uncompressed_len {
return Err(Error::Corrupt);
}
let start = out.len() - offset;
for i in 0..length {
let b = out[start + i];
out.push(b);
}
Ok(())
}