1mod expr_dyn_fn;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{Hash, Hasher};
4
5use bytes::Bytes;
6pub use expr_dyn_fn::*;
7use polars_compute::rolling::QuantileMethod;
8use polars_core::chunked_array::cast::CastOptions;
9use polars_core::error::feature_gated;
10use polars_core::prelude::*;
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13#[cfg(feature = "serde")]
14pub mod named_serde;
15#[cfg(feature = "serde")]
16mod serde_expr;
17
18use crate::prelude::*;
19
20#[derive(PartialEq, Clone, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22pub enum AggExpr {
23 Min {
24 input: Arc<Expr>,
25 propagate_nans: bool,
26 },
27 Max {
28 input: Arc<Expr>,
29 propagate_nans: bool,
30 },
31 Median(Arc<Expr>),
32 NUnique(Arc<Expr>),
33 First(Arc<Expr>),
34 Last(Arc<Expr>),
35 Mean(Arc<Expr>),
36 Implode(Arc<Expr>),
37 Count(Arc<Expr>, bool),
39 Quantile {
40 expr: Arc<Expr>,
41 quantile: Arc<Expr>,
42 method: QuantileMethod,
43 },
44 Sum(Arc<Expr>),
45 AggGroups(Arc<Expr>),
46 Std(Arc<Expr>, u8),
47 Var(Arc<Expr>, u8),
48}
49
50impl AsRef<Expr> for AggExpr {
51 fn as_ref(&self) -> &Expr {
52 use AggExpr::*;
53 match self {
54 Min { input, .. } => input,
55 Max { input, .. } => input,
56 Median(e) => e,
57 NUnique(e) => e,
58 First(e) => e,
59 Last(e) => e,
60 Mean(e) => e,
61 Implode(e) => e,
62 Count(e, _) => e,
63 Quantile { expr, .. } => expr,
64 Sum(e) => e,
65 AggGroups(e) => e,
66 Std(e, _) => e,
67 Var(e, _) => e,
68 }
69 }
70}
71
72#[derive(Clone, PartialEq)]
78#[must_use]
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
80pub enum Expr {
81 Alias(Arc<Expr>, PlSmallStr),
82 Column(PlSmallStr),
83 Columns(Arc<[PlSmallStr]>),
84 DtypeColumn(Vec<DataType>),
85 IndexColumn(Arc<[i64]>),
86 Literal(LiteralValue),
87 BinaryExpr {
88 left: Arc<Expr>,
89 op: Operator,
90 right: Arc<Expr>,
91 },
92 Cast {
93 expr: Arc<Expr>,
94 dtype: DataType,
95 options: CastOptions,
96 },
97 Sort {
98 expr: Arc<Expr>,
99 options: SortOptions,
100 },
101 Gather {
102 expr: Arc<Expr>,
103 idx: Arc<Expr>,
104 returns_scalar: bool,
105 },
106 SortBy {
107 expr: Arc<Expr>,
108 by: Vec<Expr>,
109 sort_options: SortMultipleOptions,
110 },
111 Agg(AggExpr),
112 Ternary {
115 predicate: Arc<Expr>,
116 truthy: Arc<Expr>,
117 falsy: Arc<Expr>,
118 },
119 Function {
120 input: Vec<Expr>,
122 function: FunctionExpr,
124 options: FunctionOptions,
125 },
126 Explode {
127 input: Arc<Expr>,
128 skip_empty: bool,
129 },
130 Filter {
131 input: Arc<Expr>,
132 by: Arc<Expr>,
133 },
134 Window {
136 function: Arc<Expr>,
138 partition_by: Vec<Expr>,
139 order_by: Option<(Arc<Expr>, SortOptions)>,
140 options: WindowType,
141 },
142 Wildcard,
143 Slice {
144 input: Arc<Expr>,
145 offset: Arc<Expr>,
147 length: Arc<Expr>,
148 },
149 Exclude(Arc<Expr>, Vec<Excluded>),
152 KeepName(Arc<Expr>),
154 Len,
155 Nth(i64),
157 RenameAlias {
158 function: SpecialEq<Arc<dyn RenameAliasFn>>,
159 expr: Arc<Expr>,
160 },
161 #[cfg(feature = "dtype-struct")]
162 Field(Arc<[PlSmallStr]>),
163 AnonymousFunction {
164 input: Vec<Expr>,
166 function: OpaqueColumnUdf,
168 output_type: GetOutput,
170 options: FunctionOptions,
171 },
172 SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
173 Selector(super::selector::Selector),
180}
181
182#[derive(Clone)]
183pub enum LazySerde<T: Clone> {
184 Deserialized(T),
185 Bytes(Bytes),
186 Named {
188 name: String,
191 payload: Option<Bytes>,
192 value: Option<T>,
195 },
196}
197
198impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
199 fn eq(&self, other: &Self) -> bool {
200 use LazySerde as L;
201 match (self, other) {
202 (L::Deserialized(a), L::Deserialized(b)) => a == b,
203 (L::Bytes(a), L::Bytes(b)) => {
204 std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
205 },
206 (
207 L::Named {
208 name: l,
209 payload: pl,
210 value: _,
211 },
212 L::Named {
213 name: r,
214 payload: pr,
215 value: _,
216 },
217 ) => {
218 #[cfg(debug_assertions)]
219 {
220 if l == r {
221 assert_eq!(pl, pr, "name should point to unique payload")
222 }
223 }
224 _ = pl;
225 _ = pr;
226 l == r
227 },
228 _ => false,
229 }
230 }
231}
232
233impl<T: Clone> Debug for LazySerde<T> {
234 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
235 match self {
236 Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
237 Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
238 Self::Named {
239 name,
240 payload: _,
241 value: _,
242 } => write!(f, "lazy-serde<Named>: {}", name),
243 }
244 }
245}
246
247#[allow(clippy::derived_hash_with_manual_eq)]
248impl Hash for Expr {
249 fn hash<H: Hasher>(&self, state: &mut H) {
250 let d = std::mem::discriminant(self);
251 d.hash(state);
252 match self {
253 Expr::Column(name) => name.hash(state),
254 Expr::Columns(names) => names.hash(state),
255 Expr::DtypeColumn(dtypes) => dtypes.hash(state),
256 Expr::IndexColumn(indices) => indices.hash(state),
257 Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
258 Expr::Selector(s) => s.hash(state),
259 Expr::Nth(v) => v.hash(state),
260 Expr::Filter { input, by } => {
261 input.hash(state);
262 by.hash(state);
263 },
264 Expr::BinaryExpr { left, op, right } => {
265 left.hash(state);
266 right.hash(state);
267 std::mem::discriminant(op).hash(state)
268 },
269 Expr::Cast {
270 expr,
271 dtype,
272 options: strict,
273 } => {
274 expr.hash(state);
275 dtype.hash(state);
276 strict.hash(state)
277 },
278 Expr::Sort { expr, options } => {
279 expr.hash(state);
280 options.hash(state);
281 },
282 Expr::Alias(input, name) => {
283 input.hash(state);
284 name.hash(state)
285 },
286 Expr::KeepName(input) => input.hash(state),
287 Expr::Ternary {
288 predicate,
289 truthy,
290 falsy,
291 } => {
292 predicate.hash(state);
293 truthy.hash(state);
294 falsy.hash(state);
295 },
296 Expr::Function {
297 input,
298 function,
299 options,
300 } => {
301 input.hash(state);
302 std::mem::discriminant(function).hash(state);
303 options.hash(state);
304 },
305 Expr::Gather {
306 expr,
307 idx,
308 returns_scalar,
309 } => {
310 expr.hash(state);
311 idx.hash(state);
312 returns_scalar.hash(state);
313 },
314 Expr::Wildcard | Expr::Len => {},
316 Expr::SortBy {
317 expr,
318 by,
319 sort_options,
320 } => {
321 expr.hash(state);
322 by.hash(state);
323 sort_options.hash(state);
324 },
325 Expr::Agg(input) => input.hash(state),
326 Expr::Explode { input, skip_empty } => {
327 skip_empty.hash(state);
328 input.hash(state)
329 },
330 Expr::Window {
331 function,
332 partition_by,
333 order_by,
334 options,
335 } => {
336 function.hash(state);
337 partition_by.hash(state);
338 order_by.hash(state);
339 options.hash(state);
340 },
341 Expr::Slice {
342 input,
343 offset,
344 length,
345 } => {
346 input.hash(state);
347 offset.hash(state);
348 length.hash(state);
349 },
350 Expr::Exclude(input, excl) => {
351 input.hash(state);
352 excl.hash(state);
353 },
354 Expr::RenameAlias { function: _, expr } => expr.hash(state),
355 Expr::AnonymousFunction {
356 input,
357 function: _,
358 output_type: _,
359 options,
360 } => {
361 input.hash(state);
362 options.hash(state);
363 },
364 Expr::SubPlan(_, names) => names.hash(state),
365 #[cfg(feature = "dtype-struct")]
366 Expr::Field(names) => names.hash(state),
367 }
368 }
369}
370
371impl Eq for Expr {}
372
373impl Default for Expr {
374 fn default() -> Self {
375 Expr::Literal(LiteralValue::Scalar(Scalar::default()))
376 }
377}
378
379#[derive(Debug, Clone, PartialEq, Eq, Hash)]
380#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
381pub enum Excluded {
382 Name(PlSmallStr),
383 Dtype(DataType),
384}
385
386impl Expr {
387 pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
389 let mut arena = Arena::with_capacity(5);
391 self.to_field_amortized(schema, ctxt, &mut arena)
392 }
393 pub(crate) fn to_field_amortized(
394 &self,
395 schema: &Schema,
396 ctxt: Context,
397 expr_arena: &mut Arena<AExpr>,
398 ) -> PolarsResult<Field> {
399 let root = to_aexpr(self.clone(), expr_arena)?;
400 expr_arena
401 .get(root)
402 .to_field_and_validate(schema, ctxt, expr_arena)
403 }
404
405 pub fn extract_usize(&self) -> PolarsResult<usize> {
407 match self {
408 Expr::Literal(n) => n.extract_usize(),
409 Expr::Cast { expr, dtype, .. } => {
410 if dtype.is_integer() {
412 expr.extract_usize()
413 } else {
414 polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
415 }
416 },
417 _ => {
418 polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
419 },
420 }
421 }
422
423 #[inline]
424 pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
425 Expr::n_ary(function, vec![self])
426 }
427 #[inline]
428 pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
429 Expr::n_ary(function, vec![self, rhs])
430 }
431
432 #[inline]
433 pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
434 Expr::n_ary(function, vec![self, arg1, arg2])
435 }
436
437 #[inline]
438 pub fn try_map_n_ary(
439 self,
440 function: impl Into<FunctionExpr>,
441 exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
442 ) -> PolarsResult<Expr> {
443 let exprs = exprs.into_iter();
444 let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
445 input.push(self);
446 for e in exprs {
447 input.push(e?);
448 }
449 Ok(Expr::n_ary(function, input))
450 }
451
452 #[inline]
453 pub fn map_n_ary(
454 self,
455 function: impl Into<FunctionExpr>,
456 exprs: impl IntoIterator<Item = Expr>,
457 ) -> Expr {
458 let exprs = exprs.into_iter();
459 let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
460 input.push(self);
461 input.extend(exprs);
462 Expr::n_ary(function, input)
463 }
464
465 #[inline]
466 pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
467 let function = function.into();
468 let options = function.function_options();
469 Expr::Function {
470 input,
471 function,
472 options,
473 }
474 }
475}
476
477#[derive(Copy, Clone, PartialEq, Eq, Hash)]
478#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
479pub enum Operator {
480 Eq,
481 EqValidity,
482 NotEq,
483 NotEqValidity,
484 Lt,
485 LtEq,
486 Gt,
487 GtEq,
488 Plus,
489 Minus,
490 Multiply,
491 Divide,
492 TrueDivide,
493 FloorDivide,
494 Modulus,
495 And,
496 Or,
497 Xor,
498 LogicalAnd,
499 LogicalOr,
500}
501
502impl Display for Operator {
503 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504 use Operator::*;
505 let tkn = match self {
506 Eq => "==",
507 EqValidity => "==v",
508 NotEq => "!=",
509 NotEqValidity => "!=v",
510 Lt => "<",
511 LtEq => "<=",
512 Gt => ">",
513 GtEq => ">=",
514 Plus => "+",
515 Minus => "-",
516 Multiply => "*",
517 Divide => "//",
518 TrueDivide => "/",
519 FloorDivide => "floor_div",
520 Modulus => "%",
521 And | LogicalAnd => "&",
522 Or | LogicalOr => "|",
523 Xor => "^",
524 };
525 write!(f, "{tkn}")
526 }
527}
528
529impl Operator {
530 pub fn is_comparison(&self) -> bool {
531 matches!(
532 self,
533 Self::Eq
534 | Self::NotEq
535 | Self::Lt
536 | Self::LtEq
537 | Self::Gt
538 | Self::GtEq
539 | Self::EqValidity
540 | Self::NotEqValidity
541 )
542 }
543
544 pub fn is_bitwise(&self) -> bool {
545 matches!(self, Self::And | Self::Or | Self::Xor)
546 }
547
548 pub fn is_comparison_or_bitwise(&self) -> bool {
549 self.is_comparison() || self.is_bitwise()
550 }
551
552 pub fn swap_operands(self) -> Self {
553 match self {
554 Operator::Eq => Operator::Eq,
555 Operator::Gt => Operator::Lt,
556 Operator::GtEq => Operator::LtEq,
557 Operator::LtEq => Operator::GtEq,
558 Operator::Or => Operator::Or,
559 Operator::LogicalAnd => Operator::LogicalAnd,
560 Operator::LogicalOr => Operator::LogicalOr,
561 Operator::Xor => Operator::Xor,
562 Operator::NotEq => Operator::NotEq,
563 Operator::EqValidity => Operator::EqValidity,
564 Operator::NotEqValidity => Operator::NotEqValidity,
565 Operator::Divide => Operator::Multiply,
566 Operator::Multiply => Operator::Divide,
567 Operator::And => Operator::And,
568 Operator::Plus => Operator::Minus,
569 Operator::Minus => Operator::Plus,
570 Operator::Lt => Operator::Gt,
571 _ => unimplemented!(),
572 }
573 }
574
575 pub fn is_arithmetic(&self) -> bool {
576 !(self.is_comparison_or_bitwise())
577 }
578}