1mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::SchemaRef;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use datafusion_physical_expr_common::sort_expr::LexOrdering;
37use itertools::Itertools;
38use petgraph::graph::NodeIndex;
39use petgraph::stable_graph::StableGraph;
40
41pub fn split_conjunction(
45 predicate: &Arc<dyn PhysicalExpr>,
46) -> Vec<&Arc<dyn PhysicalExpr>> {
47 split_impl(Operator::And, predicate, vec![])
48}
49
50pub fn conjunction(
55 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
56) -> Arc<dyn PhysicalExpr> {
57 conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
58}
59
60pub fn conjunction_opt(
65 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
66) -> Option<Arc<dyn PhysicalExpr>> {
67 predicates
68 .into_iter()
69 .fold(None, |acc, predicate| match acc {
70 None => Some(predicate),
71 Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
72 })
73}
74
75pub fn split_disjunction(
79 predicate: &Arc<dyn PhysicalExpr>,
80) -> Vec<&Arc<dyn PhysicalExpr>> {
81 split_impl(Operator::Or, predicate, vec![])
82}
83
84fn split_impl<'a>(
85 operator: Operator,
86 predicate: &'a Arc<dyn PhysicalExpr>,
87 mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
88) -> Vec<&'a Arc<dyn PhysicalExpr>> {
89 match predicate.as_any().downcast_ref::<BinaryExpr>() {
90 Some(binary) if binary.op() == &operator => {
91 let exprs = split_impl(operator, binary.left(), exprs);
92 split_impl(operator, binary.right(), exprs)
93 }
94 Some(_) | None => {
95 exprs.push(predicate);
96 exprs
97 }
98 }
99}
100
101pub fn map_columns_before_projection(
110 parent_required: &[Arc<dyn PhysicalExpr>],
111 proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
112) -> Vec<Arc<dyn PhysicalExpr>> {
113 if parent_required.is_empty() {
114 return vec![];
116 }
117 let column_mapping = proj_exprs
118 .iter()
119 .filter_map(|(expr, name)| {
120 expr.as_any()
121 .downcast_ref::<Column>()
122 .map(|column| (name.clone(), column.clone()))
123 })
124 .collect::<HashMap<_, _>>();
125 parent_required
126 .iter()
127 .filter_map(|r| {
128 r.as_any()
129 .downcast_ref::<Column>()
130 .and_then(|c| column_mapping.get(c.name()))
131 })
132 .map(|e| Arc::new(e.clone()) as _)
133 .collect()
134}
135
136pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
139 sequence: impl IntoIterator<Item = T>,
140) -> Vec<Arc<dyn PhysicalExpr>> {
141 sequence
142 .into_iter()
143 .map(|elem| Arc::clone(&elem.borrow().expr))
144 .collect()
145}
146
147pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
150 targets: impl IntoIterator<Item = T>,
151 items: &[Arc<dyn PhysicalExpr>],
152) -> Vec<usize> {
153 targets
154 .into_iter()
155 .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
156 .collect()
157}
158
159pub type ExprTreeNode<T> = ExprContext<Option<T>>;
160
161struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
166 graph: StableGraph<T, usize>,
168 visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
170 constructor: &'a F,
172}
173
174impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
175 fn mutate(
178 &mut self,
179 mut node: ExprTreeNode<NodeIndex>,
180 ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
181 let expr = &node.expr;
183
184 let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
186 Some((_, idx)) => *idx,
188 None => {
192 let node_idx = self.graph.add_node((self.constructor)(&node)?);
193 for expr_node in node.children.iter() {
194 self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
195 }
196 self.visited_plans.push((Arc::clone(expr), node_idx));
197 node_idx
198 }
199 };
200 node.data = Some(node_idx);
202 Ok(Transformed::yes(node))
204 }
205}
206
207pub fn build_dag<T, F>(
209 expr: Arc<dyn PhysicalExpr>,
210 constructor: &F,
211) -> Result<(NodeIndex, StableGraph<T, usize>)>
212where
213 F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
214{
215 let init = ExprTreeNode::new_default(expr);
217 let mut builder = PhysicalExprDAEGBuilder {
219 graph: StableGraph::<T, usize>::new(),
220 visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
221 constructor,
222 };
223 let root = init.transform_up(|node| builder.mutate(node)).data()?;
225 Ok((root.data.unwrap(), builder.graph))
227}
228
229pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
231 let mut columns = HashSet::<Column>::new();
232 expr.apply(|expr| {
233 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
234 columns.get_or_insert_owned(column);
235 }
236 Ok(TreeNodeRecursion::Continue)
237 })
238 .expect("no way to return error during recursion");
240 columns
241}
242
243pub fn reassign_predicate_columns(
246 pred: Arc<dyn PhysicalExpr>,
247 schema: &SchemaRef,
248 ignore_not_found: bool,
249) -> Result<Arc<dyn PhysicalExpr>> {
250 pred.transform_down(|expr| {
251 let expr_any = expr.as_any();
252
253 if let Some(column) = expr_any.downcast_ref::<Column>() {
254 let index = match schema.index_of(column.name()) {
255 Ok(idx) => idx,
256 Err(_) if ignore_not_found => usize::MAX,
257 Err(e) => return Err(e.into()),
258 };
259 return Ok(Transformed::yes(Arc::new(Column::new(
260 column.name(),
261 index,
262 ))));
263 }
264 Ok(Transformed::no(expr))
265 })
266 .data()
267}
268
269pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering {
271 left.iter()
272 .cloned()
273 .chain(right.iter().cloned())
274 .unique()
275 .collect()
276}
277
278#[cfg(test)]
279pub(crate) mod tests {
280 use std::any::Any;
281 use std::fmt::{Display, Formatter};
282
283 use super::*;
284 use crate::expressions::{binary, cast, col, in_list, lit, Literal};
285
286 use arrow::array::{ArrayRef, Float32Array, Float64Array};
287 use arrow::datatypes::{DataType, Field, Schema};
288 use datafusion_common::{exec_err, DataFusionError, ScalarValue};
289 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
290 use datafusion_expr::{
291 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
292 };
293
294 use petgraph::visit::Bfs;
295
296 #[derive(Debug, Clone)]
297 pub struct TestScalarUDF {
298 pub(crate) signature: Signature,
299 }
300
301 impl TestScalarUDF {
302 pub fn new() -> Self {
303 use DataType::*;
304 Self {
305 signature: Signature::uniform(
306 1,
307 vec![Float64, Float32],
308 Volatility::Immutable,
309 ),
310 }
311 }
312 }
313
314 impl ScalarUDFImpl for TestScalarUDF {
315 fn as_any(&self) -> &dyn Any {
316 self
317 }
318 fn name(&self) -> &str {
319 "test-scalar-udf"
320 }
321
322 fn signature(&self) -> &Signature {
323 &self.signature
324 }
325
326 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
327 let arg_type = &arg_types[0];
328
329 match arg_type {
330 DataType::Float32 => Ok(DataType::Float32),
331 _ => Ok(DataType::Float64),
332 }
333 }
334
335 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
336 Ok(input[0].sort_properties)
337 }
338
339 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
340 let args = ColumnarValue::values_to_arrays(&args.args)?;
341
342 let arr: ArrayRef = match args[0].data_type() {
343 DataType::Float64 => Arc::new({
344 let arg = &args[0]
345 .as_any()
346 .downcast_ref::<Float64Array>()
347 .ok_or_else(|| {
348 DataFusionError::Internal(format!(
349 "could not cast {} to {}",
350 self.name(),
351 std::any::type_name::<Float64Array>()
352 ))
353 })?;
354
355 arg.iter()
356 .map(|a| a.map(f64::floor))
357 .collect::<Float64Array>()
358 }),
359 DataType::Float32 => Arc::new({
360 let arg = &args[0]
361 .as_any()
362 .downcast_ref::<Float32Array>()
363 .ok_or_else(|| {
364 DataFusionError::Internal(format!(
365 "could not cast {} to {}",
366 self.name(),
367 std::any::type_name::<Float32Array>()
368 ))
369 })?;
370
371 arg.iter()
372 .map(|a| a.map(f32::floor))
373 .collect::<Float32Array>()
374 }),
375 other => {
376 return exec_err!(
377 "Unsupported data type {other:?} for function {}",
378 self.name()
379 );
380 }
381 };
382 Ok(ColumnarValue::Array(arr))
383 }
384 }
385
386 #[derive(Clone)]
387 struct DummyProperty {
388 expr_type: String,
389 }
390
391 #[derive(Clone)]
394 struct PhysicalExprDummyNode {
395 pub expr: Arc<dyn PhysicalExpr>,
396 pub property: DummyProperty,
397 }
398
399 impl Display for PhysicalExprDummyNode {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 write!(f, "{}", self.expr)
402 }
403 }
404
405 fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
406 let expr = Arc::clone(&node.expr);
407 let dummy_property = if expr.as_any().is::<BinaryExpr>() {
408 "Binary"
409 } else if expr.as_any().is::<Column>() {
410 "Column"
411 } else if expr.as_any().is::<Literal>() {
412 "Literal"
413 } else {
414 "Other"
415 }
416 .to_owned();
417 Ok(PhysicalExprDummyNode {
418 expr,
419 property: DummyProperty {
420 expr_type: dummy_property,
421 },
422 })
423 }
424
425 #[test]
426 fn test_build_dag() -> Result<()> {
427 let schema = Schema::new(vec![
428 Field::new("0", DataType::Int32, true),
429 Field::new("1", DataType::Int32, true),
430 Field::new("2", DataType::Int32, true),
431 ]);
432 let expr = binary(
433 cast(
434 binary(
435 col("0", &schema)?,
436 Operator::Plus,
437 col("1", &schema)?,
438 &schema,
439 )?,
440 &schema,
441 DataType::Int64,
442 )?,
443 Operator::Gt,
444 binary(
445 cast(col("2", &schema)?, &schema, DataType::Int64)?,
446 Operator::Plus,
447 lit(ScalarValue::Int64(Some(10))),
448 &schema,
449 )?,
450 &schema,
451 )?;
452 let mut vector_dummy_props = vec![];
453 let (root, graph) = build_dag(expr, &make_dummy_node)?;
454 let mut bfs = Bfs::new(&graph, root);
455 while let Some(node_index) = bfs.next(&graph) {
456 let node = &graph[node_index];
457 vector_dummy_props.push(node.property.clone());
458 }
459
460 assert_eq!(
461 vector_dummy_props
462 .iter()
463 .filter(|property| property.expr_type == "Binary")
464 .count(),
465 3
466 );
467 assert_eq!(
468 vector_dummy_props
469 .iter()
470 .filter(|property| property.expr_type == "Column")
471 .count(),
472 3
473 );
474 assert_eq!(
475 vector_dummy_props
476 .iter()
477 .filter(|property| property.expr_type == "Literal")
478 .count(),
479 1
480 );
481 assert_eq!(
482 vector_dummy_props
483 .iter()
484 .filter(|property| property.expr_type == "Other")
485 .count(),
486 2
487 );
488 Ok(())
489 }
490
491 #[test]
492 fn test_convert_to_expr() -> Result<()> {
493 let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
494 let sort_expr = vec![PhysicalSortExpr {
495 expr: col("a", &schema)?,
496 options: Default::default(),
497 }];
498 assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
499 Ok(())
500 }
501
502 #[test]
503 fn test_get_indices_of_exprs_strict() {
504 let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
505 Arc::new(Column::new("a", 0)),
506 Arc::new(Column::new("b", 1)),
507 Arc::new(Column::new("c", 2)),
508 Arc::new(Column::new("d", 3)),
509 ];
510 let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
511 Arc::new(Column::new("b", 1)),
512 Arc::new(Column::new("c", 2)),
513 Arc::new(Column::new("a", 0)),
514 ];
515 assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
516 assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
517 }
518
519 #[test]
520 fn test_reassign_predicate_columns_in_list() {
521 let int_field = Field::new("should_not_matter", DataType::Int64, true);
522 let dict_field = Field::new(
523 "id",
524 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
525 true,
526 );
527 let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
528 let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
529 let pred = in_list(
530 Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
531 vec![lit(ScalarValue::Dictionary(
532 Box::new(DataType::Int32),
533 Box::new(ScalarValue::from("2")),
534 ))],
535 &false,
536 &schema_big,
537 )
538 .unwrap();
539
540 let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
541
542 let expected = in_list(
543 Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
544 vec![lit(ScalarValue::Dictionary(
545 Box::new(DataType::Int32),
546 Box::new(ScalarValue::from("2")),
547 ))],
548 &false,
549 &schema_small,
550 )
551 .unwrap();
552
553 assert_eq!(actual.as_ref(), expected.as_ref());
554 }
555
556 #[test]
557 fn test_collect_columns() -> Result<()> {
558 let expr1 = Arc::new(Column::new("col1", 2)) as _;
559 let mut expected = HashSet::new();
560 expected.insert(Column::new("col1", 2));
561 assert_eq!(collect_columns(&expr1), expected);
562
563 let expr2 = Arc::new(Column::new("col2", 5)) as _;
564 let mut expected = HashSet::new();
565 expected.insert(Column::new("col2", 5));
566 assert_eq!(collect_columns(&expr2), expected);
567
568 let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
569 let mut expected = HashSet::new();
570 expected.insert(Column::new("col1", 2));
571 expected.insert(Column::new("col2", 5));
572 assert_eq!(collect_columns(&expr3), expected);
573 Ok(())
574 }
575}