polars_compute/arithmetic/
mod.rs1use std::any::TypeId;
2
3use arrow::array::{Array, PrimitiveArray};
4use arrow::bitmap::BitmapBuilder;
5use arrow::types::NativeType;
6
7pub trait ArithmeticKernel: Sized + Array {
9 type Scalar;
10 type TrueDivT: NativeType;
11
12 fn wrapping_abs(self) -> Self;
13 fn wrapping_neg(self) -> Self;
14 fn wrapping_add(self, rhs: Self) -> Self;
15 fn wrapping_sub(self, rhs: Self) -> Self;
16 fn wrapping_mul(self, rhs: Self) -> Self;
17 fn wrapping_floor_div(self, rhs: Self) -> Self;
18 fn wrapping_trunc_div(self, rhs: Self) -> Self;
19 fn wrapping_mod(self, rhs: Self) -> Self;
20
21 fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self;
22 fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self;
23 fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
24 fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self;
25 fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self;
26 fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
27 fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self;
28 fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
29 fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self;
30 fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self;
31
32 fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self;
33
34 fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;
35 fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT>;
36 fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT>;
37
38 fn legacy_div(self, rhs: Self) -> Self {
41 if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
42 let ret = self.true_div(rhs);
43 unsafe {
44 let cast_ret = std::mem::transmute_copy(&ret);
45 std::mem::forget(ret);
46 cast_ret
47 }
48 } else {
49 self.wrapping_floor_div(rhs)
50 }
51 }
52 fn legacy_div_scalar(self, rhs: Self::Scalar) -> Self {
53 if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
54 let ret = self.true_div_scalar(rhs);
55 unsafe {
56 let cast_ret = std::mem::transmute_copy(&ret);
57 std::mem::forget(ret);
58 cast_ret
59 }
60 } else {
61 self.wrapping_floor_div_scalar(rhs)
62 }
63 }
64
65 fn legacy_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self {
66 if TypeId::of::<Self>() == TypeId::of::<PrimitiveArray<Self::TrueDivT>>() {
67 let ret = ArithmeticKernel::true_div_scalar_lhs(lhs, rhs);
68 unsafe {
69 let cast_ret = std::mem::transmute_copy(&ret);
70 std::mem::forget(ret);
71 cast_ret
72 }
73 } else {
74 ArithmeticKernel::wrapping_floor_div_scalar_lhs(lhs, rhs)
75 }
76 }
77}
78
79#[allow(private_bounds)]
82pub trait HasPrimitiveArithmeticKernel: NativeType + PrimitiveArithmeticKernelImpl {}
83impl<T: NativeType + PrimitiveArithmeticKernelImpl> HasPrimitiveArithmeticKernel for T {}
84
85use PrimitiveArray as PArr;
86use num_traits::{CheckedMul, WrappingMul};
87use polars_utils::vec::PushUnchecked;
88
89#[doc(hidden)]
90pub trait PrimitiveArithmeticKernelImpl: NativeType {
91 type TrueDivT: NativeType;
92
93 fn prim_wrapping_abs(lhs: PArr<Self>) -> PArr<Self>;
94 fn prim_wrapping_neg(lhs: PArr<Self>) -> PArr<Self>;
95 fn prim_wrapping_add(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
96 fn prim_wrapping_sub(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
97 fn prim_wrapping_mul(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
98 fn prim_wrapping_floor_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
99 fn prim_wrapping_trunc_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
100 fn prim_wrapping_mod(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self>;
101
102 fn prim_wrapping_add_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
103 fn prim_wrapping_sub_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
104 fn prim_wrapping_sub_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
105 fn prim_wrapping_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
106 fn prim_wrapping_floor_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
107 fn prim_wrapping_floor_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
108 fn prim_wrapping_trunc_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
109 fn prim_wrapping_trunc_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
110 fn prim_wrapping_mod_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
111 fn prim_wrapping_mod_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self>;
112
113 fn prim_checked_mul_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self>;
114
115 fn prim_true_div(lhs: PArr<Self>, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;
116 fn prim_true_div_scalar(lhs: PArr<Self>, rhs: Self) -> PArr<Self::TrueDivT>;
117 fn prim_true_div_scalar_lhs(lhs: Self, rhs: PArr<Self>) -> PArr<Self::TrueDivT>;
118}
119
120#[rustfmt::skip]
121impl<T: HasPrimitiveArithmeticKernel> ArithmeticKernel for PrimitiveArray<T> {
122 type Scalar = T;
123 type TrueDivT = T::TrueDivT;
124
125 fn wrapping_abs(self) -> Self { T::prim_wrapping_abs(self) }
126 fn wrapping_neg(self) -> Self { T::prim_wrapping_neg(self) }
127 fn wrapping_add(self, rhs: Self) -> Self { T::prim_wrapping_add(self, rhs) }
128 fn wrapping_sub(self, rhs: Self) -> Self { T::prim_wrapping_sub(self, rhs) }
129 fn wrapping_mul(self, rhs: Self) -> Self { T::prim_wrapping_mul(self, rhs) }
130 fn wrapping_floor_div(self, rhs: Self) -> Self { T::prim_wrapping_floor_div(self, rhs) }
131 fn wrapping_trunc_div(self, rhs: Self) -> Self { T::prim_wrapping_trunc_div(self, rhs) }
132 fn wrapping_mod(self, rhs: Self) -> Self { T::prim_wrapping_mod(self, rhs) }
133
134 fn wrapping_add_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_add_scalar(self, rhs) }
135 fn wrapping_sub_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_sub_scalar(self, rhs) }
136 fn wrapping_sub_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_sub_scalar_lhs(lhs, rhs) }
137 fn wrapping_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mul_scalar(self, rhs) }
138 fn wrapping_floor_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_floor_div_scalar(self, rhs) }
139 fn wrapping_floor_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_floor_div_scalar_lhs(lhs, rhs) }
140 fn wrapping_trunc_div_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_trunc_div_scalar(self, rhs) }
141 fn wrapping_trunc_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_trunc_div_scalar_lhs(lhs, rhs) }
142 fn wrapping_mod_scalar(self, rhs: Self::Scalar) -> Self { T::prim_wrapping_mod_scalar(self, rhs) }
143 fn wrapping_mod_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> Self { T::prim_wrapping_mod_scalar_lhs(lhs, rhs) }
144
145 fn checked_mul_scalar(self, rhs: Self::Scalar) -> Self { T::prim_checked_mul_scalar(self, rhs) }
146
147 fn true_div(self, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div(self, rhs) }
148 fn true_div_scalar(self, rhs: Self::Scalar) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar(self, rhs) }
149 fn true_div_scalar_lhs(lhs: Self::Scalar, rhs: Self) -> PrimitiveArray<Self::TrueDivT> { T::prim_true_div_scalar_lhs(lhs, rhs) }
150}
151
152mod float;
153pub mod pl_num;
154mod signed;
155mod unsigned;
156
157fn prim_checked_mul_scalar<I: NativeType + CheckedMul + WrappingMul>(
158 array: &PrimitiveArray<I>,
159 factor: I,
160) -> PrimitiveArray<I> {
161 let values = array.values();
162 let mut out = Vec::with_capacity(array.len());
163 let mut i = 0;
164
165 while i < array.len() && values[i].checked_mul(&factor).is_some() {
166 unsafe { out.push_unchecked(values[i].wrapping_mul(&factor)) };
168 i += 1;
169 }
170
171 if out.len() == array.len() {
172 return PrimitiveArray::<I>::new(
173 I::PRIMITIVE.into(),
174 out.into(),
175 array.validity().cloned(),
176 );
177 }
178
179 let mut validity = BitmapBuilder::with_capacity(array.len());
180 validity.extend_constant(out.len(), true);
181
182 for &value in &values[out.len()..] {
183 unsafe {
185 out.push_unchecked(value.wrapping_mul(&factor));
186 validity.push_unchecked(value.checked_mul(&factor).is_some());
187 }
188 }
189
190 debug_assert_eq!(out.len(), array.len());
191 debug_assert_eq!(validity.len(), array.len());
192
193 let validity = validity.freeze();
194 let validity = match array.validity() {
195 None => validity,
196 Some(arr_validity) => arrow::bitmap::and(&validity, arr_validity),
197 };
198
199 PrimitiveArray::<I>::new(I::PRIMITIVE.into(), out.into(), Some(validity))
200}