datafusion_physical_expr_common/physical_expr.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::fmt::{Debug, Display, Formatter};
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use crate::utils::scatter;
25
26use arrow::array::{new_empty_array, ArrayRef, BooleanArray};
27use arrow::compute::filter_record_batch;
28use arrow::datatypes::{DataType, Field, FieldRef, Schema};
29use arrow::record_batch::RecordBatch;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
34use datafusion_expr_common::columnar_value::ColumnarValue;
35use datafusion_expr_common::interval_arithmetic::Interval;
36use datafusion_expr_common::sort_properties::ExprProperties;
37use datafusion_expr_common::statistics::Distribution;
38
39use itertools::izip;
40
41/// Shared [`PhysicalExpr`].
42pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
43
44/// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`.
45///
46/// `PhysicalExpr` knows its type, nullability and can be evaluated directly on
47/// a [`RecordBatch`] (see [`Self::evaluate`]).
48///
49/// `PhysicalExpr` are the physical counterpart to [`Expr`] used in logical
50/// planning. They are typically created from [`Expr`] by a [`PhysicalPlanner`]
51/// invoked from a higher level API
52///
53/// Some important examples of `PhysicalExpr` are:
54/// * [`Column`]: Represents a column at a given index in a RecordBatch
55///
56/// To create `PhysicalExpr` from `Expr`, see
57/// * [`SessionContext::create_physical_expr`]: A high level API
58/// * [`create_physical_expr`]: A low level API
59///
60/// # Formatting `PhysicalExpr` as strings
61/// There are three ways to format `PhysicalExpr` as a string:
62/// * [`Debug`]: Standard Rust debugging format (e.g. `Constant { value: ... }`)
63/// * [`Display`]: Detailed SQL-like format that shows expression structure (e.g. (`Utf8 ("foobar")`). This is often used for debugging and tests
64/// * [`Self::fmt_sql`]: SQL-like human readable format (e.g. ('foobar')`), See also [`sql_fmt`]
65///
66/// [`SessionContext::create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.create_physical_expr
67/// [`PhysicalPlanner`]: https://docs.rs/datafusion/latest/datafusion/physical_planner/trait.PhysicalPlanner.html
68/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
69/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
70/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
71pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
72 /// Returns the physical expression as [`Any`] so that it can be
73 /// downcast to a specific implementation.
74 fn as_any(&self) -> &dyn Any;
75 /// Get the data type of this expression, given the schema of the input
76 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
77 Ok(self.return_field(input_schema)?.data_type().to_owned())
78 }
79 /// Determine whether this expression is nullable, given the schema of the input
80 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
81 Ok(self.return_field(input_schema)?.is_nullable())
82 }
83 /// Evaluate an expression against a RecordBatch
84 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
85 /// The output field associated with this expression
86 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
87 Ok(Arc::new(Field::new(
88 format!("{self}"),
89 self.data_type(input_schema)?,
90 self.nullable(input_schema)?,
91 )))
92 }
93 /// Evaluate an expression against a RecordBatch after first applying a validity array
94 ///
95 /// # Errors
96 ///
97 /// Returns an `Err` if the expression could not be evaluated or if the length of the
98 /// `selection` validity array and the number of row in `batch` is not equal.
99 fn evaluate_selection(
100 &self,
101 batch: &RecordBatch,
102 selection: &BooleanArray,
103 ) -> Result<ColumnarValue> {
104 let row_count = batch.num_rows();
105 if row_count != selection.len() {
106 return exec_err!("Selection array length does not match batch row count: {} != {row_count}", selection.len());
107 }
108
109 let selection_count = selection.true_count();
110
111 // First, check if we can avoid filtering altogether.
112 if selection_count == row_count {
113 // All values from the `selection` filter are true and match the input batch.
114 // No need to perform any filtering.
115 return self.evaluate(batch);
116 }
117
118 // Next, prepare the result array for each 'true' row in the selection vector.
119 let filtered_result = if selection_count == 0 {
120 // Do not call `evaluate` when the selection is empty.
121 // `evaluate_selection` is used to conditionally evaluate expressions.
122 // When the expression in question is fallible, evaluating it with an empty
123 // record batch may trigger a runtime error (e.g. division by zero).
124 //
125 // Instead, create an empty array matching the expected return type.
126 let datatype = self.data_type(batch.schema_ref().as_ref())?;
127 ColumnarValue::Array(new_empty_array(&datatype))
128 } else {
129 // If we reach this point, there's no other option than to filter the batch.
130 // This is a fairly costly operation since it requires creating partial copies
131 // (worst case of length `row_count - 1`) of all the arrays in the record batch.
132 // The resulting `filtered_batch` will contain `selection_count` rows.
133 let filtered_batch = filter_record_batch(batch, selection)?;
134 self.evaluate(&filtered_batch)?
135 };
136
137 // Finally, scatter the filtered result array so that the indices match the input rows again.
138 match &filtered_result {
139 ColumnarValue::Array(a) => {
140 scatter(selection, a.as_ref()).map(ColumnarValue::Array)
141 }
142 ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
143 // When the scalar is true or false, skip the scatter process
144 if let Some(v) = value {
145 if *v {
146 Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
147 } else {
148 Ok(filtered_result)
149 }
150 } else {
151 let array = BooleanArray::from(vec![None; row_count]);
152 scatter(selection, &array).map(ColumnarValue::Array)
153 }
154 }
155 ColumnarValue::Scalar(_) => Ok(filtered_result),
156 }
157 }
158
159 /// Get a list of child PhysicalExpr that provide the input for this expr.
160 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
161
162 /// Returns a new PhysicalExpr where all children were replaced by new exprs.
163 fn with_new_children(
164 self: Arc<Self>,
165 children: Vec<Arc<dyn PhysicalExpr>>,
166 ) -> Result<Arc<dyn PhysicalExpr>>;
167
168 /// Computes the output interval for the expression, given the input
169 /// intervals.
170 ///
171 /// # Parameters
172 ///
173 /// * `children` are the intervals for the children (inputs) of this
174 /// expression.
175 ///
176 /// # Returns
177 ///
178 /// A `Result` containing the output interval for the expression in
179 /// case of success, or an error object in case of failure.
180 ///
181 /// # Example
182 ///
183 /// If the expression is `a + b`, and the input intervals are `a: [1, 2]`
184 /// and `b: [3, 4]`, then the output interval would be `[4, 6]`.
185 fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
186 not_impl_err!("Not implemented for {self}")
187 }
188
189 /// Updates bounds for child expressions, given a known interval for this
190 /// expression.
191 ///
192 /// This is used to propagate constraints down through an expression tree.
193 ///
194 /// # Parameters
195 ///
196 /// * `interval` is the currently known interval for this expression.
197 /// * `children` are the current intervals for the children of this expression.
198 ///
199 /// # Returns
200 ///
201 /// A `Result` containing a `Vec` of new intervals for the children (in order)
202 /// in case of success, or an error object in case of failure.
203 ///
204 /// If constraint propagation reveals an infeasibility for any child, returns
205 /// [`None`]. If none of the children intervals change as a result of
206 /// propagation, may return an empty vector instead of cloning `children`.
207 /// This is the default (and conservative) return value.
208 ///
209 /// # Example
210 ///
211 /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the
212 /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then
213 /// propagation would return `[0, 2]` and `[2, 4]` as `b` must be at least
214 /// `2` to make the output at least `4`.
215 fn propagate_constraints(
216 &self,
217 _interval: &Interval,
218 _children: &[&Interval],
219 ) -> Result<Option<Vec<Interval>>> {
220 Ok(Some(vec![]))
221 }
222
223 /// Computes the output statistics for the expression, given the input
224 /// statistics.
225 ///
226 /// # Parameters
227 ///
228 /// * `children` are the statistics for the children (inputs) of this
229 /// expression.
230 ///
231 /// # Returns
232 ///
233 /// A `Result` containing the output statistics for the expression in
234 /// case of success, or an error object in case of failure.
235 ///
236 /// Expressions (should) implement this function and utilize the independence
237 /// assumption, match on children distribution types and compute the output
238 /// statistics accordingly. The default implementation simply creates an
239 /// unknown output distribution by combining input ranges. This logic loses
240 /// distribution information, but is a safe default.
241 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
242 let children_ranges = children
243 .iter()
244 .map(|c| c.range())
245 .collect::<Result<Vec<_>>>()?;
246 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
247 let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
248 let dt = output_interval.data_type();
249 if dt.eq(&DataType::Boolean) {
250 let p = if output_interval.eq(&Interval::CERTAINLY_TRUE) {
251 ScalarValue::new_one(&dt)
252 } else if output_interval.eq(&Interval::CERTAINLY_FALSE) {
253 ScalarValue::new_zero(&dt)
254 } else {
255 ScalarValue::try_from(&dt)
256 }?;
257 Distribution::new_bernoulli(p)
258 } else {
259 Distribution::new_from_interval(output_interval)
260 }
261 }
262
263 /// Updates children statistics using the given parent statistic for this
264 /// expression.
265 ///
266 /// This is used to propagate statistics down through an expression tree.
267 ///
268 /// # Parameters
269 ///
270 /// * `parent` is the currently known statistics for this expression.
271 /// * `children` are the current statistics for the children of this expression.
272 ///
273 /// # Returns
274 ///
275 /// A `Result` containing a `Vec` of new statistics for the children (in order)
276 /// in case of success, or an error object in case of failure.
277 ///
278 /// If statistics propagation reveals an infeasibility for any child, returns
279 /// [`None`]. If none of the children statistics change as a result of
280 /// propagation, may return an empty vector instead of cloning `children`.
281 /// This is the default (and conservative) return value.
282 ///
283 /// Expressions (should) implement this function and apply Bayes rule to
284 /// reconcile and update parent/children statistics. This involves utilizing
285 /// the independence assumption, and matching on distribution types. The
286 /// default implementation simply creates an unknown distribution if it can
287 /// narrow the range by propagating ranges. This logic loses distribution
288 /// information, but is a safe default.
289 fn propagate_statistics(
290 &self,
291 parent: &Distribution,
292 children: &[&Distribution],
293 ) -> Result<Option<Vec<Distribution>>> {
294 let children_ranges = children
295 .iter()
296 .map(|c| c.range())
297 .collect::<Result<Vec<_>>>()?;
298 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
299 let parent_range = parent.range()?;
300 let Some(propagated_children) =
301 self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
302 else {
303 return Ok(None);
304 };
305 izip!(propagated_children.into_iter(), children_ranges, children)
306 .map(|(new_interval, old_interval, child)| {
307 if new_interval == old_interval {
308 // We weren't able to narrow the range, preserve the old statistics.
309 Ok((*child).clone())
310 } else if new_interval.data_type().eq(&DataType::Boolean) {
311 let dt = old_interval.data_type();
312 let p = if new_interval.eq(&Interval::CERTAINLY_TRUE) {
313 ScalarValue::new_one(&dt)
314 } else if new_interval.eq(&Interval::CERTAINLY_FALSE) {
315 ScalarValue::new_zero(&dt)
316 } else {
317 unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
318 }?;
319 Distribution::new_bernoulli(p)
320 } else {
321 Distribution::new_from_interval(new_interval)
322 }
323 })
324 .collect::<Result<_>>()
325 .map(Some)
326 }
327
328 /// Calculates the properties of this [`PhysicalExpr`] based on its
329 /// children's properties (i.e. order and range), recursively aggregating
330 /// the information from its children. In cases where the [`PhysicalExpr`]
331 /// has no children (e.g., `Literal` or `Column`), these properties should
332 /// be specified externally, as the function defaults to unknown properties.
333 fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
334 Ok(ExprProperties::new_unknown())
335 }
336
337 /// Format this `PhysicalExpr` in nice human readable "SQL" format
338 ///
339 /// Specifically, this format is designed to be readable by humans, at the
340 /// expense of details. Use `Display` or `Debug` for more detailed
341 /// representation.
342 ///
343 /// See the [`fmt_sql`] function for an example of printing `PhysicalExpr`s as SQL.
344 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result;
345
346 /// Take a snapshot of this `PhysicalExpr`, if it is dynamic.
347 ///
348 /// "Dynamic" in this case means containing references to structures that may change
349 /// during plan execution, such as hash tables.
350 ///
351 /// This method is used to capture the current state of `PhysicalExpr`s that may contain
352 /// dynamic references to other operators in order to serialize it over the wire
353 /// or treat it via downcast matching.
354 ///
355 /// You should not call this method directly as it does not handle recursion.
356 /// Instead use [`snapshot_physical_expr`] to handle recursion and capture the
357 /// full state of the `PhysicalExpr`.
358 ///
359 /// This is expected to return "simple" expressions that do not have mutable state
360 /// and are composed of DataFusion's built-in `PhysicalExpr` implementations.
361 /// Callers however should *not* assume anything about the returned expressions
362 /// since callers and implementers may not agree on what "simple" or "built-in"
363 /// means.
364 /// In other words, if you need to serialize a `PhysicalExpr` across the wire
365 /// you should call this method and then try to serialize the result,
366 /// but you should handle unknown or unexpected `PhysicalExpr` implementations gracefully
367 /// just as if you had not called this method at all.
368 ///
369 /// In particular, consider:
370 /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::TopK`
371 /// that is involved in a query with `SELECT * FROM t1 ORDER BY a LIMIT 10`.
372 /// This function may return something like `a >= 12`.
373 /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::joins::HashJoinExec`
374 /// from a query such as `SELECT * FROM t1 JOIN t2 ON t1.a = t2.b`.
375 /// This function may return something like `t2.b IN (1, 5, 7)`.
376 ///
377 /// A system or function that can only deal with a hardcoded set of `PhysicalExpr` implementations
378 /// or needs to serialize this state to bytes may not be able to handle these dynamic references.
379 /// In such cases, we should return a simplified version of the `PhysicalExpr` that does not
380 /// contain these dynamic references.
381 ///
382 /// Systems that implement remote execution of plans, e.g. serialize a portion of the query plan
383 /// and send it across the wire to a remote executor may want to call this method after
384 /// every batch on the source side and broadcast / update the current snapshot to the remote executor.
385 ///
386 /// Note for implementers: this method should *not* handle recursion.
387 /// Recursion is handled in [`snapshot_physical_expr`].
388 fn snapshot(&self) -> Result<Option<Arc<dyn PhysicalExpr>>> {
389 // By default, we return None to indicate that this PhysicalExpr does not
390 // have any dynamic references or state.
391 // This is a safe default behavior.
392 Ok(None)
393 }
394
395 /// Returns the generation of this `PhysicalExpr` for snapshotting purposes.
396 /// The generation is an arbitrary u64 that can be used to track changes
397 /// in the state of the `PhysicalExpr` over time without having to do an exhaustive comparison.
398 /// This is useful to avoid unnecessary computation or serialization if there are no changes to the expression.
399 /// In particular, dynamic expressions that may change over time; this allows cheap checks for changes.
400 /// Static expressions that do not change over time should return 0, as does the default implementation.
401 /// You should not call this method directly as it does not handle recursion.
402 /// Instead use [`snapshot_generation`] to handle recursion and capture the
403 /// full state of the `PhysicalExpr`.
404 fn snapshot_generation(&self) -> u64 {
405 // By default, we return 0 to indicate that this PhysicalExpr does not
406 // have any dynamic references or state.
407 // Since the recursive algorithm XORs the generations of all children the overall
408 // generation will be 0 if no children have a non-zero generation, meaning that
409 // static expressions will always return 0.
410 0
411 }
412
413 /// Returns true if the expression node is volatile, i.e. whether it can return
414 /// different results when evaluated multiple times with the same input.
415 ///
416 /// Note: unlike [`is_volatile`], this function does not consider inputs:
417 /// - `random()` returns `true`,
418 /// - `a + random()` returns `false` (because the operation `+` itself is not volatile.)
419 ///
420 /// The default to this function was set to `false` when it was created
421 /// to avoid imposing API churn on implementers, but this is not a safe default in general.
422 /// It is highly recommended that volatile expressions implement this method and return `true`.
423 /// This default may be removed in the future if it causes problems or we decide to
424 /// eat the cost of the breaking change and require all implementers to make a choice.
425 fn is_volatile_node(&self) -> bool {
426 false
427 }
428}
429
430#[deprecated(
431 since = "50.0.0",
432 note = "Use `datafusion_expr_common::dyn_eq` instead"
433)]
434pub use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
435
436impl PartialEq for dyn PhysicalExpr {
437 fn eq(&self, other: &Self) -> bool {
438 self.dyn_eq(other.as_any())
439 }
440}
441impl Eq for dyn PhysicalExpr {}
442
443impl Hash for dyn PhysicalExpr {
444 fn hash<H: Hasher>(&self, state: &mut H) {
445 self.dyn_hash(state);
446 }
447}
448
449/// Returns a copy of this expr if we change any child according to the pointer comparison.
450/// The size of `children` must be equal to the size of `PhysicalExpr::children()`.
451pub fn with_new_children_if_necessary(
452 expr: Arc<dyn PhysicalExpr>,
453 children: Vec<Arc<dyn PhysicalExpr>>,
454) -> Result<Arc<dyn PhysicalExpr>> {
455 let old_children = expr.children();
456 if children.len() != old_children.len() {
457 internal_err!("PhysicalExpr: Wrong number of children")
458 } else if children.is_empty()
459 || children
460 .iter()
461 .zip(old_children.iter())
462 .any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
463 {
464 Ok(expr.with_new_children(children)?)
465 } else {
466 Ok(expr)
467 }
468}
469
470/// Returns [`Display`] able a list of [`PhysicalExpr`]
471///
472/// Example output: `[a + 1, b]`
473pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
474where
475 T: IntoIterator,
476 T::Item: Display,
477 T::IntoIter: Clone,
478{
479 struct DisplayWrapper<I>(I)
480 where
481 I: Iterator + Clone,
482 I::Item: Display;
483
484 impl<I> Display for DisplayWrapper<I>
485 where
486 I: Iterator + Clone,
487 I::Item: Display,
488 {
489 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
490 let mut iter = self.0.clone();
491 write!(f, "[")?;
492 if let Some(expr) = iter.next() {
493 write!(f, "{expr}")?;
494 }
495 for expr in iter {
496 write!(f, ", {expr}")?;
497 }
498 write!(f, "]")?;
499 Ok(())
500 }
501 }
502
503 DisplayWrapper(exprs.into_iter())
504}
505
506/// Prints a [`PhysicalExpr`] in a SQL-like format
507///
508/// # Example
509/// ```
510/// # // The boilerplate needed to create a `PhysicalExpr` for the example
511/// # use std::any::Any;
512/// use std::collections::HashMap;
513/// # use std::fmt::Formatter;
514/// # use std::sync::Arc;
515/// # use arrow::array::RecordBatch;
516/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema};
517/// # use datafusion_common::Result;
518/// # use datafusion_expr_common::columnar_value::ColumnarValue;
519/// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr};
520/// # #[derive(Debug, PartialEq, Eq, Hash)]
521/// # struct MyExpr {}
522/// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() }
523/// # fn data_type(&self, input_schema: &Schema) -> Result<DataType> { unimplemented!() }
524/// # fn nullable(&self, input_schema: &Schema) -> Result<bool> { unimplemented!() }
525/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { unimplemented!() }
526/// # fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> { unimplemented!() }
527/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>{ unimplemented!() }
528/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> { unimplemented!() }
529/// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") }
530/// # }
531/// # impl std::fmt::Display for MyExpr {fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { unimplemented!() } }
532/// # fn make_physical_expr() -> Arc<dyn PhysicalExpr> { Arc::new(MyExpr{}) }
533/// let expr: Arc<dyn PhysicalExpr> = make_physical_expr();
534/// // wrap the expression in `sql_fmt` which can be used with
535/// // `format!`, `to_string()`, etc
536/// let expr_as_sql = fmt_sql(expr.as_ref());
537/// assert_eq!(
538/// "The SQL: CASE a > b THEN 1 ELSE 0 END",
539/// format!("The SQL: {expr_as_sql}")
540/// );
541/// ```
542pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
543 struct Wrapper<'a> {
544 expr: &'a dyn PhysicalExpr,
545 }
546
547 impl Display for Wrapper<'_> {
548 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
549 self.expr.fmt_sql(f)?;
550 Ok(())
551 }
552 }
553
554 Wrapper { expr }
555}
556
557/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
558///
559/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
560/// This is used to capture the current state of `PhysicalExpr`s that may contain
561/// dynamic references to other operators in order to serialize it over the wire
562/// or treat it via downcast matching.
563///
564/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
565///
566/// # Returns
567///
568/// Returns a snapshot of the `PhysicalExpr` if it is dynamic, otherwise
569/// returns itself.
570pub fn snapshot_physical_expr(
571 expr: Arc<dyn PhysicalExpr>,
572) -> Result<Arc<dyn PhysicalExpr>> {
573 expr.transform_up(|e| {
574 if let Some(snapshot) = e.snapshot()? {
575 Ok(Transformed::yes(snapshot))
576 } else {
577 Ok(Transformed::no(Arc::clone(&e)))
578 }
579 })
580 .data()
581}
582
583/// Check the generation of this `PhysicalExpr`.
584/// Dynamic `PhysicalExpr`s may have a generation that is incremented
585/// every time the state of the `PhysicalExpr` changes.
586/// If the generation changes that means this `PhysicalExpr` or one of its children
587/// has changed since the last time it was evaluated.
588///
589/// This algorithm will not produce collisions as long as the structure of the
590/// `PhysicalExpr` does not change and no `PhysicalExpr` decrements its own generation.
591pub fn snapshot_generation(expr: &Arc<dyn PhysicalExpr>) -> u64 {
592 let mut generation = 0u64;
593 expr.apply(|e| {
594 // Add the current generation of the `PhysicalExpr` to our global generation.
595 generation = generation.wrapping_add(e.snapshot_generation());
596 Ok(TreeNodeRecursion::Continue)
597 })
598 .expect("this traversal is infallible");
599
600 generation
601}
602
603/// Check if the given `PhysicalExpr` is dynamic.
604/// Internally this calls [`snapshot_generation`] to check if the generation is non-zero,
605/// any dynamic `PhysicalExpr` should have a non-zero generation.
606pub fn is_dynamic_physical_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
607 // If the generation is non-zero, then this `PhysicalExpr` is dynamic.
608 snapshot_generation(expr) != 0
609}
610
611/// Returns true if the expression is volatile, i.e. whether it can return different
612/// results when evaluated multiple times with the same input.
613///
614/// For example the function call `RANDOM()` is volatile as each call will
615/// return a different value.
616///
617/// This method recursively checks if any sub-expression is volatile, for example
618/// `1 + RANDOM()` will return `true`.
619pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
620 if expr.is_volatile_node() {
621 return true;
622 }
623 let mut is_volatile = false;
624 expr.apply(|e| {
625 if e.is_volatile_node() {
626 is_volatile = true;
627 Ok(TreeNodeRecursion::Stop)
628 } else {
629 Ok(TreeNodeRecursion::Continue)
630 }
631 })
632 .expect("infallible closure should not fail");
633 is_volatile
634}
635
636#[cfg(test)]
637mod test {
638 use crate::physical_expr::PhysicalExpr;
639 use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
640 use arrow::datatypes::{DataType, Schema};
641 use datafusion_expr_common::columnar_value::ColumnarValue;
642 use std::fmt::{Display, Formatter};
643 use std::sync::Arc;
644
645 #[derive(Debug, PartialEq, Eq, Hash)]
646 struct TestExpr {}
647
648 impl PhysicalExpr for TestExpr {
649 fn as_any(&self) -> &dyn std::any::Any {
650 self
651 }
652
653 fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
654 Ok(DataType::Int64)
655 }
656
657 fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
658 Ok(false)
659 }
660
661 fn evaluate(
662 &self,
663 batch: &RecordBatch,
664 ) -> datafusion_common::Result<ColumnarValue> {
665 let data = vec![1; batch.num_rows()];
666 Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
667 }
668
669 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
670 vec![]
671 }
672
673 fn with_new_children(
674 self: Arc<Self>,
675 _children: Vec<Arc<dyn PhysicalExpr>>,
676 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
677 Ok(Arc::new(Self {}))
678 }
679
680 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
681 f.write_str("TestExpr")
682 }
683 }
684
685 impl Display for TestExpr {
686 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
687 self.fmt_sql(f)
688 }
689 }
690
691 macro_rules! assert_arrays_eq {
692 ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
693 let expected = $EXPECTED.to_array(1).unwrap();
694 let actual = $ACTUAL;
695
696 let actual_array = actual.to_array(expected.len()).unwrap();
697 let actual_ref = actual_array.as_ref();
698 let expected_ref = expected.as_ref();
699 assert!(
700 actual_ref == expected_ref,
701 "{}: expected: {:?}, actual: {:?}",
702 $MESSAGE,
703 $EXPECTED,
704 actual_ref
705 );
706 };
707 }
708
709 fn test_evaluate_selection(
710 batch: &RecordBatch,
711 selection: &BooleanArray,
712 expected: &ColumnarValue,
713 ) {
714 let expr = TestExpr {};
715
716 // First check that the `evaluate_selection` is the expected one
717 let selection_result = expr.evaluate_selection(batch, selection).unwrap();
718 assert_eq!(
719 expected.to_array(1).unwrap().len(),
720 selection_result.to_array(1).unwrap().len(),
721 "evaluate_selection should output row count should match input record batch"
722 );
723 assert_arrays_eq!(
724 expected,
725 &selection_result,
726 "evaluate_selection returned unexpected value"
727 );
728
729 // If we're selecting all rows, the result should be the same as calling `evaluate`
730 // with the full record batch.
731 if (0..batch.num_rows())
732 .all(|row_idx| row_idx < selection.len() && selection.value(row_idx))
733 {
734 let empty_result = expr.evaluate(batch).unwrap();
735
736 assert_arrays_eq!(
737 empty_result,
738 &selection_result,
739 "evaluate_selection does not match unfiltered evaluate result"
740 );
741 }
742 }
743
744 fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) {
745 let expr = TestExpr {};
746
747 // First check that the `evaluate_selection` is the expected one
748 let selection_result = expr.evaluate_selection(batch, selection);
749 assert!(selection_result.is_err(), "evaluate_selection should fail");
750 }
751
752 #[test]
753 pub fn test_evaluate_selection_with_empty_record_batch() {
754 test_evaluate_selection(
755 &RecordBatch::new_empty(Arc::new(Schema::empty())),
756 &BooleanArray::from(vec![false; 0]),
757 &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
758 );
759 }
760
761 #[test]
762 pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
763 test_evaluate_selection_error(
764 &RecordBatch::new_empty(Arc::new(Schema::empty())),
765 &BooleanArray::from(vec![false; 10]),
766 );
767 }
768
769 #[test]
770 pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
771 test_evaluate_selection_error(
772 &RecordBatch::new_empty(Arc::new(Schema::empty())),
773 &BooleanArray::from(vec![true; 10]),
774 );
775 }
776
777 #[test]
778 pub fn test_evaluate_selection_with_non_empty_record_batch() {
779 test_evaluate_selection(
780 unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
781 &BooleanArray::from(vec![true; 10]),
782 &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
783 );
784 }
785
786 #[test]
787 pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection(
788 ) {
789 test_evaluate_selection_error(
790 unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
791 &BooleanArray::from(vec![false; 20]),
792 );
793 }
794
795 #[test]
796 pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection(
797 ) {
798 test_evaluate_selection_error(
799 unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
800 &BooleanArray::from(vec![true; 20]),
801 );
802 }
803
804 #[test]
805 pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection(
806 ) {
807 test_evaluate_selection_error(
808 unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
809 &BooleanArray::from(vec![false; 5]),
810 );
811 }
812
813 #[test]
814 pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection(
815 ) {
816 test_evaluate_selection_error(
817 unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
818 &BooleanArray::from(vec![true; 5]),
819 );
820 }
821}