Skip to main content

ferray_ufunc/ops/
rounding.rs

1// ferray-ufunc: Rounding functions
2//
3// round (banker's rounding!), floor, ceil, trunc, fix, rint, around
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::FerrayResult;
9use num_traits::Float;
10
11use crate::helpers::unary_float_op;
12
13/// Banker's rounding: round half to even (AC-9).
14///
15/// `round(0.5) == 0`, `round(1.5) == 2`, `round(2.5) == 2`.
16fn bankers_round<T: Float>(x: T) -> T {
17    // Check if x is exactly at a .5 boundary
18    let half = T::from(0.5).unwrap();
19    let two = T::from(2.0).unwrap();
20
21    // Get the fractional part: x - floor(x)
22    let floored = x.floor();
23    let frac = x - floored;
24
25    // Check if fractional part is exactly 0.5
26    if frac == half {
27        // At exact .5 -- round to even
28        let ceiled = x.ceil();
29        // Check which of floor/ceil is even
30        // A number is even if dividing by 2 and flooring gives back the same
31        if (floored / two).floor() * two == floored {
32            floored
33        } else {
34            ceiled
35        }
36    } else if frac == -half {
37        // Negative half case: x is negative, frac = x - floor(x) could be 0.5 for negatives
38        // Actually for negative numbers like -0.5: floor(-0.5) = -1, frac = -0.5 - (-1) = 0.5
39        // So the above branch handles it. This branch is for safety.
40        x.ceil()
41    } else {
42        // Not at a .5 boundary, standard rounding is fine
43        x.round()
44    }
45}
46
47/// Elementwise banker's rounding (round half to even).
48///
49/// This matches `NumPy`'s `np.round` / `np.around` behavior.
50/// AC-9: `round(0.5)==0`, `round(1.5)==2`.
51pub fn round<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
52where
53    T: Element + Float,
54    D: Dimension,
55{
56    unary_float_op(input, bankers_round)
57}
58
59/// Alias for [`round`] -- matches `NumPy`'s `around`.
60pub fn around<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
61where
62    T: Element + Float,
63    D: Dimension,
64{
65    round(input)
66}
67
68/// Alias for [`round`] -- matches `NumPy`'s `rint`.
69pub fn rint<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
70where
71    T: Element + Float,
72    D: Dimension,
73{
74    round(input)
75}
76
77/// Elementwise floor (round toward negative infinity).
78pub fn floor<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
79where
80    T: Element + Float,
81    D: Dimension,
82{
83    unary_float_op(input, T::floor)
84}
85
86/// Elementwise ceiling (round toward positive infinity).
87pub fn ceil<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
88where
89    T: Element + Float,
90    D: Dimension,
91{
92    unary_float_op(input, T::ceil)
93}
94
95/// Elementwise truncation (round toward zero).
96pub fn trunc<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
97where
98    T: Element + Float,
99    D: Dimension,
100{
101    unary_float_op(input, T::trunc)
102}
103
104/// Elementwise fix: round toward zero (same as trunc for real numbers).
105pub fn fix<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
106where
107    T: Element + Float,
108    D: Dimension,
109{
110    trunc(input)
111}
112
113// ---------------------------------------------------------------------------
114// f16 variants (f32-promoted) — generated via the shared unary_f16_fn!
115// macro (#142).
116// ---------------------------------------------------------------------------
117
118use crate::helpers::unary_f16_fn;
119
120unary_f16_fn!(
121    /// Elementwise floor for f16 arrays via f32 promotion.
122    #[cfg(feature = "f16")]
123    floor_f16,
124    f32::floor
125);
126unary_f16_fn!(
127    /// Elementwise ceiling for f16 arrays via f32 promotion.
128    #[cfg(feature = "f16")]
129    ceil_f16,
130    f32::ceil
131);
132unary_f16_fn!(
133    /// Elementwise truncation for f16 arrays via f32 promotion.
134    #[cfg(feature = "f16")]
135    trunc_f16,
136    f32::trunc
137);
138unary_f16_fn!(
139    /// Elementwise banker's rounding for f16 arrays via f32 promotion.
140    ///
141    /// Reuses the generic [`bankers_round`] via monomorphization on
142    /// `f32`; the hand-rolled f32 copy was deleted in #144.
143    #[cfg(feature = "f16")]
144    round_f16,
145    bankers_round::<f32>
146);
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    use crate::test_util::arr1;
153
154    #[test]
155    fn test_bankers_round_half_to_even_ac9() {
156        // AC-9: round(0.5)==0, round(1.5)==2
157        let a = arr1(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
158        let r = round(&a).unwrap();
159        let s = r.as_slice().unwrap();
160        assert_eq!(s[0], 0.0); // 0.5 -> 0 (even)
161        assert_eq!(s[1], 2.0); // 1.5 -> 2 (even)
162        assert_eq!(s[2], 2.0); // 2.5 -> 2 (even)
163        assert_eq!(s[3], 4.0); // 3.5 -> 4 (even)
164        assert_eq!(s[4], 0.0); // -0.5 -> 0 (even)
165        assert_eq!(s[5], -2.0); // -1.5 -> -2 (even)
166    }
167
168    #[test]
169    fn test_round_normal() {
170        let a = arr1(vec![1.2, 2.7, -1.3, -2.8]);
171        let r = round(&a).unwrap();
172        let s = r.as_slice().unwrap();
173        assert_eq!(s[0], 1.0);
174        assert_eq!(s[1], 3.0);
175        assert_eq!(s[2], -1.0);
176        assert_eq!(s[3], -3.0);
177    }
178
179    #[test]
180    fn test_floor() {
181        let a = arr1(vec![1.7, -1.7, 0.0]);
182        let r = floor(&a).unwrap();
183        let s = r.as_slice().unwrap();
184        assert_eq!(s[0], 1.0);
185        assert_eq!(s[1], -2.0);
186        assert_eq!(s[2], 0.0);
187    }
188
189    #[test]
190    fn test_ceil() {
191        let a = arr1(vec![1.2, -1.2, 0.0]);
192        let r = ceil(&a).unwrap();
193        let s = r.as_slice().unwrap();
194        assert_eq!(s[0], 2.0);
195        assert_eq!(s[1], -1.0);
196        assert_eq!(s[2], 0.0);
197    }
198
199    #[test]
200    fn test_trunc() {
201        let a = arr1(vec![1.9, -1.9, 0.0]);
202        let r = trunc(&a).unwrap();
203        let s = r.as_slice().unwrap();
204        assert_eq!(s[0], 1.0);
205        assert_eq!(s[1], -1.0);
206        assert_eq!(s[2], 0.0);
207    }
208
209    #[test]
210    fn test_fix() {
211        let a = arr1(vec![2.9, -2.9]);
212        let r = fix(&a).unwrap();
213        let s = r.as_slice().unwrap();
214        assert_eq!(s[0], 2.0);
215        assert_eq!(s[1], -2.0);
216    }
217
218    #[test]
219    fn test_around_alias() {
220        let a = arr1(vec![0.5, 1.5]);
221        let r = around(&a).unwrap();
222        let s = r.as_slice().unwrap();
223        assert_eq!(s[0], 0.0);
224        assert_eq!(s[1], 2.0);
225    }
226
227    #[test]
228    fn test_rint_alias() {
229        let a = arr1(vec![0.5, 1.5]);
230        let r = rint(&a).unwrap();
231        let s = r.as_slice().unwrap();
232        assert_eq!(s[0], 0.0);
233        assert_eq!(s[1], 2.0);
234    }
235
236    // ----------------------------------------------------------------------
237    // f32 sibling tests (#152) — every rounding op exercised on f32 to
238    // verify the SIMD f32 path and confirm bit-exact rounding behaviour
239    // matches the f64 path on values both representable.
240    // ----------------------------------------------------------------------
241
242    use ferray_core::Array;
243    use ferray_core::dimension::Ix1;
244
245    fn arr1_f32(data: Vec<f32>) -> Array<f32, Ix1> {
246        Array::<f32, Ix1>::from_vec(Ix1::new([data.len()]), data).unwrap()
247    }
248
249    #[test]
250    fn test_bankers_round_half_to_even_f32() {
251        let a = arr1_f32(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
252        let r = round(&a).unwrap();
253        let s = r.as_slice().unwrap();
254        assert_eq!(s[0], 0.0);
255        assert_eq!(s[1], 2.0);
256        assert_eq!(s[2], 2.0);
257        assert_eq!(s[3], 4.0);
258        assert_eq!(s[4], 0.0);
259        assert_eq!(s[5], -2.0);
260    }
261
262    #[test]
263    fn test_round_normal_f32() {
264        let a = arr1_f32(vec![1.2, 2.7, -1.3, -2.8]);
265        let r = round(&a).unwrap();
266        let s = r.as_slice().unwrap();
267        assert_eq!(s[0], 1.0);
268        assert_eq!(s[1], 3.0);
269        assert_eq!(s[2], -1.0);
270        assert_eq!(s[3], -3.0);
271    }
272
273    #[test]
274    fn test_floor_f32() {
275        let a = arr1_f32(vec![1.7, -1.7, 0.0]);
276        let r = floor(&a).unwrap();
277        let s = r.as_slice().unwrap();
278        assert_eq!(s[0], 1.0);
279        assert_eq!(s[1], -2.0);
280        assert_eq!(s[2], 0.0);
281    }
282
283    #[test]
284    fn test_ceil_f32() {
285        let a = arr1_f32(vec![1.2, -1.2, 0.0]);
286        let r = ceil(&a).unwrap();
287        let s = r.as_slice().unwrap();
288        assert_eq!(s[0], 2.0);
289        assert_eq!(s[1], -1.0);
290        assert_eq!(s[2], 0.0);
291    }
292
293    #[test]
294    fn test_trunc_f32() {
295        let a = arr1_f32(vec![1.9, -1.9, 0.0]);
296        let r = trunc(&a).unwrap();
297        let s = r.as_slice().unwrap();
298        assert_eq!(s[0], 1.0);
299        assert_eq!(s[1], -1.0);
300        assert_eq!(s[2], 0.0);
301    }
302
303    #[test]
304    fn test_fix_f32() {
305        let a = arr1_f32(vec![2.9, -2.9]);
306        let r = fix(&a).unwrap();
307        let s = r.as_slice().unwrap();
308        assert_eq!(s[0], 2.0);
309        assert_eq!(s[1], -2.0);
310    }
311
312    #[test]
313    fn test_around_alias_f32() {
314        let a = arr1_f32(vec![0.5, 1.5]);
315        let r = around(&a).unwrap();
316        let s = r.as_slice().unwrap();
317        assert_eq!(s[0], 0.0);
318        assert_eq!(s[1], 2.0);
319    }
320
321    #[test]
322    fn test_rint_alias_f32() {
323        let a = arr1_f32(vec![0.5, 1.5]);
324        let r = rint(&a).unwrap();
325        let s = r.as_slice().unwrap();
326        assert_eq!(s[0], 0.0);
327        assert_eq!(s[1], 2.0);
328    }
329}