Skip to main content

ferray_ufunc/ops/
rounding.rs

1//! ferray-ufunc: Rounding functions
2//!
3//! round (banker's rounding!), floor, ceil, trunc, fix, rint, around
4//!
5//! ## REQ status — REQ-24 floor/ceil/trunc/round integer-input identity
6//!
7//! SHIPPED: `floor_int`/`ceil_int`/`trunc_int`/`round_int`/`fix_int`/
8//! `around_int` accept `T: Element + Copy` (any int/bool) and return the input
9//! UNCHANGED in the INPUT-int-dtype via `round_identity` — `np.floor(int32)→
10//! int32`, `np.round(int64)→int64`. Mirrors NumPy registering `TD(bints)`
11//! FIRST for floor/ceil/trunc (floor `generate_umath.py:1011`, ceil `:983`,
12//! trunc `:993`). BOOL EXCEPTION for round/around: `floor`/`ceil`/`trunc`/`fix`
13//! keep bool as bool (those `TD(bints)` loops cover bool), but `round`/`around`
14//! on bool PROMOTE to **float16** like `rint` — `round`/`around` dispatch to
15//! `ndarray.round` whose bool kernel has no `TD(bints)`
16//! (`generate_umath.py:1021`, same as `rint`), so `np.round(bool)→float16`
17//! `[1.0,0.0,1.0]`. That bool path is `round_bool`/`around_bool` here, routing
18//! through `rint_promote` (`PromoteFloat`, bool→f16). int8..uint64 round/around
19//! stay input-dtype identity. This is the OPPOSITE of `rint`, which has NO
20//! `TD(bints)` for ANY integer and promotes every integer input to float
21//! (REQ-23) — served by `rint_promote` in `promoted.rs`. The float
22//! `floor`/`ceil`/`trunc`/`round`/`rint` (`T: Float`) are untouched (f32/f64
23//! byte-identical). Consumer: the `*_int`/`*_bool` entries are the
24//! integer/bool-rounding public surface re-exported from `lib.rs`. Verified by
25//! `tests/divergence_unary_promote.rs` and `divergence_unary_promote_audit.rs`.
26
27use ferray_core::Array;
28use ferray_core::dimension::Dimension;
29use ferray_core::dtype::Element;
30use ferray_core::error::FerrayResult;
31use num_traits::Float;
32
33use crate::helpers::unary_float_op;
34
35/// Banker's rounding: round half to even (AC-9).
36///
37/// `round(0.5) == 0`, `round(1.5) == 2`, `round(2.5) == 2`.
38fn bankers_round<T: Float>(x: T) -> T {
39    // Check if x is exactly at a .5 boundary
40    let half = T::from(0.5).unwrap();
41    let two = T::from(2.0).unwrap();
42
43    // Get the fractional part: x - floor(x)
44    let floored = x.floor();
45    let frac = x - floored;
46
47    // Check if fractional part is exactly 0.5
48    if frac == half {
49        // At exact .5 -- round to even
50        let ceiled = x.ceil();
51        // Check which of floor/ceil is even
52        // A number is even if dividing by 2 and flooring gives back the same
53        if (floored / two).floor() * two == floored {
54            floored
55        } else {
56            ceiled
57        }
58    } else if frac == -half {
59        // Negative half case: x is negative, frac = x - floor(x) could be 0.5 for negatives
60        // Actually for negative numbers like -0.5: floor(-0.5) = -1, frac = -0.5 - (-1) = 0.5
61        // So the above branch handles it. This branch is for safety.
62        x.ceil()
63    } else {
64        // Not at a .5 boundary, standard rounding is fine
65        x.round()
66    }
67}
68
69/// Elementwise banker's rounding (round half to even).
70///
71/// This matches `NumPy`'s `np.round` / `np.around` behavior.
72/// AC-9: `round(0.5)==0`, `round(1.5)==2`.
73pub fn round<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
74where
75    T: Element + Float,
76    D: Dimension,
77{
78    unary_float_op(input, bankers_round)
79}
80
81/// Alias for [`round`] -- matches `NumPy`'s `around`.
82pub fn around<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
83where
84    T: Element + Float,
85    D: Dimension,
86{
87    round(input)
88}
89
90/// Alias for [`round`] -- matches `NumPy`'s `rint`.
91pub fn rint<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
92where
93    T: Element + Float,
94    D: Dimension,
95{
96    round(input)
97}
98
99/// Elementwise floor (round toward negative infinity).
100pub fn floor<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
101where
102    T: Element + Float,
103    D: Dimension,
104{
105    unary_float_op(input, T::floor)
106}
107
108/// Elementwise ceiling (round toward positive infinity).
109pub fn ceil<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
110where
111    T: Element + Float,
112    D: Dimension,
113{
114    unary_float_op(input, T::ceil)
115}
116
117/// Elementwise truncation (round toward zero).
118pub fn trunc<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
119where
120    T: Element + Float,
121    D: Dimension,
122{
123    unary_float_op(input, T::trunc)
124}
125
126/// Elementwise fix: round toward zero (same as trunc for real numbers).
127pub fn fix<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
128where
129    T: Element + Float,
130    D: Dimension,
131{
132    trunc(input)
133}
134
135// ---------------------------------------------------------------------------
136// Integer-input rounding (REQ-24): floor/ceil/trunc/round/fix/around are
137// INT-IDENTITY on integer/bool input.
138//
139// NumPy registers `TD(bints)` FIRST for floor/ceil/trunc (and the rounding
140// resolver routes round/around/fix the same way), so an integer or bool
141// input array selects the integer loop and is returned UNCHANGED, in the
142// INPUT dtype — `np.floor(int64)->int64`, `np.ceil(int32)->int32`,
143// `np.floor(bool)->bool`
144// (numpy/_core/code_generators/generate_umath.py:983 ceil, :993 trunc,
145// :1011 floor — `TD(bints)` is the first registered loop). This is the
146// OPPOSITE of `rint`, which has NO `TD(bints)` (`generate_umath.py:1021`)
147// and promotes integer input to float (REQ-23, `rint_promote`). ferray's
148// float `floor`/`ceil`/`trunc`/`round` above are `T: Float` and reject
149// integer input at compile time, so these `*_int` siblings carry the
150// integer contract; the float entry points are untouched (f32/f64
151// byte-identical).
152//
153// `T: Element + Copy` (no `Float`) accepts every integer and bool element
154// type. The op is a pure identity copy preserving shape, layout, and dtype.
155// ---------------------------------------------------------------------------
156
157/// Identity copy preserving the input dtype/shape — the shared kernel for
158/// the integer-input rounding ops (REQ-24).
159#[inline]
160fn round_identity<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
161where
162    T: Element + Copy,
163    D: Dimension,
164{
165    let data: Vec<T> = input.iter().copied().collect();
166    Array::from_vec(input.dim().clone(), data)
167}
168
169/// `floor` on integer/bool input: returns the values unchanged in the input
170/// dtype (REQ-24, `TD(bints)` first). `np.floor(int64 [1,2,4]) == [1,2,4]`
171/// with `.dtype == int64`; `np.floor(bool [T,F]).dtype == bool`.
172pub fn floor_int<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
173where
174    T: Element + Copy,
175    D: Dimension,
176{
177    round_identity(input)
178}
179
180/// `ceil` on integer/bool input: int-identity (REQ-24).
181/// `np.ceil(int32 [5]).dtype == int32`.
182pub fn ceil_int<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
183where
184    T: Element + Copy,
185    D: Dimension,
186{
187    round_identity(input)
188}
189
190/// `trunc` on integer/bool input: int-identity (REQ-24).
191pub fn trunc_int<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
192where
193    T: Element + Copy,
194    D: Dimension,
195{
196    round_identity(input)
197}
198
199/// `round` on integer/bool input: int-identity (REQ-24). Unlike `rint`
200/// (REQ-23 float-promote), `round`/`around` keep integer input integer.
201pub fn round_int<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
202where
203    T: Element + Copy,
204    D: Dimension,
205{
206    round_identity(input)
207}
208
209/// `around` (alias of `round`) on integer/bool input: int-identity (REQ-24).
210pub fn around_int<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
211where
212    T: Element + Copy,
213    D: Dimension,
214{
215    round_identity(input)
216}
217
218// ---------------------------------------------------------------------------
219// bool-input round/around: promote to float16 (NOT bool identity).
220//
221// Unlike floor/ceil/trunc/fix (which register `TD(bints)` FIRST, so bool stays
222// bool — handled by the `*_int` identity above), `round`/`around` dispatch to
223// `ndarray.round()` whose bool kernel promotes to float16 — the SAME promotion
224// the `rint` ufunc applies (`rint` registers `TD('e', f='rint')` FIRST with NO
225// `TD(bints)`, `generate_umath.py:1021`). Live numpy 2.4.5:
226//   np.round(np.array([True,False,True])).dtype  == float16
227//   np.round(np.array([True,False,True])).tolist() == [1.0, 0.0, 1.0]
228// So bool round/around route through the existing `PromoteFloat` machinery
229// (`crate::rint_promote`, bool->float16) rather than `round_identity`. The
230// round-half-to-even of the promoted 0.0/1.0 values is the identity, so the
231// f16 output equals numpy's `[1.0, 0.0, 1.0]`. The non-bool integer dtypes
232// (int8..uint64) keep their input-dtype identity via `round_int`/`around_int`.
233// ---------------------------------------------------------------------------
234
235/// `round` on bool input: promotes to float16 (REQ-24, bool exception).
236///
237/// `np.round(np.array([True,False,True]))` is `float16 [1.0,0.0,1.0]` — bool
238/// round has no `TD(bints)` loop and promotes like `rint`
239/// (`generate_umath.py:1021`), unlike `floor`/`ceil`/`trunc` which keep bool.
240#[cfg(feature = "f16")]
241pub fn round_bool<D>(input: &Array<bool, D>) -> FerrayResult<Array<half::f16, D>>
242where
243    D: Dimension,
244{
245    crate::rint_promote(input)
246}
247
248/// `around` (alias of `round`) on bool input: promotes to float16 (REQ-24,
249/// bool exception). See [`round_bool`].
250#[cfg(feature = "f16")]
251pub fn around_bool<D>(input: &Array<bool, D>) -> FerrayResult<Array<half::f16, D>>
252where
253    D: Dimension,
254{
255    round_bool(input)
256}
257
258/// `fix` (alias of `trunc`) on integer/bool input: int-identity (REQ-24).
259pub fn fix_int<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
260where
261    T: Element + Copy,
262    D: Dimension,
263{
264    round_identity(input)
265}
266
267// ---------------------------------------------------------------------------
268// f16 variants (f32-promoted) — generated via the shared unary_f16_fn!
269// macro (#142).
270// ---------------------------------------------------------------------------
271
272use crate::helpers::unary_f16_fn;
273
274unary_f16_fn!(
275    /// Elementwise floor for f16 arrays via f32 promotion.
276    #[cfg(feature = "f16")]
277    floor_f16,
278    f32::floor
279);
280unary_f16_fn!(
281    /// Elementwise ceiling for f16 arrays via f32 promotion.
282    #[cfg(feature = "f16")]
283    ceil_f16,
284    f32::ceil
285);
286unary_f16_fn!(
287    /// Elementwise truncation for f16 arrays via f32 promotion.
288    #[cfg(feature = "f16")]
289    trunc_f16,
290    f32::trunc
291);
292unary_f16_fn!(
293    /// Elementwise banker's rounding for f16 arrays via f32 promotion.
294    ///
295    /// Reuses the generic [`bankers_round`] via monomorphization on
296    /// `f32`; the hand-rolled f32 copy was deleted in #144.
297    #[cfg(feature = "f16")]
298    round_f16,
299    bankers_round::<f32>
300);
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    use crate::test_util::arr1;
307
308    #[test]
309    fn test_bankers_round_half_to_even_ac9() {
310        // AC-9: round(0.5)==0, round(1.5)==2
311        let a = arr1(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
312        let r = round(&a).unwrap();
313        let s = r.as_slice().unwrap();
314        assert_eq!(s[0], 0.0); // 0.5 -> 0 (even)
315        assert_eq!(s[1], 2.0); // 1.5 -> 2 (even)
316        assert_eq!(s[2], 2.0); // 2.5 -> 2 (even)
317        assert_eq!(s[3], 4.0); // 3.5 -> 4 (even)
318        assert_eq!(s[4], 0.0); // -0.5 -> 0 (even)
319        assert_eq!(s[5], -2.0); // -1.5 -> -2 (even)
320    }
321
322    #[test]
323    fn test_round_normal() {
324        let a = arr1(vec![1.2, 2.7, -1.3, -2.8]);
325        let r = round(&a).unwrap();
326        let s = r.as_slice().unwrap();
327        assert_eq!(s[0], 1.0);
328        assert_eq!(s[1], 3.0);
329        assert_eq!(s[2], -1.0);
330        assert_eq!(s[3], -3.0);
331    }
332
333    #[test]
334    fn test_floor() {
335        let a = arr1(vec![1.7, -1.7, 0.0]);
336        let r = floor(&a).unwrap();
337        let s = r.as_slice().unwrap();
338        assert_eq!(s[0], 1.0);
339        assert_eq!(s[1], -2.0);
340        assert_eq!(s[2], 0.0);
341    }
342
343    #[test]
344    fn test_ceil() {
345        let a = arr1(vec![1.2, -1.2, 0.0]);
346        let r = ceil(&a).unwrap();
347        let s = r.as_slice().unwrap();
348        assert_eq!(s[0], 2.0);
349        assert_eq!(s[1], -1.0);
350        assert_eq!(s[2], 0.0);
351    }
352
353    #[test]
354    fn test_trunc() {
355        let a = arr1(vec![1.9, -1.9, 0.0]);
356        let r = trunc(&a).unwrap();
357        let s = r.as_slice().unwrap();
358        assert_eq!(s[0], 1.0);
359        assert_eq!(s[1], -1.0);
360        assert_eq!(s[2], 0.0);
361    }
362
363    #[test]
364    fn test_fix() {
365        let a = arr1(vec![2.9, -2.9]);
366        let r = fix(&a).unwrap();
367        let s = r.as_slice().unwrap();
368        assert_eq!(s[0], 2.0);
369        assert_eq!(s[1], -2.0);
370    }
371
372    #[test]
373    fn test_around_alias() {
374        let a = arr1(vec![0.5, 1.5]);
375        let r = around(&a).unwrap();
376        let s = r.as_slice().unwrap();
377        assert_eq!(s[0], 0.0);
378        assert_eq!(s[1], 2.0);
379    }
380
381    #[test]
382    fn test_rint_alias() {
383        let a = arr1(vec![0.5, 1.5]);
384        let r = rint(&a).unwrap();
385        let s = r.as_slice().unwrap();
386        assert_eq!(s[0], 0.0);
387        assert_eq!(s[1], 2.0);
388    }
389
390    // ----------------------------------------------------------------------
391    // f32 sibling tests (#152) — every rounding op exercised on f32 to
392    // verify the SIMD f32 path and confirm bit-exact rounding behaviour
393    // matches the f64 path on values both representable.
394    // ----------------------------------------------------------------------
395
396    use ferray_core::Array;
397    use ferray_core::dimension::Ix1;
398
399    fn arr1_f32(data: Vec<f32>) -> Array<f32, Ix1> {
400        Array::<f32, Ix1>::from_vec(Ix1::new([data.len()]), data).unwrap()
401    }
402
403    #[test]
404    fn test_bankers_round_half_to_even_f32() {
405        let a = arr1_f32(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
406        let r = round(&a).unwrap();
407        let s = r.as_slice().unwrap();
408        assert_eq!(s[0], 0.0);
409        assert_eq!(s[1], 2.0);
410        assert_eq!(s[2], 2.0);
411        assert_eq!(s[3], 4.0);
412        assert_eq!(s[4], 0.0);
413        assert_eq!(s[5], -2.0);
414    }
415
416    #[test]
417    fn test_round_normal_f32() {
418        let a = arr1_f32(vec![1.2, 2.7, -1.3, -2.8]);
419        let r = round(&a).unwrap();
420        let s = r.as_slice().unwrap();
421        assert_eq!(s[0], 1.0);
422        assert_eq!(s[1], 3.0);
423        assert_eq!(s[2], -1.0);
424        assert_eq!(s[3], -3.0);
425    }
426
427    #[test]
428    fn test_floor_f32() {
429        let a = arr1_f32(vec![1.7, -1.7, 0.0]);
430        let r = floor(&a).unwrap();
431        let s = r.as_slice().unwrap();
432        assert_eq!(s[0], 1.0);
433        assert_eq!(s[1], -2.0);
434        assert_eq!(s[2], 0.0);
435    }
436
437    #[test]
438    fn test_ceil_f32() {
439        let a = arr1_f32(vec![1.2, -1.2, 0.0]);
440        let r = ceil(&a).unwrap();
441        let s = r.as_slice().unwrap();
442        assert_eq!(s[0], 2.0);
443        assert_eq!(s[1], -1.0);
444        assert_eq!(s[2], 0.0);
445    }
446
447    #[test]
448    fn test_trunc_f32() {
449        let a = arr1_f32(vec![1.9, -1.9, 0.0]);
450        let r = trunc(&a).unwrap();
451        let s = r.as_slice().unwrap();
452        assert_eq!(s[0], 1.0);
453        assert_eq!(s[1], -1.0);
454        assert_eq!(s[2], 0.0);
455    }
456
457    #[test]
458    fn test_fix_f32() {
459        let a = arr1_f32(vec![2.9, -2.9]);
460        let r = fix(&a).unwrap();
461        let s = r.as_slice().unwrap();
462        assert_eq!(s[0], 2.0);
463        assert_eq!(s[1], -2.0);
464    }
465
466    #[test]
467    fn test_around_alias_f32() {
468        let a = arr1_f32(vec![0.5, 1.5]);
469        let r = around(&a).unwrap();
470        let s = r.as_slice().unwrap();
471        assert_eq!(s[0], 0.0);
472        assert_eq!(s[1], 2.0);
473    }
474
475    #[test]
476    fn test_rint_alias_f32() {
477        let a = arr1_f32(vec![0.5, 1.5]);
478        let r = rint(&a).unwrap();
479        let s = r.as_slice().unwrap();
480        assert_eq!(s[0], 0.0);
481        assert_eq!(s[1], 2.0);
482    }
483}