use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use ndarray::{Array, ArrayBase, Data, Dimension};
use num_traits::Float;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
};
use rand::{
SeedableRng,
distr::{Distribution, Open01},
};
use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
use wyhash::{WyHash, WyRng};
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct StochasticRoundingCodec {
pub precision: NonNegative<f64>,
pub seed: u64,
#[serde(default, rename = "_version")]
pub version: StaticCodecVersion<1, 0, 0>,
}
impl Codec for StochasticRoundingCodec {
type Error = StochasticRoundingCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
#[expect(clippy::cast_possible_truncation)]
AnyCowArray::F32(data) => Ok(AnyArray::F32(stochastic_rounding(
data,
NonNegative(self.precision.0 as f32),
self.seed,
))),
AnyCowArray::F64(data) => Ok(AnyArray::F64(stochastic_rounding(
data,
self.precision,
self.seed,
))),
encoded => Err(StochasticRoundingCodecError::UnsupportedDtype(
encoded.dtype(),
)),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
encoded => Err(StochasticRoundingCodecError::UnsupportedDtype(
encoded.dtype(),
)),
}
}
fn decode_into(
&self,
encoded: AnyArrayView,
mut decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
return Err(StochasticRoundingCodecError::UnsupportedDtype(
encoded.dtype(),
));
}
Ok(decoded.assign(&encoded)?)
}
}
impl StaticCodec for StochasticRoundingCodec {
const CODEC_ID: &'static str = "stochastic-rounding.rs";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<'_, Self> {
StaticCodecConfig::from(self)
}
}
#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
pub struct NonNegative<T: Float>(T);
impl Serialize for NonNegative<f64> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_f64(self.0)
}
}
impl<'de> Deserialize<'de> for NonNegative<f64> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let x = f64::deserialize(deserializer)?;
if x >= 0.0 {
Ok(Self(x))
} else {
Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Float(x),
&"a non-negative value",
))
}
}
}
impl JsonSchema for NonNegative<f64> {
fn schema_name() -> Cow<'static, str> {
Cow::Borrowed("NonNegativeF64")
}
fn schema_id() -> Cow<'static, str> {
Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
}
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
json_schema!({
"type": "number",
"minimum": 0.0
})
}
}
#[derive(Debug, Error)]
pub enum StochasticRoundingCodecError {
#[error("StochasticRounding does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
#[error("StochasticRounding cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
#[must_use]
pub fn stochastic_rounding<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
data: ArrayBase<S, D>,
precision: NonNegative<T>,
seed: u64,
) -> Array<T, D>
where
Open01: Distribution<T>,
{
let mut encoded = data.into_owned();
if precision.0.is_zero() {
return encoded;
}
let mut hasher = WyHash::with_seed(seed);
encoded.shape().hash(&mut hasher);
encoded
.iter()
.copied()
.for_each(|x| x.hash_bits(&mut hasher));
let seed = hasher.finish();
let mut rng: WyRng = WyRng::seed_from_u64(seed);
for x in &mut encoded {
if !x.is_finite() {
continue;
}
let remainder = x.rem_euclid(precision.0);
let mut lower = *x - remainder;
if (*x - lower) > precision.0 {
lower = lower.next_up();
}
let mut upper = *x + (precision.0 - remainder);
if (upper - *x) > precision.0 {
upper = upper.next_down();
}
let threshold = remainder / precision.0;
let u01: T = Open01.sample(&mut rng);
*x = if u01 >= threshold { lower } else { upper };
}
encoded
}
pub trait FloatExt: Float {
const NEG_HALF: Self;
fn hash_bits<H: Hasher>(self, hasher: &mut H);
#[must_use]
fn rem_euclid(self, rhs: Self) -> Self;
#[must_use]
fn next_up(self) -> Self;
#[must_use]
fn next_down(self) -> Self;
}
impl FloatExt for f32 {
const NEG_HALF: Self = -0.5;
fn hash_bits<H: Hasher>(self, hasher: &mut H) {
hasher.write_u32(self.to_bits());
}
fn rem_euclid(self, rhs: Self) -> Self {
Self::rem_euclid(self, rhs)
}
fn next_up(self) -> Self {
Self::next_up(self)
}
fn next_down(self) -> Self {
Self::next_down(self)
}
}
impl FloatExt for f64 {
const NEG_HALF: Self = -0.5;
fn hash_bits<H: Hasher>(self, hasher: &mut H) {
hasher.write_u64(self.to_bits());
}
fn rem_euclid(self, rhs: Self) -> Self {
Self::rem_euclid(self, rhs)
}
fn next_up(self) -> Self {
Self::next_up(self)
}
fn next_down(self) -> Self {
Self::next_down(self)
}
}
#[cfg(test)]
mod tests {
use ndarray::{array, linspace};
use super::*;
#[test]
fn round_zero_precision() {
let data = array![1.1, 2.1];
let rounded = stochastic_rounding(data.view(), NonNegative(0.0), 42);
assert_eq!(data, rounded);
}
#[test]
fn round_infinite_precision() {
let data = array![1.1, 2.1];
let rounded = stochastic_rounding(data.view(), NonNegative(f64::INFINITY), 42);
assert_eq!(rounded, array![0.0, 0.0]);
}
#[test]
fn round_minimal_precision() {
let data = array![0.1, 1.0, 11.0, 21.0];
assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
let rounded = stochastic_rounding(data.view(), NonNegative(f64::MIN_POSITIVE), 42);
assert_eq!(data, rounded);
}
#[test]
fn round_edge_cases() {
let data = array![
-f64::NAN,
-f64::INFINITY,
-42.0,
-4.2,
-0.0,
0.0,
4.2,
42.0,
f64::INFINITY,
f64::NAN
];
let precision = 1.0;
let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
for (d, r) in data.into_iter().zip(rounded) {
assert!((r - d).abs() <= precision || d.to_bits() == r.to_bits());
}
}
#[test]
fn round_rounding_errors() {
let data = Array::from_iter(linspace(-100.0, 100.0, 3741));
let precision = 0.1;
let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
for (d, r) in data.into_iter().zip(rounded) {
assert!((r - d).abs() <= precision);
}
}
#[test]
fn test_rounding_bug() {
let data = array![
-1.23540_f32,
-1.23539_f32,
-1.23538_f32,
-1.23537_f32,
-1.23536_f32,
-1.23535_f32,
-1.23534_f32,
-1.23533_f32,
-1.23532_f32,
-1.23531_f32,
-1.23530_f32,
1.23540_f32,
1.23539_f32,
1.23538_f32,
1.23537_f32,
1.23536_f32,
1.23535_f32,
1.23534_f32,
1.23533_f32,
1.23532_f32,
1.23531_f32,
1.23530_f32,
];
let precision = 0.00018_f32;
let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
for (d, r) in data.into_iter().zip(rounded) {
assert!((r - d).abs() <= precision);
}
}
}