numcodecs_round/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-round
10//! [crates.io]: https://crates.io/crates/numcodecs-round
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-round
13//! [docs.rs]: https://docs.rs/numcodecs-round/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_round
17//!
18//! Rounding codec implementation for the [`numcodecs`] API.
19
20use std::borrow::Cow;
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[derive(Clone, Serialize, Deserialize, JsonSchema)]
33#[serde(deny_unknown_fields)]
34/// Codec that rounds the data on encoding and passes through the input
35/// unchanged during decoding.
36///
37/// The codec only supports floating point data.
38pub struct RoundCodec {
39    /// Precision of the rounding operation
40    pub precision: NonNegative<f64>,
41    /// The codec's encoding format version. Do not provide this parameter explicitly.
42    #[serde(default, rename = "_version")]
43    pub version: StaticCodecVersion<1, 0, 0>,
44}
45
46impl Codec for RoundCodec {
47    type Error = RoundCodecError;
48
49    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
50        match data {
51            #[expect(clippy::cast_possible_truncation)]
52            AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
53                data,
54                NonNegative(self.precision.0 as f32),
55            ))),
56            AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
57            encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
58        }
59    }
60
61    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
62        match encoded {
63            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
64            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
65            encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
66        }
67    }
68
69    fn decode_into(
70        &self,
71        encoded: AnyArrayView,
72        mut decoded: AnyArrayViewMut,
73    ) -> Result<(), Self::Error> {
74        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
75            return Err(RoundCodecError::UnsupportedDtype(encoded.dtype()));
76        }
77
78        Ok(decoded.assign(&encoded)?)
79    }
80}
81
82impl StaticCodec for RoundCodec {
83    const CODEC_ID: &'static str = "round.rs";
84
85    type Config<'de> = Self;
86
87    fn from_config(config: Self::Config<'_>) -> Self {
88        config
89    }
90
91    fn get_config(&self) -> StaticCodecConfig<Self> {
92        StaticCodecConfig::from(self)
93    }
94}
95
96#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
97#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
98/// Non-negative floating point number
99pub struct NonNegative<T: Float>(T);
100
101impl Serialize for NonNegative<f64> {
102    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
103        serializer.serialize_f64(self.0)
104    }
105}
106
107impl<'de> Deserialize<'de> for NonNegative<f64> {
108    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
109        let x = f64::deserialize(deserializer)?;
110
111        if x >= 0.0 {
112            Ok(Self(x))
113        } else {
114            Err(serde::de::Error::invalid_value(
115                serde::de::Unexpected::Float(x),
116                &"a non-negative value",
117            ))
118        }
119    }
120}
121
122impl JsonSchema for NonNegative<f64> {
123    fn schema_name() -> Cow<'static, str> {
124        Cow::Borrowed("NonNegativeF64")
125    }
126
127    fn schema_id() -> Cow<'static, str> {
128        Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
129    }
130
131    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
132        json_schema!({
133            "type": "number",
134            "exclusiveMinimum": 0.0
135        })
136    }
137}
138
139#[derive(Debug, Error)]
140/// Errors that may occur when applying the [`RoundCodec`].
141pub enum RoundCodecError {
142    /// [`RoundCodec`] does not support the dtype
143    #[error("Round does not support the dtype {0}")]
144    UnsupportedDtype(AnyArrayDType),
145    /// [`RoundCodec`] cannot decode into the provided array
146    #[error("Round cannot decode into the provided array")]
147    MismatchedDecodeIntoArray {
148        /// The source of the error
149        #[from]
150        source: AnyArrayAssignError,
151    },
152}
153
154#[must_use]
155/// Rounds the input `data` using
156/// `$c = \text{round}\left( \frac{x}{precision} \right) \cdot precision$`
157///
158/// If precision is zero, the `data` is returned unchanged.
159pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
160    data: ArrayBase<S, D>,
161    precision: NonNegative<T>,
162) -> Array<T, D> {
163    let mut encoded = data.into_owned();
164
165    if precision.0.is_zero() {
166        return encoded;
167    }
168
169    encoded.mapv_inplace(|x| {
170        let n = x / precision.0;
171
172        // if x / precision is not finite, don't try to round
173        //  e.g. when x / eps = inf
174        if !n.is_finite() {
175            return x;
176        }
177
178        // round x to be a multiple of precision
179        n.round() * precision.0
180    });
181
182    encoded
183}
184
185#[cfg(test)]
186mod tests {
187    use ndarray::array;
188
189    use super::*;
190
191    #[test]
192    fn round_zero_precision() {
193        let data = array![1.1, 2.1];
194
195        let rounded = round(data.view(), NonNegative(0.0));
196
197        assert_eq!(data, rounded);
198    }
199
200    #[test]
201    fn round_minimal_precision() {
202        let data = array![0.1, 1.0, 11.0, 21.0];
203
204        assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
205        let rounded = round(data.view(), NonNegative(f64::MIN_POSITIVE));
206
207        assert_eq!(data, rounded);
208    }
209
210    #[test]
211    fn round_roundoff_errors() {
212        let data = array![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
213
214        let rounded = round(data.view(), NonNegative(0.1));
215
216        assert_eq!(
217            rounded,
218            array![
219                0.0,
220                0.1,
221                0.2,
222                0.30000000000000004,
223                0.4,
224                0.5,
225                0.6000000000000001,
226                0.7000000000000001,
227                0.8,
228                0.9,
229                1.0
230            ]
231        );
232
233        let rounded_twice = round(rounded.view(), NonNegative(0.1));
234
235        assert_eq!(rounded, rounded_twice);
236    }
237
238    #[test]
239    fn round_edge_cases() {
240        let data = array![
241            -f64::NAN,
242            -f64::INFINITY,
243            -42.0,
244            -0.0,
245            0.0,
246            42.0,
247            f64::INFINITY,
248            f64::NAN
249        ];
250
251        let rounded = round(data.view(), NonNegative(1.0));
252
253        for (d, r) in data.into_iter().zip(rounded) {
254            assert!(d == r || d.to_bits() == r.to_bits());
255        }
256    }
257}