polars_arrow_rvsry99dx/compute/kernels/
arithmetic.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines basic arithmetic kernels for `PrimitiveArrays`.
19//!
20//! These kernels can leverage SIMD if available on your system.  Currently no runtime
21//! detection is provided, you should enable the specific SIMD intrinsics using
22//! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
23//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
24
25#[cfg(feature = "simd")]
26use std::mem;
27use std::ops::{Add, Div, Mul, Sub};
28#[cfg(feature = "simd")]
29use std::slice::from_raw_parts_mut;
30use std::sync::Arc;
31
32use num::{One, Zero};
33
34use crate::array::*;
35#[cfg(feature = "simd")]
36use crate::bitmap::Bitmap;
37use crate::buffer::Buffer;
38#[cfg(feature = "simd")]
39use crate::buffer::MutableBuffer;
40use crate::compute::util::apply_bin_op_to_option_bitmap;
41#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
42use crate::compute::util::simd_load_set_invalid;
43use crate::datatypes;
44use crate::datatypes::ToByteSlice;
45use crate::error::{ArrowError, Result};
46use crate::util::bit_util;
47
48/// Helper function to perform math lambda function on values from two arrays. If either
49/// left or right value is null then the output value is also null, so `1 + null` is
50/// `null`.
51pub fn math_op<T, F>(
52    left: &PrimitiveArray<T>,
53    right: &PrimitiveArray<T>,
54    op: F,
55) -> Result<PrimitiveArray<T>>
56where
57    T: datatypes::ArrowNumericType,
58    F: Fn(T::Native, T::Native) -> Result<T::Native>,
59{
60    if left.len() != right.len() {
61        return Err(ArrowError::ComputeError(
62            "Cannot perform math operation on arrays of different length".to_string(),
63        ));
64    }
65
66    let null_bit_buffer = apply_bin_op_to_option_bitmap(
67        left.data().null_bitmap(),
68        right.data().null_bitmap(),
69        |a, b| a & b,
70    )?;
71
72    let mut values = Vec::with_capacity(left.len());
73    if let Some(b) = &null_bit_buffer {
74        for i in 0..left.len() {
75            unsafe {
76                if bit_util::get_bit_raw(b.raw_data(), i) {
77                    values.push(op(left.value(i), right.value(i))?);
78                } else {
79                    values.push(T::default_value())
80                }
81            }
82        }
83    } else {
84        for i in 0..left.len() {
85            values.push(op(left.value(i), right.value(i))?);
86        }
87    }
88
89    let data = ArrayData::new(
90        T::get_data_type(),
91        left.len(),
92        None,
93        null_bit_buffer,
94        left.offset(),
95        vec![Buffer::from(values.to_byte_slice())],
96        vec![],
97    );
98    Ok(PrimitiveArray::<T>::from(Arc::new(data)))
99}
100
101/// SIMD vectorized version of `math_op` above.
102#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
103fn simd_math_op<T, F>(
104    left: &PrimitiveArray<T>,
105    right: &PrimitiveArray<T>,
106    op: F,
107) -> Result<PrimitiveArray<T>>
108where
109    T: datatypes::ArrowNumericType,
110    T::Simd: Add<Output = T::Simd>
111        + Sub<Output = T::Simd>
112        + Mul<Output = T::Simd>
113        + Div<Output = T::Simd>,
114    F: Fn(T::Simd, T::Simd) -> T::Simd,
115{
116    if left.len() != right.len() {
117        return Err(ArrowError::ComputeError(
118            "Cannot perform math operation on arrays of different length".to_string(),
119        ));
120    }
121
122    let null_bit_buffer = apply_bin_op_to_option_bitmap(
123        left.data().null_bitmap(),
124        right.data().null_bitmap(),
125        |a, b| a & b,
126    )?;
127
128    let lanes = T::lanes();
129    let buffer_size = left.len() * mem::size_of::<T::Native>();
130    let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
131
132    for i in (0..left.len()).step_by(lanes) {
133        let simd_left = T::load(left.value_slice(i, lanes));
134        let simd_right = T::load(right.value_slice(i, lanes));
135        let simd_result = T::bin_op(simd_left, simd_right, &op);
136
137        let result_slice: &mut [T::Native] = unsafe {
138            from_raw_parts_mut(
139                (result.data_mut().as_mut_ptr() as *mut T::Native).add(i),
140                lanes,
141            )
142        };
143        T::write(simd_result, result_slice);
144    }
145
146    let data = ArrayData::new(
147        T::get_data_type(),
148        left.len(),
149        None,
150        null_bit_buffer,
151        left.offset(),
152        vec![result.freeze()],
153        vec![],
154    );
155    Ok(PrimitiveArray::<T>::from(Arc::new(data)))
156}
157
158/// SIMD vectorized version of `divide`, the divide kernel needs it's own implementation as there
159/// is a need to handle situations where a divide by `0` occurs.  This is complicated by `NULL`
160/// slots and padding.
161#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
162fn simd_divide<T>(
163    left: &PrimitiveArray<T>,
164    right: &PrimitiveArray<T>,
165) -> Result<PrimitiveArray<T>>
166where
167    T: datatypes::ArrowNumericType,
168    T::Native: One + Zero,
169{
170    if left.len() != right.len() {
171        return Err(ArrowError::ComputeError(
172            "Cannot perform math operation on arrays of different length".to_string(),
173        ));
174    }
175
176    // Create the combined `Bitmap`
177    let null_bit_buffer = apply_bin_op_to_option_bitmap(
178        left.data().null_bitmap(),
179        right.data().null_bitmap(),
180        |a, b| a & b,
181    )?;
182    let bitmap = null_bit_buffer.map(Bitmap::from);
183
184    let lanes = T::lanes();
185    let buffer_size = left.len() * mem::size_of::<T::Native>();
186    let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
187
188    for i in (0..left.len()).step_by(lanes) {
189        let right_no_invalid_zeros =
190            unsafe { simd_load_set_invalid(right, &bitmap, i, lanes, T::Native::one()) };
191        let is_zero = T::eq(T::init(T::Native::zero()), right_no_invalid_zeros);
192        if T::mask_any(is_zero) {
193            return Err(ArrowError::DivideByZero);
194        }
195        let right_no_invalid_zeros =
196            unsafe { simd_load_set_invalid(right, &bitmap, i, lanes, T::Native::one()) };
197        let simd_left = T::load(left.value_slice(i, lanes));
198        let simd_result = T::bin_op(simd_left, right_no_invalid_zeros, |a, b| a / b);
199
200        let result_slice: &mut [T::Native] = unsafe {
201            from_raw_parts_mut(
202                (result.data_mut().as_mut_ptr() as *mut T::Native).add(i),
203                lanes,
204            )
205        };
206        T::write(simd_result, result_slice);
207    }
208
209    let null_bit_buffer = bitmap.map(|b| b.bits);
210
211    let data = ArrayData::new(
212        T::get_data_type(),
213        left.len(),
214        None,
215        null_bit_buffer,
216        left.offset(),
217        vec![result.freeze()],
218        vec![],
219    );
220    Ok(PrimitiveArray::<T>::from(Arc::new(data)))
221}
222
223/// Perform `left + right` operation on two arrays. If either left or right value is null
224/// then the result is also null.
225pub fn add<T>(
226    left: &PrimitiveArray<T>,
227    right: &PrimitiveArray<T>,
228) -> Result<PrimitiveArray<T>>
229where
230    T: datatypes::ArrowNumericType,
231    T::Native: Add<Output = T::Native>
232        + Sub<Output = T::Native>
233        + Mul<Output = T::Native>
234        + Div<Output = T::Native>
235        + Zero,
236{
237    #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
238    return simd_math_op(&left, &right, |a, b| a + b);
239
240    #[allow(unreachable_code)]
241    math_op(left, right, |a, b| Ok(a + b))
242}
243
244/// Perform `left - right` operation on two arrays. If either left or right value is null
245/// then the result is also null.
246pub fn subtract<T>(
247    left: &PrimitiveArray<T>,
248    right: &PrimitiveArray<T>,
249) -> Result<PrimitiveArray<T>>
250where
251    T: datatypes::ArrowNumericType,
252    T::Native: Add<Output = T::Native>
253        + Sub<Output = T::Native>
254        + Mul<Output = T::Native>
255        + Div<Output = T::Native>
256        + Zero,
257{
258    #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
259    return simd_math_op(&left, &right, |a, b| a - b);
260
261    #[allow(unreachable_code)]
262    math_op(left, right, |a, b| Ok(a - b))
263}
264
265/// Perform `left * right` operation on two arrays. If either left or right value is null
266/// then the result is also null.
267pub fn multiply<T>(
268    left: &PrimitiveArray<T>,
269    right: &PrimitiveArray<T>,
270) -> Result<PrimitiveArray<T>>
271where
272    T: datatypes::ArrowNumericType,
273    T::Native: Add<Output = T::Native>
274        + Sub<Output = T::Native>
275        + Mul<Output = T::Native>
276        + Div<Output = T::Native>
277        + Zero,
278{
279    #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
280    return simd_math_op(&left, &right, |a, b| a * b);
281
282    #[allow(unreachable_code)]
283    math_op(left, right, |a, b| Ok(a * b))
284}
285
286/// Perform `left / right` operation on two arrays. If either left or right value is null
287/// then the result is also null. If any right hand value is zero then the result of this
288/// operation will be `Err(ArrowError::DivideByZero)`.
289pub fn divide<T>(
290    left: &PrimitiveArray<T>,
291    right: &PrimitiveArray<T>,
292) -> Result<PrimitiveArray<T>>
293where
294    T: datatypes::ArrowNumericType,
295    T::Native: Add<Output = T::Native>
296        + Sub<Output = T::Native>
297        + Mul<Output = T::Native>
298        + Div<Output = T::Native>
299        + Zero
300        + One,
301{
302    #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
303    return simd_divide(&left, &right);
304
305    #[allow(unreachable_code)]
306    math_op(left, right, |a, b| {
307        if b.is_zero() {
308            Err(ArrowError::DivideByZero)
309        } else {
310            Ok(a / b)
311        }
312    })
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::array::Int32Array;
319
320    #[test]
321    fn test_primitive_array_add() {
322        let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
323        let b = Int32Array::from(vec![6, 7, 8, 9, 8]);
324        let c = add(&a, &b).unwrap();
325        assert_eq!(11, c.value(0));
326        assert_eq!(13, c.value(1));
327        assert_eq!(15, c.value(2));
328        assert_eq!(17, c.value(3));
329        assert_eq!(17, c.value(4));
330    }
331
332    #[test]
333    fn test_primitive_array_add_mismatched_length() {
334        let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
335        let b = Int32Array::from(vec![6, 7, 8]);
336        let e = add(&a, &b)
337            .err()
338            .expect("should have failed due to different lengths");
339        assert_eq!(
340            "ComputeError(\"Cannot perform math operation on arrays of different length\")",
341            format!("{:?}", e)
342        );
343    }
344
345    #[test]
346    fn test_primitive_array_subtract() {
347        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
348        let b = Int32Array::from(vec![5, 4, 3, 2, 1]);
349        let c = subtract(&a, &b).unwrap();
350        assert_eq!(-4, c.value(0));
351        assert_eq!(-2, c.value(1));
352        assert_eq!(0, c.value(2));
353        assert_eq!(2, c.value(3));
354        assert_eq!(4, c.value(4));
355    }
356
357    #[test]
358    fn test_primitive_array_multiply() {
359        let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
360        let b = Int32Array::from(vec![6, 7, 8, 9, 8]);
361        let c = multiply(&a, &b).unwrap();
362        assert_eq!(30, c.value(0));
363        assert_eq!(42, c.value(1));
364        assert_eq!(56, c.value(2));
365        assert_eq!(72, c.value(3));
366        assert_eq!(72, c.value(4));
367    }
368
369    #[test]
370    fn test_primitive_array_divide() {
371        let a = Int32Array::from(vec![15, 15, 8, 1, 9]);
372        let b = Int32Array::from(vec![5, 6, 8, 9, 1]);
373        let c = divide(&a, &b).unwrap();
374        assert_eq!(3, c.value(0));
375        assert_eq!(2, c.value(1));
376        assert_eq!(1, c.value(2));
377        assert_eq!(0, c.value(3));
378        assert_eq!(9, c.value(4));
379    }
380
381    #[test]
382    fn test_primitive_array_divide_with_nulls() {
383        let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9), None]);
384        let b = Int32Array::from(vec![Some(5), Some(6), Some(8), Some(9), None, None]);
385        let c = divide(&a, &b).unwrap();
386        assert_eq!(3, c.value(0));
387        assert_eq!(true, c.is_null(1));
388        assert_eq!(1, c.value(2));
389        assert_eq!(0, c.value(3));
390        assert_eq!(true, c.is_null(4));
391        assert_eq!(true, c.is_null(5));
392    }
393
394    #[test]
395    fn test_primitive_array_divide_by_zero() {
396        let a = Int32Array::from(vec![15]);
397        let b = Int32Array::from(vec![0]);
398        assert_eq!(
399            ArrowError::DivideByZero,
400            divide(&a, &b).err().expect("divide by zero should fail")
401        );
402    }
403
404    #[test]
405    fn test_primitive_array_divide_f64() {
406        let a = Float64Array::from(vec![15.0, 15.0, 8.0]);
407        let b = Float64Array::from(vec![5.0, 6.0, 8.0]);
408        let c = divide(&a, &b).unwrap();
409        assert_eq!(3.0, c.value(0));
410        assert_eq!(2.5, c.value(1));
411        assert_eq!(1.0, c.value(2));
412    }
413
414    #[test]
415    fn test_primitive_array_add_with_nulls() {
416        let a = Int32Array::from(vec![Some(5), None, Some(7), None]);
417        let b = Int32Array::from(vec![None, None, Some(6), Some(7)]);
418        let c = add(&a, &b).unwrap();
419        assert_eq!(true, c.is_null(0));
420        assert_eq!(true, c.is_null(1));
421        assert_eq!(false, c.is_null(2));
422        assert_eq!(true, c.is_null(3));
423        assert_eq!(13, c.value(2));
424    }
425}