use std::prelude::v1::*;
use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::{prelude::*, types::PyTuple};
use crate::{
pybindings::array1_to_vec,
stream::{
queue::{DecoderFrontendError, RangeCoderState},
Decode, Encode,
},
Pos, Seek, UnwrapInfallible,
};
use super::model::{internals::EncoderDecoderModel, Model};
#[pymodule]
#[pyo3(name = "queue")]
pub fn init_module(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_class::<RangeEncoder>()?;
module.add_class::<RangeDecoder>()?;
Ok(())
}
#[pyclass]
#[derive(Debug, Default, Clone)]
pub struct RangeEncoder {
inner: crate::stream::queue::DefaultRangeEncoder,
}
#[pymethods]
impl RangeEncoder {
#[new]
#[pyo3(signature = ())]
pub fn new() -> Self {
let inner = crate::stream::queue::DefaultRangeEncoder::new();
Self { inner }
}
#[pyo3(signature = ())]
pub fn clear(&mut self) {
self.inner.clear();
}
#[pyo3(signature = ())]
pub fn pos(&mut self) -> (usize, (u64, u64)) {
let (pos, state) = self.inner.pos();
(pos, (state.lower(), state.range().get()))
}
#[pyo3(signature = ())]
pub fn num_words(&self) -> usize {
self.inner.num_words()
}
#[pyo3(signature = ())]
pub fn num_bits(&self) -> usize {
self.inner.num_bits()
}
#[pyo3(signature = ())]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[pyo3(signature = ())]
pub fn get_compressed<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyArray1<u32>> {
PyArray1::from_slice(py, &self.inner.get_compressed())
}
#[pyo3(signature = ())]
pub fn get_decoder(&mut self) -> RangeDecoder {
let compressed = self.inner.get_compressed().to_vec();
RangeDecoder::from_vec(compressed)
}
#[pyo3(signature = (symbols, model, *optional_model_params))]
pub fn encode(
&mut self,
py: Python<'_>,
symbols: &Bound<'_, PyAny>,
model: &Model,
optional_model_params: &Bound<'_, PyTuple>,
) -> PyResult<()> {
if let Ok(symbol) = symbols.extract::<i32>() {
if !optional_model_params.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"To encode a single symbol, use a concrete model, i.e., pass the\n\
model parameters directly to the constructor of the model and not to the\n\
`encode` method of the entropy coder. Delaying the specification of model\n\
parameters until calling `encode` is only useful if you want to encode several\n\
symbols in a row with individual model parameters for each symbol. If this is\n\
what you're trying to do then the `symbols` argument should be a numpy array,\n\
not a scalar.",
));
}
return model.0.as_parameterized(py, &mut |model| {
self.inner
.encode_symbol(symbol, EncoderDecoderModel(model))?;
Ok(())
});
}
let symbols = symbols.extract::<PyReadonlyArray1<'_, i32>>()?;
let symbols = symbols.as_array();
if optional_model_params.is_empty() {
model.0.as_parameterized(py, &mut |model| {
self.inner
.encode_iid_symbols(symbols, EncoderDecoderModel(model))?;
Ok(())
})?;
} else {
if symbols.len()
!= model.0.len(
optional_model_params
.get_borrowed_item(0)
.expect("len checked above"),
)?
{
return Err(pyo3::exceptions::PyValueError::new_err(
"`symbols` argument has wrong length.",
));
}
let mut symbol_iter = symbols.iter();
model
.0
.parameterize(py, optional_model_params, false, &mut |model| {
let symbol = symbol_iter.next().expect("TODO");
self.inner
.encode_symbol(*symbol, EncoderDecoderModel(model))?;
Ok(())
})?;
}
Ok(())
}
#[pyo3(signature = ())]
pub fn clone(&self) -> Self {
Clone::clone(self)
}
}
#[pyclass]
#[derive(Debug, Clone)]
pub struct RangeDecoder {
inner: crate::stream::queue::DefaultRangeDecoder,
}
#[pymethods]
impl RangeDecoder {
#[new]
#[pyo3(signature = (compressed))]
pub fn new(compressed: PyReadonlyArray1<'_, u32>) -> PyResult<Self> {
Ok(Self::from_vec(array1_to_vec(compressed)))
}
#[pyo3(signature = (position, state))]
pub fn seek(&mut self, position: usize, state: (u64, u64)) -> PyResult<()> {
let (lower, range) = state;
let state = RangeCoderState::new(lower, range)
.map_err(|()| pyo3::exceptions::PyValueError::new_err("Invalid coder state."))?;
self.inner.seek((position, state)).map_err(|()| {
pyo3::exceptions::PyValueError::new_err("Tried to seek past end of stream.")
})
}
#[pyo3(signature = ())]
pub fn maybe_exhausted(&self) -> bool {
self.inner.maybe_exhausted()
}
#[pyo3(signature = (model, *optional_amt_or_model_params))]
pub fn decode(
&mut self,
py: Python<'_>,
model: &Model,
optional_amt_or_model_params: &Bound<'_, PyTuple>,
) -> PyResult<Py<PyAny>> {
match optional_amt_or_model_params.len() {
0 => {
let mut symbol = 0;
model.0.as_parameterized(py, &mut |model| {
symbol = self.inner.decode_symbol(EncoderDecoderModel(model))?;
Ok(())
})?;
return Ok(symbol
.into_pyobject(py)
.unwrap_infallible()
.into_any()
.unbind());
}
1 => {
if let Ok(amt) = optional_amt_or_model_params
.get_borrowed_item(0)
.expect("len checked above")
.extract::<usize>()
{
let mut symbols = Vec::with_capacity(amt);
model.0.as_parameterized(py, &mut |model| {
for symbol in self
.inner
.decode_iid_symbols(amt, EncoderDecoderModel(model))
{
symbols.push(symbol?);
}
Ok(())
})?;
return Ok(PyArray1::from_iter(py, symbols).into_any().unbind());
}
}
_ => {} };
let mut symbols = Vec::with_capacity(
model.0.len(
optional_amt_or_model_params
.get_borrowed_item(0)
.expect("len checked above"),
)?,
);
model
.0
.parameterize(py, optional_amt_or_model_params, false, &mut |model| {
let symbol = self.inner.decode_symbol(EncoderDecoderModel(model))?;
symbols.push(symbol);
Ok(())
})?;
Ok(PyArray1::from_vec(py, symbols).into_any().unbind())
}
#[pyo3(signature = ())]
pub fn clone(&self) -> Self {
Clone::clone(self)
}
}
impl RangeDecoder {
pub fn from_vec(compressed: Vec<u32>) -> Self {
let inner = crate::stream::queue::DefaultRangeDecoder::from_compressed(compressed)
.unwrap_infallible();
Self { inner }
}
}
impl From<DecoderFrontendError> for pyo3::PyErr {
fn from(err: DecoderFrontendError) -> Self {
match err {
DecoderFrontendError::InvalidData => {
pyo3::exceptions::PyAssertionError::new_err(err.to_string())
}
}
}
}