Skip to main content

datafusion_physical_expr_common/
physical_expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::fmt::{Debug, Display, Formatter};
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use crate::utils::scatter;
25
26use arrow::array::{Array, ArrayRef, BooleanArray, new_empty_array};
27use arrow::compute::filter_record_batch;
28use arrow::datatypes::{DataType, Field, FieldRef, Schema};
29use arrow::record_batch::RecordBatch;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{
34    Result, ScalarValue, assert_eq_or_internal_err, exec_err, not_impl_err,
35};
36use datafusion_expr_common::columnar_value::ColumnarValue;
37use datafusion_expr_common::interval_arithmetic::Interval;
38use datafusion_expr_common::placement::ExpressionPlacement;
39use datafusion_expr_common::sort_properties::ExprProperties;
40#[expect(deprecated)]
41use datafusion_expr_common::statistics::Distribution;
42
43use itertools::izip;
44
45/// Shared [`PhysicalExpr`].
46pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
47
48/// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`.
49///
50/// `PhysicalExpr` knows its type, nullability and can be evaluated directly on
51/// a [`RecordBatch`] (see [`Self::evaluate`]).
52///
53/// `PhysicalExpr` are the physical counterpart to [`Expr`] used in logical
54/// planning. They are typically created from [`Expr`] by a [`PhysicalPlanner`]
55/// invoked from a higher level API
56///
57/// Some important examples of `PhysicalExpr` are:
58/// * [`Column`]: Represents a column at a given index in a RecordBatch
59///
60/// To create `PhysicalExpr` from  `Expr`, see
61/// * [`SessionContext::create_physical_expr`]: A high level API
62/// * [`create_physical_expr`]: A low level API
63///
64/// # Formatting `PhysicalExpr` as strings
65/// There are three ways to format `PhysicalExpr` as a string:
66/// * [`Debug`]: Standard Rust debugging format (e.g. `Constant { value: ... }`)
67/// * [`Display`]: Detailed SQL-like format that shows expression structure (e.g. (`Utf8 ("foobar")`). This is often used for debugging and tests
68/// * [`Self::fmt_sql`]: SQL-like human readable format (e.g. ('foobar')`), See also [`sql_fmt`]
69///
70/// [`SessionContext::create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.create_physical_expr
71/// [`PhysicalPlanner`]: https://docs.rs/datafusion/latest/datafusion/physical_planner/trait.PhysicalPlanner.html
72/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
73/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
74/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
75pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
76    /// Get the data type of this expression, given the schema of the input.
77    /// Returns an error if the data type cannot be determined, ex. if the
78    /// schema is missing a required field.
79    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
80        Ok(self.return_field(input_schema)?.data_type().to_owned())
81    }
82    /// Determine whether this expression is nullable, given the schema of the input
83    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
84        Ok(self.return_field(input_schema)?.is_nullable())
85    }
86    /// Evaluate an expression against a RecordBatch
87    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
88    /// The output field associated with this expression
89    fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
90        Ok(Arc::new(Field::new(
91            format!("{self}"),
92            self.data_type(input_schema)?,
93            self.nullable(input_schema)?,
94        )))
95    }
96    /// Evaluate an expression against a RecordBatch after first applying a validity array
97    ///
98    /// # Errors
99    ///
100    /// Returns an `Err` if the expression could not be evaluated or if the length of the
101    /// `selection` validity array and the number of row in `batch` is not equal.
102    fn evaluate_selection(
103        &self,
104        batch: &RecordBatch,
105        selection: &BooleanArray,
106    ) -> Result<ColumnarValue> {
107        let row_count = batch.num_rows();
108        if row_count != selection.len() {
109            return exec_err!(
110                "Selection array length does not match batch row count: {} != {row_count}",
111                selection.len()
112            );
113        }
114
115        // First, check if we can avoid filtering altogether.
116        if selection.null_count() == 0 && !selection.has_false() {
117            // All values from the `selection` filter are true and match the input batch.
118            // No need to perform any filtering.
119            return self.evaluate(batch);
120        }
121
122        // Next, prepare the result array for each 'true' row in the selection vector.
123        let filtered_result = if !selection.has_true() {
124            // Do not call `evaluate` when the selection is empty.
125            // `evaluate_selection` is used to conditionally evaluate expressions.
126            // When the expression in question is fallible, evaluating it with an empty
127            // record batch may trigger a runtime error (e.g. division by zero).
128            //
129            // Instead, create an empty array matching the expected return type.
130            let datatype = self.data_type(batch.schema_ref().as_ref())?;
131            ColumnarValue::Array(new_empty_array(&datatype))
132        } else {
133            // If we reach this point, there's no other option than to filter the batch.
134            // This is a fairly costly operation since it requires creating partial copies
135            // (worst case of length `row_count - 1`) of all the arrays in the record batch.
136            // The resulting `filtered_batch` will contain one row per true in `selection`.
137            let filtered_batch = filter_record_batch(batch, selection)?;
138            self.evaluate(&filtered_batch)?
139        };
140
141        // Finally, scatter the filtered result array so that the indices match the input rows again.
142        match &filtered_result {
143            ColumnarValue::Array(a) => {
144                scatter(selection, a.as_ref()).map(ColumnarValue::Array)
145            }
146            ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
147                // When the scalar is true or false, skip the scatter process
148                if let Some(v) = value {
149                    if *v {
150                        Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
151                    } else {
152                        Ok(filtered_result)
153                    }
154                } else {
155                    let array = BooleanArray::from(vec![None; row_count]);
156                    scatter(selection, &array).map(ColumnarValue::Array)
157                }
158            }
159            ColumnarValue::Scalar(_) => Ok(filtered_result),
160        }
161    }
162
163    /// Get a list of child PhysicalExpr that provide the input for this expr.
164    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
165
166    /// Returns a new PhysicalExpr where all children were replaced by new exprs.
167    ///
168    /// If the implementation returns a [`PhysicalExpr::expression_id`], then
169    /// the identifier should be preserved by the new expression.
170    fn with_new_children(
171        self: Arc<Self>,
172        children: Vec<Arc<dyn PhysicalExpr>>,
173    ) -> Result<Arc<dyn PhysicalExpr>>;
174
175    /// Computes the output interval for the expression, given the input
176    /// intervals.
177    ///
178    /// # Parameters
179    ///
180    /// * `children` are the intervals for the children (inputs) of this
181    ///   expression.
182    ///
183    /// # Returns
184    ///
185    /// A `Result` containing the output interval for the expression in
186    /// case of success, or an error object in case of failure.
187    ///
188    /// Note that the output bounds must form an **envelope** that contains all
189    /// possible outputs of the expression given the input bounds. While
190    /// expressions should output the tightest possible bounds, they do not need
191    /// to be exact and can be conservative.
192    ///
193    /// # Example
194    ///
195    /// If the expression is `a + b`, and the input intervals are `a: [1, 2]`
196    /// and `b: [3, 4]`, then the output interval would be `[4, 6]`.
197    ///
198    /// If the expression is `sin(a)`, it is correct (though not precise) to
199    /// produce the interval `[-1, 1]` for any input interval for `a`.
200    fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
201        not_impl_err!("Not implemented for {self}")
202    }
203
204    /// Updates bounds for child expressions, given a known interval for this
205    /// expression.
206    ///
207    /// This is used to propagate constraints down through an expression tree.
208    ///
209    /// # Parameters
210    ///
211    /// * `interval` is the currently known interval for this expression.
212    /// * `children` are the current intervals for the children of this expression.
213    ///
214    /// # Returns
215    ///
216    /// A `Result` containing a `Vec` of new intervals for the children (in order)
217    /// in case of success, or an error object in case of failure.
218    ///
219    /// If constraint propagation reveals an infeasibility for any child, returns
220    /// [`None`]. If none of the children intervals change as a result of
221    /// propagation, may return an empty vector instead of cloning `children`.
222    /// This is the default (and conservative) return value.
223    ///
224    /// # Example
225    ///
226    /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the
227    /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then
228    /// propagation would return `[0, 2]` and `[2, 4]` as `b` must be at least
229    /// `2` to make the output at least `4`.
230    fn propagate_constraints(
231        &self,
232        _interval: &Interval,
233        _children: &[&Interval],
234    ) -> Result<Option<Vec<Interval>>> {
235        Ok(Some(vec![]))
236    }
237
238    /// Computes the output statistics for the expression, given the input
239    /// statistics.
240    ///
241    /// # Parameters
242    ///
243    /// * `children` are the statistics for the children (inputs) of this
244    ///   expression.
245    ///
246    /// # Returns
247    ///
248    /// A `Result` containing the output statistics for the expression in
249    /// case of success, or an error object in case of failure.
250    ///
251    /// Expressions (should) implement this function and utilize the independence
252    /// assumption, match on children distribution types and compute the output
253    /// statistics accordingly. The default implementation simply creates an
254    /// unknown output distribution by combining input ranges. This logic loses
255    /// distribution information, but is a safe default.
256    #[deprecated(
257        since = "54.0.0",
258        note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
259    )]
260    #[expect(deprecated)]
261    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
262        let children_ranges = children
263            .iter()
264            .map(|c| c.range())
265            .collect::<Result<Vec<_>>>()?;
266        let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
267        let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
268        let dt = output_interval.data_type();
269        if dt.eq(&DataType::Boolean) {
270            let p = if output_interval.eq(&Interval::TRUE) {
271                ScalarValue::new_one(&dt)
272            } else if output_interval.eq(&Interval::FALSE) {
273                ScalarValue::new_zero(&dt)
274            } else {
275                ScalarValue::try_from(&dt)
276            }?;
277            Distribution::new_bernoulli(p)
278        } else {
279            Distribution::new_from_interval(output_interval)
280        }
281    }
282
283    /// Updates children statistics using the given parent statistic for this
284    /// expression.
285    ///
286    /// This is used to propagate statistics down through an expression tree.
287    ///
288    /// # Parameters
289    ///
290    /// * `parent` is the currently known statistics for this expression.
291    /// * `children` are the current statistics for the children of this expression.
292    ///
293    /// # Returns
294    ///
295    /// A `Result` containing a `Vec` of new statistics for the children (in order)
296    /// in case of success, or an error object in case of failure.
297    ///
298    /// If statistics propagation reveals an infeasibility for any child, returns
299    /// [`None`]. If none of the children statistics change as a result of
300    /// propagation, may return an empty vector instead of cloning `children`.
301    /// This is the default (and conservative) return value.
302    ///
303    /// Expressions (should) implement this function and apply Bayes rule to
304    /// reconcile and update parent/children statistics. This involves utilizing
305    /// the independence assumption, and matching on distribution types. The
306    /// default implementation simply creates an unknown distribution if it can
307    /// narrow the range by propagating ranges. This logic loses distribution
308    /// information, but is a safe default.
309    #[deprecated(
310        since = "54.0.0",
311        note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071"
312    )]
313    #[expect(deprecated)]
314    fn propagate_statistics(
315        &self,
316        parent: &Distribution,
317        children: &[&Distribution],
318    ) -> Result<Option<Vec<Distribution>>> {
319        let children_ranges = children
320            .iter()
321            .map(|c| c.range())
322            .collect::<Result<Vec<_>>>()?;
323        let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
324        let parent_range = parent.range()?;
325        let Some(propagated_children) =
326            self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
327        else {
328            return Ok(None);
329        };
330        izip!(propagated_children.into_iter(), children_ranges, children)
331            .map(|(new_interval, old_interval, child)| {
332                if new_interval == old_interval {
333                    // We weren't able to narrow the range, preserve the old statistics.
334                    Ok((*child).clone())
335                } else if new_interval.data_type().eq(&DataType::Boolean) {
336                    let dt = old_interval.data_type();
337                    let p = if new_interval.eq(&Interval::TRUE) {
338                        ScalarValue::new_one(&dt)
339                    } else if new_interval.eq(&Interval::FALSE) {
340                        ScalarValue::new_zero(&dt)
341                    } else {
342                        unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
343                    }?;
344                    Distribution::new_bernoulli(p)
345                } else {
346                    Distribution::new_from_interval(new_interval)
347                }
348            })
349            .collect::<Result<_>>()
350            .map(Some)
351    }
352
353    /// Calculates the properties of this [`PhysicalExpr`] based on its
354    /// children's properties (i.e. order and range), recursively aggregating
355    /// the information from its children. In cases where the [`PhysicalExpr`]
356    /// has no children (e.g., `Literal` or `Column`), these properties should
357    /// be specified externally, as the function defaults to unknown properties.
358    fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
359        Ok(ExprProperties::new_unknown())
360    }
361
362    /// Format this `PhysicalExpr` in nice human readable "SQL" format
363    ///
364    /// Specifically, this format is designed to be readable by humans, at the
365    /// expense of details. Use `Display` or `Debug` for more detailed
366    /// representation.
367    ///
368    /// See the [`fmt_sql`] function for an example of printing `PhysicalExpr`s as SQL.
369    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result;
370
371    /// Take a snapshot of this `PhysicalExpr`, if it is dynamic.
372    ///
373    /// "Dynamic" in this case means containing references to structures that may change
374    /// during plan execution, such as hash tables.
375    ///
376    /// This method is used to capture the current state of `PhysicalExpr`s that may contain
377    /// dynamic references to other operators in order to serialize it over the wire
378    /// or treat it via downcast matching.
379    ///
380    /// You should not call this method directly as it does not handle recursion.
381    /// Instead use [`snapshot_physical_expr`] to handle recursion and capture the
382    /// full state of the `PhysicalExpr`.
383    ///
384    /// This is expected to return "simple" expressions that do not have mutable state
385    /// and are composed of DataFusion's built-in `PhysicalExpr` implementations.
386    /// Callers however should *not* assume anything about the returned expressions
387    /// since callers and implementers may not agree on what "simple" or "built-in"
388    /// means.
389    /// In other words, if you need to serialize a `PhysicalExpr` across the wire
390    /// you should call this method and then try to serialize the result,
391    /// but you should handle unknown or unexpected `PhysicalExpr` implementations gracefully
392    /// just as if you had not called this method at all.
393    ///
394    /// In particular, consider:
395    /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::TopK`
396    ///   that is involved in a query with `SELECT * FROM t1 ORDER BY a LIMIT 10`.
397    ///   This function may return something like `a >= 12`.
398    /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::joins::HashJoinExec`
399    ///   from a query such as `SELECT * FROM t1 JOIN t2 ON t1.a = t2.b`.
400    ///   This function may return something like `t2.b IN (1, 5, 7)`.
401    ///
402    /// A system or function that can only deal with a hardcoded set of `PhysicalExpr` implementations
403    /// or needs to serialize this state to bytes may not be able to handle these dynamic references.
404    /// In such cases, we should return a simplified version of the `PhysicalExpr` that does not
405    /// contain these dynamic references.
406    ///
407    /// Systems that implement remote execution of plans, e.g. serialize a portion of the query plan
408    /// and send it across the wire to a remote executor may want to call this method after
409    /// every batch on the source side and broadcast / update the current snapshot to the remote executor.
410    ///
411    /// Note for implementers: this method should *not* handle recursion.
412    /// Recursion is handled in [`snapshot_physical_expr`].
413    fn snapshot(&self) -> Result<Option<Arc<dyn PhysicalExpr>>> {
414        // By default, we return None to indicate that this PhysicalExpr does not
415        // have any dynamic references or state.
416        // This is a safe default behavior.
417        Ok(None)
418    }
419
420    /// Returns the generation of this `PhysicalExpr` for snapshotting purposes.
421    /// The generation is an arbitrary u64 that can be used to track changes
422    /// in the state of the `PhysicalExpr` over time without having to do an exhaustive comparison.
423    /// This is useful to avoid unnecessary computation or serialization if there are no changes to the expression.
424    /// In particular, dynamic expressions that may change over time; this allows cheap checks for changes.
425    /// Static expressions that do not change over time should return 0, as does the default implementation.
426    /// You should not call this method directly as it does not handle recursion.
427    /// Instead use [`snapshot_generation`] to handle recursion and capture the
428    /// full state of the `PhysicalExpr`.
429    fn snapshot_generation(&self) -> u64 {
430        // By default, we return 0 to indicate that this PhysicalExpr does not
431        // have any dynamic references or state.
432        // Since the recursive algorithm XORs the generations of all children the overall
433        // generation will be 0 if no children have a non-zero generation, meaning that
434        // static expressions will always return 0.
435        0
436    }
437
438    /// Returns true if the expression node is volatile, i.e. whether it can return
439    /// different results when evaluated multiple times with the same input.
440    ///
441    /// Note: unlike [`is_volatile`], this function does not consider inputs:
442    /// - `random()` returns `true`,
443    /// - `a + random()` returns `false` (because the operation `+` itself is not volatile.)
444    ///
445    /// The default to this function was set to `false` when it was created
446    /// to avoid imposing API churn on implementers, but this is not a safe default in general.
447    /// It is highly recommended that volatile expressions implement this method and return `true`.
448    /// This default may be removed in the future if it causes problems or we decide to
449    /// eat the cost of the breaking change and require all implementers to make a choice.
450    fn is_volatile_node(&self) -> bool {
451        false
452    }
453
454    /// Returns placement information for this expression.
455    ///
456    /// This is used by optimizers to make decisions about expression placement,
457    /// such as whether to push expressions down through projections.
458    ///
459    /// The default implementation returns [`ExpressionPlacement::KeepInPlace`].
460    fn placement(&self) -> ExpressionPlacement {
461        ExpressionPlacement::KeepInPlace
462    }
463
464    /// Return a stable, globally-unique identifier for this [`PhysicalExpr`], if it
465    /// has one.
466    ///
467    /// This identifier tracks which expressions which are connected (e.g. `DynamicFilterPhysicalExpr`
468    /// where two expressions may be different but store the same mutable inner state). Tracking
469    /// connected expressions helps preserve referential integrity within plan nodes
470    /// during serialization and deserialization.
471    ///
472    /// This id must be preserved across [`PhysicalExpr::with_new_children`] or any other
473    /// methods which may want to preserve identity.
474    ///
475    /// Default is `None`: the expression has no identity worth preserving across a
476    /// serialization boundary.
477    fn expression_id(&self) -> Option<u64> {
478        None
479    }
480}
481
482#[deprecated(
483    since = "50.0.0",
484    note = "Use `datafusion_expr_common::dyn_eq` instead"
485)]
486pub use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
487
488impl dyn PhysicalExpr {
489    /// Returns `true` if the expression is of type `T`.
490    ///
491    /// Prefer this over `downcast_ref::<T>().is_some()`. Works correctly when
492    /// called on `Arc<dyn PhysicalExpr>` via auto-deref.
493    pub fn is<T: PhysicalExpr>(&self) -> bool {
494        (self as &dyn Any).is::<T>()
495    }
496
497    /// Attempts to downcast this expression to a concrete type `T`, returning
498    /// `None` if the expression is not of that type.
499    ///
500    /// Works correctly when called on `Arc<dyn PhysicalExpr>` via auto-deref,
501    /// unlike `(&arc as &dyn Any).downcast_ref::<T>()` which would attempt to
502    /// downcast the `Arc` itself.
503    pub fn downcast_ref<T: PhysicalExpr>(&self) -> Option<&T> {
504        (self as &dyn Any).downcast_ref()
505    }
506}
507
508impl PartialEq for dyn PhysicalExpr {
509    fn eq(&self, other: &Self) -> bool {
510        self.dyn_eq(other as &dyn Any)
511    }
512}
513impl Eq for dyn PhysicalExpr {}
514
515impl Hash for dyn PhysicalExpr {
516    fn hash<H: Hasher>(&self, state: &mut H) {
517        self.dyn_hash(state);
518    }
519}
520
521/// Returns a copy of this expr if we change any child according to the pointer comparison.
522/// The size of `children` must be equal to the size of `PhysicalExpr::children()`.
523pub fn with_new_children_if_necessary(
524    expr: Arc<dyn PhysicalExpr>,
525    children: Vec<Arc<dyn PhysicalExpr>>,
526) -> Result<Arc<dyn PhysicalExpr>> {
527    let old_children = expr.children();
528    assert_eq_or_internal_err!(
529        children.len(),
530        old_children.len(),
531        "PhysicalExpr: Wrong number of children"
532    );
533
534    if children.is_empty()
535        || children
536            .iter()
537            .zip(old_children.iter())
538            .any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
539    {
540        Ok(expr.with_new_children(children)?)
541    } else {
542        Ok(expr)
543    }
544}
545
546/// Returns [`Display`] able a list of [`PhysicalExpr`]
547///
548/// Example output: `[a + 1, b]`
549pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
550where
551    T: IntoIterator,
552    T::Item: Display,
553    T::IntoIter: Clone,
554{
555    struct DisplayWrapper<I>(I)
556    where
557        I: Iterator + Clone,
558        I::Item: Display;
559
560    impl<I> Display for DisplayWrapper<I>
561    where
562        I: Iterator + Clone,
563        I::Item: Display,
564    {
565        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
566            let mut iter = self.0.clone();
567            write!(f, "[")?;
568            if let Some(expr) = iter.next() {
569                write!(f, "{expr}")?;
570            }
571            for expr in iter {
572                write!(f, ", {expr}")?;
573            }
574            write!(f, "]")?;
575            Ok(())
576        }
577    }
578
579    DisplayWrapper(exprs.into_iter())
580}
581
582/// Prints a [`PhysicalExpr`] in a SQL-like format
583///
584/// # Example
585/// ```
586/// # // The boilerplate needed to create a `PhysicalExpr` for the example
587/// use std::collections::HashMap;
588/// # use std::fmt::Formatter;
589/// # use std::sync::Arc;
590/// # use arrow::array::RecordBatch;
591/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema};
592/// # use datafusion_common::Result;
593/// # use datafusion_expr_common::columnar_value::ColumnarValue;
594/// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr};
595/// # #[derive(Debug, PartialEq, Eq, Hash)]
596/// # struct MyExpr {}
597/// # impl PhysicalExpr for MyExpr {
598/// # fn data_type(&self, input_schema: &Schema) -> Result<DataType> { unimplemented!() }
599/// # fn nullable(&self, input_schema: &Schema) -> Result<bool> { unimplemented!() }
600/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { unimplemented!() }
601/// # fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> { unimplemented!() }
602/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>{ unimplemented!() }
603/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> { unimplemented!() }
604/// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") }
605/// # }
606/// # impl std::fmt::Display for MyExpr {fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { unimplemented!() } }
607/// # fn make_physical_expr() -> Arc<dyn PhysicalExpr> { Arc::new(MyExpr{}) }
608/// let expr: Arc<dyn PhysicalExpr> = make_physical_expr();
609/// // wrap the expression in `sql_fmt` which can be used with
610/// // `format!`, `to_string()`, etc
611/// let expr_as_sql = fmt_sql(expr.as_ref());
612/// assert_eq!(
613///   "The SQL: CASE a > b THEN 1 ELSE 0 END",
614///   format!("The SQL: {expr_as_sql}")
615/// );
616/// ```
617pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
618    struct Wrapper<'a> {
619        expr: &'a dyn PhysicalExpr,
620    }
621
622    impl Display for Wrapper<'_> {
623        fn fmt(&self, f: &mut Formatter) -> fmt::Result {
624            self.expr.fmt_sql(f)?;
625            Ok(())
626        }
627    }
628
629    Wrapper { expr }
630}
631
632/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
633///
634/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
635/// This is used to capture the current state of `PhysicalExpr`s that may contain
636/// dynamic references to other operators in order to serialize it over the wire
637/// or treat it via downcast matching.
638///
639/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
640///
641/// # Returns
642///
643/// Returns a snapshot of the `PhysicalExpr` if it is dynamic, otherwise
644/// returns itself.
645pub fn snapshot_physical_expr(
646    expr: Arc<dyn PhysicalExpr>,
647) -> Result<Arc<dyn PhysicalExpr>> {
648    snapshot_physical_expr_opt(expr).data()
649}
650
651/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
652///
653/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
654/// This is used to capture the current state of `PhysicalExpr`s that may contain
655/// dynamic references to other operators in order to serialize it over the wire
656/// or treat it via downcast matching.
657///
658/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
659///
660/// # Returns
661///
662/// Returns a `[`Transformed`] indicating whether a snapshot was taken,
663/// along with the resulting `PhysicalExpr`.
664pub fn snapshot_physical_expr_opt(
665    expr: Arc<dyn PhysicalExpr>,
666) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
667    expr.transform_up(|e| {
668        if let Some(snapshot) = e.snapshot()? {
669            Ok(Transformed::yes(snapshot))
670        } else {
671            Ok(Transformed::no(Arc::clone(&e)))
672        }
673    })
674}
675
676/// Check the generation of this `PhysicalExpr`.
677/// Dynamic `PhysicalExpr`s may have a generation that is incremented
678/// every time the state of the `PhysicalExpr` changes.
679/// If the generation changes that means this `PhysicalExpr` or one of its children
680/// has changed since the last time it was evaluated.
681///
682/// This algorithm will not produce collisions as long as the structure of the
683/// `PhysicalExpr` does not change and no `PhysicalExpr` decrements its own generation.
684pub fn snapshot_generation(expr: &Arc<dyn PhysicalExpr>) -> u64 {
685    let mut generation = 0u64;
686    expr.apply(|e| {
687        // Add the current generation of the `PhysicalExpr` to our global generation.
688        generation = generation.wrapping_add(e.snapshot_generation());
689        Ok(TreeNodeRecursion::Continue)
690    })
691    .expect("this traversal is infallible");
692
693    generation
694}
695
696/// Check if the given `PhysicalExpr` is dynamic.
697/// Internally this calls [`snapshot_generation`] to check if the generation is non-zero,
698/// any dynamic `PhysicalExpr` should have a non-zero generation.
699pub fn is_dynamic_physical_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
700    // If the generation is non-zero, then this `PhysicalExpr` is dynamic.
701    snapshot_generation(expr) != 0
702}
703
704/// Returns true if the expression is volatile, i.e. whether it can return different
705/// results when evaluated multiple times with the same input.
706///
707/// For example the function call `RANDOM()` is volatile as each call will
708/// return a different value.
709///
710/// This method recursively checks if any sub-expression is volatile, for example
711/// `1 + RANDOM()` will return `true`.
712pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
713    if expr.is_volatile_node() {
714        return true;
715    }
716    let mut is_volatile = false;
717    expr.apply(|e| {
718        if e.is_volatile_node() {
719            is_volatile = true;
720            Ok(TreeNodeRecursion::Stop)
721        } else {
722            Ok(TreeNodeRecursion::Continue)
723        }
724    })
725    .expect("infallible closure should not fail");
726    is_volatile
727}
728
729#[cfg(test)]
730mod test {
731    use crate::physical_expr::PhysicalExpr;
732    use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
733    use arrow::datatypes::{DataType, Schema};
734    use datafusion_expr_common::columnar_value::ColumnarValue;
735    use std::fmt::{Display, Formatter};
736    use std::sync::Arc;
737
738    #[derive(Debug, PartialEq, Eq, Hash)]
739    struct TestExpr {}
740
741    impl PhysicalExpr for TestExpr {
742        fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
743            Ok(DataType::Int64)
744        }
745
746        fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
747            Ok(false)
748        }
749
750        fn evaluate(
751            &self,
752            batch: &RecordBatch,
753        ) -> datafusion_common::Result<ColumnarValue> {
754            let data = vec![1; batch.num_rows()];
755            Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
756        }
757
758        fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
759            vec![]
760        }
761
762        fn with_new_children(
763            self: Arc<Self>,
764            _children: Vec<Arc<dyn PhysicalExpr>>,
765        ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
766            Ok(Arc::new(Self {}))
767        }
768
769        fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
770            f.write_str("TestExpr")
771        }
772    }
773
774    impl Display for TestExpr {
775        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
776            self.fmt_sql(f)
777        }
778    }
779
780    macro_rules! assert_arrays_eq {
781        ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
782            let expected = $EXPECTED.to_array(1).unwrap();
783            let actual = $ACTUAL;
784
785            let actual_array = actual.to_array(expected.len()).unwrap();
786            let actual_ref = actual_array.as_ref();
787            let expected_ref = expected.as_ref();
788            assert!(
789                actual_ref == expected_ref,
790                "{}: expected: {:?}, actual: {:?}",
791                $MESSAGE,
792                $EXPECTED,
793                actual_ref
794            );
795        };
796    }
797
798    fn test_evaluate_selection(
799        batch: &RecordBatch,
800        selection: &BooleanArray,
801        expected: &ColumnarValue,
802    ) {
803        let expr = TestExpr {};
804
805        // First check that the `evaluate_selection` is the expected one
806        let selection_result = expr.evaluate_selection(batch, selection).unwrap();
807        assert_eq!(
808            expected.to_array(1).unwrap().len(),
809            selection_result.to_array(1).unwrap().len(),
810            "evaluate_selection should output row count should match input record batch"
811        );
812        assert_arrays_eq!(
813            expected,
814            &selection_result,
815            "evaluate_selection returned unexpected value"
816        );
817
818        // If we're selecting all rows, the result should be the same as calling `evaluate`
819        // with the full record batch.
820        if (0..batch.num_rows())
821            .all(|row_idx| row_idx < selection.len() && selection.value(row_idx))
822        {
823            let empty_result = expr.evaluate(batch).unwrap();
824
825            assert_arrays_eq!(
826                empty_result,
827                &selection_result,
828                "evaluate_selection does not match unfiltered evaluate result"
829            );
830        }
831    }
832
833    fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) {
834        let expr = TestExpr {};
835
836        // First check that the `evaluate_selection` is the expected one
837        let selection_result = expr.evaluate_selection(batch, selection);
838        assert!(selection_result.is_err(), "evaluate_selection should fail");
839    }
840
841    #[test]
842    pub fn test_evaluate_selection_with_empty_record_batch() {
843        test_evaluate_selection(
844            &RecordBatch::new_empty(Arc::new(Schema::empty())),
845            &BooleanArray::from(vec![false; 0]),
846            &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
847        );
848    }
849
850    #[test]
851    pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
852        test_evaluate_selection_error(
853            &RecordBatch::new_empty(Arc::new(Schema::empty())),
854            &BooleanArray::from(vec![false; 10]),
855        );
856    }
857
858    #[test]
859    pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
860        test_evaluate_selection_error(
861            &RecordBatch::new_empty(Arc::new(Schema::empty())),
862            &BooleanArray::from(vec![true; 10]),
863        );
864    }
865
866    #[test]
867    pub fn test_evaluate_selection_with_non_empty_record_batch() {
868        test_evaluate_selection(
869            &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
870            &BooleanArray::from(vec![true; 10]),
871            &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
872        );
873    }
874
875    #[test]
876    pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection()
877     {
878        test_evaluate_selection_error(
879            &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
880            &BooleanArray::from(vec![false; 20]),
881        );
882    }
883
884    #[test]
885    pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection()
886     {
887        test_evaluate_selection_error(
888            &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
889            &BooleanArray::from(vec![true; 20]),
890        );
891    }
892
893    #[test]
894    pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection()
895     {
896        test_evaluate_selection_error(
897            &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
898            &BooleanArray::from(vec![false; 5]),
899        );
900    }
901
902    #[test]
903    pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection()
904     {
905        test_evaluate_selection_error(
906            &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
907            &BooleanArray::from(vec![true; 5]),
908        );
909    }
910}