polars_plan/dsl/function_expr/
pow.rs1use num_traits::pow::Pow;
2use num_traits::{Float, One, ToPrimitive, Zero};
3use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values};
4use polars_core::with_match_physical_integer_type;
5
6use super::*;
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
10pub enum PowFunction {
11 Generic,
12 Sqrt,
13 Cbrt,
14}
15
16impl PowFunction {
17 pub fn function_options(&self) -> FunctionOptions {
18 use PowFunction as P;
19 match self {
20 P::Generic | P::Sqrt | P::Cbrt => FunctionOptions::elementwise(),
21 }
22 }
23}
24
25impl Display for PowFunction {
26 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
27 use self::*;
28 match self {
29 PowFunction::Generic => write!(f, "pow"),
30 PowFunction::Sqrt => write!(f, "sqrt"),
31 PowFunction::Cbrt => write!(f, "cbrt"),
32 }
33 }
34}
35
36impl From<PowFunction> for FunctionExpr {
37 fn from(value: PowFunction) -> Self {
38 Self::Pow(value)
39 }
40}
41
42fn pow_on_chunked_arrays<T, F>(
43 base: &ChunkedArray<T>,
44 exponent: &ChunkedArray<F>,
45) -> ChunkedArray<T>
46where
47 T: PolarsNumericType,
48 F: PolarsNumericType,
49 T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
50{
51 if exponent.len() == 1 {
52 if let Some(e) = exponent.get(0) {
53 if e == F::Native::zero() {
54 return unary_elementwise_values(base, |_| T::Native::one());
55 }
56 if e == F::Native::one() {
57 return base.clone();
58 }
59 if e == F::Native::one() + F::Native::one() {
60 return base * base;
61 }
62 }
63 }
64
65 broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?)))
66}
67
68fn pow_on_floats<T>(
69 base: &ChunkedArray<T>,
70 exponent: &ChunkedArray<T>,
71) -> PolarsResult<Option<Column>>
72where
73 T: PolarsFloatType,
74 T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
75 ChunkedArray<T>: IntoColumn,
76{
77 let dtype = T::get_dtype();
78
79 if exponent.len() == 1 {
80 let Some(exponent_value) = exponent.get(0) else {
81 return Ok(Some(Column::full_null(
82 base.name().clone(),
83 base.len(),
84 &dtype,
85 )));
86 };
87 let s = match exponent_value.to_f64().unwrap() {
88 1.0 => base.clone().into_column(),
89 0.5 => base.apply_values(|v| v.sqrt()).into_column(),
92 a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
93 let mut out = base.clone();
94
95 for _ in 1..exponent_value.to_u8().unwrap() {
96 out = out * base.clone()
97 }
98 out.into_column()
99 },
100 _ => base
101 .apply_values(|v| Pow::pow(v, exponent_value))
102 .into_column(),
103 };
104 Ok(Some(s))
105 } else {
106 Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
107 }
108}
109
110fn pow_to_uint_dtype<T, F>(
111 base: &ChunkedArray<T>,
112 exponent: &ChunkedArray<F>,
113) -> PolarsResult<Option<Column>>
114where
115 T: PolarsIntegerType,
116 F: PolarsIntegerType,
117 T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
118 ChunkedArray<T>: IntoColumn,
119{
120 let dtype = T::get_dtype();
121
122 if exponent.len() == 1 {
123 let Some(exponent_value) = exponent.get(0) else {
124 return Ok(Some(Column::full_null(
125 base.name().clone(),
126 base.len(),
127 &dtype,
128 )));
129 };
130 let s = match exponent_value.to_u64().unwrap() {
131 1 => base.clone().into_column(),
132 2..=10 => {
133 let mut out = base.clone();
134
135 for _ in 1..exponent_value.to_u8().unwrap() {
136 out = out * base.clone()
137 }
138 out.into_column()
139 },
140 _ => base
141 .apply_values(|v| Pow::pow(v, exponent_value))
142 .into_column(),
143 };
144 Ok(Some(s))
145 } else {
146 Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
147 }
148}
149
150fn pow_on_series(base: &Column, exponent: &Column) -> PolarsResult<Option<Column>> {
151 use DataType::*;
152
153 let base_dtype = base.dtype();
154 polars_ensure!(
155 base_dtype.is_primitive_numeric(),
156 InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
157 );
158 let exponent_dtype = exponent.dtype();
159 polars_ensure!(
160 exponent_dtype.is_primitive_numeric(),
161 InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype
162 );
163
164 if base_dtype.is_integer() {
166 with_match_physical_integer_type!(base_dtype, |$native_type| {
167 if exponent_dtype.is_float() {
168 match exponent_dtype {
169 Float32 => {
170 let ca = base.cast(&DataType::Float32)?;
171 pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
172 },
173 Float64 => {
174 let ca = base.cast(&DataType::Float64)?;
175 pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
176 },
177 _ => unreachable!(),
178 }
179 } else {
180 let ca = base.$native_type().unwrap();
181 let exponent = exponent.strict_cast(&DataType::UInt32).map_err(|err| polars_err!(
182 InvalidOperation:
183 "{}\n\nHint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first.",
184 err
185 ))?;
186 pow_to_uint_dtype(ca, exponent.u32().unwrap())
187 }
188 })
189 } else {
190 match base_dtype {
191 Float32 => {
192 let ca = base.f32().unwrap();
193 let exponent = exponent.strict_cast(&DataType::Float32)?;
194 pow_on_floats(ca, exponent.f32().unwrap())
195 },
196 Float64 => {
197 let ca = base.f64().unwrap();
198 let exponent = exponent.strict_cast(&DataType::Float64)?;
199 pow_on_floats(ca, exponent.f64().unwrap())
200 },
201 _ => unreachable!(),
202 }
203 }
204}
205
206pub(super) fn pow(s: &mut [Column]) -> PolarsResult<Option<Column>> {
207 let base = &s[0];
208 let exponent = &s[1];
209
210 let base_len = base.len();
211 let exp_len = exponent.len();
212 match (base_len, exp_len) {
213 (1, _) | (_, 1) => pow_on_series(base, exponent),
214 (len_a, len_b) if len_a == len_b => pow_on_series(base, exponent),
215 _ => polars_bail!(
216 ComputeError:
217 "exponent shape: {} in `pow` expression does not match that of the base: {}",
218 exp_len, base_len,
219 ),
220 }
221}
222
223pub(super) fn sqrt(base: &Column) -> PolarsResult<Column> {
224 use DataType::*;
225 match base.dtype() {
226 Float32 => {
227 let ca = base.f32().unwrap();
228 sqrt_on_floats(ca)
229 },
230 Float64 => {
231 let ca = base.f64().unwrap();
232 sqrt_on_floats(ca)
233 },
234 _ => {
235 let base = base.cast(&DataType::Float64)?;
236 sqrt(&base)
237 },
238 }
239}
240
241fn sqrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
242where
243 T: PolarsFloatType,
244 T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
245 ChunkedArray<T>: IntoColumn,
246{
247 Ok(base.apply_values(|v| v.sqrt()).into_column())
248}
249
250pub(super) fn cbrt(base: &Column) -> PolarsResult<Column> {
251 use DataType::*;
252 match base.dtype() {
253 Float32 => {
254 let ca = base.f32().unwrap();
255 cbrt_on_floats(ca)
256 },
257 Float64 => {
258 let ca = base.f64().unwrap();
259 cbrt_on_floats(ca)
260 },
261 _ => {
262 let base = base.cast(&DataType::Float64)?;
263 cbrt(&base)
264 },
265 }
266}
267
268fn cbrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
269where
270 T: PolarsFloatType,
271 T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
272 ChunkedArray<T>: IntoColumn,
273{
274 Ok(base.apply_values(|v| v.cbrt()).into_column())
275}