Skip to main content

diskann_benchmark_runner/utils/
num.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Number utilities for enforcing deserialization constraints and computing relative errors.
7
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use thiserror::Error;
10
11/// Compute the relative change from `before` to `after`.
12///
13/// This helper is intentionally opinionated for benchmark-style metrics:
14///
15/// - `before` must be finite and strictly positive.
16/// - `after` must be finite and non-negative.
17///
18/// In other words, this computes:
19/// ```text
20/// (after - before) / before
21/// ```
22///
23/// Negative values indicate improvements while positive values indicate regressions.
24pub fn relative_change(before: f64, after: f64) -> Result<f64, RelativeChangeError> {
25    if !before.is_finite() {
26        return Err(RelativeChangeError::NonFiniteBefore);
27    }
28    if before <= 0.0 {
29        return Err(RelativeChangeError::NonPositiveBefore);
30    }
31
32    let after = NonNegativeFinite::new(after).map_err(RelativeChangeError::InvalidAfter)?;
33    let after = after.get();
34
35    let change = (after - before) / before;
36    if !change.is_finite() {
37        return Err(RelativeChangeError::NonFiniteComputedChange);
38    }
39
40    Ok(change)
41}
42
43/// Error returned when attempting to compute a relative change.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
45pub enum RelativeChangeError {
46    #[error("expected \"before\" to be a finite number")]
47    NonFiniteBefore,
48    #[error("expected \"before\" to be greater than zero")]
49    NonPositiveBefore,
50    #[error("invalid \"after\" value: {0}")]
51    InvalidAfter(InvalidNonNegativeFinite),
52    #[error("computed relative change is not finite")]
53    NonFiniteComputedChange,
54}
55
56/// A finite floating-point value that is greater than or equal to zero.
57#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
58pub struct NonNegativeFinite(f64);
59
60impl NonNegativeFinite {
61    /// Attempt to construct `Self` from `value`.
62    pub const fn new(value: f64) -> Result<Self, InvalidNonNegativeFinite> {
63        if !value.is_finite() {
64            Err(InvalidNonNegativeFinite::NonFinite)
65        } else if value < 0.0 {
66            Err(InvalidNonNegativeFinite::Negative)
67        } else if value == 0.0 {
68            Ok(Self(0.0))
69        } else {
70            Ok(Self(value))
71        }
72    }
73
74    /// Return the underlying floating-point value.
75    pub const fn get(self) -> f64 {
76        self.0
77    }
78}
79
80impl std::fmt::Display for NonNegativeFinite {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", self.0)
83    }
84}
85
86impl TryFrom<f64> for NonNegativeFinite {
87    type Error = InvalidNonNegativeFinite;
88
89    fn try_from(value: f64) -> Result<Self, Self::Error> {
90        Self::new(value)
91    }
92}
93
94impl From<NonNegativeFinite> for f64 {
95    fn from(value: NonNegativeFinite) -> Self {
96        value.get()
97    }
98}
99
100impl Serialize for NonNegativeFinite {
101    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
102    where
103        S: Serializer,
104    {
105        serializer.serialize_f64(self.0)
106    }
107}
108
109impl<'de> Deserialize<'de> for NonNegativeFinite {
110    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
111    where
112        D: Deserializer<'de>,
113    {
114        let value = f64::deserialize(deserializer)?;
115        Self::new(value).map_err(serde::de::Error::custom)
116    }
117}
118
119/// Error returned when attempting to construct a [`NonNegativeFinite`].
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
121pub enum InvalidNonNegativeFinite {
122    #[error("expected a finite number")]
123    NonFinite,
124    #[error("expected a non-negative number")]
125    Negative,
126}
127
128///////////
129// Tests //
130///////////
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_new() {
138        let to_non_negative =
139            |x: f64| -> Result<NonNegativeFinite, InvalidNonNegativeFinite> { x.try_into() };
140        let to_f64 = |x: NonNegativeFinite| -> f64 { x.into() };
141
142        assert_eq!(NonNegativeFinite::new(0.0).unwrap().get(), 0.0);
143        assert_eq!(NonNegativeFinite::new(-0.0).unwrap().get(), 0.0);
144        assert_eq!(NonNegativeFinite::new(0.25).unwrap().get(), 0.25);
145
146        assert_eq!(to_f64(to_non_negative(0.0).unwrap()), 0.0);
147        assert_eq!(to_f64(to_non_negative(-0.0).unwrap()), 0.0);
148        assert_eq!(to_f64(to_non_negative(0.25).unwrap()), 0.25);
149
150        assert_eq!(to_non_negative(0.25).unwrap().to_string(), 0.25.to_string());
151
152        assert_eq!(
153            NonNegativeFinite::new(-1.0).unwrap_err(),
154            InvalidNonNegativeFinite::Negative
155        );
156        assert_eq!(
157            to_non_negative(-1.0).unwrap_err(),
158            InvalidNonNegativeFinite::Negative
159        );
160
161        assert_eq!(
162            NonNegativeFinite::new(f64::INFINITY).unwrap_err(),
163            InvalidNonNegativeFinite::NonFinite
164        );
165        assert_eq!(
166            NonNegativeFinite::new(f64::NEG_INFINITY).unwrap_err(),
167            InvalidNonNegativeFinite::NonFinite
168        );
169        assert_eq!(
170            NonNegativeFinite::new(f64::NAN).unwrap_err(),
171            InvalidNonNegativeFinite::NonFinite
172        );
173    }
174
175    #[test]
176    fn test_serde() {
177        let value: NonNegativeFinite = serde_json::from_str("0.1").unwrap();
178        assert_eq!(value.get(), 0.1);
179
180        let serialized = serde_json::to_string(&value).unwrap();
181        assert_eq!(serialized, "0.1");
182
183        let err = serde_json::from_str::<NonNegativeFinite>("-0.5").unwrap_err();
184        assert!(err.to_string().contains("expected a non-negative number"));
185    }
186
187    #[test]
188    fn test_relative_change() {
189        assert_eq!(relative_change(10.0, 10.0).unwrap(), 0.0);
190        assert_eq!(relative_change(10.0, 12.5).unwrap(), 0.25);
191        assert_eq!(relative_change(10.0, 8.0).unwrap(), -0.2);
192        assert_eq!(relative_change(10.0, -0.0).unwrap(), -1.0);
193
194        assert_eq!(
195            relative_change(0.0, 1.0).unwrap_err(),
196            RelativeChangeError::NonPositiveBefore
197        );
198        assert_eq!(
199            relative_change(-1.0, 1.0).unwrap_err(),
200            RelativeChangeError::NonPositiveBefore
201        );
202        assert_eq!(
203            relative_change(f64::NAN, 1.0).unwrap_err(),
204            RelativeChangeError::NonFiniteBefore
205        );
206        assert_eq!(
207            relative_change(f64::INFINITY, 1.0).unwrap_err(),
208            RelativeChangeError::NonFiniteBefore
209        );
210        assert_eq!(
211            relative_change(1.0, -1.0).unwrap_err(),
212            RelativeChangeError::InvalidAfter(InvalidNonNegativeFinite::Negative)
213        );
214        assert_eq!(
215            relative_change(1.0, f64::NAN).unwrap_err(),
216            RelativeChangeError::InvalidAfter(InvalidNonNegativeFinite::NonFinite)
217        );
218        assert_eq!(
219            relative_change(f64::MIN_POSITIVE, f64::MAX).unwrap_err(),
220            RelativeChangeError::NonFiniteComputedChange
221        );
222    }
223}