use std::borrow::Cow;
use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
use num_traits::Float;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
};
use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
#[cfg(test)]
use ::serde_json as _;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(deny_unknown_fields)]
pub struct TthreshCodec {
#[serde(flatten)]
pub error_bound: TthreshErrorBound,
#[serde(default, rename = "_version")]
pub version: StaticCodecVersion<0, 1, 0>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "eb_mode")]
#[serde(deny_unknown_fields)]
pub enum TthreshErrorBound {
#[serde(rename = "eps")]
Eps {
#[serde(rename = "eb_eps")]
eps: NonNegative<f64>,
},
#[serde(rename = "rmse")]
RMSE {
#[serde(rename = "eb_rmse")]
rmse: NonNegative<f64>,
},
#[serde(rename = "psnr")]
PSNR {
#[serde(rename = "eb_psnr")]
psnr: NonNegative<f64>,
},
}
impl Codec for TthreshCodec {
type Error = TthreshCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
AnyCowArray::U8(data) => Ok(AnyArray::U8(
Array1::from(compress(data, &self.error_bound)?).into_dyn(),
)),
AnyCowArray::U16(data) => Ok(AnyArray::U8(
Array1::from(compress(data, &self.error_bound)?).into_dyn(),
)),
AnyCowArray::I32(data) => Ok(AnyArray::U8(
Array1::from(compress(data, &self.error_bound)?).into_dyn(),
)),
AnyCowArray::F32(data) => Ok(AnyArray::U8(
Array1::from(compress(data, &self.error_bound)?).into_dyn(),
)),
AnyCowArray::F64(data) => Ok(AnyArray::U8(
Array1::from(compress(data, &self.error_bound)?).into_dyn(),
)),
encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
let AnyCowArray::U8(encoded) = encoded else {
return Err(TthreshCodecError::EncodedDataNotBytes {
dtype: encoded.dtype(),
});
};
if !matches!(encoded.shape(), [_]) {
return Err(TthreshCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
});
}
decompress(&AnyCowArray::U8(encoded).as_bytes())
}
fn decode_into(
&self,
encoded: AnyArrayView,
mut decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
let decoded_in = self.decode(encoded.cow())?;
Ok(decoded.assign(&decoded_in)?)
}
}
impl StaticCodec for TthreshCodec {
const CODEC_ID: &'static str = "tthresh.rs";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<'_, Self> {
StaticCodecConfig::from(self)
}
}
#[derive(Debug, Error)]
pub enum TthreshCodecError {
#[error("Tthresh does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
#[error("Tthresh failed to encode the data")]
TthreshEncodeFailed {
source: TthreshCodingError,
},
#[error(
"Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
)]
EncodedDataNotBytes {
dtype: AnyArrayDType,
},
#[error(
"Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
)]
EncodedDataNotOneDimensional {
shape: Vec<usize>,
},
#[error("Tthresh failed to decode the data")]
TthreshDecodeFailed {
source: TthreshCodingError,
},
#[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
DecodeInvalidShapeHeader {
#[from]
source: ShapeError,
},
#[error("Tthresh cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
#[derive(Debug, Error)]
#[error(transparent)]
pub struct TthreshCodingError(tthresh::Error);
#[expect(clippy::needless_pass_by_value)]
pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
data: ArrayBase<S, D>,
error_bound: &TthreshErrorBound,
) -> Result<Vec<u8>, TthreshCodecError> {
#[expect(clippy::option_if_let_else)]
let data_cow = match data.as_slice() {
Some(data) => Cow::Borrowed(data),
None => Cow::Owned(data.iter().copied().collect()),
};
let compressed = tthresh::compress(
&data_cow,
data.shape(),
match error_bound {
TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
},
false,
false,
)
.map_err(|err| TthreshCodecError::TthreshEncodeFailed {
source: TthreshCodingError(err),
})?;
Ok(compressed)
}
pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
TthreshCodecError::TthreshDecodeFailed {
source: TthreshCodingError(err),
}
})?;
let decoded = match decompressed {
tthresh::Buffer::U8(decompressed) => {
AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
}
tthresh::Buffer::U16(decompressed) => {
AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
}
tthresh::Buffer::I32(decompressed) => {
AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
}
tthresh::Buffer::F32(decompressed) => {
AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
}
tthresh::Buffer::F64(decompressed) => {
AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
}
};
Ok(decoded)
}
pub trait TthreshElement: Copy + tthresh::Element {}
impl TthreshElement for u8 {}
impl TthreshElement for u16 {}
impl TthreshElement for i32 {}
impl TthreshElement for f32 {}
impl TthreshElement for f64 {}
#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
pub struct NonNegative<T: Float>(T);
impl Serialize for NonNegative<f64> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_f64(self.0)
}
}
impl<'de> Deserialize<'de> for NonNegative<f64> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let x = f64::deserialize(deserializer)?;
if x >= 0.0 {
Ok(Self(x))
} else {
Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Float(x),
&"a non-negative value",
))
}
}
}
impl JsonSchema for NonNegative<f64> {
fn schema_name() -> Cow<'static, str> {
Cow::Borrowed("NonNegativeF64")
}
fn schema_id() -> Cow<'static, str> {
Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
}
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
json_schema!({
"type": "number",
"minimum": 0.0
})
}
}