polars_compute/arithmetic/
mod.rs

1use std::any::TypeId;
2
3use arrow::array::{Array, PrimitiveArray};
4use arrow::bitmap::BitmapBuilder;
5use arrow::types::NativeType;
6
7// Low-level comparison kernel.
8pub trait ArithmeticKernel: Sized + Array {
9    type Scalar;
10    type TrueDivT: NativeType;
11
12    fn wrapping_abs(self) -> Self;
13    fn wrapping_neg(self) -> Self;
14    fn wrapping_add(self, rhs: Self) -> Self;
15    fn wrapping_sub(self, rhs: Self) -> Self;
16    fn wrapping_mul(self, rhs: Self) -> Self;
17    fn wrapping_floor_div(self, rhs: Self) -> Self;
18    fn wrapping_trunc_div(self, rhs: Self) -> Self;
19    fn wrapping_mod(self, rhs: Self) -> Self;
20
21    fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self;
22    fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self;
23    fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
24    fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self;
25    fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self;
26    fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
27    fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self;
28    fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
29    fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self;
30    fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
31
32    fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self;
33
34    fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;
35    fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT>;
36    fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;
37
38    // TODO: remove these.
39    // These are flooring division for integer types, true division for floating point types.
40    fn legacy_div(self, rhs: Self) -> Self {
41        if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
42            let ret = self.true_div(rhs);
43            unsafe {
44                let cast_ret = std::mem::transmute_copy(&ret);
45                std::mem::forget(ret);
46                cast_ret
47            }
48        } else {
49            self.wrapping_floor_div(rhs)
50        }
51    }
52    fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self {
53        if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
54            let ret = self.true_div_scalar(rhs);
55            unsafe {
56                let cast_ret = std::mem::transmute_copy(&ret);
57                std::mem::forget(ret);
58                cast_ret
59            }
60        } else {
61            self.wrapping_floor_div_scalar(rhs)
62        }
63    }
64
65    fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self {
66        if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
67            let ret = ArithmeticKernel::true_div_scalar_lhs(lhs, rhs);
68            unsafe {
69                let cast_ret = std::mem::transmute_copy(&ret);
70                std::mem::forget(ret);
71                cast_ret
72            }
73        } else {
74            ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, rhs)
75        }
76    }
77}
78
79// Proxy trait so one can bound T: HasPrimitiveArithmeticKernel. Sadly Rust
80// doesn't support adding supertraits for other types.
81#[allow(private_bounds)]
82pub trait HasPrimitiveArithmeticKernel: NativeType + PrimitiveArithmeticKernelImpl {}
83impl<T: NativeType + PrimitiveArithmeticKernelImpl> HasPrimitiveArithmeticKernel for T {}
84
85use PrimitiveArray as PArr;
86use num_traits::{CheckedMul, WrappingMul};
87use polars_utils::vec::PushUnchecked;
88
89#[doc(hidden)]
90pub trait PrimitiveArithmeticKernelImpl: NativeType {
91    type TrueDivT: NativeType;
92
93    fn prim_wrapping_abs(lhs: PArr<Self>) -> PArr<Self>;
94    fn prim_wrapping_neg(lhs: PArr<Self>) -> PArr<Self>;
95    fn prim_wrapping_add(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
96    fn prim_wrapping_sub(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
97    fn prim_wrapping_mul(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
98    fn prim_wrapping_floor_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
99    fn prim_wrapping_trunc_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
100    fn prim_wrapping_mod(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
101
102    fn prim_wrapping_add_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
103    fn prim_wrapping_sub_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
104    fn prim_wrapping_sub_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
105    fn prim_wrapping_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
106    fn prim_wrapping_floor_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
107    fn prim_wrapping_floor_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
108    fn prim_wrapping_trunc_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
109    fn prim_wrapping_trunc_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
110    fn prim_wrapping_mod_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
111    fn prim_wrapping_mod_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
112
113    fn prim_checked_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
114
115    fn prim_true_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;
116    fn prim_true_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self::TrueDivT>;
117    fn prim_true_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;
118}
119
120#[rustfmt::skip]
121impl<T: HasPrimitiveArithmeticKernel> ArithmeticKernel for PrimitiveArray<T> {
122    type Scalar = T;
123    type TrueDivT = T::TrueDivT;
124
125    fn wrapping_abs(self) -> Self { T::prim_wrapping_abs(self) }
126    fn wrapping_neg(self) -> Self { T::prim_wrapping_neg(self) }
127    fn wrapping_add(self, rhs: Self) -> Self { T::prim_wrapping_add(self, rhs) }
128    fn wrapping_sub(self, rhs: Self) -> Self { T::prim_wrapping_sub(self, rhs) }
129    fn wrapping_mul(self, rhs: Self) -> Self { T::prim_wrapping_mul(self, rhs) }
130    fn wrapping_floor_div(self, rhs: Self) -> Self { T::prim_wrapping_floor_div(self, rhs) }
131    fn wrapping_trunc_div(self, rhs: Self) -> Self { T::prim_wrapping_trunc_div(self, rhs) }
132    fn wrapping_mod(self, rhs: Self) -> Self { T::prim_wrapping_mod(self, rhs) }
133
134    fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_add_scalar(self, rhs) }
135    fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_sub_scalar(self, rhs) }
136    fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_sub_scalar_lhs(lhs, rhs) }
137    fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mul_scalar(self, rhs) }
138    fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_floor_div_scalar(self, rhs) }
139    fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) }
140    fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_trunc_div_scalar(self, rhs) }
141    fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_trunc_div_scalar_lhs(lhs, rhs) }
142    fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mod_scalar(self, rhs) }
143    fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_mod_scalar_lhs(lhs, rhs) }
144
145    fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_checked_mul_scalar(self, rhs) }
146
147    fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div(self, rhs) }
148    fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar(self, rhs) }
149    fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar_lhs(lhs, rhs) }
150}
151
152mod float;
153pub mod pl_num;
154mod signed;
155mod unsigned;
156
157fn prim_checked_mul_scalar<I: NativeType + CheckedMul + WrappingMul>(
158    array: &PrimitiveArray<I>,
159    factor: I,
160) -> PrimitiveArray<I> {
161    let values = array.values();
162    let mut out = Vec::with_capacity(array.len());
163    let mut i = 0;
164
165    while i < array.len() && values[i].checked_mul(&factor).is_some() {
166        // SAFETY: We allocated enough before.
167        unsafe { out.push_unchecked(values[i].wrapping_mul(&factor)) };
168        i += 1;
169    }
170
171    if out.len() == array.len() {
172        return PrimitiveArray::<I>::new(
173            I::PRIMITIVE.into(),
174            out.into(),
175            array.validity().cloned(),
176        );
177    }
178
179    let mut validity = BitmapBuilder::with_capacity(array.len());
180    validity.extend_constant(out.len(), true);
181
182    for &value in &values[out.len()..] {
183        // SAFETY: We allocated enough before.
184        unsafe {
185            out.push_unchecked(value.wrapping_mul(&factor));
186            validity.push_unchecked(value.checked_mul(&factor).is_some());
187        }
188    }
189
190    debug_assert_eq!(out.len(), array.len());
191    debug_assert_eq!(validity.len(), array.len());
192
193    let validity = validity.freeze();
194    let validity = match array.validity() {
195        None => validity,
196        Some(arr_validity) => arrow::bitmap::and(&validity, arr_validity),
197    };
198
199    PrimitiveArray::<I>::new(I::PRIMITIVE.into(), out.into(), Some(validity))
200}