pub mod huffman;
use core::{
convert::Infallible,
sync::atomic::{AtomicBool, Ordering},
};
use std::prelude::v1::*;
use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::{prelude::*, wrap_pymodule};
use crate::{
backends::Cursor,
symbol::{
DefaultQueueDecoder, DefaultQueueEncoder, DefaultStackCoder, ReadBitStream,
SymbolCodeError, WriteBitStream,
},
};
use super::array1_to_vec;
#[pymodule]
#[pyo3(name = "symbol")]
pub fn init_module(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(huffman::init_module))?;
module.add_class::<StackCoder>()?;
module.add_class::<QueueEncoder>()?;
module.add_class::<QueueDecoder>()?;
Ok(())
}
#[pyclass]
#[derive(Debug)]
pub struct StackCoder {
inner: DefaultStackCoder,
}
#[pymethods]
impl StackCoder {
#[new]
#[pyo3(signature = (compressed=None))]
pub fn new(compressed: Option<PyReadonlyArray1<'_, u32>>) -> PyResult<Self> {
let inner = match compressed {
None => DefaultStackCoder::new(),
Some(compressed) => DefaultStackCoder::from_compressed(array1_to_vec(compressed))
.map_err(|_| {
pyo3::exceptions::PyValueError::new_err(
"Compressed data for a stack must not end in a zero word.",
)
})?,
};
Ok(Self { inner })
}
#[pyo3(signature = (symbol, codebook))]
pub fn encode_symbol(
&mut self,
symbol: usize,
codebook: &huffman::EncoderHuffmanTree,
) -> PyResult<()> {
Ok(self.inner.encode_symbol(symbol, &codebook.inner)?)
}
#[pyo3(signature = (codebook))]
pub fn decode_symbol(&mut self, codebook: &huffman::DecoderHuffmanTree) -> PyResult<usize> {
Ok(self.inner.decode_symbol(&codebook.inner)?)
}
#[pyo3(signature = ())]
pub fn get_compressed_and_bitrate<'py>(
&mut self,
py: Python<'py>,
) -> (Bound<'py, PyArray1<u32>>, usize) {
let len = self.inner.len();
(PyArray1::from_slice(py, &self.inner.get_compressed()), len)
}
#[pyo3(signature = ())]
pub fn get_compressed<'py>(&mut self, py: Python<'py>) -> (Bound<'py, PyArray1<u32>>, usize) {
static WARNED: AtomicBool = AtomicBool::new(false);
if !WARNED.swap(true, Ordering::AcqRel) {
let _ = py.run(
pyo3::ffi::c_str!(
"print('WARNING: `StackCoder.get_compressed` has been renamed to\\n\
\x20 `StackCoder.get_compressed_and_bitrate` to avoid confusion."
),
None,
None,
);
}
self.get_compressed_and_bitrate(py)
}
}
#[pyclass]
#[derive(Debug, Default)]
pub struct QueueEncoder {
inner: DefaultQueueEncoder,
}
#[pymethods]
impl QueueEncoder {
#[new]
#[pyo3(signature = ())]
pub fn new() -> Self {
Self {
inner: DefaultQueueEncoder::new(),
}
}
#[pyo3(signature = (symbol, codebook))]
pub fn encode_symbol(
&mut self,
symbol: usize,
codebook: &huffman::EncoderHuffmanTree,
) -> PyResult<()> {
Ok(self.inner.encode_symbol(symbol, &codebook.inner)?)
}
#[pyo3(signature = ())]
pub fn get_compressed_and_bitrate<'py>(
&mut self,
py: Python<'py>,
) -> (Bound<'py, PyArray1<u32>>, usize) {
let len = self.inner.len();
(PyArray1::from_slice(py, &self.inner.get_compressed()), len)
}
#[pyo3(signature = ())]
pub fn get_compressed<'py>(&mut self, py: Python<'py>) -> (Bound<'py, PyArray1<u32>>, usize) {
static WARNED: AtomicBool = AtomicBool::new(false);
if !WARNED.swap(true, Ordering::AcqRel) {
let _ = py.run(
pyo3::ffi::c_str!(
"print('WARNING: `QueueEncoder.get_compressed` has been renamed to\\n\
\x20 `QueueEncoder.get_compressed_and_bitrate` to avoid confusion."
),
None,
None,
);
}
self.get_compressed_and_bitrate(py)
}
#[pyo3(signature = ())]
pub fn get_decoder(&mut self) -> QueueDecoder {
let compressed = self.inner.get_compressed().to_vec();
QueueDecoder::from_vec(compressed)
}
}
#[pyclass]
#[derive(Debug)]
pub struct QueueDecoder {
inner: DefaultQueueDecoder,
}
#[pymethods]
impl QueueDecoder {
#[new]
#[pyo3(signature = (compressed))]
pub fn new(compressed: PyReadonlyArray1<'_, u32>) -> PyResult<Self> {
Ok(Self::from_vec(array1_to_vec(compressed)))
}
#[pyo3(signature = (codebook))]
pub fn decode_symbol(&mut self, codebook: &huffman::DecoderHuffmanTree) -> PyResult<usize> {
Ok(self.inner.decode_symbol(&codebook.inner)?)
}
}
impl QueueDecoder {
fn from_vec(compressed: Vec<u32>) -> Self {
let compressed = Cursor::new_at_write_beginning(compressed);
Self {
inner: DefaultQueueDecoder::from_compressed(compressed),
}
}
}
impl From<SymbolCodeError<Infallible>> for PyErr {
fn from(err: SymbolCodeError<Infallible>) -> Self {
match err {
SymbolCodeError::OutOfCompressedData => {
pyo3::exceptions::PyValueError::new_err("Ran out of bits in compressed data.")
}
SymbolCodeError::InvalidCodeword(infallible) => match infallible {},
}
}
}