use crate::error::{IoError, Result};
pub trait Codec: std::fmt::Debug + Send + Sync {
fn name(&self) -> &str;
fn encode(&self, data: &[u8]) -> Result<Vec<u8>>;
fn decode(&self, data: &[u8]) -> Result<Vec<u8>>;
}
#[derive(Debug, Clone, Copy)]
pub struct BytesCodec {
pub endian: Endian,
pub element_size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Endian {
Little,
Big,
}
impl BytesCodec {
pub fn new(endian: Endian, element_size: usize) -> Self {
Self {
endian,
element_size,
}
}
fn needs_swap(&self) -> bool {
match self.endian {
Endian::Little => cfg!(target_endian = "big"),
Endian::Big => cfg!(target_endian = "little"),
}
}
fn swap_bytes(data: &[u8], elem_size: usize) -> Vec<u8> {
if elem_size <= 1 {
return data.to_vec();
}
let mut out = data.to_vec();
for chunk in out.chunks_exact_mut(elem_size) {
chunk.reverse();
}
out
}
}
impl Codec for BytesCodec {
fn name(&self) -> &str {
"bytes"
}
fn encode(&self, data: &[u8]) -> Result<Vec<u8>> {
if self.needs_swap() {
Ok(Self::swap_bytes(data, self.element_size))
} else {
Ok(data.to_vec())
}
}
fn decode(&self, data: &[u8]) -> Result<Vec<u8>> {
if self.needs_swap() {
Ok(Self::swap_bytes(data, self.element_size))
} else {
Ok(data.to_vec())
}
}
}
#[derive(Debug, Clone)]
pub struct TransposeCodec {
shape: Vec<usize>,
element_size: usize,
}
impl TransposeCodec {
pub fn new(shape: Vec<usize>, element_size: usize) -> Self {
Self {
shape,
element_size,
}
}
fn total_elements(&self) -> usize {
self.shape.iter().product()
}
fn c_to_f_index(&self, c_linear: usize) -> usize {
let ndim = self.shape.len();
if ndim == 0 {
return 0;
}
let mut indices = vec![0usize; ndim];
let mut rem = c_linear;
for d in (0..ndim).rev() {
indices[d] = rem % self.shape[d];
rem /= self.shape[d];
}
let mut f_linear = 0usize;
let mut stride = 1usize;
for d in 0..ndim {
f_linear += indices[d] * stride;
stride *= self.shape[d];
}
f_linear
}
fn f_to_c_index(&self, f_linear: usize) -> usize {
let ndim = self.shape.len();
if ndim == 0 {
return 0;
}
let mut indices = vec![0usize; ndim];
let mut rem = f_linear;
for d in 0..ndim {
indices[d] = rem % self.shape[d];
rem /= self.shape[d];
}
let mut c_linear = 0usize;
let mut stride = 1usize;
for d in (0..ndim).rev() {
c_linear += indices[d] * stride;
stride *= self.shape[d];
}
c_linear
}
}
impl Codec for TransposeCodec {
fn name(&self) -> &str {
"transpose"
}
fn encode(&self, data: &[u8]) -> Result<Vec<u8>> {
let n = self.total_elements();
let expected = n * self.element_size;
if data.len() != expected {
return Err(IoError::FormatError(format!(
"Transpose encode: expected {} bytes, got {}",
expected,
data.len()
)));
}
let mut out = vec![0u8; expected];
for c_idx in 0..n {
let f_idx = self.c_to_f_index(c_idx);
let src = c_idx * self.element_size;
let dst = f_idx * self.element_size;
out[dst..dst + self.element_size].copy_from_slice(&data[src..src + self.element_size]);
}
Ok(out)
}
fn decode(&self, data: &[u8]) -> Result<Vec<u8>> {
let n = self.total_elements();
let expected = n * self.element_size;
if data.len() != expected {
return Err(IoError::FormatError(format!(
"Transpose decode: expected {} bytes, got {}",
expected,
data.len()
)));
}
let mut out = vec![0u8; expected];
for f_idx in 0..n {
let c_idx = self.f_to_c_index(f_idx);
let src = f_idx * self.element_size;
let dst = c_idx * self.element_size;
out[dst..dst + self.element_size].copy_from_slice(&data[src..src + self.element_size]);
}
Ok(out)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ZstdCodec {
pub level: i32,
}
impl ZstdCodec {
pub fn new(level: i32) -> Self {
Self { level }
}
}
impl Default for ZstdCodec {
fn default() -> Self {
Self { level: 3 }
}
}
impl Codec for ZstdCodec {
fn name(&self) -> &str {
"zstd"
}
fn encode(&self, data: &[u8]) -> Result<Vec<u8>> {
oxiarc_zstd::compress(data)
.map_err(|e| IoError::CompressionError(format!("Zstd compression failed: {e}")))
}
fn decode(&self, data: &[u8]) -> Result<Vec<u8>> {
oxiarc_zstd::decompress(data)
.map_err(|e| IoError::DecompressionError(format!("Zstd decompression failed: {e}")))
}
}
#[derive(Debug, Clone, Copy)]
pub struct ShuffleCodec {
pub element_size: usize,
}
impl ShuffleCodec {
pub fn new(element_size: usize) -> Self {
Self { element_size }
}
}
impl Codec for ShuffleCodec {
fn name(&self) -> &str {
"shuffle"
}
fn encode(&self, data: &[u8]) -> Result<Vec<u8>> {
if self.element_size <= 1 {
return Ok(data.to_vec());
}
let n_elements = data.len() / self.element_size;
if data.len() % self.element_size != 0 {
return Err(IoError::FormatError(format!(
"Shuffle encode: data length {} not divisible by element size {}",
data.len(),
self.element_size
)));
}
let mut out = vec![0u8; data.len()];
for elem_idx in 0..n_elements {
for byte_idx in 0..self.element_size {
let src = elem_idx * self.element_size + byte_idx;
let dst = byte_idx * n_elements + elem_idx;
out[dst] = data[src];
}
}
Ok(out)
}
fn decode(&self, data: &[u8]) -> Result<Vec<u8>> {
if self.element_size <= 1 {
return Ok(data.to_vec());
}
let n_elements = data.len() / self.element_size;
if data.len() % self.element_size != 0 {
return Err(IoError::FormatError(format!(
"Shuffle decode: data length {} not divisible by element size {}",
data.len(),
self.element_size
)));
}
let mut out = vec![0u8; data.len()];
for elem_idx in 0..n_elements {
for byte_idx in 0..self.element_size {
let src = byte_idx * n_elements + elem_idx;
let dst = elem_idx * self.element_size + byte_idx;
out[dst] = data[src];
}
}
Ok(out)
}
}
#[derive(Debug)]
pub struct CodecPipeline {
codecs: Vec<Box<dyn Codec>>,
}
impl CodecPipeline {
pub fn new() -> Self {
Self { codecs: Vec::new() }
}
pub fn push<C: Codec + 'static>(&mut self, codec: C) {
self.codecs.push(Box::new(codec));
}
pub fn len(&self) -> usize {
self.codecs.len()
}
pub fn is_empty(&self) -> bool {
self.codecs.is_empty()
}
pub fn encode(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut buf = data.to_vec();
for codec in &self.codecs {
buf = codec.encode(&buf)?;
}
Ok(buf)
}
pub fn decode(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut buf = data.to_vec();
for codec in self.codecs.iter().rev() {
buf = codec.decode(&buf)?;
}
Ok(buf)
}
}
impl Default for CodecPipeline {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bytes_codec_no_swap() {
let codec = BytesCodec::new(
if cfg!(target_endian = "little") {
Endian::Little
} else {
Endian::Big
},
4,
);
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
let encoded = codec.encode(&data).expect("encode");
assert_eq!(encoded, data);
let decoded = codec.decode(&encoded).expect("decode");
assert_eq!(decoded, data);
}
#[test]
fn test_bytes_codec_swap() {
let non_native = if cfg!(target_endian = "little") {
Endian::Big
} else {
Endian::Little
};
let codec = BytesCodec::new(non_native, 2);
let data = vec![0x01, 0x02, 0x03, 0x04];
let encoded = codec.encode(&data).expect("encode");
assert_eq!(encoded, vec![0x02, 0x01, 0x04, 0x03]);
let decoded = codec.decode(&encoded).expect("decode");
assert_eq!(decoded, data);
}
#[test]
fn test_transpose_codec_roundtrip() {
let codec = TransposeCodec::new(vec![2, 3], 4);
let mut data = Vec::new();
for val in [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] {
data.extend_from_slice(&val.to_ne_bytes());
}
let encoded = codec.encode(&data).expect("encode");
assert_ne!(encoded, data);
let decoded = codec.decode(&encoded).expect("decode");
assert_eq!(decoded, data);
}
#[test]
fn test_zstd_codec_roundtrip() {
let codec = ZstdCodec::new(3);
let data: Vec<u8> = vec![42u8; 4096];
let compressed = codec.encode(&data).expect("compress");
assert!(compressed.len() < data.len());
let decompressed = codec.decode(&compressed).expect("decompress");
assert_eq!(decompressed, data);
}
#[test]
fn test_shuffle_codec_roundtrip() {
let codec = ShuffleCodec::new(4);
let data: Vec<u8> = (0..32).collect();
let encoded = codec.encode(&data).expect("encode");
assert_ne!(encoded, data);
let decoded = codec.decode(&encoded).expect("decode");
assert_eq!(decoded, data);
}
#[test]
fn test_shuffle_single_byte_passthrough() {
let codec = ShuffleCodec::new(1);
let data = vec![10, 20, 30];
let encoded = codec.encode(&data).expect("encode");
assert_eq!(encoded, data);
}
#[test]
fn test_codec_pipeline_chain() {
let mut pipeline = CodecPipeline::new();
pipeline.push(ShuffleCodec::new(8));
pipeline.push(ZstdCodec::new(1));
assert_eq!(pipeline.len(), 2);
let data: Vec<u8> = (0..800).map(|i| (i % 256) as u8).collect();
let encoded = pipeline.encode(&data).expect("pipeline encode");
let decoded = pipeline.decode(&encoded).expect("pipeline decode");
assert_eq!(decoded, data);
}
#[test]
fn test_codec_pipeline_empty() {
let pipeline = CodecPipeline::new();
assert!(pipeline.is_empty());
let data = vec![1, 2, 3];
let encoded = pipeline.encode(&data).expect("encode");
assert_eq!(encoded, data);
let decoded = pipeline.decode(&data).expect("decode");
assert_eq!(decoded, data);
}
#[test]
fn test_transpose_codec_1d() {
let codec = TransposeCodec::new(vec![8], 4);
let data: Vec<u8> = (0..32).collect();
let encoded = codec.encode(&data).expect("encode");
assert_eq!(encoded, data);
}
}