datafusion_physical_expr/equivalence/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::expressions::Column;
21use crate::{LexRequirement, PhysicalExpr};
22
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24
25mod class;
26mod ordering;
27mod projection;
28mod properties;
29
30pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
31pub use ordering::OrderingEquivalenceClass;
32pub use projection::ProjectionMapping;
33pub use properties::{
34    calculate_union, join_equivalence_properties, EquivalenceProperties,
35};
36
37/// This function constructs a duplicate-free `LexOrderingReq` by filtering out
38/// duplicate entries that have same physical expression inside. For example,
39/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`.
40///
41/// It will also filter out entries that are ordered if the next entry is;
42/// for instance, `vec![floor(a) Some(ASC), a Some(ASC)]` will be collapsed to
43/// `vec![a Some(ASC)]`.
44#[deprecated(since = "45.0.0", note = "Use LexRequirement::collapse")]
45pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement {
46    input.collapse()
47}
48
49/// Adds the `offset` value to `Column` indices inside `expr`. This function is
50/// generally used during the update of the right table schema in join operations.
51pub fn add_offset_to_expr(
52    expr: Arc<dyn PhysicalExpr>,
53    offset: usize,
54) -> Arc<dyn PhysicalExpr> {
55    expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
56        Some(col) => Ok(Transformed::yes(Arc::new(Column::new(
57            col.name(),
58            offset + col.index(),
59        )))),
60        None => Ok(Transformed::no(e)),
61    })
62    .data()
63    .unwrap()
64    // Note that we can safely unwrap here since our transform always returns
65    // an `Ok` value.
66}
67
68#[cfg(test)]
69mod tests {
70
71    use super::*;
72    use crate::expressions::col;
73    use crate::PhysicalSortExpr;
74
75    use arrow::compute::SortOptions;
76    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
77    use datafusion_common::{plan_datafusion_err, Result};
78    use datafusion_physical_expr_common::sort_expr::{
79        LexOrdering, PhysicalSortRequirement,
80    };
81
82    /// Converts a string to a physical sort expression
83    ///
84    /// # Example
85    /// * `"a"` -> (`"a"`, `SortOptions::default()`)
86    /// * `"a ASC"` -> (`"a"`, `SortOptions { descending: false, nulls_first: false }`)
87    pub fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr {
88        let mut parts = name.split_whitespace();
89        let name = parts.next().expect("empty sort expression");
90        let mut sort_expr = PhysicalSortExpr::new(
91            col(name, schema).expect("invalid column name"),
92            SortOptions::default(),
93        );
94
95        if let Some(options) = parts.next() {
96            sort_expr = match options {
97                "ASC" => sort_expr.asc(),
98                "DESC" => sort_expr.desc(),
99                _ => panic!(
100                    "unknown sort options. Expected 'ASC' or 'DESC', got {options}"
101                ),
102            }
103        }
104
105        assert!(
106            parts.next().is_none(),
107            "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got  '{name}'"
108        );
109
110        sort_expr
111    }
112
113    pub fn output_schema(
114        mapping: &ProjectionMapping,
115        input_schema: &Arc<Schema>,
116    ) -> Result<SchemaRef> {
117        // Calculate output schema
118        let fields: Result<Vec<Field>> = mapping
119            .iter()
120            .map(|(source, target)| {
121                let name = target
122                    .as_any()
123                    .downcast_ref::<Column>()
124                    .ok_or_else(|| plan_datafusion_err!("Expects to have column"))?
125                    .name();
126                let field = Field::new(
127                    name,
128                    source.data_type(input_schema)?,
129                    source.nullable(input_schema)?,
130                );
131
132                Ok(field)
133            })
134            .collect();
135
136        let output_schema = Arc::new(Schema::new_with_metadata(
137            fields?,
138            input_schema.metadata().clone(),
139        ));
140
141        Ok(output_schema)
142    }
143
144    // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h)
145    pub fn create_test_schema() -> Result<SchemaRef> {
146        let a = Field::new("a", DataType::Int32, true);
147        let b = Field::new("b", DataType::Int32, true);
148        let c = Field::new("c", DataType::Int32, true);
149        let d = Field::new("d", DataType::Int32, true);
150        let e = Field::new("e", DataType::Int32, true);
151        let f = Field::new("f", DataType::Int32, true);
152        let g = Field::new("g", DataType::Int32, true);
153        let h = Field::new("h", DataType::Int32, true);
154        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h]));
155
156        Ok(schema)
157    }
158
159    /// Construct a schema with following properties
160    /// Schema satisfies following orderings:
161    /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC]
162    /// and
163    /// Column [a=c] (e.g they are aliases).
164    pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> {
165        let test_schema = create_test_schema()?;
166        let col_a = &col("a", &test_schema)?;
167        let col_b = &col("b", &test_schema)?;
168        let col_c = &col("c", &test_schema)?;
169        let col_d = &col("d", &test_schema)?;
170        let col_e = &col("e", &test_schema)?;
171        let col_f = &col("f", &test_schema)?;
172        let col_g = &col("g", &test_schema)?;
173        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
174        eq_properties.add_equal_conditions(col_a, col_c)?;
175
176        let option_asc = SortOptions {
177            descending: false,
178            nulls_first: false,
179        };
180        let option_desc = SortOptions {
181            descending: true,
182            nulls_first: true,
183        };
184        let orderings = vec![
185            // [a ASC]
186            vec![(col_a, option_asc)],
187            // [d ASC, b ASC]
188            vec![(col_d, option_asc), (col_b, option_asc)],
189            // [e DESC, f ASC, g ASC]
190            vec![
191                (col_e, option_desc),
192                (col_f, option_asc),
193                (col_g, option_asc),
194            ],
195        ];
196        let orderings = convert_to_orderings(&orderings);
197        eq_properties.add_new_orderings(orderings);
198        Ok((test_schema, eq_properties))
199    }
200
201    // Convert each tuple to PhysicalSortRequirement
202    pub fn convert_to_sort_reqs(
203        in_data: &[(&Arc<dyn PhysicalExpr>, Option<SortOptions>)],
204    ) -> LexRequirement {
205        in_data
206            .iter()
207            .map(|(expr, options)| {
208                PhysicalSortRequirement::new(Arc::clone(*expr), *options)
209            })
210            .collect()
211    }
212
213    // Convert each tuple to PhysicalSortExpr
214    pub fn convert_to_sort_exprs(
215        in_data: &[(&Arc<dyn PhysicalExpr>, SortOptions)],
216    ) -> LexOrdering {
217        in_data
218            .iter()
219            .map(|(expr, options)| PhysicalSortExpr {
220                expr: Arc::clone(*expr),
221                options: *options,
222            })
223            .collect()
224    }
225
226    // Convert each inner tuple to PhysicalSortExpr
227    pub fn convert_to_orderings(
228        orderings: &[Vec<(&Arc<dyn PhysicalExpr>, SortOptions)>],
229    ) -> Vec<LexOrdering> {
230        orderings
231            .iter()
232            .map(|sort_exprs| convert_to_sort_exprs(sort_exprs))
233            .collect()
234    }
235
236    // Convert each tuple to PhysicalSortExpr
237    pub fn convert_to_sort_exprs_owned(
238        in_data: &[(Arc<dyn PhysicalExpr>, SortOptions)],
239    ) -> LexOrdering {
240        LexOrdering::new(
241            in_data
242                .iter()
243                .map(|(expr, options)| PhysicalSortExpr {
244                    expr: Arc::clone(expr),
245                    options: *options,
246                })
247                .collect(),
248        )
249    }
250
251    // Convert each inner tuple to PhysicalSortExpr
252    pub fn convert_to_orderings_owned(
253        orderings: &[Vec<(Arc<dyn PhysicalExpr>, SortOptions)>],
254    ) -> Vec<LexOrdering> {
255        orderings
256            .iter()
257            .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs))
258            .collect()
259    }
260
261    #[test]
262    fn add_equal_conditions_test() -> Result<()> {
263        let schema = Arc::new(Schema::new(vec![
264            Field::new("a", DataType::Int64, true),
265            Field::new("b", DataType::Int64, true),
266            Field::new("c", DataType::Int64, true),
267            Field::new("x", DataType::Int64, true),
268            Field::new("y", DataType::Int64, true),
269        ]));
270
271        let mut eq_properties = EquivalenceProperties::new(schema);
272        let col_a_expr = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
273        let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
274        let col_c_expr = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
275        let col_x_expr = Arc::new(Column::new("x", 3)) as Arc<dyn PhysicalExpr>;
276        let col_y_expr = Arc::new(Column::new("y", 4)) as Arc<dyn PhysicalExpr>;
277
278        // a and b are aliases
279        eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?;
280        assert_eq!(eq_properties.eq_group().len(), 1);
281
282        // This new entry is redundant, size shouldn't increase
283        eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?;
284        assert_eq!(eq_properties.eq_group().len(), 1);
285        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
286        assert_eq!(eq_groups.len(), 2);
287        assert!(eq_groups.contains(&col_a_expr));
288        assert!(eq_groups.contains(&col_b_expr));
289
290        // b and c are aliases. Existing equivalence class should expand,
291        // however there shouldn't be any new equivalence class
292        eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?;
293        assert_eq!(eq_properties.eq_group().len(), 1);
294        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
295        assert_eq!(eq_groups.len(), 3);
296        assert!(eq_groups.contains(&col_a_expr));
297        assert!(eq_groups.contains(&col_b_expr));
298        assert!(eq_groups.contains(&col_c_expr));
299
300        // This is a new set of equality. Hence equivalent class count should be 2.
301        eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?;
302        assert_eq!(eq_properties.eq_group().len(), 2);
303
304        // This equality bridges distinct equality sets.
305        // Hence equivalent class count should decrease from 2 to 1.
306        eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?;
307        assert_eq!(eq_properties.eq_group().len(), 1);
308        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
309        assert_eq!(eq_groups.len(), 5);
310        assert!(eq_groups.contains(&col_a_expr));
311        assert!(eq_groups.contains(&col_b_expr));
312        assert!(eq_groups.contains(&col_c_expr));
313        assert!(eq_groups.contains(&col_x_expr));
314        assert!(eq_groups.contains(&col_y_expr));
315
316        Ok(())
317    }
318}