use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use num_traits::Float;
use crate::helpers::unary_float_op;
fn bankers_round<T: Float>(x: T) -> T {
let half = T::from(0.5).unwrap();
let two = T::from(2.0).unwrap();
let floored = x.floor();
let frac = x - floored;
if frac == half {
let ceiled = x.ceil();
if (floored / two).floor() * two == floored {
floored
} else {
ceiled
}
} else if frac == -half {
x.ceil()
} else {
x.round()
}
}
pub fn round<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
unary_float_op(input, bankers_round)
}
pub fn around<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
round(input)
}
pub fn rint<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
round(input)
}
pub fn floor<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
unary_float_op(input, T::floor)
}
pub fn ceil<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
unary_float_op(input, T::ceil)
}
pub fn trunc<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
unary_float_op(input, T::trunc)
}
pub fn fix<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
where
T: Element + Float,
D: Dimension,
{
trunc(input)
}
use crate::helpers::unary_f16_fn;
unary_f16_fn!(
#[cfg(feature = "f16")]
floor_f16,
f32::floor
);
unary_f16_fn!(
#[cfg(feature = "f16")]
ceil_f16,
f32::ceil
);
unary_f16_fn!(
#[cfg(feature = "f16")]
trunc_f16,
f32::trunc
);
unary_f16_fn!(
#[cfg(feature = "f16")]
round_f16,
bankers_round::<f32>
);
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::arr1;
#[test]
fn test_bankers_round_half_to_even_ac9() {
let a = arr1(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
let r = round(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 0.0); assert_eq!(s[1], 2.0); assert_eq!(s[2], 2.0); assert_eq!(s[3], 4.0); assert_eq!(s[4], 0.0); assert_eq!(s[5], -2.0); }
#[test]
fn test_round_normal() {
let a = arr1(vec![1.2, 2.7, -1.3, -2.8]);
let r = round(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 1.0);
assert_eq!(s[1], 3.0);
assert_eq!(s[2], -1.0);
assert_eq!(s[3], -3.0);
}
#[test]
fn test_floor() {
let a = arr1(vec![1.7, -1.7, 0.0]);
let r = floor(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 1.0);
assert_eq!(s[1], -2.0);
assert_eq!(s[2], 0.0);
}
#[test]
fn test_ceil() {
let a = arr1(vec![1.2, -1.2, 0.0]);
let r = ceil(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 2.0);
assert_eq!(s[1], -1.0);
assert_eq!(s[2], 0.0);
}
#[test]
fn test_trunc() {
let a = arr1(vec![1.9, -1.9, 0.0]);
let r = trunc(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 1.0);
assert_eq!(s[1], -1.0);
assert_eq!(s[2], 0.0);
}
#[test]
fn test_fix() {
let a = arr1(vec![2.9, -2.9]);
let r = fix(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 2.0);
assert_eq!(s[1], -2.0);
}
#[test]
fn test_around_alias() {
let a = arr1(vec![0.5, 1.5]);
let r = around(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 0.0);
assert_eq!(s[1], 2.0);
}
#[test]
fn test_rint_alias() {
let a = arr1(vec![0.5, 1.5]);
let r = rint(&a).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s[0], 0.0);
assert_eq!(s[1], 2.0);
}
}