1use std::borrow::Cow;
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[derive(Clone, Serialize, Deserialize, JsonSchema)]
33#[serde(deny_unknown_fields)]
34pub struct RoundCodec {
39 pub precision: NonNegative<f64>,
41 #[serde(default, rename = "_version")]
43 pub version: StaticCodecVersion<1, 0, 0>,
44}
45
46impl Codec for RoundCodec {
47 type Error = RoundCodecError;
48
49 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
50 match data {
51 #[expect(clippy::cast_possible_truncation)]
52 AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
53 data,
54 NonNegative(self.precision.0 as f32),
55 ))),
56 AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
57 encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
58 }
59 }
60
61 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
62 match encoded {
63 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
64 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
65 encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
66 }
67 }
68
69 fn decode_into(
70 &self,
71 encoded: AnyArrayView,
72 mut decoded: AnyArrayViewMut,
73 ) -> Result<(), Self::Error> {
74 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
75 return Err(RoundCodecError::UnsupportedDtype(encoded.dtype()));
76 }
77
78 Ok(decoded.assign(&encoded)?)
79 }
80}
81
82impl StaticCodec for RoundCodec {
83 const CODEC_ID: &'static str = "round.rs";
84
85 type Config<'de> = Self;
86
87 fn from_config(config: Self::Config<'_>) -> Self {
88 config
89 }
90
91 fn get_config(&self) -> StaticCodecConfig<Self> {
92 StaticCodecConfig::from(self)
93 }
94}
95
96#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
98pub struct NonNegative<T: Float>(T);
100
101impl Serialize for NonNegative<f64> {
102 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
103 serializer.serialize_f64(self.0)
104 }
105}
106
107impl<'de> Deserialize<'de> for NonNegative<f64> {
108 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
109 let x = f64::deserialize(deserializer)?;
110
111 if x >= 0.0 {
112 Ok(Self(x))
113 } else {
114 Err(serde::de::Error::invalid_value(
115 serde::de::Unexpected::Float(x),
116 &"a non-negative value",
117 ))
118 }
119 }
120}
121
122impl JsonSchema for NonNegative<f64> {
123 fn schema_name() -> Cow<'static, str> {
124 Cow::Borrowed("NonNegativeF64")
125 }
126
127 fn schema_id() -> Cow<'static, str> {
128 Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
129 }
130
131 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
132 json_schema!({
133 "type": "number",
134 "exclusiveMinimum": 0.0
135 })
136 }
137}
138
139#[derive(Debug, Error)]
140pub enum RoundCodecError {
142 #[error("Round does not support the dtype {0}")]
144 UnsupportedDtype(AnyArrayDType),
145 #[error("Round cannot decode into the provided array")]
147 MismatchedDecodeIntoArray {
148 #[from]
150 source: AnyArrayAssignError,
151 },
152}
153
154#[must_use]
155pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
160 data: ArrayBase<S, D>,
161 precision: NonNegative<T>,
162) -> Array<T, D> {
163 let mut encoded = data.into_owned();
164
165 if precision.0.is_zero() {
166 return encoded;
167 }
168
169 encoded.mapv_inplace(|x| {
170 let n = x / precision.0;
171
172 if !n.is_finite() {
175 return x;
176 }
177
178 n.round() * precision.0
180 });
181
182 encoded
183}
184
185#[cfg(test)]
186mod tests {
187 use ndarray::array;
188
189 use super::*;
190
191 #[test]
192 fn round_zero_precision() {
193 let data = array![1.1, 2.1];
194
195 let rounded = round(data.view(), NonNegative(0.0));
196
197 assert_eq!(data, rounded);
198 }
199
200 #[test]
201 fn round_minimal_precision() {
202 let data = array![0.1, 1.0, 11.0, 21.0];
203
204 assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
205 let rounded = round(data.view(), NonNegative(f64::MIN_POSITIVE));
206
207 assert_eq!(data, rounded);
208 }
209
210 #[test]
211 fn round_roundoff_errors() {
212 let data = array![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
213
214 let rounded = round(data.view(), NonNegative(0.1));
215
216 assert_eq!(
217 rounded,
218 array![
219 0.0,
220 0.1,
221 0.2,
222 0.30000000000000004,
223 0.4,
224 0.5,
225 0.6000000000000001,
226 0.7000000000000001,
227 0.8,
228 0.9,
229 1.0
230 ]
231 );
232
233 let rounded_twice = round(rounded.view(), NonNegative(0.1));
234
235 assert_eq!(rounded, rounded_twice);
236 }
237
238 #[test]
239 fn round_edge_cases() {
240 let data = array![
241 -f64::NAN,
242 -f64::INFINITY,
243 -42.0,
244 -0.0,
245 0.0,
246 42.0,
247 f64::INFINITY,
248 f64::NAN
249 ];
250
251 let rounded = round(data.view(), NonNegative(1.0));
252
253 for (d, r) in data.into_iter().zip(rounded) {
254 assert!(d == r || d.to_bits() == r.to_bits());
255 }
256 }
257}