polars_plan/plans/python/
pyarrow.rs

1use std::fmt::Write;
2
3use polars_core::datatypes::AnyValue;
4use polars_core::prelude::{TimeUnit, TimeZone};
5
6use crate::prelude::*;
7
8#[derive(Default, Copy, Clone)]
9pub struct PyarrowArgs {
10    // pyarrow doesn't allow `filter([True, False])`
11    // but does allow `filter(field("a").isin([True, False]))`
12    allow_literal_series: bool,
13}
14
15fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String {
16    // note: `to_py_datetime` and the `Datetime`
17    // dtype have to be in-scope on the python side
18    match tz {
19        None => format!("to_py_datetime({},'{}')", v, tu.to_ascii()),
20        Some(tz) => format!("to_py_datetime({},'{}',{})", v, tu.to_ascii(), tz),
21    }
22}
23
24// convert to a pyarrow expression that can be evaluated with pythons eval
25pub fn predicate_to_pa(
26    predicate: Node,
27    expr_arena: &Arena<AExpr>,
28    args: PyarrowArgs,
29) -> Option<String> {
30    match expr_arena.get(predicate) {
31        AExpr::BinaryExpr { left, right, op } => {
32            if op.is_comparison_or_bitwise() {
33                let left = predicate_to_pa(*left, expr_arena, args)?;
34                let right = predicate_to_pa(*right, expr_arena, args)?;
35                Some(format!("({left} {op} {right})"))
36            } else {
37                None
38            }
39        },
40        AExpr::Column(name) => Some(format!("pa.compute.field('{name}')")),
41        AExpr::Literal(LiteralValue::Series(s)) => {
42            if !args.allow_literal_series || s.is_empty() || s.len() > 100 {
43                None
44            } else {
45                let mut list_repr = String::with_capacity(s.len() * 5);
46                list_repr.push('[');
47                for av in s.rechunk().iter() {
48                    match av {
49                        AnyValue::Boolean(v) => {
50                            let s = if v { "True" } else { "False" };
51                            write!(list_repr, "{s},").unwrap();
52                        },
53                        #[cfg(feature = "dtype-datetime")]
54                        AnyValue::Datetime(v, tu, tz) => {
55                            let dtm = to_py_datetime(v, &tu, tz);
56                            write!(list_repr, "{dtm},").unwrap();
57                        },
58                        #[cfg(feature = "dtype-date")]
59                        AnyValue::Date(v) => {
60                            write!(list_repr, "to_py_date({v}),").unwrap();
61                        },
62                        _ => {
63                            write!(list_repr, "{av},").unwrap();
64                        },
65                    }
66                }
67                // pop last comma
68                list_repr.pop();
69                list_repr.push(']');
70                Some(list_repr)
71            }
72        },
73        AExpr::Literal(lv) => {
74            let av = lv.to_any_value()?;
75            let dtype = av.dtype();
76            match av.as_borrowed() {
77                AnyValue::String(s) => Some(format!("'{s}'")),
78                AnyValue::Boolean(val) => {
79                    // python bools are capitalized
80                    if val {
81                        Some("pa.compute.scalar(True)".to_string())
82                    } else {
83                        Some("pa.compute.scalar(False)".to_string())
84                    }
85                },
86                #[cfg(feature = "dtype-date")]
87                AnyValue::Date(v) => {
88                    // the function `to_py_date` and the `Date`
89                    // dtype have to be in scope on the python side
90                    Some(format!("to_py_date({v})"))
91                },
92                #[cfg(feature = "dtype-datetime")]
93                AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz)),
94                // Activate once pyarrow supports them
95                // #[cfg(feature = "dtype-time")]
96                // AnyValue::Time(v) => {
97                //     // the function `to_py_time` has to be in scope
98                //     // on the python side
99                //     Some(format!("to_py_time(value={v})"))
100                // }
101                // #[cfg(feature = "dtype-duration")]
102                // AnyValue::Duration(v, tu) => {
103                //     // the function `to_py_timedelta` has to be in scope
104                //     // on the python side
105                //     Some(format!(
106                //         "to_py_timedelta(value={}, tu='{}')",
107                //         v,
108                //         tu.to_ascii()
109                //     ))
110                // }
111                av => {
112                    if dtype.is_float() {
113                        let val = av.extract::<f64>()?;
114                        Some(format!("{val}"))
115                    } else if dtype.is_integer() {
116                        let val = av.extract::<i64>()?;
117                        Some(format!("{val}"))
118                    } else {
119                        None
120                    }
121                },
122            }
123        },
124        #[cfg(feature = "is_in")]
125        AExpr::Function {
126            function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),
127            input,
128            ..
129        } => {
130            let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;
131            let mut args = args;
132            args.allow_literal_series = true;
133            let values = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;
134
135            Some(format!("({col}).isin({values})"))
136        },
137        #[cfg(feature = "is_between")]
138        AExpr::Function {
139            function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),
140            input,
141            ..
142        } => {
143            if !matches!(expr_arena.get(input.first()?.node()), AExpr::Column(_)) {
144                None
145            } else {
146                let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;
147                let left_cmp_op = match closed {
148                    ClosedInterval::None | ClosedInterval::Right => Operator::Gt,
149                    ClosedInterval::Both | ClosedInterval::Left => Operator::GtEq,
150                };
151                let right_cmp_op = match closed {
152                    ClosedInterval::None | ClosedInterval::Left => Operator::Lt,
153                    ClosedInterval::Both | ClosedInterval::Right => Operator::LtEq,
154                };
155
156                let lower = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;
157                let upper = predicate_to_pa(input.get(2)?.node(), expr_arena, args)?;
158
159                Some(format!(
160                    "(({col} {left_cmp_op} {lower}) & ({col} {right_cmp_op} {upper}))"
161                ))
162            }
163        },
164        AExpr::Function {
165            function, input, ..
166        } => {
167            let input = input.first().unwrap().node();
168            let input = predicate_to_pa(input, expr_arena, args)?;
169
170            match function {
171                IRFunctionExpr::Boolean(IRBooleanFunction::Not) => Some(format!("~({input})")),
172                IRFunctionExpr::Boolean(IRBooleanFunction::IsNull) => {
173                    Some(format!("({input}).is_null()"))
174                },
175                IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull) => {
176                    Some(format!("~({input}).is_null()"))
177                },
178                _ => None,
179            }
180        },
181        _ => None,
182    }
183}