use std::borrow::Cow;
use ndarray::{Array, ArrayBase, Data, Dimension};
use num_traits::Float;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
};
use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct RoundCodec {
pub precision: NonNegative<f64>,
#[serde(default, rename = "_version")]
pub version: StaticCodecVersion<1, 0, 0>,
}
impl Codec for RoundCodec {
type Error = RoundCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
#[expect(clippy::cast_possible_truncation)]
AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
data,
NonNegative(self.precision.0 as f32),
))),
AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode_into(
&self,
encoded: AnyArrayView,
mut decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
return Err(RoundCodecError::UnsupportedDtype(encoded.dtype()));
}
Ok(decoded.assign(&encoded)?)
}
}
impl StaticCodec for RoundCodec {
const CODEC_ID: &'static str = "round.rs";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<'_, Self> {
StaticCodecConfig::from(self)
}
}
#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
pub struct NonNegative<T: Float>(T);
impl Serialize for NonNegative<f64> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_f64(self.0)
}
}
impl<'de> Deserialize<'de> for NonNegative<f64> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let x = f64::deserialize(deserializer)?;
if x >= 0.0 {
Ok(Self(x))
} else {
Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Float(x),
&"a non-negative value",
))
}
}
}
impl JsonSchema for NonNegative<f64> {
fn schema_name() -> Cow<'static, str> {
Cow::Borrowed("NonNegativeF64")
}
fn schema_id() -> Cow<'static, str> {
Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
}
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
json_schema!({
"type": "number",
"minimum": 0.0
})
}
}
#[derive(Debug, Error)]
pub enum RoundCodecError {
#[error("Round does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
#[error("Round cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
#[must_use]
pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
data: ArrayBase<S, D>,
precision: NonNegative<T>,
) -> Array<T, D> {
let mut encoded = data.into_owned();
if precision.0.is_zero() {
return encoded;
}
encoded.mapv_inplace(|x| {
let n = x / precision.0;
if !n.is_finite() {
return x;
}
n.round() * precision.0
});
encoded
}
#[cfg(test)]
mod tests {
use ndarray::array;
use super::*;
#[test]
fn round_zero_precision() {
let data = array![1.1, 2.1];
let rounded = round(data.view(), NonNegative(0.0));
assert_eq!(data, rounded);
}
#[test]
fn round_minimal_precision() {
let data = array![0.1, 1.0, 11.0, 21.0];
assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
let rounded = round(data.view(), NonNegative(f64::MIN_POSITIVE));
assert_eq!(data, rounded);
}
#[test]
fn round_roundoff_errors() {
let data = array![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
let rounded = round(data.view(), NonNegative(0.1));
assert_eq!(
rounded,
array![
0.0,
0.1,
0.2,
0.30000000000000004,
0.4,
0.5,
0.6000000000000001,
0.7000000000000001,
0.8,
0.9,
1.0
]
);
let rounded_twice = round(rounded.view(), NonNegative(0.1));
assert_eq!(rounded, rounded_twice);
}
#[test]
fn round_edge_cases() {
let data = array![
-f64::NAN,
-f64::INFINITY,
-42.0,
-0.0,
0.0,
42.0,
f64::INFINITY,
f64::NAN
];
let rounded = round(data.view(), NonNegative(1.0));
for (d, r) in data.into_iter().zip(rounded) {
assert!(d == r || d.to_bits() == r.to_bits());
}
}
}