use thiserror::Error;
use super::{CodecId, DecompressError};
pub type DecodeResult<T> = Result<T, DecompressError>;
pub type FallbackDecoderFn<T> = Box<dyn Fn(CodecId, &[u8]) -> DecodeResult<T> + Send + Sync>;
pub struct ExactFallbackAdapter<T = Vec<u8>> {
fallback_decoder: FallbackDecoderFn<T>,
strict_mode: bool,
}
impl<T> ExactFallbackAdapter<T> {
pub fn new(fallback_decoder: FallbackDecoderFn<T>) -> Self {
Self {
fallback_decoder,
strict_mode: true,
}
}
pub fn with_strict_mode(mut self, strict: bool) -> Self {
self.strict_mode = strict;
self
}
pub fn decode_exact(&self, codec_id: CodecId, compressed_data: &[u8]) -> DecodeResult<T> {
if codec_id == CodecId::Uncompressed {
return (self.fallback_decoder)(codec_id, compressed_data);
}
(self.fallback_decoder)(codec_id, compressed_data)
}
pub fn decode_batch(&self, items: &[(CodecId, &[u8])]) -> DecodeResult<Vec<T>> {
let mut results = Vec::with_capacity(items.len());
for (codec_id, data) in items {
results.push(self.decode_exact(*codec_id, data)?);
}
Ok(results)
}
pub fn is_strict(&self) -> bool {
self.strict_mode
}
}
#[allow(dead_code)]
#[derive(Debug, Error)]
pub enum AdapterError {
#[error("decode failed for codec `{codec_id}`: {source}")]
DecodeFailed {
codec_id: CodecId,
#[source]
source: DecompressError,
},
#[error("batch decode failed at index {index}: {source}")]
BatchFailed {
index: usize,
#[source]
source: DecompressError,
},
}
impl<T> ExactFallbackAdapter<T>
where
T: Clone,
{
pub fn decode_clone(&self, codec_id: CodecId, compressed_data: &[u8]) -> DecodeResult<T> {
self.decode_exact(codec_id, compressed_data)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_adapter() -> ExactFallbackAdapter<Vec<u8>> {
ExactFallbackAdapter::new(Box::new(|codec_id, data| {
match codec_id {
CodecId::Uncompressed => Ok(data.to_vec()),
CodecId::TurboQuant => {
Ok(data.iter().rev().cloned().collect())
}
CodecId::FibQuant => {
let mut out = vec![0xF1, 0xB0];
out.extend_from_slice(data);
Ok(out)
}
CodecId::Polar | CodecId::Qjl => Ok(data.to_vec()),
}
}))
}
#[test]
fn decode_uncompressed_is_identity() {
let adapter = test_adapter();
let data = b"hello world";
let result = adapter.decode_exact(CodecId::Uncompressed, data).unwrap();
assert_eq!(result, data);
}
#[test]
fn decode_turbo_quant_reverses() {
let adapter = test_adapter();
let data = b"abcde";
let result = adapter.decode_exact(CodecId::TurboQuant, data).unwrap();
assert_eq!(result, b"edcba");
}
#[test]
fn decode_fib_quant_prepends_marker() {
let adapter = test_adapter();
let data = b"test";
let result = adapter.decode_exact(CodecId::FibQuant, data).unwrap();
assert_eq!(result, &[0xF1, 0xB0, b't', b'e', b's', b't']);
}
#[test]
fn decode_batch_all_ok() {
let adapter = test_adapter();
let items = vec![
(CodecId::Uncompressed, b"abc".as_slice()),
(CodecId::TurboQuant, b"xyz".as_slice()),
(CodecId::FibQuant, b"123".as_slice()),
];
let results = adapter.decode_batch(&items).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn decode_batch_short_circuits_on_error() {
let adapter = test_adapter();
let items = vec![
(CodecId::Uncompressed, b"abc".as_slice()),
(CodecId::TurboQuant, b"xyz".as_slice()),
];
let results = adapter.decode_batch(&items);
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 2);
}
#[test]
fn non_strict_mode_still_decodes() {
let adapter = ExactFallbackAdapter::new(Box::new(|_codec_id, data| Ok(data.to_vec())))
.with_strict_mode(false);
let result = adapter.decode_exact(CodecId::TurboQuant, b"hello");
assert!(result.is_ok());
}
}