use std::prelude::v1::*;
use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::{prelude::*, types::PyTuple};
use crate::{
pybindings::array1_to_vec,
stream::{Decode, Encode},
Pos, Seek, UnwrapInfallible,
};
use super::model::{internals::EncoderDecoderModel, Model};
#[pymodule]
#[pyo3(name = "stack")]
pub fn init_module(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_class::<AnsCoder>()?;
Ok(())
}
#[pyclass]
#[derive(Debug, Clone)]
pub struct AnsCoder {
inner: crate::stream::stack::DefaultAnsCoder,
}
#[pymethods]
impl AnsCoder {
#[new]
#[pyo3(signature = (compressed=None, seal=false))]
pub fn new(compressed: Option<PyReadonlyArray1<'_, u32>>, seal: bool) -> PyResult<Self> {
if compressed.is_none() && seal {
return Err(pyo3::exceptions::PyValueError::new_err(
"Need compressed data to seal.",
));
}
let inner = if let Some(compressed) = compressed {
let compressed = array1_to_vec(compressed);
if seal {
crate::stream::stack::AnsCoder::from_binary(compressed).unwrap_infallible()
} else {
crate::stream::stack::AnsCoder::from_compressed(compressed).map_err(|_| {
pyo3::exceptions::PyValueError::new_err(
"Invalid compressed data: ANS compressed data never ends in a zero word.",
)
})?
}
} else {
crate::stream::stack::AnsCoder::new()
};
Ok(Self { inner })
}
#[pyo3(signature = ())]
pub fn pos(&mut self) -> (usize, u64) {
self.inner.pos()
}
#[pyo3(signature = (position, state))]
pub fn seek(&mut self, position: usize, state: u64) -> PyResult<()> {
self.inner.seek((position, state)).map_err(|()| {
pyo3::exceptions::PyValueError::new_err(
"Tried to seek past end of stream. Note: in an ANS coder,\n\
both decoding and seeking *consume* compressed data. The Python API of\n\
`constriction`'s ANS coder currently does not support seeking backward.",
)
})
}
#[pyo3(signature = ())]
pub fn clear(&mut self) {
self.inner.clear();
}
#[pyo3(signature = ())]
pub fn num_words(&self) -> usize {
self.inner.num_words()
}
#[pyo3(signature = ())]
pub fn num_bits(&self) -> usize {
self.inner.num_bits()
}
#[pyo3(signature = ())]
pub fn num_valid_bits(&self) -> usize {
self.inner.num_valid_bits()
}
#[pyo3(signature = ())]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[pyo3(signature = (unseal=false))]
pub fn get_compressed<'py>(
&mut self,
py: Python<'py>,
unseal: bool,
) -> PyResult<Bound<'py, PyArray1<u32>>> {
if unseal {
let binary = self.inner.get_binary().map_err(|_|
pyo3::exceptions::PyAssertionError::new_err(
"Cannot unseal compressed data because it doesn't fit into integer number of words. Did you create the encoder with `seal=True` and restore its original state?",
))?;
Ok(PyArray1::from_slice(py, &binary))
} else {
Ok(PyArray1::from_slice(
py,
&self.inner.get_compressed().unwrap_infallible(),
))
}
}
#[pyo3(signature = (symbols, model, *optional_model_params))]
pub fn encode_reverse(
&mut self,
py: Python<'_>,
symbols: &Bound<'_, PyAny>,
model: &Model,
optional_model_params: &Bound<'_, PyTuple>,
) -> PyResult<()> {
if let Ok(symbol) = symbols.extract::<i32>() {
if !optional_model_params.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"To encode a single symbol, use a concrete model, i.e., pass the\n\
model parameters directly to the constructor of the model and not to the\n\
`encode` method of the entropy coder. Delaying the specification of model\n\
parameters until calling `encode_reverse` is only useful if you want to encode
several symbols in a row with individual model parameters for each symbol. If\n\
this is what you're trying to do then the `symbols` argument should be a numpy\n\
array, not a scalar.",
));
}
return model.0.as_parameterized(py, &mut |model| {
self.inner
.encode_symbol(symbol, EncoderDecoderModel(model))?;
Ok(())
});
}
let symbols = symbols.extract::<PyReadonlyArray1<'_, i32>>()?;
let symbols = symbols.as_array();
if optional_model_params.is_empty() {
model.0.as_parameterized(py, &mut |model| {
self.inner
.encode_iid_symbols_reverse(symbols, EncoderDecoderModel(model))?;
Ok(())
})?;
} else {
if symbols.len()
!= model.0.len(
optional_model_params
.get_borrowed_item(0)
.expect("len checked above"),
)?
{
return Err(pyo3::exceptions::PyValueError::new_err(
"`symbols` argument has wrong length.",
));
}
let mut symbol_iter = symbols.iter().rev();
model
.0
.parameterize(py, optional_model_params, true, &mut |model| {
let symbol = symbol_iter.next().expect("TODO");
self.inner
.encode_symbol(*symbol, EncoderDecoderModel(model))?;
Ok(())
})?;
}
Ok(())
}
#[pyo3(signature = (model, *optional_amt_or_model_params))]
pub fn decode(
&mut self,
py: Python<'_>,
model: &Model,
optional_amt_or_model_params: &Bound<'_, PyTuple>,
) -> PyResult<Py<PyAny>> {
match optional_amt_or_model_params.len() {
0 => {
let mut symbol = 0;
model.0.as_parameterized(py, &mut |model| {
symbol = self
.inner
.decode_symbol(EncoderDecoderModel(model))
.unwrap_infallible();
Ok(())
})?;
return Ok(symbol
.into_pyobject(py)
.unwrap_infallible()
.into_any()
.unbind());
}
1 => {
if let Ok(amt) = optional_amt_or_model_params
.get_borrowed_item(0)
.expect("len checked above")
.extract::<usize>()
{
let mut symbols = Vec::with_capacity(amt);
model.0.as_parameterized(py, &mut |model| {
for symbol in self
.inner
.decode_iid_symbols(amt, EncoderDecoderModel(model))
{
symbols.push(symbol.unwrap_infallible());
}
Ok(())
})?;
return Ok(PyArray1::from_iter(py, symbols).into_any().unbind());
}
}
_ => {} };
let mut symbols = Vec::with_capacity(
model.0.len(
optional_amt_or_model_params
.get_borrowed_item(0)
.expect("len checked above"),
)?,
);
model
.0
.parameterize(py, optional_amt_or_model_params, false, &mut |model| {
let symbol = self
.inner
.decode_symbol(EncoderDecoderModel(model))
.unwrap_infallible();
symbols.push(symbol);
Ok(())
})?;
Ok(PyArray1::from_vec(py, symbols).into_any().unbind())
}
#[pyo3(signature = ())]
pub fn clone(&self) -> Self {
Clone::clone(self)
}
}