1mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::Schema;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use petgraph::graph::NodeIndex;
37use petgraph::stable_graph::StableGraph;
38
39pub fn split_conjunction(
43 predicate: &Arc<dyn PhysicalExpr>,
44) -> Vec<&Arc<dyn PhysicalExpr>> {
45 split_impl(Operator::And, predicate, vec![])
46}
47
48pub fn conjunction(
53 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
54) -> Arc<dyn PhysicalExpr> {
55 conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
56}
57
58pub fn conjunction_opt(
63 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
64) -> Option<Arc<dyn PhysicalExpr>> {
65 predicates
66 .into_iter()
67 .fold(None, |acc, predicate| match acc {
68 None => Some(predicate),
69 Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
70 })
71}
72
73pub fn split_disjunction(
77 predicate: &Arc<dyn PhysicalExpr>,
78) -> Vec<&Arc<dyn PhysicalExpr>> {
79 split_impl(Operator::Or, predicate, vec![])
80}
81
82fn split_impl<'a>(
83 operator: Operator,
84 predicate: &'a Arc<dyn PhysicalExpr>,
85 mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
86) -> Vec<&'a Arc<dyn PhysicalExpr>> {
87 match predicate.as_any().downcast_ref::<BinaryExpr>() {
88 Some(binary) if binary.op() == &operator => {
89 let exprs = split_impl(operator, binary.left(), exprs);
90 split_impl(operator, binary.right(), exprs)
91 }
92 Some(_) | None => {
93 exprs.push(predicate);
94 exprs
95 }
96 }
97}
98
99pub fn map_columns_before_projection(
108 parent_required: &[Arc<dyn PhysicalExpr>],
109 proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
110) -> Vec<Arc<dyn PhysicalExpr>> {
111 if parent_required.is_empty() {
112 return vec![];
114 }
115 let column_mapping = proj_exprs
116 .iter()
117 .filter_map(|(expr, name)| {
118 expr.as_any()
119 .downcast_ref::<Column>()
120 .map(|column| (name.clone(), column.clone()))
121 })
122 .collect::<HashMap<_, _>>();
123 parent_required
124 .iter()
125 .filter_map(|r| {
126 r.as_any()
127 .downcast_ref::<Column>()
128 .and_then(|c| column_mapping.get(c.name()))
129 })
130 .map(|e| Arc::new(e.clone()) as _)
131 .collect()
132}
133
134pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
137 sequence: impl IntoIterator<Item = T>,
138) -> Vec<Arc<dyn PhysicalExpr>> {
139 sequence
140 .into_iter()
141 .map(|elem| Arc::clone(&elem.borrow().expr))
142 .collect()
143}
144
145pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
148 targets: impl IntoIterator<Item = T>,
149 items: &[Arc<dyn PhysicalExpr>],
150) -> Vec<usize> {
151 targets
152 .into_iter()
153 .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
154 .collect()
155}
156
157pub type ExprTreeNode<T> = ExprContext<Option<T>>;
158
159struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
164 graph: StableGraph<T, usize>,
166 visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
168 constructor: &'a F,
170}
171
172impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
173 fn mutate(
176 &mut self,
177 mut node: ExprTreeNode<NodeIndex>,
178 ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
179 let expr = &node.expr;
181
182 let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
184 Some((_, idx)) => *idx,
186 None => {
190 let node_idx = self.graph.add_node((self.constructor)(&node)?);
191 for expr_node in node.children.iter() {
192 self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
193 }
194 self.visited_plans.push((Arc::clone(expr), node_idx));
195 node_idx
196 }
197 };
198 node.data = Some(node_idx);
200 Ok(Transformed::yes(node))
202 }
203}
204
205pub fn build_dag<T, F>(
207 expr: Arc<dyn PhysicalExpr>,
208 constructor: &F,
209) -> Result<(NodeIndex, StableGraph<T, usize>)>
210where
211 F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
212{
213 let init = ExprTreeNode::new_default(expr);
215 let mut builder = PhysicalExprDAEGBuilder {
217 graph: StableGraph::<T, usize>::new(),
218 visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
219 constructor,
220 };
221 let root = init.transform_up(|node| builder.mutate(node)).data()?;
223 Ok((root.data.unwrap(), builder.graph))
225}
226
227pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
229 let mut columns = HashSet::<Column>::new();
230 expr.apply(|expr| {
231 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
232 columns.get_or_insert_owned(column);
233 }
234 Ok(TreeNodeRecursion::Continue)
235 })
236 .expect("no way to return error during recursion");
238 columns
239}
240
241pub fn reassign_predicate_columns(
244 pred: Arc<dyn PhysicalExpr>,
245 schema: &Schema,
246 ignore_not_found: bool,
247) -> Result<Arc<dyn PhysicalExpr>> {
248 pred.transform_down(|expr| {
249 let expr_any = expr.as_any();
250
251 if let Some(column) = expr_any.downcast_ref::<Column>() {
252 let index = match schema.index_of(column.name()) {
253 Ok(idx) => idx,
254 Err(_) if ignore_not_found => usize::MAX,
255 Err(e) => return Err(e.into()),
256 };
257 return Ok(Transformed::yes(Arc::new(Column::new(
258 column.name(),
259 index,
260 ))));
261 }
262 Ok(Transformed::no(expr))
263 })
264 .data()
265}
266
267#[cfg(test)]
268pub(crate) mod tests {
269 use std::any::Any;
270 use std::fmt::{Display, Formatter};
271
272 use super::*;
273 use crate::expressions::{binary, cast, col, in_list, lit, Literal};
274
275 use arrow::array::{ArrayRef, Float32Array, Float64Array};
276 use arrow::datatypes::{DataType, Field, Schema};
277 use datafusion_common::{exec_err, DataFusionError, ScalarValue};
278 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
279 use datafusion_expr::{
280 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
281 };
282
283 use petgraph::visit::Bfs;
284
285 #[derive(Debug, Clone)]
286 pub struct TestScalarUDF {
287 pub(crate) signature: Signature,
288 }
289
290 impl TestScalarUDF {
291 pub fn new() -> Self {
292 use DataType::*;
293 Self {
294 signature: Signature::uniform(
295 1,
296 vec![Float64, Float32],
297 Volatility::Immutable,
298 ),
299 }
300 }
301 }
302
303 impl ScalarUDFImpl for TestScalarUDF {
304 fn as_any(&self) -> &dyn Any {
305 self
306 }
307 fn name(&self) -> &str {
308 "test-scalar-udf"
309 }
310
311 fn signature(&self) -> &Signature {
312 &self.signature
313 }
314
315 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
316 let arg_type = &arg_types[0];
317
318 match arg_type {
319 DataType::Float32 => Ok(DataType::Float32),
320 _ => Ok(DataType::Float64),
321 }
322 }
323
324 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
325 Ok(input[0].sort_properties)
326 }
327
328 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
329 let args = ColumnarValue::values_to_arrays(&args.args)?;
330
331 let arr: ArrayRef = match args[0].data_type() {
332 DataType::Float64 => Arc::new({
333 let arg = &args[0]
334 .as_any()
335 .downcast_ref::<Float64Array>()
336 .ok_or_else(|| {
337 DataFusionError::Internal(format!(
338 "could not cast {} to {}",
339 self.name(),
340 std::any::type_name::<Float64Array>()
341 ))
342 })?;
343
344 arg.iter()
345 .map(|a| a.map(f64::floor))
346 .collect::<Float64Array>()
347 }),
348 DataType::Float32 => Arc::new({
349 let arg = &args[0]
350 .as_any()
351 .downcast_ref::<Float32Array>()
352 .ok_or_else(|| {
353 DataFusionError::Internal(format!(
354 "could not cast {} to {}",
355 self.name(),
356 std::any::type_name::<Float32Array>()
357 ))
358 })?;
359
360 arg.iter()
361 .map(|a| a.map(f32::floor))
362 .collect::<Float32Array>()
363 }),
364 other => {
365 return exec_err!(
366 "Unsupported data type {other:?} for function {}",
367 self.name()
368 );
369 }
370 };
371 Ok(ColumnarValue::Array(arr))
372 }
373 }
374
375 #[derive(Clone)]
376 struct DummyProperty {
377 expr_type: String,
378 }
379
380 #[derive(Clone)]
383 struct PhysicalExprDummyNode {
384 pub expr: Arc<dyn PhysicalExpr>,
385 pub property: DummyProperty,
386 }
387
388 impl Display for PhysicalExprDummyNode {
389 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
390 write!(f, "{}", self.expr)
391 }
392 }
393
394 fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
395 let expr = Arc::clone(&node.expr);
396 let dummy_property = if expr.as_any().is::<BinaryExpr>() {
397 "Binary"
398 } else if expr.as_any().is::<Column>() {
399 "Column"
400 } else if expr.as_any().is::<Literal>() {
401 "Literal"
402 } else {
403 "Other"
404 }
405 .to_owned();
406 Ok(PhysicalExprDummyNode {
407 expr,
408 property: DummyProperty {
409 expr_type: dummy_property,
410 },
411 })
412 }
413
414 #[test]
415 fn test_build_dag() -> Result<()> {
416 let schema = Schema::new(vec![
417 Field::new("0", DataType::Int32, true),
418 Field::new("1", DataType::Int32, true),
419 Field::new("2", DataType::Int32, true),
420 ]);
421 let expr = binary(
422 cast(
423 binary(
424 col("0", &schema)?,
425 Operator::Plus,
426 col("1", &schema)?,
427 &schema,
428 )?,
429 &schema,
430 DataType::Int64,
431 )?,
432 Operator::Gt,
433 binary(
434 cast(col("2", &schema)?, &schema, DataType::Int64)?,
435 Operator::Plus,
436 lit(ScalarValue::Int64(Some(10))),
437 &schema,
438 )?,
439 &schema,
440 )?;
441 let mut vector_dummy_props = vec![];
442 let (root, graph) = build_dag(expr, &make_dummy_node)?;
443 let mut bfs = Bfs::new(&graph, root);
444 while let Some(node_index) = bfs.next(&graph) {
445 let node = &graph[node_index];
446 vector_dummy_props.push(node.property.clone());
447 }
448
449 assert_eq!(
450 vector_dummy_props
451 .iter()
452 .filter(|property| property.expr_type == "Binary")
453 .count(),
454 3
455 );
456 assert_eq!(
457 vector_dummy_props
458 .iter()
459 .filter(|property| property.expr_type == "Column")
460 .count(),
461 3
462 );
463 assert_eq!(
464 vector_dummy_props
465 .iter()
466 .filter(|property| property.expr_type == "Literal")
467 .count(),
468 1
469 );
470 assert_eq!(
471 vector_dummy_props
472 .iter()
473 .filter(|property| property.expr_type == "Other")
474 .count(),
475 2
476 );
477 Ok(())
478 }
479
480 #[test]
481 fn test_convert_to_expr() -> Result<()> {
482 let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
483 let sort_expr = vec![PhysicalSortExpr {
484 expr: col("a", &schema)?,
485 options: Default::default(),
486 }];
487 assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
488 Ok(())
489 }
490
491 #[test]
492 fn test_get_indices_of_exprs_strict() {
493 let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
494 Arc::new(Column::new("a", 0)),
495 Arc::new(Column::new("b", 1)),
496 Arc::new(Column::new("c", 2)),
497 Arc::new(Column::new("d", 3)),
498 ];
499 let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
500 Arc::new(Column::new("b", 1)),
501 Arc::new(Column::new("c", 2)),
502 Arc::new(Column::new("a", 0)),
503 ];
504 assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
505 assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
506 }
507
508 #[test]
509 fn test_reassign_predicate_columns_in_list() {
510 let int_field = Field::new("should_not_matter", DataType::Int64, true);
511 let dict_field = Field::new(
512 "id",
513 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
514 true,
515 );
516 let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
517 let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
518 let pred = in_list(
519 Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
520 vec![lit(ScalarValue::Dictionary(
521 Box::new(DataType::Int32),
522 Box::new(ScalarValue::from("2")),
523 ))],
524 &false,
525 &schema_big,
526 )
527 .unwrap();
528
529 let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
530
531 let expected = in_list(
532 Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
533 vec![lit(ScalarValue::Dictionary(
534 Box::new(DataType::Int32),
535 Box::new(ScalarValue::from("2")),
536 ))],
537 &false,
538 &schema_small,
539 )
540 .unwrap();
541
542 assert_eq!(actual.as_ref(), expected.as_ref());
543 }
544
545 #[test]
546 fn test_collect_columns() -> Result<()> {
547 let expr1 = Arc::new(Column::new("col1", 2)) as _;
548 let mut expected = HashSet::new();
549 expected.insert(Column::new("col1", 2));
550 assert_eq!(collect_columns(&expr1), expected);
551
552 let expr2 = Arc::new(Column::new("col2", 5)) as _;
553 let mut expected = HashSet::new();
554 expected.insert(Column::new("col2", 5));
555 assert_eq!(collect_columns(&expr2), expected);
556
557 let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
558 let mut expected = HashSet::new();
559 expected.insert(Column::new("col1", 2));
560 expected.insert(Column::new("col2", 5));
561 assert_eq!(collect_columns(&expr3), expected);
562 Ok(())
563 }
564}