1use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::FerrayResult;
9use num_traits::Float;
10
11use crate::helpers::unary_float_op;
12
13fn bankers_round<T: Float>(x: T) -> T {
17 let half = T::from(0.5).unwrap();
19 let two = T::from(2.0).unwrap();
20
21 let floored = x.floor();
23 let frac = x - floored;
24
25 if frac == half {
27 let ceiled = x.ceil();
29 if (floored / two).floor() * two == floored {
32 floored
33 } else {
34 ceiled
35 }
36 } else if frac == -half {
37 x.ceil()
41 } else {
42 x.round()
44 }
45}
46
47pub fn round<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
52where
53 T: Element + Float,
54 D: Dimension,
55{
56 unary_float_op(input, bankers_round)
57}
58
59pub fn around<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
61where
62 T: Element + Float,
63 D: Dimension,
64{
65 round(input)
66}
67
68pub fn rint<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
70where
71 T: Element + Float,
72 D: Dimension,
73{
74 round(input)
75}
76
77pub fn floor<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
79where
80 T: Element + Float,
81 D: Dimension,
82{
83 unary_float_op(input, T::floor)
84}
85
86pub fn ceil<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
88where
89 T: Element + Float,
90 D: Dimension,
91{
92 unary_float_op(input, T::ceil)
93}
94
95pub fn trunc<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
97where
98 T: Element + Float,
99 D: Dimension,
100{
101 unary_float_op(input, T::trunc)
102}
103
104pub fn fix<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
106where
107 T: Element + Float,
108 D: Dimension,
109{
110 trunc(input)
111}
112
113use crate::helpers::unary_f16_fn;
119
120unary_f16_fn!(
121 #[cfg(feature = "f16")]
123 floor_f16,
124 f32::floor
125);
126unary_f16_fn!(
127 #[cfg(feature = "f16")]
129 ceil_f16,
130 f32::ceil
131);
132unary_f16_fn!(
133 #[cfg(feature = "f16")]
135 trunc_f16,
136 f32::trunc
137);
138unary_f16_fn!(
139 #[cfg(feature = "f16")]
144 round_f16,
145 bankers_round::<f32>
146);
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 use crate::test_util::arr1;
153
154 #[test]
155 fn test_bankers_round_half_to_even_ac9() {
156 let a = arr1(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
158 let r = round(&a).unwrap();
159 let s = r.as_slice().unwrap();
160 assert_eq!(s[0], 0.0); assert_eq!(s[1], 2.0); assert_eq!(s[2], 2.0); assert_eq!(s[3], 4.0); assert_eq!(s[4], 0.0); assert_eq!(s[5], -2.0); }
167
168 #[test]
169 fn test_round_normal() {
170 let a = arr1(vec![1.2, 2.7, -1.3, -2.8]);
171 let r = round(&a).unwrap();
172 let s = r.as_slice().unwrap();
173 assert_eq!(s[0], 1.0);
174 assert_eq!(s[1], 3.0);
175 assert_eq!(s[2], -1.0);
176 assert_eq!(s[3], -3.0);
177 }
178
179 #[test]
180 fn test_floor() {
181 let a = arr1(vec![1.7, -1.7, 0.0]);
182 let r = floor(&a).unwrap();
183 let s = r.as_slice().unwrap();
184 assert_eq!(s[0], 1.0);
185 assert_eq!(s[1], -2.0);
186 assert_eq!(s[2], 0.0);
187 }
188
189 #[test]
190 fn test_ceil() {
191 let a = arr1(vec![1.2, -1.2, 0.0]);
192 let r = ceil(&a).unwrap();
193 let s = r.as_slice().unwrap();
194 assert_eq!(s[0], 2.0);
195 assert_eq!(s[1], -1.0);
196 assert_eq!(s[2], 0.0);
197 }
198
199 #[test]
200 fn test_trunc() {
201 let a = arr1(vec![1.9, -1.9, 0.0]);
202 let r = trunc(&a).unwrap();
203 let s = r.as_slice().unwrap();
204 assert_eq!(s[0], 1.0);
205 assert_eq!(s[1], -1.0);
206 assert_eq!(s[2], 0.0);
207 }
208
209 #[test]
210 fn test_fix() {
211 let a = arr1(vec![2.9, -2.9]);
212 let r = fix(&a).unwrap();
213 let s = r.as_slice().unwrap();
214 assert_eq!(s[0], 2.0);
215 assert_eq!(s[1], -2.0);
216 }
217
218 #[test]
219 fn test_around_alias() {
220 let a = arr1(vec![0.5, 1.5]);
221 let r = around(&a).unwrap();
222 let s = r.as_slice().unwrap();
223 assert_eq!(s[0], 0.0);
224 assert_eq!(s[1], 2.0);
225 }
226
227 #[test]
228 fn test_rint_alias() {
229 let a = arr1(vec![0.5, 1.5]);
230 let r = rint(&a).unwrap();
231 let s = r.as_slice().unwrap();
232 assert_eq!(s[0], 0.0);
233 assert_eq!(s[1], 2.0);
234 }
235
236 use ferray_core::Array;
243 use ferray_core::dimension::Ix1;
244
245 fn arr1_f32(data: Vec<f32>) -> Array<f32, Ix1> {
246 Array::<f32, Ix1>::from_vec(Ix1::new([data.len()]), data).unwrap()
247 }
248
249 #[test]
250 fn test_bankers_round_half_to_even_f32() {
251 let a = arr1_f32(vec![0.5, 1.5, 2.5, 3.5, -0.5, -1.5]);
252 let r = round(&a).unwrap();
253 let s = r.as_slice().unwrap();
254 assert_eq!(s[0], 0.0);
255 assert_eq!(s[1], 2.0);
256 assert_eq!(s[2], 2.0);
257 assert_eq!(s[3], 4.0);
258 assert_eq!(s[4], 0.0);
259 assert_eq!(s[5], -2.0);
260 }
261
262 #[test]
263 fn test_round_normal_f32() {
264 let a = arr1_f32(vec![1.2, 2.7, -1.3, -2.8]);
265 let r = round(&a).unwrap();
266 let s = r.as_slice().unwrap();
267 assert_eq!(s[0], 1.0);
268 assert_eq!(s[1], 3.0);
269 assert_eq!(s[2], -1.0);
270 assert_eq!(s[3], -3.0);
271 }
272
273 #[test]
274 fn test_floor_f32() {
275 let a = arr1_f32(vec![1.7, -1.7, 0.0]);
276 let r = floor(&a).unwrap();
277 let s = r.as_slice().unwrap();
278 assert_eq!(s[0], 1.0);
279 assert_eq!(s[1], -2.0);
280 assert_eq!(s[2], 0.0);
281 }
282
283 #[test]
284 fn test_ceil_f32() {
285 let a = arr1_f32(vec![1.2, -1.2, 0.0]);
286 let r = ceil(&a).unwrap();
287 let s = r.as_slice().unwrap();
288 assert_eq!(s[0], 2.0);
289 assert_eq!(s[1], -1.0);
290 assert_eq!(s[2], 0.0);
291 }
292
293 #[test]
294 fn test_trunc_f32() {
295 let a = arr1_f32(vec![1.9, -1.9, 0.0]);
296 let r = trunc(&a).unwrap();
297 let s = r.as_slice().unwrap();
298 assert_eq!(s[0], 1.0);
299 assert_eq!(s[1], -1.0);
300 assert_eq!(s[2], 0.0);
301 }
302
303 #[test]
304 fn test_fix_f32() {
305 let a = arr1_f32(vec![2.9, -2.9]);
306 let r = fix(&a).unwrap();
307 let s = r.as_slice().unwrap();
308 assert_eq!(s[0], 2.0);
309 assert_eq!(s[1], -2.0);
310 }
311
312 #[test]
313 fn test_around_alias_f32() {
314 let a = arr1_f32(vec![0.5, 1.5]);
315 let r = around(&a).unwrap();
316 let s = r.as_slice().unwrap();
317 assert_eq!(s[0], 0.0);
318 assert_eq!(s[1], 2.0);
319 }
320
321 #[test]
322 fn test_rint_alias_f32() {
323 let a = arr1_f32(vec![0.5, 1.5]);
324 let r = rint(&a).unwrap();
325 let s = r.as_slice().unwrap();
326 assert_eq!(s[0], 0.0);
327 assert_eq!(s[1], 2.0);
328 }
329}