datafusion_common/
rounding.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Floating point rounding mode utility library
19//! TODO: Remove this custom implementation and the "libc" dependency when
20//!       floating-point rounding mode manipulation functions become available
21//!       in Rust.
22
23use std::ops::{Add, BitAnd, Sub};
24
25use crate::Result;
26use crate::ScalarValue;
27
28// Define constants for ARM
29#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
30const FE_UPWARD: i32 = 0x00400000;
31#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
32const FE_DOWNWARD: i32 = 0x00800000;
33
34// Define constants for x86_64
35#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
36const FE_UPWARD: i32 = 0x0800;
37#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
38const FE_DOWNWARD: i32 = 0x0400;
39
40#[cfg(all(
41    any(target_arch = "x86_64", target_arch = "aarch64"),
42    not(target_os = "windows")
43))]
44extern crate libc;
45
46#[cfg(all(
47    any(target_arch = "x86_64", target_arch = "aarch64"),
48    not(target_os = "windows")
49))]
50extern "C" {
51    fn fesetround(round: i32);
52    fn fegetround() -> i32;
53}
54
55/// A trait to manipulate floating-point types with bitwise operations.
56/// Provides functions to convert a floating-point value to/from its bitwise
57/// representation as well as utility methods to handle special values.
58pub trait FloatBits {
59    /// The integer type used for bitwise operations.
60    type Item: Copy
61        + PartialEq
62        + BitAnd<Output = Self::Item>
63        + Add<Output = Self::Item>
64        + Sub<Output = Self::Item>;
65
66    /// The smallest positive floating-point value representable by this type.
67    const TINY_BITS: Self::Item;
68
69    /// The smallest (in magnitude) negative floating-point value representable by this type.
70    const NEG_TINY_BITS: Self::Item;
71
72    /// A mask to clear the sign bit of the floating-point value's bitwise representation.
73    const CLEAR_SIGN_MASK: Self::Item;
74
75    /// The integer value 1, used in bitwise operations.
76    const ONE: Self::Item;
77
78    /// The integer value 0, used in bitwise operations.
79    const ZERO: Self::Item;
80    const NEG_ZERO: Self::Item;
81
82    /// Converts the floating-point value to its bitwise representation.
83    fn to_bits(self) -> Self::Item;
84
85    /// Converts the bitwise representation to the corresponding floating-point value.
86    fn from_bits(bits: Self::Item) -> Self;
87
88    /// Returns true if the floating-point value is NaN (not a number).
89    fn float_is_nan(self) -> bool;
90
91    /// Returns the positive infinity value for this floating-point type.
92    fn infinity() -> Self;
93
94    /// Returns the negative infinity value for this floating-point type.
95    fn neg_infinity() -> Self;
96}
97
98impl FloatBits for f32 {
99    type Item = u32;
100    const TINY_BITS: u32 = 0x1; // Smallest positive f32.
101    const NEG_TINY_BITS: u32 = 0x8000_0001; // Smallest (in magnitude) negative f32.
102    const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
103    const ONE: Self::Item = 1;
104    const ZERO: Self::Item = 0;
105    const NEG_ZERO: Self::Item = 0x8000_0000;
106
107    fn to_bits(self) -> Self::Item {
108        self.to_bits()
109    }
110
111    fn from_bits(bits: Self::Item) -> Self {
112        f32::from_bits(bits)
113    }
114
115    fn float_is_nan(self) -> bool {
116        self.is_nan()
117    }
118
119    fn infinity() -> Self {
120        f32::INFINITY
121    }
122
123    fn neg_infinity() -> Self {
124        f32::NEG_INFINITY
125    }
126}
127
128impl FloatBits for f64 {
129    type Item = u64;
130    const TINY_BITS: u64 = 0x1; // Smallest positive f64.
131    const NEG_TINY_BITS: u64 = 0x8000_0000_0000_0001; // Smallest (in magnitude) negative f64.
132    const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff;
133    const ONE: Self::Item = 1;
134    const ZERO: Self::Item = 0;
135    const NEG_ZERO: Self::Item = 0x8000_0000_0000_0000;
136
137    fn to_bits(self) -> Self::Item {
138        self.to_bits()
139    }
140
141    fn from_bits(bits: Self::Item) -> Self {
142        f64::from_bits(bits)
143    }
144
145    fn float_is_nan(self) -> bool {
146        self.is_nan()
147    }
148
149    fn infinity() -> Self {
150        f64::INFINITY
151    }
152
153    fn neg_infinity() -> Self {
154        f64::NEG_INFINITY
155    }
156}
157
158/// Returns the next representable floating-point value greater than the input value.
159///
160/// This function takes a floating-point value that implements the FloatBits trait,
161/// calculates the next representable value greater than the input, and returns it.
162///
163/// If the input value is NaN or positive infinity, the function returns the input value.
164///
165/// # Examples
166///
167/// ```
168/// use datafusion_common::rounding::next_up;
169///
170/// let f: f32 = 1.0;
171/// let next_f = next_up(f);
172/// assert_eq!(next_f, 1.0000001);
173/// ```
174pub fn next_up<F: FloatBits + Copy>(float: F) -> F {
175    let bits = float.to_bits();
176    if float.float_is_nan() || bits == F::infinity().to_bits() {
177        return float;
178    }
179
180    let abs = bits & F::CLEAR_SIGN_MASK;
181    let next_bits = if bits == F::ZERO {
182        F::TINY_BITS
183    } else if abs == F::ZERO {
184        F::ZERO
185    } else if bits == abs {
186        bits + F::ONE
187    } else {
188        bits - F::ONE
189    };
190    F::from_bits(next_bits)
191}
192
193/// Returns the next representable floating-point value smaller than the input value.
194///
195/// This function takes a floating-point value that implements the FloatBits trait,
196/// calculates the next representable value smaller than the input, and returns it.
197///
198/// If the input value is NaN or negative infinity, the function returns the input value.
199///
200/// # Examples
201///
202/// ```
203/// use datafusion_common::rounding::next_down;
204///
205/// let f: f32 = 1.0;
206/// let next_f = next_down(f);
207/// assert_eq!(next_f, 0.99999994);
208/// ```
209pub fn next_down<F: FloatBits + Copy>(float: F) -> F {
210    let bits = float.to_bits();
211    if float.float_is_nan() || bits == F::neg_infinity().to_bits() {
212        return float;
213    }
214
215    let abs = bits & F::CLEAR_SIGN_MASK;
216    let next_bits = if bits == F::ZERO {
217        F::NEG_ZERO
218    } else if abs == F::ZERO {
219        F::NEG_TINY_BITS
220    } else if bits == abs {
221        bits - F::ONE
222    } else {
223        bits + F::ONE
224    };
225    F::from_bits(next_bits)
226}
227
228#[cfg(any(
229    not(any(target_arch = "x86_64", target_arch = "aarch64")),
230    target_os = "windows"
231))]
232fn alter_fp_rounding_mode_conservative<const UPPER: bool, F>(
233    lhs: &ScalarValue,
234    rhs: &ScalarValue,
235    operation: F,
236) -> Result<ScalarValue>
237where
238    F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
239{
240    let mut result = operation(lhs, rhs)?;
241    match &mut result {
242        ScalarValue::Float64(Some(value)) => {
243            if UPPER {
244                *value = next_up(*value)
245            } else {
246                *value = next_down(*value)
247            }
248        }
249        ScalarValue::Float32(Some(value)) => {
250            if UPPER {
251                *value = next_up(*value)
252            } else {
253                *value = next_down(*value)
254            }
255        }
256        _ => {}
257    };
258    Ok(result)
259}
260
261pub fn alter_fp_rounding_mode<const UPPER: bool, F>(
262    lhs: &ScalarValue,
263    rhs: &ScalarValue,
264    operation: F,
265) -> Result<ScalarValue>
266where
267    F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
268{
269    #[cfg(all(
270        any(target_arch = "x86_64", target_arch = "aarch64"),
271        not(target_os = "windows")
272    ))]
273    unsafe {
274        let current = fegetround();
275        fesetround(if UPPER { FE_UPWARD } else { FE_DOWNWARD });
276        let result = operation(lhs, rhs);
277        fesetround(current);
278        result
279    }
280    #[cfg(any(
281        not(any(target_arch = "x86_64", target_arch = "aarch64")),
282        target_os = "windows"
283    ))]
284    alter_fp_rounding_mode_conservative::<UPPER, _>(lhs, rhs, operation)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::{next_down, next_up};
290
291    #[test]
292    fn test_next_down() {
293        let x = 1.0f64;
294        // Clamp value into range [0, 1).
295        let clamped = x.clamp(0.0, next_down(1.0f64));
296        assert!(clamped < 1.0);
297        assert_eq!(next_up(clamped), 1.0);
298    }
299
300    #[test]
301    fn test_next_up_small_positive() {
302        let value: f64 = 1.0;
303        let result = next_up(value);
304        assert_eq!(result, 1.0000000000000002);
305    }
306
307    #[test]
308    fn test_next_up_small_negative() {
309        let value: f64 = -1.0;
310        let result = next_up(value);
311        assert_eq!(result, -0.9999999999999999);
312    }
313
314    #[test]
315    fn test_next_up_pos_infinity() {
316        let value: f64 = f64::INFINITY;
317        let result = next_up(value);
318        assert_eq!(result, f64::INFINITY);
319    }
320
321    #[test]
322    fn test_next_up_nan() {
323        let value: f64 = f64::NAN;
324        let result = next_up(value);
325        assert!(result.is_nan());
326    }
327
328    #[test]
329    fn test_next_down_small_positive() {
330        let value: f64 = 1.0;
331        let result = next_down(value);
332        assert_eq!(result, 0.9999999999999999);
333    }
334
335    #[test]
336    fn test_next_down_small_negative() {
337        let value: f64 = -1.0;
338        let result = next_down(value);
339        assert_eq!(result, -1.0000000000000002);
340    }
341
342    #[test]
343    fn test_next_down_neg_infinity() {
344        let value: f64 = f64::NEG_INFINITY;
345        let result = next_down(value);
346        assert_eq!(result, f64::NEG_INFINITY);
347    }
348
349    #[test]
350    fn test_next_down_nan() {
351        let value: f64 = f64::NAN;
352        let result = next_down(value);
353        assert!(result.is_nan());
354    }
355
356    #[test]
357    fn test_next_up_small_positive_f32() {
358        let value: f32 = 1.0;
359        let result = next_up(value);
360        assert_eq!(result, 1.0000001);
361    }
362
363    #[test]
364    fn test_next_up_small_negative_f32() {
365        let value: f32 = -1.0;
366        let result = next_up(value);
367        assert_eq!(result, -0.99999994);
368    }
369
370    #[test]
371    fn test_next_up_pos_infinity_f32() {
372        let value: f32 = f32::INFINITY;
373        let result = next_up(value);
374        assert_eq!(result, f32::INFINITY);
375    }
376
377    #[test]
378    fn test_next_up_nan_f32() {
379        let value: f32 = f32::NAN;
380        let result = next_up(value);
381        assert!(result.is_nan());
382    }
383    #[test]
384    fn test_next_down_small_positive_f32() {
385        let value: f32 = 1.0;
386        let result = next_down(value);
387        assert_eq!(result, 0.99999994);
388    }
389    #[test]
390    fn test_next_down_small_negative_f32() {
391        let value: f32 = -1.0;
392        let result = next_down(value);
393        assert_eq!(result, -1.0000001);
394    }
395    #[test]
396    fn test_next_down_neg_infinity_f32() {
397        let value: f32 = f32::NEG_INFINITY;
398        let result = next_down(value);
399        assert_eq!(result, f32::NEG_INFINITY);
400    }
401    #[test]
402    fn test_next_down_nan_f32() {
403        let value: f32 = f32::NAN;
404        let result = next_down(value);
405        assert!(result.is_nan());
406    }
407
408    #[test]
409    fn test_next_up_neg_zero_f32() {
410        let value: f32 = -0.0;
411        let result = next_up(value);
412        assert_eq!(result, 0.0);
413    }
414
415    #[test]
416    fn test_next_down_zero_f32() {
417        let value: f32 = 0.0;
418        let result = next_down(value);
419        assert_eq!(result, -0.0);
420    }
421
422    #[test]
423    fn test_next_up_neg_zero_f64() {
424        let value: f64 = -0.0;
425        let result = next_up(value);
426        assert_eq!(result, 0.0);
427    }
428
429    #[test]
430    fn test_next_down_zero_f64() {
431        let value: f64 = 0.0;
432        let result = next_down(value);
433        assert_eq!(result, -0.0);
434    }
435}