1use std::fmt::{Debug, Display, Formatter};
2use std::hash::{Hash, Hasher};
3
4use bytes::Bytes;
5use polars_compute::rolling::QuantileMethod;
6use polars_core::chunked_array::cast::CastOptions;
7use polars_core::error::feature_gated;
8use polars_core::prelude::*;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub use super::expr_dyn_fn::*;
13use crate::prelude::*;
14
15#[derive(PartialEq, Clone, Hash)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub enum AggExpr {
18 Min {
19 input: Arc<Expr>,
20 propagate_nans: bool,
21 },
22 Max {
23 input: Arc<Expr>,
24 propagate_nans: bool,
25 },
26 Median(Arc<Expr>),
27 NUnique(Arc<Expr>),
28 First(Arc<Expr>),
29 Last(Arc<Expr>),
30 Mean(Arc<Expr>),
31 Implode(Arc<Expr>),
32 Count(Arc<Expr>, bool),
34 Quantile {
35 expr: Arc<Expr>,
36 quantile: Arc<Expr>,
37 method: QuantileMethod,
38 },
39 Sum(Arc<Expr>),
40 AggGroups(Arc<Expr>),
41 Std(Arc<Expr>, u8),
42 Var(Arc<Expr>, u8),
43}
44
45impl AsRef<Expr> for AggExpr {
46 fn as_ref(&self) -> &Expr {
47 use AggExpr::*;
48 match self {
49 Min { input, .. } => input,
50 Max { input, .. } => input,
51 Median(e) => e,
52 NUnique(e) => e,
53 First(e) => e,
54 Last(e) => e,
55 Mean(e) => e,
56 Implode(e) => e,
57 Count(e, _) => e,
58 Quantile { expr, .. } => expr,
59 Sum(e) => e,
60 AggGroups(e) => e,
61 Std(e, _) => e,
62 Var(e, _) => e,
63 }
64 }
65}
66
67#[derive(Clone, PartialEq)]
73#[must_use]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75pub enum Expr {
76 Alias(Arc<Expr>, PlSmallStr),
77 Column(PlSmallStr),
78 Columns(Arc<[PlSmallStr]>),
79 DtypeColumn(Vec<DataType>),
80 IndexColumn(Arc<[i64]>),
81 Literal(LiteralValue),
82 BinaryExpr {
83 left: Arc<Expr>,
84 op: Operator,
85 right: Arc<Expr>,
86 },
87 Cast {
88 expr: Arc<Expr>,
89 dtype: DataType,
90 options: CastOptions,
91 },
92 Sort {
93 expr: Arc<Expr>,
94 options: SortOptions,
95 },
96 Gather {
97 expr: Arc<Expr>,
98 idx: Arc<Expr>,
99 returns_scalar: bool,
100 },
101 SortBy {
102 expr: Arc<Expr>,
103 by: Vec<Expr>,
104 sort_options: SortMultipleOptions,
105 },
106 Agg(AggExpr),
107 Ternary {
110 predicate: Arc<Expr>,
111 truthy: Arc<Expr>,
112 falsy: Arc<Expr>,
113 },
114 Function {
115 input: Vec<Expr>,
117 function: FunctionExpr,
119 options: FunctionOptions,
120 },
121 Explode {
122 input: Arc<Expr>,
123 skip_empty: bool,
124 },
125 Filter {
126 input: Arc<Expr>,
127 by: Arc<Expr>,
128 },
129 Window {
131 function: Arc<Expr>,
133 partition_by: Vec<Expr>,
134 order_by: Option<(Arc<Expr>, SortOptions)>,
135 options: WindowType,
136 },
137 Wildcard,
138 Slice {
139 input: Arc<Expr>,
140 offset: Arc<Expr>,
142 length: Arc<Expr>,
143 },
144 Exclude(Arc<Expr>, Vec<Excluded>),
147 KeepName(Arc<Expr>),
149 Len,
150 Nth(i64),
152 RenameAlias {
153 function: SpecialEq<Arc<dyn RenameAliasFn>>,
154 expr: Arc<Expr>,
155 },
156 #[cfg(feature = "dtype-struct")]
157 Field(Arc<[PlSmallStr]>),
158 AnonymousFunction {
159 input: Vec<Expr>,
161 function: OpaqueColumnUdf,
163 output_type: GetOutput,
165 options: FunctionOptions,
166 },
167 SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
168 Selector(super::selector::Selector),
175}
176
177pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
178pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
179 LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
180}
181
182#[derive(Clone)]
183pub enum LazySerde<T: Clone> {
184 Deserialized(T),
185 Bytes(Bytes),
186}
187
188impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
189 fn eq(&self, other: &Self) -> bool {
190 use LazySerde as L;
191 match (self, other) {
192 (L::Deserialized(a), L::Deserialized(b)) => a == b,
193 (L::Bytes(a), L::Bytes(b)) => {
194 std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
195 },
196 _ => false,
197 }
198 }
199}
200
201impl<T: Clone> Debug for LazySerde<T> {
202 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
203 match self {
204 Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
205 Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
206 }
207 }
208}
209
210impl OpaqueColumnUdf {
211 pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
212 match self {
213 Self::Deserialized(t) => Ok(t),
214 Self::Bytes(_b) => {
215 feature_gated!("serde";"python", {
216 crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(_b.as_ref()).map(SpecialEq::new)
217 })
218 },
219 }
220 }
221}
222
223#[allow(clippy::derived_hash_with_manual_eq)]
224impl Hash for Expr {
225 fn hash<H: Hasher>(&self, state: &mut H) {
226 let d = std::mem::discriminant(self);
227 d.hash(state);
228 match self {
229 Expr::Column(name) => name.hash(state),
230 Expr::Columns(names) => names.hash(state),
231 Expr::DtypeColumn(dtypes) => dtypes.hash(state),
232 Expr::IndexColumn(indices) => indices.hash(state),
233 Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
234 Expr::Selector(s) => s.hash(state),
235 Expr::Nth(v) => v.hash(state),
236 Expr::Filter { input, by } => {
237 input.hash(state);
238 by.hash(state);
239 },
240 Expr::BinaryExpr { left, op, right } => {
241 left.hash(state);
242 right.hash(state);
243 std::mem::discriminant(op).hash(state)
244 },
245 Expr::Cast {
246 expr,
247 dtype,
248 options: strict,
249 } => {
250 expr.hash(state);
251 dtype.hash(state);
252 strict.hash(state)
253 },
254 Expr::Sort { expr, options } => {
255 expr.hash(state);
256 options.hash(state);
257 },
258 Expr::Alias(input, name) => {
259 input.hash(state);
260 name.hash(state)
261 },
262 Expr::KeepName(input) => input.hash(state),
263 Expr::Ternary {
264 predicate,
265 truthy,
266 falsy,
267 } => {
268 predicate.hash(state);
269 truthy.hash(state);
270 falsy.hash(state);
271 },
272 Expr::Function {
273 input,
274 function,
275 options,
276 } => {
277 input.hash(state);
278 std::mem::discriminant(function).hash(state);
279 options.hash(state);
280 },
281 Expr::Gather {
282 expr,
283 idx,
284 returns_scalar,
285 } => {
286 expr.hash(state);
287 idx.hash(state);
288 returns_scalar.hash(state);
289 },
290 Expr::Wildcard | Expr::Len => {},
292 Expr::SortBy {
293 expr,
294 by,
295 sort_options,
296 } => {
297 expr.hash(state);
298 by.hash(state);
299 sort_options.hash(state);
300 },
301 Expr::Agg(input) => input.hash(state),
302 Expr::Explode { input, skip_empty } => {
303 skip_empty.hash(state);
304 input.hash(state)
305 },
306 Expr::Window {
307 function,
308 partition_by,
309 order_by,
310 options,
311 } => {
312 function.hash(state);
313 partition_by.hash(state);
314 order_by.hash(state);
315 options.hash(state);
316 },
317 Expr::Slice {
318 input,
319 offset,
320 length,
321 } => {
322 input.hash(state);
323 offset.hash(state);
324 length.hash(state);
325 },
326 Expr::Exclude(input, excl) => {
327 input.hash(state);
328 excl.hash(state);
329 },
330 Expr::RenameAlias { function: _, expr } => expr.hash(state),
331 Expr::AnonymousFunction {
332 input,
333 function: _,
334 output_type: _,
335 options,
336 } => {
337 input.hash(state);
338 options.hash(state);
339 },
340 Expr::SubPlan(_, names) => names.hash(state),
341 #[cfg(feature = "dtype-struct")]
342 Expr::Field(names) => names.hash(state),
343 }
344 }
345}
346
347impl Eq for Expr {}
348
349impl Default for Expr {
350 fn default() -> Self {
351 Expr::Literal(LiteralValue::Scalar(Scalar::default()))
352 }
353}
354
355#[derive(Debug, Clone, PartialEq, Eq, Hash)]
356#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
357pub enum Excluded {
358 Name(PlSmallStr),
359 Dtype(DataType),
360}
361
362impl Expr {
363 pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
365 let mut arena = Arena::with_capacity(5);
367 self.to_field_amortized(schema, ctxt, &mut arena)
368 }
369 pub(crate) fn to_field_amortized(
370 &self,
371 schema: &Schema,
372 ctxt: Context,
373 expr_arena: &mut Arena<AExpr>,
374 ) -> PolarsResult<Field> {
375 let root = to_aexpr(self.clone(), expr_arena)?;
376 expr_arena
377 .get(root)
378 .to_field_and_validate(schema, ctxt, expr_arena)
379 }
380
381 pub fn extract_usize(&self) -> PolarsResult<usize> {
383 match self {
384 Expr::Literal(n) => n.extract_usize(),
385 Expr::Cast { expr, dtype, .. } => {
386 if dtype.is_integer() {
388 expr.extract_usize()
389 } else {
390 polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
391 }
392 },
393 _ => {
394 polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
395 },
396 }
397 }
398
399 #[inline]
400 pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
401 Expr::n_ary(function, vec![self])
402 }
403 #[inline]
404 pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
405 Expr::n_ary(function, vec![self, rhs])
406 }
407
408 #[inline]
409 pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
410 Expr::n_ary(function, vec![self, arg1, arg2])
411 }
412
413 #[inline]
414 pub fn try_map_n_ary(
415 self,
416 function: impl Into<FunctionExpr>,
417 exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
418 ) -> PolarsResult<Expr> {
419 let exprs = exprs.into_iter();
420 let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
421 input.push(self);
422 for e in exprs {
423 input.push(e?);
424 }
425 Ok(Expr::n_ary(function, input))
426 }
427
428 #[inline]
429 pub fn map_n_ary(
430 self,
431 function: impl Into<FunctionExpr>,
432 exprs: impl IntoIterator<Item = Expr>,
433 ) -> Expr {
434 let exprs = exprs.into_iter();
435 let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
436 input.push(self);
437 input.extend(exprs);
438 Expr::n_ary(function, input)
439 }
440
441 #[inline]
442 pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
443 let function = function.into();
444 let options = function.function_options();
445 Expr::Function {
446 input,
447 function,
448 options,
449 }
450 }
451}
452
453#[derive(Copy, Clone, PartialEq, Eq, Hash)]
454#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
455pub enum Operator {
456 Eq,
457 EqValidity,
458 NotEq,
459 NotEqValidity,
460 Lt,
461 LtEq,
462 Gt,
463 GtEq,
464 Plus,
465 Minus,
466 Multiply,
467 Divide,
468 TrueDivide,
469 FloorDivide,
470 Modulus,
471 And,
472 Or,
473 Xor,
474 LogicalAnd,
475 LogicalOr,
476}
477
478impl Display for Operator {
479 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
480 use Operator::*;
481 let tkn = match self {
482 Eq => "==",
483 EqValidity => "==v",
484 NotEq => "!=",
485 NotEqValidity => "!=v",
486 Lt => "<",
487 LtEq => "<=",
488 Gt => ">",
489 GtEq => ">=",
490 Plus => "+",
491 Minus => "-",
492 Multiply => "*",
493 Divide => "//",
494 TrueDivide => "/",
495 FloorDivide => "floor_div",
496 Modulus => "%",
497 And | LogicalAnd => "&",
498 Or | LogicalOr => "|",
499 Xor => "^",
500 };
501 write!(f, "{tkn}")
502 }
503}
504
505impl Operator {
506 pub fn is_comparison(&self) -> bool {
507 matches!(
508 self,
509 Self::Eq
510 | Self::NotEq
511 | Self::Lt
512 | Self::LtEq
513 | Self::Gt
514 | Self::GtEq
515 | Self::EqValidity
516 | Self::NotEqValidity
517 )
518 }
519
520 pub fn is_bitwise(&self) -> bool {
521 matches!(self, Self::And | Self::Or | Self::Xor)
522 }
523
524 pub fn is_comparison_or_bitwise(&self) -> bool {
525 self.is_comparison() || self.is_bitwise()
526 }
527
528 pub fn swap_operands(self) -> Self {
529 match self {
530 Operator::Eq => Operator::Eq,
531 Operator::Gt => Operator::Lt,
532 Operator::GtEq => Operator::LtEq,
533 Operator::LtEq => Operator::GtEq,
534 Operator::Or => Operator::Or,
535 Operator::LogicalAnd => Operator::LogicalAnd,
536 Operator::LogicalOr => Operator::LogicalOr,
537 Operator::Xor => Operator::Xor,
538 Operator::NotEq => Operator::NotEq,
539 Operator::EqValidity => Operator::EqValidity,
540 Operator::NotEqValidity => Operator::NotEqValidity,
541 Operator::Divide => Operator::Multiply,
542 Operator::Multiply => Operator::Divide,
543 Operator::And => Operator::And,
544 Operator::Plus => Operator::Minus,
545 Operator::Minus => Operator::Plus,
546 Operator::Lt => Operator::Gt,
547 _ => unimplemented!(),
548 }
549 }
550
551 pub fn is_arithmetic(&self) -> bool {
552 !(self.is_comparison_or_bitwise())
553 }
554}