polars_plan/dsl/function_expr/
pow.rs

1use num_traits::pow::Pow;
2use num_traits::{Float, One, ToPrimitive, Zero};
3use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values};
4use polars_core::with_match_physical_integer_type;
5
6use super::*;
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
10pub enum PowFunction {
11    Generic,
12    Sqrt,
13    Cbrt,
14}
15
16impl PowFunction {
17    pub fn function_options(&self) -> FunctionOptions {
18        use PowFunction as P;
19        match self {
20            P::Generic | P::Sqrt | P::Cbrt => FunctionOptions::elementwise(),
21        }
22    }
23}
24
25impl Display for PowFunction {
26    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
27        use self::*;
28        match self {
29            PowFunction::Generic => write!(f, "pow"),
30            PowFunction::Sqrt => write!(f, "sqrt"),
31            PowFunction::Cbrt => write!(f, "cbrt"),
32        }
33    }
34}
35
36impl From<PowFunction> for FunctionExpr {
37    fn from(value: PowFunction) -> Self {
38        Self::Pow(value)
39    }
40}
41
42fn pow_on_chunked_arrays<T, F>(
43    base: &ChunkedArray<T>,
44    exponent: &ChunkedArray<F>,
45) -> ChunkedArray<T>
46where
47    T: PolarsNumericType,
48    F: PolarsNumericType,
49    T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
50{
51    if exponent.len() == 1 {
52        if let Some(e) = exponent.get(0) {
53            if e == F::Native::zero() {
54                return unary_elementwise_values(base, |_| T::Native::one());
55            }
56            if e == F::Native::one() {
57                return base.clone();
58            }
59            if e == F::Native::one() + F::Native::one() {
60                return base * base;
61            }
62        }
63    }
64
65    broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?)))
66}
67
68fn pow_on_floats<T>(
69    base: &ChunkedArray<T>,
70    exponent: &ChunkedArray<T>,
71) -> PolarsResult<Option<Column>>
72where
73    T: PolarsFloatType,
74    T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
75    ChunkedArray<T>: IntoColumn,
76{
77    let dtype = T::get_dtype();
78
79    if exponent.len() == 1 {
80        let Some(exponent_value) = exponent.get(0) else {
81            return Ok(Some(Column::full_null(
82                base.name().clone(),
83                base.len(),
84                &dtype,
85            )));
86        };
87        let s = match exponent_value.to_f64().unwrap() {
88            1.0 => base.clone().into_column(),
89            // specialized sqrt will ensure (-inf)^0.5 = NaN
90            // and will likely be faster as well.
91            0.5 => base.apply_values(|v| v.sqrt()).into_column(),
92            a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
93                let mut out = base.clone();
94
95                for _ in 1..exponent_value.to_u8().unwrap() {
96                    out = out * base.clone()
97                }
98                out.into_column()
99            },
100            _ => base
101                .apply_values(|v| Pow::pow(v, exponent_value))
102                .into_column(),
103        };
104        Ok(Some(s))
105    } else {
106        Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
107    }
108}
109
110fn pow_to_uint_dtype<T, F>(
111    base: &ChunkedArray<T>,
112    exponent: &ChunkedArray<F>,
113) -> PolarsResult<Option<Column>>
114where
115    T: PolarsIntegerType,
116    F: PolarsIntegerType,
117    T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
118    ChunkedArray<T>: IntoColumn,
119{
120    let dtype = T::get_dtype();
121
122    if exponent.len() == 1 {
123        let Some(exponent_value) = exponent.get(0) else {
124            return Ok(Some(Column::full_null(
125                base.name().clone(),
126                base.len(),
127                &dtype,
128            )));
129        };
130        let s = match exponent_value.to_u64().unwrap() {
131            1 => base.clone().into_column(),
132            2..=10 => {
133                let mut out = base.clone();
134
135                for _ in 1..exponent_value.to_u8().unwrap() {
136                    out = out * base.clone()
137                }
138                out.into_column()
139            },
140            _ => base
141                .apply_values(|v| Pow::pow(v, exponent_value))
142                .into_column(),
143        };
144        Ok(Some(s))
145    } else {
146        Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
147    }
148}
149
150fn pow_on_series(base: &Column, exponent: &Column) -> PolarsResult<Option<Column>> {
151    use DataType::*;
152
153    let base_dtype = base.dtype();
154    polars_ensure!(
155        base_dtype.is_primitive_numeric(),
156        InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
157    );
158    let exponent_dtype = exponent.dtype();
159    polars_ensure!(
160        exponent_dtype.is_primitive_numeric(),
161        InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype
162    );
163
164    // if false, dtype is float
165    if base_dtype.is_integer() {
166        with_match_physical_integer_type!(base_dtype, |$native_type| {
167            if exponent_dtype.is_float() {
168                match exponent_dtype {
169                    Float32 => {
170                        let ca = base.cast(&DataType::Float32)?;
171                        pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
172                    },
173                    Float64 => {
174                        let ca = base.cast(&DataType::Float64)?;
175                        pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
176                    },
177                    _ => unreachable!(),
178                }
179            } else {
180                let ca = base.$native_type().unwrap();
181                let exponent = exponent.strict_cast(&DataType::UInt32).map_err(|err| polars_err!(
182                    InvalidOperation:
183                    "{}\n\nHint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first.",
184                    err
185                ))?;
186                pow_to_uint_dtype(ca, exponent.u32().unwrap())
187            }
188        })
189    } else {
190        match base_dtype {
191            Float32 => {
192                let ca = base.f32().unwrap();
193                let exponent = exponent.strict_cast(&DataType::Float32)?;
194                pow_on_floats(ca, exponent.f32().unwrap())
195            },
196            Float64 => {
197                let ca = base.f64().unwrap();
198                let exponent = exponent.strict_cast(&DataType::Float64)?;
199                pow_on_floats(ca, exponent.f64().unwrap())
200            },
201            _ => unreachable!(),
202        }
203    }
204}
205
206pub(super) fn pow(s: &mut [Column]) -> PolarsResult<Option<Column>> {
207    let base = &s[0];
208    let exponent = &s[1];
209
210    let base_len = base.len();
211    let exp_len = exponent.len();
212    match (base_len, exp_len) {
213        (1, _) | (_, 1) => pow_on_series(base, exponent),
214        (len_a, len_b) if len_a == len_b => pow_on_series(base, exponent),
215        _ => polars_bail!(
216            ComputeError:
217            "exponent shape: {} in `pow` expression does not match that of the base: {}",
218            exp_len, base_len,
219        ),
220    }
221}
222
223pub(super) fn sqrt(base: &Column) -> PolarsResult<Column> {
224    use DataType::*;
225    match base.dtype() {
226        Float32 => {
227            let ca = base.f32().unwrap();
228            sqrt_on_floats(ca)
229        },
230        Float64 => {
231            let ca = base.f64().unwrap();
232            sqrt_on_floats(ca)
233        },
234        _ => {
235            let base = base.cast(&DataType::Float64)?;
236            sqrt(&base)
237        },
238    }
239}
240
241fn sqrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
242where
243    T: PolarsFloatType,
244    T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
245    ChunkedArray<T>: IntoColumn,
246{
247    Ok(base.apply_values(|v| v.sqrt()).into_column())
248}
249
250pub(super) fn cbrt(base: &Column) -> PolarsResult<Column> {
251    use DataType::*;
252    match base.dtype() {
253        Float32 => {
254            let ca = base.f32().unwrap();
255            cbrt_on_floats(ca)
256        },
257        Float64 => {
258            let ca = base.f64().unwrap();
259            cbrt_on_floats(ca)
260        },
261        _ => {
262            let base = base.cast(&DataType::Float64)?;
263            cbrt(&base)
264        },
265    }
266}
267
268fn cbrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
269where
270    T: PolarsFloatType,
271    T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
272    ChunkedArray<T>: IntoColumn,
273{
274    Ok(base.apply_values(|v| v.cbrt()).into_column())
275}