Skip to main content

provable_contracts/kernels/
ulp.rs

1//! ULP (Unit in the Last Place) distance utilities for floating-point comparison.
2//!
3//! Used to verify SIMD and PTX kernels produce results within acceptable
4//! tolerance of the scalar reference implementation.
5
6/// Compute the ULP distance between two f32 values.
7///
8/// Returns the number of representable floats between `a` and `b`.
9/// Special cases: if either value is NaN, returns `u32::MAX`.
10/// If signs differ and neither is zero, returns `u32::MAX`.
11#[must_use]
12pub fn ulp_distance(a: f32, b: f32) -> u32 {
13    if a.is_nan() || b.is_nan() {
14        return u32::MAX;
15    }
16    if a == b {
17        return 0;
18    }
19    let a_bits = a.to_bits() as i32;
20    let b_bits = b.to_bits() as i32;
21    // Handle sign mismatch (excluding ±0)
22    if (a_bits < 0) != (b_bits < 0) {
23        // Both ±0 case already handled by a == b above
24        return u32::MAX;
25    }
26    a_bits.abs_diff(b_bits)
27}
28
29/// Assert that two f32 slices are equal within the given ULP tolerance.
30///
31/// # Panics
32///
33/// Panics if slices have different lengths or any element pair exceeds
34/// the ULP tolerance.
35pub fn assert_ulp_eq(a: &[f32], b: &[f32], max_ulp: u32) {
36    assert_eq!(
37        a.len(),
38        b.len(),
39        "slice length mismatch: {} vs {}",
40        a.len(),
41        b.len()
42    );
43    for (i, (&va, &vb)) in a.iter().zip(b.iter()).enumerate() {
44        let dist = ulp_distance(va, vb);
45        assert!(
46            dist <= max_ulp,
47            "ULP violation at index {i}: {va} vs {vb} (ULP distance {dist}, max {max_ulp})"
48        );
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[test]
57    fn test_ulp_distance_identical() {
58        assert_eq!(ulp_distance(1.0, 1.0), 0);
59        assert_eq!(ulp_distance(0.0, 0.0), 0);
60        assert_eq!(ulp_distance(-1.0, -1.0), 0);
61    }
62
63    #[test]
64    fn test_ulp_distance_adjacent() {
65        let a: f32 = 1.0;
66        let b = f32::from_bits(a.to_bits() + 1);
67        assert_eq!(ulp_distance(a, b), 1);
68    }
69
70    #[test]
71    fn test_ulp_distance_nan() {
72        assert_eq!(ulp_distance(f32::NAN, 1.0), u32::MAX);
73        assert_eq!(ulp_distance(1.0, f32::NAN), u32::MAX);
74        assert_eq!(ulp_distance(f32::NAN, f32::NAN), u32::MAX);
75    }
76
77    #[test]
78    fn test_ulp_distance_sign_mismatch() {
79        assert_eq!(ulp_distance(1.0, -1.0), u32::MAX);
80    }
81
82    #[test]
83    fn test_ulp_distance_small_gap() {
84        let a: f32 = 1.0;
85        let b = f32::from_bits(a.to_bits() + 10);
86        assert_eq!(ulp_distance(a, b), 10);
87    }
88
89    #[test]
90    fn test_assert_ulp_eq_passes() {
91        let a = [1.0f32, 2.0, 3.0];
92        let b = [1.0f32, 2.0, 3.0];
93        assert_ulp_eq(&a, &b, 0);
94    }
95
96    #[test]
97    #[should_panic(expected = "ULP violation")]
98    fn test_assert_ulp_eq_fails() {
99        let a = [1.0f32];
100        let b = [2.0f32];
101        assert_ulp_eq(&a, &b, 0);
102    }
103
104    #[test]
105    #[should_panic(expected = "slice length mismatch")]
106    fn test_assert_ulp_eq_length_mismatch() {
107        assert_ulp_eq(&[1.0], &[1.0, 2.0], 0);
108    }
109
110    #[test]
111    fn test_ulp_distance_negative_zero() {
112        assert_eq!(ulp_distance(0.0, -0.0), 0);
113    }
114}