datafusion_physical_expr_common/physical_expr.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::fmt::{Debug, Display, Formatter};
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use crate::utils::scatter;
25
26use arrow::array::{ArrayRef, BooleanArray, new_empty_array};
27use arrow::compute::filter_record_batch;
28use arrow::datatypes::{DataType, Field, FieldRef, Schema};
29use arrow::record_batch::RecordBatch;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{
34 Result, ScalarValue, assert_eq_or_internal_err, exec_err, not_impl_err,
35};
36use datafusion_expr_common::columnar_value::ColumnarValue;
37use datafusion_expr_common::interval_arithmetic::Interval;
38use datafusion_expr_common::sort_properties::ExprProperties;
39use datafusion_expr_common::statistics::Distribution;
40
41use itertools::izip;
42
43/// Shared [`PhysicalExpr`].
44pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
45
46/// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`.
47///
48/// `PhysicalExpr` knows its type, nullability and can be evaluated directly on
49/// a [`RecordBatch`] (see [`Self::evaluate`]).
50///
51/// `PhysicalExpr` are the physical counterpart to [`Expr`] used in logical
52/// planning. They are typically created from [`Expr`] by a [`PhysicalPlanner`]
53/// invoked from a higher level API
54///
55/// Some important examples of `PhysicalExpr` are:
56/// * [`Column`]: Represents a column at a given index in a RecordBatch
57///
58/// To create `PhysicalExpr` from `Expr`, see
59/// * [`SessionContext::create_physical_expr`]: A high level API
60/// * [`create_physical_expr`]: A low level API
61///
62/// # Formatting `PhysicalExpr` as strings
63/// There are three ways to format `PhysicalExpr` as a string:
64/// * [`Debug`]: Standard Rust debugging format (e.g. `Constant { value: ... }`)
65/// * [`Display`]: Detailed SQL-like format that shows expression structure (e.g. (`Utf8 ("foobar")`). This is often used for debugging and tests
66/// * [`Self::fmt_sql`]: SQL-like human readable format (e.g. ('foobar')`), See also [`sql_fmt`]
67///
68/// [`SessionContext::create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.create_physical_expr
69/// [`PhysicalPlanner`]: https://docs.rs/datafusion/latest/datafusion/physical_planner/trait.PhysicalPlanner.html
70/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
71/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
72/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
73pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
74 /// Returns the physical expression as [`Any`] so that it can be
75 /// downcast to a specific implementation.
76 fn as_any(&self) -> &dyn Any;
77 /// Get the data type of this expression, given the schema of the input
78 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
79 Ok(self.return_field(input_schema)?.data_type().to_owned())
80 }
81 /// Determine whether this expression is nullable, given the schema of the input
82 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
83 Ok(self.return_field(input_schema)?.is_nullable())
84 }
85 /// Evaluate an expression against a RecordBatch
86 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
87 /// The output field associated with this expression
88 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
89 Ok(Arc::new(Field::new(
90 format!("{self}"),
91 self.data_type(input_schema)?,
92 self.nullable(input_schema)?,
93 )))
94 }
95 /// Evaluate an expression against a RecordBatch after first applying a validity array
96 ///
97 /// # Errors
98 ///
99 /// Returns an `Err` if the expression could not be evaluated or if the length of the
100 /// `selection` validity array and the number of row in `batch` is not equal.
101 fn evaluate_selection(
102 &self,
103 batch: &RecordBatch,
104 selection: &BooleanArray,
105 ) -> Result<ColumnarValue> {
106 let row_count = batch.num_rows();
107 if row_count != selection.len() {
108 return exec_err!(
109 "Selection array length does not match batch row count: {} != {row_count}",
110 selection.len()
111 );
112 }
113
114 let selection_count = selection.true_count();
115
116 // First, check if we can avoid filtering altogether.
117 if selection_count == row_count {
118 // All values from the `selection` filter are true and match the input batch.
119 // No need to perform any filtering.
120 return self.evaluate(batch);
121 }
122
123 // Next, prepare the result array for each 'true' row in the selection vector.
124 let filtered_result = if selection_count == 0 {
125 // Do not call `evaluate` when the selection is empty.
126 // `evaluate_selection` is used to conditionally evaluate expressions.
127 // When the expression in question is fallible, evaluating it with an empty
128 // record batch may trigger a runtime error (e.g. division by zero).
129 //
130 // Instead, create an empty array matching the expected return type.
131 let datatype = self.data_type(batch.schema_ref().as_ref())?;
132 ColumnarValue::Array(new_empty_array(&datatype))
133 } else {
134 // If we reach this point, there's no other option than to filter the batch.
135 // This is a fairly costly operation since it requires creating partial copies
136 // (worst case of length `row_count - 1`) of all the arrays in the record batch.
137 // The resulting `filtered_batch` will contain `selection_count` rows.
138 let filtered_batch = filter_record_batch(batch, selection)?;
139 self.evaluate(&filtered_batch)?
140 };
141
142 // Finally, scatter the filtered result array so that the indices match the input rows again.
143 match &filtered_result {
144 ColumnarValue::Array(a) => {
145 scatter(selection, a.as_ref()).map(ColumnarValue::Array)
146 }
147 ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
148 // When the scalar is true or false, skip the scatter process
149 if let Some(v) = value {
150 if *v {
151 Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
152 } else {
153 Ok(filtered_result)
154 }
155 } else {
156 let array = BooleanArray::from(vec![None; row_count]);
157 scatter(selection, &array).map(ColumnarValue::Array)
158 }
159 }
160 ColumnarValue::Scalar(_) => Ok(filtered_result),
161 }
162 }
163
164 /// Get a list of child PhysicalExpr that provide the input for this expr.
165 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
166
167 /// Returns a new PhysicalExpr where all children were replaced by new exprs.
168 fn with_new_children(
169 self: Arc<Self>,
170 children: Vec<Arc<dyn PhysicalExpr>>,
171 ) -> Result<Arc<dyn PhysicalExpr>>;
172
173 /// Computes the output interval for the expression, given the input
174 /// intervals.
175 ///
176 /// # Parameters
177 ///
178 /// * `children` are the intervals for the children (inputs) of this
179 /// expression.
180 ///
181 /// # Returns
182 ///
183 /// A `Result` containing the output interval for the expression in
184 /// case of success, or an error object in case of failure.
185 ///
186 /// # Example
187 ///
188 /// If the expression is `a + b`, and the input intervals are `a: [1, 2]`
189 /// and `b: [3, 4]`, then the output interval would be `[4, 6]`.
190 fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
191 not_impl_err!("Not implemented for {self}")
192 }
193
194 /// Updates bounds for child expressions, given a known interval for this
195 /// expression.
196 ///
197 /// This is used to propagate constraints down through an expression tree.
198 ///
199 /// # Parameters
200 ///
201 /// * `interval` is the currently known interval for this expression.
202 /// * `children` are the current intervals for the children of this expression.
203 ///
204 /// # Returns
205 ///
206 /// A `Result` containing a `Vec` of new intervals for the children (in order)
207 /// in case of success, or an error object in case of failure.
208 ///
209 /// If constraint propagation reveals an infeasibility for any child, returns
210 /// [`None`]. If none of the children intervals change as a result of
211 /// propagation, may return an empty vector instead of cloning `children`.
212 /// This is the default (and conservative) return value.
213 ///
214 /// # Example
215 ///
216 /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the
217 /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then
218 /// propagation would return `[0, 2]` and `[2, 4]` as `b` must be at least
219 /// `2` to make the output at least `4`.
220 fn propagate_constraints(
221 &self,
222 _interval: &Interval,
223 _children: &[&Interval],
224 ) -> Result<Option<Vec<Interval>>> {
225 Ok(Some(vec![]))
226 }
227
228 /// Computes the output statistics for the expression, given the input
229 /// statistics.
230 ///
231 /// # Parameters
232 ///
233 /// * `children` are the statistics for the children (inputs) of this
234 /// expression.
235 ///
236 /// # Returns
237 ///
238 /// A `Result` containing the output statistics for the expression in
239 /// case of success, or an error object in case of failure.
240 ///
241 /// Expressions (should) implement this function and utilize the independence
242 /// assumption, match on children distribution types and compute the output
243 /// statistics accordingly. The default implementation simply creates an
244 /// unknown output distribution by combining input ranges. This logic loses
245 /// distribution information, but is a safe default.
246 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
247 let children_ranges = children
248 .iter()
249 .map(|c| c.range())
250 .collect::<Result<Vec<_>>>()?;
251 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
252 let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
253 let dt = output_interval.data_type();
254 if dt.eq(&DataType::Boolean) {
255 let p = if output_interval.eq(&Interval::TRUE) {
256 ScalarValue::new_one(&dt)
257 } else if output_interval.eq(&Interval::FALSE) {
258 ScalarValue::new_zero(&dt)
259 } else {
260 ScalarValue::try_from(&dt)
261 }?;
262 Distribution::new_bernoulli(p)
263 } else {
264 Distribution::new_from_interval(output_interval)
265 }
266 }
267
268 /// Updates children statistics using the given parent statistic for this
269 /// expression.
270 ///
271 /// This is used to propagate statistics down through an expression tree.
272 ///
273 /// # Parameters
274 ///
275 /// * `parent` is the currently known statistics for this expression.
276 /// * `children` are the current statistics for the children of this expression.
277 ///
278 /// # Returns
279 ///
280 /// A `Result` containing a `Vec` of new statistics for the children (in order)
281 /// in case of success, or an error object in case of failure.
282 ///
283 /// If statistics propagation reveals an infeasibility for any child, returns
284 /// [`None`]. If none of the children statistics change as a result of
285 /// propagation, may return an empty vector instead of cloning `children`.
286 /// This is the default (and conservative) return value.
287 ///
288 /// Expressions (should) implement this function and apply Bayes rule to
289 /// reconcile and update parent/children statistics. This involves utilizing
290 /// the independence assumption, and matching on distribution types. The
291 /// default implementation simply creates an unknown distribution if it can
292 /// narrow the range by propagating ranges. This logic loses distribution
293 /// information, but is a safe default.
294 fn propagate_statistics(
295 &self,
296 parent: &Distribution,
297 children: &[&Distribution],
298 ) -> Result<Option<Vec<Distribution>>> {
299 let children_ranges = children
300 .iter()
301 .map(|c| c.range())
302 .collect::<Result<Vec<_>>>()?;
303 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
304 let parent_range = parent.range()?;
305 let Some(propagated_children) =
306 self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
307 else {
308 return Ok(None);
309 };
310 izip!(propagated_children.into_iter(), children_ranges, children)
311 .map(|(new_interval, old_interval, child)| {
312 if new_interval == old_interval {
313 // We weren't able to narrow the range, preserve the old statistics.
314 Ok((*child).clone())
315 } else if new_interval.data_type().eq(&DataType::Boolean) {
316 let dt = old_interval.data_type();
317 let p = if new_interval.eq(&Interval::TRUE) {
318 ScalarValue::new_one(&dt)
319 } else if new_interval.eq(&Interval::FALSE) {
320 ScalarValue::new_zero(&dt)
321 } else {
322 unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
323 }?;
324 Distribution::new_bernoulli(p)
325 } else {
326 Distribution::new_from_interval(new_interval)
327 }
328 })
329 .collect::<Result<_>>()
330 .map(Some)
331 }
332
333 /// Calculates the properties of this [`PhysicalExpr`] based on its
334 /// children's properties (i.e. order and range), recursively aggregating
335 /// the information from its children. In cases where the [`PhysicalExpr`]
336 /// has no children (e.g., `Literal` or `Column`), these properties should
337 /// be specified externally, as the function defaults to unknown properties.
338 fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
339 Ok(ExprProperties::new_unknown())
340 }
341
342 /// Format this `PhysicalExpr` in nice human readable "SQL" format
343 ///
344 /// Specifically, this format is designed to be readable by humans, at the
345 /// expense of details. Use `Display` or `Debug` for more detailed
346 /// representation.
347 ///
348 /// See the [`fmt_sql`] function for an example of printing `PhysicalExpr`s as SQL.
349 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result;
350
351 /// Take a snapshot of this `PhysicalExpr`, if it is dynamic.
352 ///
353 /// "Dynamic" in this case means containing references to structures that may change
354 /// during plan execution, such as hash tables.
355 ///
356 /// This method is used to capture the current state of `PhysicalExpr`s that may contain
357 /// dynamic references to other operators in order to serialize it over the wire
358 /// or treat it via downcast matching.
359 ///
360 /// You should not call this method directly as it does not handle recursion.
361 /// Instead use [`snapshot_physical_expr`] to handle recursion and capture the
362 /// full state of the `PhysicalExpr`.
363 ///
364 /// This is expected to return "simple" expressions that do not have mutable state
365 /// and are composed of DataFusion's built-in `PhysicalExpr` implementations.
366 /// Callers however should *not* assume anything about the returned expressions
367 /// since callers and implementers may not agree on what "simple" or "built-in"
368 /// means.
369 /// In other words, if you need to serialize a `PhysicalExpr` across the wire
370 /// you should call this method and then try to serialize the result,
371 /// but you should handle unknown or unexpected `PhysicalExpr` implementations gracefully
372 /// just as if you had not called this method at all.
373 ///
374 /// In particular, consider:
375 /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::TopK`
376 /// that is involved in a query with `SELECT * FROM t1 ORDER BY a LIMIT 10`.
377 /// This function may return something like `a >= 12`.
378 /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::joins::HashJoinExec`
379 /// from a query such as `SELECT * FROM t1 JOIN t2 ON t1.a = t2.b`.
380 /// This function may return something like `t2.b IN (1, 5, 7)`.
381 ///
382 /// A system or function that can only deal with a hardcoded set of `PhysicalExpr` implementations
383 /// or needs to serialize this state to bytes may not be able to handle these dynamic references.
384 /// In such cases, we should return a simplified version of the `PhysicalExpr` that does not
385 /// contain these dynamic references.
386 ///
387 /// Systems that implement remote execution of plans, e.g. serialize a portion of the query plan
388 /// and send it across the wire to a remote executor may want to call this method after
389 /// every batch on the source side and broadcast / update the current snapshot to the remote executor.
390 ///
391 /// Note for implementers: this method should *not* handle recursion.
392 /// Recursion is handled in [`snapshot_physical_expr`].
393 fn snapshot(&self) -> Result<Option<Arc<dyn PhysicalExpr>>> {
394 // By default, we return None to indicate that this PhysicalExpr does not
395 // have any dynamic references or state.
396 // This is a safe default behavior.
397 Ok(None)
398 }
399
400 /// Returns the generation of this `PhysicalExpr` for snapshotting purposes.
401 /// The generation is an arbitrary u64 that can be used to track changes
402 /// in the state of the `PhysicalExpr` over time without having to do an exhaustive comparison.
403 /// This is useful to avoid unnecessary computation or serialization if there are no changes to the expression.
404 /// In particular, dynamic expressions that may change over time; this allows cheap checks for changes.
405 /// Static expressions that do not change over time should return 0, as does the default implementation.
406 /// You should not call this method directly as it does not handle recursion.
407 /// Instead use [`snapshot_generation`] to handle recursion and capture the
408 /// full state of the `PhysicalExpr`.
409 fn snapshot_generation(&self) -> u64 {
410 // By default, we return 0 to indicate that this PhysicalExpr does not
411 // have any dynamic references or state.
412 // Since the recursive algorithm XORs the generations of all children the overall
413 // generation will be 0 if no children have a non-zero generation, meaning that
414 // static expressions will always return 0.
415 0
416 }
417
418 /// Returns true if the expression node is volatile, i.e. whether it can return
419 /// different results when evaluated multiple times with the same input.
420 ///
421 /// Note: unlike [`is_volatile`], this function does not consider inputs:
422 /// - `random()` returns `true`,
423 /// - `a + random()` returns `false` (because the operation `+` itself is not volatile.)
424 ///
425 /// The default to this function was set to `false` when it was created
426 /// to avoid imposing API churn on implementers, but this is not a safe default in general.
427 /// It is highly recommended that volatile expressions implement this method and return `true`.
428 /// This default may be removed in the future if it causes problems or we decide to
429 /// eat the cost of the breaking change and require all implementers to make a choice.
430 fn is_volatile_node(&self) -> bool {
431 false
432 }
433}
434
435#[deprecated(
436 since = "50.0.0",
437 note = "Use `datafusion_expr_common::dyn_eq` instead"
438)]
439pub use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
440
441impl PartialEq for dyn PhysicalExpr {
442 fn eq(&self, other: &Self) -> bool {
443 self.dyn_eq(other.as_any())
444 }
445}
446impl Eq for dyn PhysicalExpr {}
447
448impl Hash for dyn PhysicalExpr {
449 fn hash<H: Hasher>(&self, state: &mut H) {
450 self.dyn_hash(state);
451 }
452}
453
454/// Returns a copy of this expr if we change any child according to the pointer comparison.
455/// The size of `children` must be equal to the size of `PhysicalExpr::children()`.
456pub fn with_new_children_if_necessary(
457 expr: Arc<dyn PhysicalExpr>,
458 children: Vec<Arc<dyn PhysicalExpr>>,
459) -> Result<Arc<dyn PhysicalExpr>> {
460 let old_children = expr.children();
461 assert_eq_or_internal_err!(
462 children.len(),
463 old_children.len(),
464 "PhysicalExpr: Wrong number of children"
465 );
466
467 if children.is_empty()
468 || children
469 .iter()
470 .zip(old_children.iter())
471 .any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
472 {
473 Ok(expr.with_new_children(children)?)
474 } else {
475 Ok(expr)
476 }
477}
478
479/// Returns [`Display`] able a list of [`PhysicalExpr`]
480///
481/// Example output: `[a + 1, b]`
482pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
483where
484 T: IntoIterator,
485 T::Item: Display,
486 T::IntoIter: Clone,
487{
488 struct DisplayWrapper<I>(I)
489 where
490 I: Iterator + Clone,
491 I::Item: Display;
492
493 impl<I> Display for DisplayWrapper<I>
494 where
495 I: Iterator + Clone,
496 I::Item: Display,
497 {
498 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
499 let mut iter = self.0.clone();
500 write!(f, "[")?;
501 if let Some(expr) = iter.next() {
502 write!(f, "{expr}")?;
503 }
504 for expr in iter {
505 write!(f, ", {expr}")?;
506 }
507 write!(f, "]")?;
508 Ok(())
509 }
510 }
511
512 DisplayWrapper(exprs.into_iter())
513}
514
515/// Prints a [`PhysicalExpr`] in a SQL-like format
516///
517/// # Example
518/// ```
519/// # // The boilerplate needed to create a `PhysicalExpr` for the example
520/// # use std::any::Any;
521/// use std::collections::HashMap;
522/// # use std::fmt::Formatter;
523/// # use std::sync::Arc;
524/// # use arrow::array::RecordBatch;
525/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema};
526/// # use datafusion_common::Result;
527/// # use datafusion_expr_common::columnar_value::ColumnarValue;
528/// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr};
529/// # #[derive(Debug, PartialEq, Eq, Hash)]
530/// # struct MyExpr {}
531/// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() }
532/// # fn data_type(&self, input_schema: &Schema) -> Result<DataType> { unimplemented!() }
533/// # fn nullable(&self, input_schema: &Schema) -> Result<bool> { unimplemented!() }
534/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { unimplemented!() }
535/// # fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> { unimplemented!() }
536/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>{ unimplemented!() }
537/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> { unimplemented!() }
538/// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") }
539/// # }
540/// # impl std::fmt::Display for MyExpr {fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { unimplemented!() } }
541/// # fn make_physical_expr() -> Arc<dyn PhysicalExpr> { Arc::new(MyExpr{}) }
542/// let expr: Arc<dyn PhysicalExpr> = make_physical_expr();
543/// // wrap the expression in `sql_fmt` which can be used with
544/// // `format!`, `to_string()`, etc
545/// let expr_as_sql = fmt_sql(expr.as_ref());
546/// assert_eq!(
547/// "The SQL: CASE a > b THEN 1 ELSE 0 END",
548/// format!("The SQL: {expr_as_sql}")
549/// );
550/// ```
551pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
552 struct Wrapper<'a> {
553 expr: &'a dyn PhysicalExpr,
554 }
555
556 impl Display for Wrapper<'_> {
557 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
558 self.expr.fmt_sql(f)?;
559 Ok(())
560 }
561 }
562
563 Wrapper { expr }
564}
565
566/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
567///
568/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
569/// This is used to capture the current state of `PhysicalExpr`s that may contain
570/// dynamic references to other operators in order to serialize it over the wire
571/// or treat it via downcast matching.
572///
573/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
574///
575/// # Returns
576///
577/// Returns a snapshot of the `PhysicalExpr` if it is dynamic, otherwise
578/// returns itself.
579pub fn snapshot_physical_expr(
580 expr: Arc<dyn PhysicalExpr>,
581) -> Result<Arc<dyn PhysicalExpr>> {
582 snapshot_physical_expr_opt(expr).data()
583}
584
585/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
586///
587/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
588/// This is used to capture the current state of `PhysicalExpr`s that may contain
589/// dynamic references to other operators in order to serialize it over the wire
590/// or treat it via downcast matching.
591///
592/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
593///
594/// # Returns
595///
596/// Returns a `[`Transformed`] indicating whether a snapshot was taken,
597/// along with the resulting `PhysicalExpr`.
598pub fn snapshot_physical_expr_opt(
599 expr: Arc<dyn PhysicalExpr>,
600) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
601 expr.transform_up(|e| {
602 if let Some(snapshot) = e.snapshot()? {
603 Ok(Transformed::yes(snapshot))
604 } else {
605 Ok(Transformed::no(Arc::clone(&e)))
606 }
607 })
608}
609
610/// Check the generation of this `PhysicalExpr`.
611/// Dynamic `PhysicalExpr`s may have a generation that is incremented
612/// every time the state of the `PhysicalExpr` changes.
613/// If the generation changes that means this `PhysicalExpr` or one of its children
614/// has changed since the last time it was evaluated.
615///
616/// This algorithm will not produce collisions as long as the structure of the
617/// `PhysicalExpr` does not change and no `PhysicalExpr` decrements its own generation.
618pub fn snapshot_generation(expr: &Arc<dyn PhysicalExpr>) -> u64 {
619 let mut generation = 0u64;
620 expr.apply(|e| {
621 // Add the current generation of the `PhysicalExpr` to our global generation.
622 generation = generation.wrapping_add(e.snapshot_generation());
623 Ok(TreeNodeRecursion::Continue)
624 })
625 .expect("this traversal is infallible");
626
627 generation
628}
629
630/// Check if the given `PhysicalExpr` is dynamic.
631/// Internally this calls [`snapshot_generation`] to check if the generation is non-zero,
632/// any dynamic `PhysicalExpr` should have a non-zero generation.
633pub fn is_dynamic_physical_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
634 // If the generation is non-zero, then this `PhysicalExpr` is dynamic.
635 snapshot_generation(expr) != 0
636}
637
638/// Returns true if the expression is volatile, i.e. whether it can return different
639/// results when evaluated multiple times with the same input.
640///
641/// For example the function call `RANDOM()` is volatile as each call will
642/// return a different value.
643///
644/// This method recursively checks if any sub-expression is volatile, for example
645/// `1 + RANDOM()` will return `true`.
646pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
647 if expr.is_volatile_node() {
648 return true;
649 }
650 let mut is_volatile = false;
651 expr.apply(|e| {
652 if e.is_volatile_node() {
653 is_volatile = true;
654 Ok(TreeNodeRecursion::Stop)
655 } else {
656 Ok(TreeNodeRecursion::Continue)
657 }
658 })
659 .expect("infallible closure should not fail");
660 is_volatile
661}
662
663#[cfg(test)]
664mod test {
665 use crate::physical_expr::PhysicalExpr;
666 use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
667 use arrow::datatypes::{DataType, Schema};
668 use datafusion_expr_common::columnar_value::ColumnarValue;
669 use std::fmt::{Display, Formatter};
670 use std::sync::Arc;
671
672 #[derive(Debug, PartialEq, Eq, Hash)]
673 struct TestExpr {}
674
675 impl PhysicalExpr for TestExpr {
676 fn as_any(&self) -> &dyn std::any::Any {
677 self
678 }
679
680 fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
681 Ok(DataType::Int64)
682 }
683
684 fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
685 Ok(false)
686 }
687
688 fn evaluate(
689 &self,
690 batch: &RecordBatch,
691 ) -> datafusion_common::Result<ColumnarValue> {
692 let data = vec![1; batch.num_rows()];
693 Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
694 }
695
696 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
697 vec![]
698 }
699
700 fn with_new_children(
701 self: Arc<Self>,
702 _children: Vec<Arc<dyn PhysicalExpr>>,
703 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
704 Ok(Arc::new(Self {}))
705 }
706
707 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
708 f.write_str("TestExpr")
709 }
710 }
711
712 impl Display for TestExpr {
713 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
714 self.fmt_sql(f)
715 }
716 }
717
718 macro_rules! assert_arrays_eq {
719 ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
720 let expected = $EXPECTED.to_array(1).unwrap();
721 let actual = $ACTUAL;
722
723 let actual_array = actual.to_array(expected.len()).unwrap();
724 let actual_ref = actual_array.as_ref();
725 let expected_ref = expected.as_ref();
726 assert!(
727 actual_ref == expected_ref,
728 "{}: expected: {:?}, actual: {:?}",
729 $MESSAGE,
730 $EXPECTED,
731 actual_ref
732 );
733 };
734 }
735
736 fn test_evaluate_selection(
737 batch: &RecordBatch,
738 selection: &BooleanArray,
739 expected: &ColumnarValue,
740 ) {
741 let expr = TestExpr {};
742
743 // First check that the `evaluate_selection` is the expected one
744 let selection_result = expr.evaluate_selection(batch, selection).unwrap();
745 assert_eq!(
746 expected.to_array(1).unwrap().len(),
747 selection_result.to_array(1).unwrap().len(),
748 "evaluate_selection should output row count should match input record batch"
749 );
750 assert_arrays_eq!(
751 expected,
752 &selection_result,
753 "evaluate_selection returned unexpected value"
754 );
755
756 // If we're selecting all rows, the result should be the same as calling `evaluate`
757 // with the full record batch.
758 if (0..batch.num_rows())
759 .all(|row_idx| row_idx < selection.len() && selection.value(row_idx))
760 {
761 let empty_result = expr.evaluate(batch).unwrap();
762
763 assert_arrays_eq!(
764 empty_result,
765 &selection_result,
766 "evaluate_selection does not match unfiltered evaluate result"
767 );
768 }
769 }
770
771 fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) {
772 let expr = TestExpr {};
773
774 // First check that the `evaluate_selection` is the expected one
775 let selection_result = expr.evaluate_selection(batch, selection);
776 assert!(selection_result.is_err(), "evaluate_selection should fail");
777 }
778
779 #[test]
780 pub fn test_evaluate_selection_with_empty_record_batch() {
781 test_evaluate_selection(
782 &RecordBatch::new_empty(Arc::new(Schema::empty())),
783 &BooleanArray::from(vec![false; 0]),
784 &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
785 );
786 }
787
788 #[test]
789 pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
790 test_evaluate_selection_error(
791 &RecordBatch::new_empty(Arc::new(Schema::empty())),
792 &BooleanArray::from(vec![false; 10]),
793 );
794 }
795
796 #[test]
797 pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
798 test_evaluate_selection_error(
799 &RecordBatch::new_empty(Arc::new(Schema::empty())),
800 &BooleanArray::from(vec![true; 10]),
801 );
802 }
803
804 #[test]
805 pub fn test_evaluate_selection_with_non_empty_record_batch() {
806 test_evaluate_selection(
807 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
808 &BooleanArray::from(vec![true; 10]),
809 &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
810 );
811 }
812
813 #[test]
814 pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection()
815 {
816 test_evaluate_selection_error(
817 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
818 &BooleanArray::from(vec![false; 20]),
819 );
820 }
821
822 #[test]
823 pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection()
824 {
825 test_evaluate_selection_error(
826 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
827 &BooleanArray::from(vec![true; 20]),
828 );
829 }
830
831 #[test]
832 pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection()
833 {
834 test_evaluate_selection_error(
835 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
836 &BooleanArray::from(vec![false; 5]),
837 );
838 }
839
840 #[test]
841 pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection()
842 {
843 test_evaluate_selection_error(
844 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
845 &BooleanArray::from(vec![true; 5]),
846 );
847 }
848}