datafusion_physical_expr/equivalence/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::borrow::Borrow;
19use std::sync::Arc;
20
21use crate::PhysicalExpr;
22
23use arrow::compute::SortOptions;
24use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
25
26mod class;
27mod ordering;
28mod projection;
29mod properties;
30
31pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
32pub use ordering::OrderingEquivalenceClass;
33pub use projection::ProjectionMapping;
34pub use properties::{
35    calculate_union, join_equivalence_properties, EquivalenceProperties,
36};
37
38// Convert each tuple to a `PhysicalSortExpr` and construct a vector.
39pub fn convert_to_sort_exprs<T: Borrow<Arc<dyn PhysicalExpr>>>(
40    args: &[(T, SortOptions)],
41) -> Vec<PhysicalSortExpr> {
42    args.iter()
43        .map(|(expr, options)| PhysicalSortExpr::new(Arc::clone(expr.borrow()), *options))
44        .collect()
45}
46
47// Convert each vector of tuples to a `LexOrdering`.
48pub fn convert_to_orderings<T: Borrow<Arc<dyn PhysicalExpr>>>(
49    args: &[Vec<(T, SortOptions)>],
50) -> Vec<LexOrdering> {
51    args.iter()
52        .filter_map(|sort_exprs| LexOrdering::new(convert_to_sort_exprs(sort_exprs)))
53        .collect()
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use crate::expressions::{col, Column};
60    use crate::{LexRequirement, PhysicalSortExpr};
61
62    use arrow::compute::SortOptions;
63    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
64    use datafusion_common::{plan_err, Result};
65    use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement;
66
67    /// Converts a string to a physical sort expression
68    ///
69    /// # Example
70    /// * `"a"` -> (`"a"`, `SortOptions::default()`)
71    /// * `"a ASC"` -> (`"a"`, `SortOptions { descending: false, nulls_first: false }`)
72    pub fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr {
73        let mut parts = name.split_whitespace();
74        let name = parts.next().expect("empty sort expression");
75        let mut sort_expr = PhysicalSortExpr::new(
76            col(name, schema).expect("invalid column name"),
77            SortOptions::default(),
78        );
79
80        if let Some(options) = parts.next() {
81            sort_expr = match options {
82                "ASC" => sort_expr.asc(),
83                "DESC" => sort_expr.desc(),
84                _ => panic!(
85                    "unknown sort options. Expected 'ASC' or 'DESC', got {options}"
86                ),
87            }
88        }
89
90        assert!(
91            parts.next().is_none(),
92            "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got  '{name}'"
93        );
94
95        sort_expr
96    }
97
98    pub fn output_schema(
99        mapping: &ProjectionMapping,
100        input_schema: &Arc<Schema>,
101    ) -> Result<SchemaRef> {
102        // Calculate output schema:
103        let mut fields = vec![];
104        for (source, targets) in mapping.iter() {
105            let data_type = source.data_type(input_schema)?;
106            let nullable = source.nullable(input_schema)?;
107            for (target, _) in targets.iter() {
108                let Some(column) = target.as_any().downcast_ref::<Column>() else {
109                    return plan_err!("Expects to have column");
110                };
111                fields.push(Field::new(column.name(), data_type.clone(), nullable));
112            }
113        }
114
115        let output_schema = Arc::new(Schema::new_with_metadata(
116            fields,
117            input_schema.metadata().clone(),
118        ));
119
120        Ok(output_schema)
121    }
122
123    // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h)
124    pub fn create_test_schema() -> Result<SchemaRef> {
125        let a = Field::new("a", DataType::Int32, true);
126        let b = Field::new("b", DataType::Int32, true);
127        let c = Field::new("c", DataType::Int32, true);
128        let d = Field::new("d", DataType::Int32, true);
129        let e = Field::new("e", DataType::Int32, true);
130        let f = Field::new("f", DataType::Int32, true);
131        let g = Field::new("g", DataType::Int32, true);
132        let h = Field::new("h", DataType::Int32, true);
133        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h]));
134
135        Ok(schema)
136    }
137
138    /// Construct a schema with following properties
139    /// Schema satisfies following orderings:
140    /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC]
141    /// and
142    /// Column [a=c] (e.g they are aliases).
143    pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> {
144        let test_schema = create_test_schema()?;
145        let col_a = col("a", &test_schema)?;
146        let col_b = col("b", &test_schema)?;
147        let col_c = col("c", &test_schema)?;
148        let col_d = col("d", &test_schema)?;
149        let col_e = col("e", &test_schema)?;
150        let col_f = col("f", &test_schema)?;
151        let col_g = col("g", &test_schema)?;
152        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
153        eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_c))?;
154
155        let option_asc = SortOptions {
156            descending: false,
157            nulls_first: false,
158        };
159        let option_desc = SortOptions {
160            descending: true,
161            nulls_first: true,
162        };
163        let orderings = vec![
164            // [a ASC]
165            vec![(col_a, option_asc)],
166            // [d ASC, b ASC]
167            vec![(col_d, option_asc), (col_b, option_asc)],
168            // [e DESC, f ASC, g ASC]
169            vec![
170                (col_e, option_desc),
171                (col_f, option_asc),
172                (col_g, option_asc),
173            ],
174        ];
175        let orderings = convert_to_orderings(&orderings);
176        eq_properties.add_orderings(orderings);
177        Ok((test_schema, eq_properties))
178    }
179
180    // Convert each tuple to a `PhysicalSortRequirement` and construct a
181    // a `LexRequirement` from them.
182    pub fn convert_to_sort_reqs(
183        args: &[(&Arc<dyn PhysicalExpr>, Option<SortOptions>)],
184    ) -> LexRequirement {
185        let exprs = args.iter().map(|(expr, options)| {
186            PhysicalSortRequirement::new(Arc::clone(*expr), *options)
187        });
188        LexRequirement::new(exprs).unwrap()
189    }
190
191    #[test]
192    fn add_equal_conditions_test() -> Result<()> {
193        let schema = Arc::new(Schema::new(vec![
194            Field::new("a", DataType::Int64, true),
195            Field::new("b", DataType::Int64, true),
196            Field::new("c", DataType::Int64, true),
197            Field::new("x", DataType::Int64, true),
198            Field::new("y", DataType::Int64, true),
199        ]));
200
201        let mut eq_properties = EquivalenceProperties::new(schema);
202        let col_a = Arc::new(Column::new("a", 0)) as _;
203        let col_b = Arc::new(Column::new("b", 1)) as _;
204        let col_c = Arc::new(Column::new("c", 2)) as _;
205        let col_x = Arc::new(Column::new("x", 3)) as _;
206        let col_y = Arc::new(Column::new("y", 4)) as _;
207
208        // a and b are aliases
209        eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?;
210        assert_eq!(eq_properties.eq_group().len(), 1);
211
212        // This new entry is redundant, size shouldn't increase
213        eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?;
214        assert_eq!(eq_properties.eq_group().len(), 1);
215        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
216        assert_eq!(eq_groups.len(), 2);
217        assert!(eq_groups.contains(&col_a));
218        assert!(eq_groups.contains(&col_b));
219
220        // b and c are aliases. Existing equivalence class should expand,
221        // however there shouldn't be any new equivalence class
222        eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?;
223        assert_eq!(eq_properties.eq_group().len(), 1);
224        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
225        assert_eq!(eq_groups.len(), 3);
226        assert!(eq_groups.contains(&col_a));
227        assert!(eq_groups.contains(&col_b));
228        assert!(eq_groups.contains(&col_c));
229
230        // This is a new set of equality. Hence equivalent class count should be 2.
231        eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?;
232        assert_eq!(eq_properties.eq_group().len(), 2);
233
234        // This equality bridges distinct equality sets.
235        // Hence equivalent class count should decrease from 2 to 1.
236        eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?;
237        assert_eq!(eq_properties.eq_group().len(), 1);
238        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
239        assert_eq!(eq_groups.len(), 5);
240        assert!(eq_groups.contains(&col_a));
241        assert!(eq_groups.contains(&col_b));
242        assert!(eq_groups.contains(&col_c));
243        assert!(eq_groups.contains(&col_x));
244        assert!(eq_groups.contains(&col_y));
245
246        Ok(())
247    }
248}