Skip to main content

alopex_dataframe/expr/
expr.rs

1/// Expression AST used by `DataFrame` and `LazyFrame`.
2#[derive(Debug, Clone, PartialEq)]
3pub enum Expr {
4    /// Column reference.
5    Column(String),
6    /// Literal scalar value.
7    Literal(Scalar),
8    /// Binary operator expression.
9    BinaryOp {
10        left: Box<Expr>,
11        op: Operator,
12        right: Box<Expr>,
13    },
14    /// Unary operator expression.
15    UnaryOp { op: UnaryOperator, expr: Box<Expr> },
16    /// Aggregation expression (only valid under `group_by().agg()`).
17    Agg { func: AggFunc, expr: Box<Expr> },
18    /// Expression alias (renames the resulting column).
19    Alias { expr: Box<Expr>, name: String },
20    /// Wildcard (`*`) that expands to all columns in projections.
21    Wildcard,
22}
23
24/// Supported binary operators.
25#[derive(Debug, Copy, Clone, PartialEq, Eq)]
26pub enum Operator {
27    /// Addition.
28    Add,
29    /// Subtraction.
30    Sub,
31    /// Multiplication.
32    Mul,
33    /// Division.
34    Div,
35    /// Equality.
36    Eq,
37    /// Inequality.
38    Neq,
39    /// Greater-than.
40    Gt,
41    /// Less-than.
42    Lt,
43    /// Greater-than-or-equal.
44    Ge,
45    /// Less-than-or-equal.
46    Le,
47    /// Boolean AND.
48    And,
49    /// Boolean OR.
50    Or,
51}
52
53/// Supported unary operators.
54#[derive(Debug, Copy, Clone, PartialEq, Eq)]
55pub enum UnaryOperator {
56    /// Boolean NOT.
57    Not,
58}
59
60/// Supported aggregation functions.
61#[derive(Debug, Copy, Clone, PartialEq, Eq)]
62pub enum AggFunc {
63    /// Sum of non-null values.
64    Sum,
65    /// Mean of non-null values.
66    Mean,
67    /// Count of non-null values.
68    Count,
69    /// Minimum of non-null values.
70    Min,
71    /// Maximum of non-null values.
72    Max,
73}
74
75/// Scalar literal values.
76#[derive(Debug, Clone, PartialEq)]
77pub enum Scalar {
78    /// Null literal.
79    Null,
80    /// Boolean literal.
81    Boolean(bool),
82    /// 64-bit integer literal.
83    Int64(i64),
84    /// 64-bit float literal.
85    Float64(f64),
86    /// UTF-8 string literal.
87    Utf8(String),
88}
89
90impl From<()> for Scalar {
91    fn from(_: ()) -> Self {
92        Scalar::Null
93    }
94}
95
96impl From<bool> for Scalar {
97    fn from(v: bool) -> Self {
98        Scalar::Boolean(v)
99    }
100}
101
102impl From<i64> for Scalar {
103    fn from(v: i64) -> Self {
104        Scalar::Int64(v)
105    }
106}
107
108impl From<f64> for Scalar {
109    fn from(v: f64) -> Self {
110        Scalar::Float64(v)
111    }
112}
113
114impl From<String> for Scalar {
115    fn from(v: String) -> Self {
116        Scalar::Utf8(v)
117    }
118}
119
120impl From<&str> for Scalar {
121    fn from(v: &str) -> Self {
122        Scalar::Utf8(v.to_string())
123    }
124}
125
126impl Expr {
127    /// Alias this expression (used to name output columns).
128    pub fn alias(self, name: impl Into<String>) -> Expr {
129        Expr::Alias {
130            expr: Box::new(self),
131            name: name.into(),
132        }
133    }
134
135    /// Build an addition expression.
136    #[allow(clippy::should_implement_trait)]
137    pub fn add(self, rhs: Expr) -> Expr {
138        Expr::BinaryOp {
139            left: Box::new(self),
140            op: Operator::Add,
141            right: Box::new(rhs),
142        }
143    }
144
145    /// Build a subtraction expression.
146    #[allow(clippy::should_implement_trait)]
147    pub fn sub(self, rhs: Expr) -> Expr {
148        Expr::BinaryOp {
149            left: Box::new(self),
150            op: Operator::Sub,
151            right: Box::new(rhs),
152        }
153    }
154
155    /// Build a multiplication expression.
156    #[allow(clippy::should_implement_trait)]
157    pub fn mul(self, rhs: Expr) -> Expr {
158        Expr::BinaryOp {
159            left: Box::new(self),
160            op: Operator::Mul,
161            right: Box::new(rhs),
162        }
163    }
164
165    /// Build a division expression.
166    #[allow(clippy::should_implement_trait)]
167    pub fn div(self, rhs: Expr) -> Expr {
168        Expr::BinaryOp {
169            left: Box::new(self),
170            op: Operator::Div,
171            right: Box::new(rhs),
172        }
173    }
174
175    /// Build an equality predicate.
176    pub fn eq(self, rhs: Expr) -> Expr {
177        Expr::BinaryOp {
178            left: Box::new(self),
179            op: Operator::Eq,
180            right: Box::new(rhs),
181        }
182    }
183
184    /// Build an inequality predicate.
185    pub fn neq(self, rhs: Expr) -> Expr {
186        Expr::BinaryOp {
187            left: Box::new(self),
188            op: Operator::Neq,
189            right: Box::new(rhs),
190        }
191    }
192
193    /// Build a greater-than predicate.
194    pub fn gt(self, rhs: Expr) -> Expr {
195        Expr::BinaryOp {
196            left: Box::new(self),
197            op: Operator::Gt,
198            right: Box::new(rhs),
199        }
200    }
201
202    /// Build a less-than predicate.
203    pub fn lt(self, rhs: Expr) -> Expr {
204        Expr::BinaryOp {
205            left: Box::new(self),
206            op: Operator::Lt,
207            right: Box::new(rhs),
208        }
209    }
210
211    /// Build a greater-than-or-equal predicate.
212    pub fn ge(self, rhs: Expr) -> Expr {
213        Expr::BinaryOp {
214            left: Box::new(self),
215            op: Operator::Ge,
216            right: Box::new(rhs),
217        }
218    }
219
220    /// Build a less-than-or-equal predicate.
221    pub fn le(self, rhs: Expr) -> Expr {
222        Expr::BinaryOp {
223            left: Box::new(self),
224            op: Operator::Le,
225            right: Box::new(rhs),
226        }
227    }
228
229    /// Build a boolean AND predicate.
230    pub fn and_(self, rhs: Expr) -> Expr {
231        Expr::BinaryOp {
232            left: Box::new(self),
233            op: Operator::And,
234            right: Box::new(rhs),
235        }
236    }
237
238    /// Build a boolean OR predicate.
239    pub fn or_(self, rhs: Expr) -> Expr {
240        Expr::BinaryOp {
241            left: Box::new(self),
242            op: Operator::Or,
243            right: Box::new(rhs),
244        }
245    }
246
247    /// Build a boolean NOT predicate.
248    pub fn not_(self) -> Expr {
249        Expr::UnaryOp {
250            op: UnaryOperator::Not,
251            expr: Box::new(self),
252        }
253    }
254
255    /// Build a `sum` aggregation.
256    pub fn sum(self) -> Expr {
257        Expr::Agg {
258            func: AggFunc::Sum,
259            expr: Box::new(self),
260        }
261    }
262
263    /// Build a `mean` aggregation.
264    pub fn mean(self) -> Expr {
265        Expr::Agg {
266            func: AggFunc::Mean,
267            expr: Box::new(self),
268        }
269    }
270
271    /// Build a `count` aggregation (nulls excluded).
272    pub fn count(self) -> Expr {
273        Expr::Agg {
274            func: AggFunc::Count,
275            expr: Box::new(self),
276        }
277    }
278
279    /// Build a `min` aggregation.
280    pub fn min(self) -> Expr {
281        Expr::Agg {
282            func: AggFunc::Min,
283            expr: Box::new(self),
284        }
285    }
286
287    /// Build a `max` aggregation.
288    pub fn max(self) -> Expr {
289        Expr::Agg {
290            func: AggFunc::Max,
291            expr: Box::new(self),
292        }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::{AggFunc, Expr, Operator, Scalar, UnaryOperator};
299    use crate::expr::{col, lit};
300
301    #[test]
302    fn builder_and_chaining_works() {
303        let expr = col("a").add(lit(1_i64)).alias("b");
304        assert_eq!(
305            expr,
306            Expr::Alias {
307                expr: Box::new(Expr::BinaryOp {
308                    left: Box::new(Expr::Column("a".to_string())),
309                    op: Operator::Add,
310                    right: Box::new(Expr::Literal(Scalar::Int64(1))),
311                }),
312                name: "b".to_string(),
313            }
314        );
315    }
316
317    #[test]
318    fn logical_and_agg_works() {
319        let expr = col("x")
320            .gt(lit(1_i64))
321            .and_(col("y").lt(lit(10_i64)).not_())
322            .alias("p");
323
324        assert!(matches!(
325            expr,
326            Expr::Alias {
327                expr: _,
328                name
329            } if name == "p"
330        ));
331
332        let agg = col("v").sum();
333        assert_eq!(
334            agg,
335            Expr::Agg {
336                func: AggFunc::Sum,
337                expr: Box::new(Expr::Column("v".to_string()))
338            }
339        );
340
341        let u = Expr::Column("a".to_string()).not_();
342        assert_eq!(
343            u,
344            Expr::UnaryOp {
345                op: UnaryOperator::Not,
346                expr: Box::new(Expr::Column("a".to_string()))
347            }
348        );
349    }
350}