1mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column, Literal};
25use crate::tree_node::ExprContext;
26use crate::{
27 AcrossPartitions, ConstExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr,
28};
29
30use arrow::datatypes::Schema;
31use datafusion_common::tree_node::{
32 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
33};
34use datafusion_common::{HashMap, HashSet, Result};
35use datafusion_expr::Operator;
36
37use petgraph::graph::NodeIndex;
38use petgraph::stable_graph::StableGraph;
39
40pub fn split_conjunction(
44 predicate: &Arc<dyn PhysicalExpr>,
45) -> Vec<&Arc<dyn PhysicalExpr>> {
46 split_impl(Operator::And, predicate, vec![])
47}
48
49impl ConstExpr {
50 pub fn collect_predicate_constants(
63 input_eqs: &EquivalenceProperties,
64 predicate: &Arc<dyn PhysicalExpr>,
65 ) -> Vec<ConstExpr> {
66 fn expr_constant_or_literal(
70 expr: &Arc<dyn PhysicalExpr>,
71 input_eqs: &EquivalenceProperties,
72 ) -> Option<AcrossPartitions> {
73 input_eqs.is_expr_constant(expr).or_else(|| {
74 expr.downcast_ref::<Literal>()
75 .map(|l| AcrossPartitions::Uniform(Some(l.value().clone())))
76 })
77 }
78
79 let mut constants = Vec::new();
80 for conjunction in split_conjunction(predicate) {
81 if let Some(binary) = conjunction.downcast_ref::<BinaryExpr>()
82 && binary.op() == &Operator::Eq
83 {
84 let left_const = expr_constant_or_literal(binary.left(), input_eqs);
88 let right_const = expr_constant_or_literal(binary.right(), input_eqs);
89
90 if let Some(left_across) = left_const {
91 let across = right_const.unwrap_or(left_across);
95 constants.push(ConstExpr::new(Arc::clone(binary.right()), across));
96 } else if let Some(right_across) = right_const {
97 constants
99 .push(ConstExpr::new(Arc::clone(binary.left()), right_across));
100 }
101 }
102 }
103
104 constants
105 }
106}
107
108pub fn conjunction(
113 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
114) -> Arc<dyn PhysicalExpr> {
115 conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
116}
117
118pub fn conjunction_opt(
123 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
124) -> Option<Arc<dyn PhysicalExpr>> {
125 predicates
126 .into_iter()
127 .fold(None, |acc, predicate| match acc {
128 None => Some(predicate),
129 Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
130 })
131}
132
133pub fn split_disjunction(
137 predicate: &Arc<dyn PhysicalExpr>,
138) -> Vec<&Arc<dyn PhysicalExpr>> {
139 split_impl(Operator::Or, predicate, vec![])
140}
141
142fn split_impl<'a>(
143 operator: Operator,
144 predicate: &'a Arc<dyn PhysicalExpr>,
145 mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
146) -> Vec<&'a Arc<dyn PhysicalExpr>> {
147 match predicate.downcast_ref::<BinaryExpr>() {
148 Some(binary) if binary.op() == &operator => {
149 let exprs = split_impl(operator, binary.left(), exprs);
150 split_impl(operator, binary.right(), exprs)
151 }
152 Some(_) | None => {
153 exprs.push(predicate);
154 exprs
155 }
156 }
157}
158
159pub fn map_columns_before_projection(
168 parent_required: &[Arc<dyn PhysicalExpr>],
169 proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
170) -> Vec<Arc<dyn PhysicalExpr>> {
171 if parent_required.is_empty() {
172 return vec![];
174 }
175 let column_mapping = proj_exprs
176 .iter()
177 .filter_map(|(expr, name)| {
178 expr.downcast_ref::<Column>()
179 .map(|column| (name.clone(), column.clone()))
180 })
181 .collect::<HashMap<_, _>>();
182 parent_required
183 .iter()
184 .filter_map(|r| {
185 r.downcast_ref::<Column>()
186 .and_then(|c| column_mapping.get(c.name()))
187 })
188 .map(|e| Arc::new(e.clone()) as _)
189 .collect()
190}
191
192pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
195 sequence: impl IntoIterator<Item = T>,
196) -> Vec<Arc<dyn PhysicalExpr>> {
197 sequence
198 .into_iter()
199 .map(|elem| Arc::clone(&elem.borrow().expr))
200 .collect()
201}
202
203pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
206 targets: impl IntoIterator<Item = T>,
207 items: &[Arc<dyn PhysicalExpr>],
208) -> Vec<usize> {
209 targets
210 .into_iter()
211 .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
212 .collect()
213}
214
215pub type ExprTreeNode<T> = ExprContext<Option<T>>;
216
217struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
222 graph: StableGraph<T, usize>,
224 visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
226 constructor: &'a F,
228}
229
230impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
231 fn mutate(
234 &mut self,
235 mut node: ExprTreeNode<NodeIndex>,
236 ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
237 let expr = &node.expr;
239
240 let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
242 Some((_, idx)) => *idx,
244 None => {
248 let node_idx = self.graph.add_node((self.constructor)(&node)?);
249 for expr_node in node.children.iter() {
250 self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
251 }
252 self.visited_plans.push((Arc::clone(expr), node_idx));
253 node_idx
254 }
255 };
256 node.data = Some(node_idx);
258 Ok(Transformed::yes(node))
260 }
261}
262
263pub fn build_dag<T, F>(
265 expr: Arc<dyn PhysicalExpr>,
266 constructor: &F,
267) -> Result<(NodeIndex, StableGraph<T, usize>)>
268where
269 F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
270{
271 let init = ExprTreeNode::new_default(expr);
273 let mut builder = PhysicalExprDAEGBuilder {
275 graph: StableGraph::<T, usize>::new(),
276 visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
277 constructor,
278 };
279 let root = init.transform_up(|node| builder.mutate(node)).data()?;
281 Ok((root.data.unwrap(), builder.graph))
283}
284
285pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
287 let mut columns = HashSet::<Column>::new();
288 expr.apply(|expr| {
289 if let Some(column) = expr.downcast_ref::<Column>() {
290 columns.get_or_insert_with(column, |c| c.clone());
291 }
292 Ok(TreeNodeRecursion::Continue)
293 })
294 .expect("no way to return error during recursion");
296 columns
297}
298
299pub fn reassign_expr_columns(
309 expr: Arc<dyn PhysicalExpr>,
310 schema: &Schema,
311) -> Result<Arc<dyn PhysicalExpr>> {
312 expr.transform_down(|expr| {
313 if let Some(column) = expr.downcast_ref::<Column>() {
314 let index = schema.index_of(column.name())?;
315
316 return Ok(Transformed::yes(Arc::new(Column::new(
317 column.name(),
318 index,
319 ))));
320 }
321 Ok(Transformed::no(expr))
322 })
323 .data()
324}
325
326#[cfg(test)]
327pub(crate) mod tests {
328
329 use std::fmt::{Display, Formatter};
330
331 use super::*;
332 use crate::expressions::{Literal, binary, cast, col, in_list, lit};
333
334 use arrow::array::{ArrayRef, Float32Array, Float64Array};
335 use arrow::datatypes::{DataType, Field};
336 use datafusion_common::{ScalarValue, exec_err, internal_datafusion_err};
337 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
338 use datafusion_expr::{
339 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
340 };
341
342 use petgraph::visit::Bfs;
343
344 #[derive(Debug, PartialEq, Eq, Hash)]
345 pub struct TestScalarUDF {
346 pub(crate) signature: Signature,
347 }
348
349 impl TestScalarUDF {
350 pub fn new() -> Self {
351 use DataType::*;
352 Self {
353 signature: Signature::uniform(
354 1,
355 vec![Float64, Float32],
356 Volatility::Immutable,
357 ),
358 }
359 }
360 }
361
362 impl ScalarUDFImpl for TestScalarUDF {
363 fn name(&self) -> &str {
364 "test-scalar-udf"
365 }
366
367 fn signature(&self) -> &Signature {
368 &self.signature
369 }
370
371 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
372 let arg_type = &arg_types[0];
373
374 match arg_type {
375 DataType::Float32 => Ok(DataType::Float32),
376 _ => Ok(DataType::Float64),
377 }
378 }
379
380 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
381 Ok(input[0].sort_properties)
382 }
383
384 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
385 let args = ColumnarValue::values_to_arrays(&args.args)?;
386
387 let arr: ArrayRef = match args[0].data_type() {
388 DataType::Float64 => Arc::new({
389 let arg = &args[0]
390 .as_any()
391 .downcast_ref::<Float64Array>()
392 .ok_or_else(|| {
393 internal_datafusion_err!(
394 "could not cast {} to {}",
395 self.name(),
396 std::any::type_name::<Float64Array>()
397 )
398 })?;
399
400 arg.iter()
401 .map(|a| a.map(f64::floor))
402 .collect::<Float64Array>()
403 }),
404 DataType::Float32 => Arc::new({
405 let arg = &args[0]
406 .as_any()
407 .downcast_ref::<Float32Array>()
408 .ok_or_else(|| {
409 internal_datafusion_err!(
410 "could not cast {} to {}",
411 self.name(),
412 std::any::type_name::<Float32Array>()
413 )
414 })?;
415
416 arg.iter()
417 .map(|a| a.map(f32::floor))
418 .collect::<Float32Array>()
419 }),
420 other => {
421 return exec_err!(
422 "Unsupported data type {other:?} for function {}",
423 self.name()
424 );
425 }
426 };
427 Ok(ColumnarValue::Array(arr))
428 }
429 }
430
431 #[derive(Clone)]
432 struct DummyProperty {
433 expr_type: String,
434 }
435
436 #[derive(Clone)]
439 struct PhysicalExprDummyNode {
440 pub expr: Arc<dyn PhysicalExpr>,
441 pub property: DummyProperty,
442 }
443
444 impl Display for PhysicalExprDummyNode {
445 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446 write!(f, "{}", self.expr)
447 }
448 }
449
450 fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
451 let expr = Arc::clone(&node.expr);
452 let dummy_property = if expr.is::<BinaryExpr>() {
453 "Binary"
454 } else if expr.is::<Column>() {
455 "Column"
456 } else if expr.is::<Literal>() {
457 "Literal"
458 } else {
459 "Other"
460 }
461 .to_owned();
462 Ok(PhysicalExprDummyNode {
463 expr,
464 property: DummyProperty {
465 expr_type: dummy_property,
466 },
467 })
468 }
469
470 #[test]
471 fn test_build_dag() -> Result<()> {
472 let schema = Schema::new(vec![
473 Field::new("0", DataType::Int32, true),
474 Field::new("1", DataType::Int32, true),
475 Field::new("2", DataType::Int32, true),
476 ]);
477 let expr = binary(
478 cast(
479 binary(
480 col("0", &schema)?,
481 Operator::Plus,
482 col("1", &schema)?,
483 &schema,
484 )?,
485 &schema,
486 DataType::Int64,
487 )?,
488 Operator::Gt,
489 binary(
490 cast(col("2", &schema)?, &schema, DataType::Int64)?,
491 Operator::Plus,
492 lit(ScalarValue::Int64(Some(10))),
493 &schema,
494 )?,
495 &schema,
496 )?;
497 let mut vector_dummy_props = vec![];
498 let (root, graph) = build_dag(expr, &make_dummy_node)?;
499 let mut bfs = Bfs::new(&graph, root);
500 while let Some(node_index) = bfs.next(&graph) {
501 let node = &graph[node_index];
502 vector_dummy_props.push(node.property.clone());
503 }
504
505 assert_eq!(
506 vector_dummy_props
507 .iter()
508 .filter(|property| property.expr_type == "Binary")
509 .count(),
510 3
511 );
512 assert_eq!(
513 vector_dummy_props
514 .iter()
515 .filter(|property| property.expr_type == "Column")
516 .count(),
517 3
518 );
519 assert_eq!(
520 vector_dummy_props
521 .iter()
522 .filter(|property| property.expr_type == "Literal")
523 .count(),
524 1
525 );
526 assert_eq!(
527 vector_dummy_props
528 .iter()
529 .filter(|property| property.expr_type == "Other")
530 .count(),
531 2
532 );
533 Ok(())
534 }
535
536 #[test]
537 fn test_convert_to_expr() -> Result<()> {
538 let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
539 let sort_expr = vec![PhysicalSortExpr {
540 expr: col("a", &schema)?,
541 options: Default::default(),
542 }];
543 assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
544 Ok(())
545 }
546
547 #[test]
548 fn test_get_indices_of_exprs_strict() {
549 let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
550 Arc::new(Column::new("a", 0)),
551 Arc::new(Column::new("b", 1)),
552 Arc::new(Column::new("c", 2)),
553 Arc::new(Column::new("d", 3)),
554 ];
555 let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
556 Arc::new(Column::new("b", 1)),
557 Arc::new(Column::new("c", 2)),
558 Arc::new(Column::new("a", 0)),
559 ];
560 assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
561 assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
562 }
563
564 #[test]
565 fn test_reassign_expr_columns_in_list() {
566 let int_field = Field::new("should_not_matter", DataType::Int64, true);
567 let dict_field = Field::new(
568 "id",
569 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
570 true,
571 );
572 let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
573 let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
574 let pred = in_list(
575 Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
576 vec![lit(ScalarValue::Dictionary(
577 Box::new(DataType::Int32),
578 Box::new(ScalarValue::from("2")),
579 ))],
580 &false,
581 &schema_big,
582 )
583 .unwrap();
584
585 let actual = reassign_expr_columns(pred, &schema_small).unwrap();
586
587 let expected = in_list(
588 Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
589 vec![lit(ScalarValue::Dictionary(
590 Box::new(DataType::Int32),
591 Box::new(ScalarValue::from("2")),
592 ))],
593 &false,
594 &schema_small,
595 )
596 .unwrap();
597
598 assert_eq!(actual.as_ref(), expected.as_ref());
599 }
600
601 #[test]
602 fn test_collect_columns() -> Result<()> {
603 let expr1 = Arc::new(Column::new("col1", 2)) as _;
604 let mut expected = HashSet::new();
605 expected.insert(Column::new("col1", 2));
606 assert_eq!(collect_columns(&expr1), expected);
607
608 let expr2 = Arc::new(Column::new("col2", 5)) as _;
609 let mut expected = HashSet::new();
610 expected.insert(Column::new("col2", 5));
611 assert_eq!(collect_columns(&expr2), expected);
612
613 let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
614 let mut expected = HashSet::new();
615 expected.insert(Column::new("col1", 2));
616 expected.insert(Column::new("col2", 5));
617 assert_eq!(collect_columns(&expr3), expected);
618 Ok(())
619 }
620
621 #[test]
622 fn test_collect_predicate_constants_propagates_uniform_literal_value() -> Result<()> {
623 let schema = Arc::new(Schema::new(vec![Field::new(
624 "ticker",
625 DataType::Utf8,
626 false,
627 )]));
628 let predicate = binary(
629 col("ticker", schema.as_ref())?,
630 Operator::Eq,
631 lit(ScalarValue::Utf8(Some("NGJ26".to_string()))),
632 schema.as_ref(),
633 )?;
634 let eq_properties = EquivalenceProperties::new(schema);
635
636 let constants =
637 ConstExpr::collect_predicate_constants(&eq_properties, &predicate);
638
639 assert_eq!(constants.len(), 1);
640 assert_eq!(
641 constants[0].across_partitions,
642 AcrossPartitions::Uniform(Some(ScalarValue::Utf8(Some("NGJ26".to_string()))))
643 );
644
645 Ok(())
646 }
647}