1mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::Schema;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use petgraph::graph::NodeIndex;
37use petgraph::stable_graph::StableGraph;
38
39pub fn split_conjunction(
43 predicate: &Arc<dyn PhysicalExpr>,
44) -> Vec<&Arc<dyn PhysicalExpr>> {
45 split_impl(Operator::And, predicate, vec![])
46}
47
48pub fn conjunction(
53 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
54) -> Arc<dyn PhysicalExpr> {
55 conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
56}
57
58pub fn conjunction_opt(
63 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
64) -> Option<Arc<dyn PhysicalExpr>> {
65 predicates
66 .into_iter()
67 .fold(None, |acc, predicate| match acc {
68 None => Some(predicate),
69 Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
70 })
71}
72
73pub fn split_disjunction(
77 predicate: &Arc<dyn PhysicalExpr>,
78) -> Vec<&Arc<dyn PhysicalExpr>> {
79 split_impl(Operator::Or, predicate, vec![])
80}
81
82fn split_impl<'a>(
83 operator: Operator,
84 predicate: &'a Arc<dyn PhysicalExpr>,
85 mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
86) -> Vec<&'a Arc<dyn PhysicalExpr>> {
87 match predicate.as_any().downcast_ref::<BinaryExpr>() {
88 Some(binary) if binary.op() == &operator => {
89 let exprs = split_impl(operator, binary.left(), exprs);
90 split_impl(operator, binary.right(), exprs)
91 }
92 Some(_) | None => {
93 exprs.push(predicate);
94 exprs
95 }
96 }
97}
98
99pub fn map_columns_before_projection(
108 parent_required: &[Arc<dyn PhysicalExpr>],
109 proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
110) -> Vec<Arc<dyn PhysicalExpr>> {
111 if parent_required.is_empty() {
112 return vec![];
114 }
115 let column_mapping = proj_exprs
116 .iter()
117 .filter_map(|(expr, name)| {
118 expr.as_any()
119 .downcast_ref::<Column>()
120 .map(|column| (name.clone(), column.clone()))
121 })
122 .collect::<HashMap<_, _>>();
123 parent_required
124 .iter()
125 .filter_map(|r| {
126 r.as_any()
127 .downcast_ref::<Column>()
128 .and_then(|c| column_mapping.get(c.name()))
129 })
130 .map(|e| Arc::new(e.clone()) as _)
131 .collect()
132}
133
134pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
137 sequence: impl IntoIterator<Item = T>,
138) -> Vec<Arc<dyn PhysicalExpr>> {
139 sequence
140 .into_iter()
141 .map(|elem| Arc::clone(&elem.borrow().expr))
142 .collect()
143}
144
145pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
148 targets: impl IntoIterator<Item = T>,
149 items: &[Arc<dyn PhysicalExpr>],
150) -> Vec<usize> {
151 targets
152 .into_iter()
153 .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
154 .collect()
155}
156
157pub type ExprTreeNode<T> = ExprContext<Option<T>>;
158
159struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
164 graph: StableGraph<T, usize>,
166 visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
168 constructor: &'a F,
170}
171
172impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
173 fn mutate(
176 &mut self,
177 mut node: ExprTreeNode<NodeIndex>,
178 ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
179 let expr = &node.expr;
181
182 let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
184 Some((_, idx)) => *idx,
186 None => {
190 let node_idx = self.graph.add_node((self.constructor)(&node)?);
191 for expr_node in node.children.iter() {
192 self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
193 }
194 self.visited_plans.push((Arc::clone(expr), node_idx));
195 node_idx
196 }
197 };
198 node.data = Some(node_idx);
200 Ok(Transformed::yes(node))
202 }
203}
204
205pub fn build_dag<T, F>(
207 expr: Arc<dyn PhysicalExpr>,
208 constructor: &F,
209) -> Result<(NodeIndex, StableGraph<T, usize>)>
210where
211 F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
212{
213 let init = ExprTreeNode::new_default(expr);
215 let mut builder = PhysicalExprDAEGBuilder {
217 graph: StableGraph::<T, usize>::new(),
218 visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
219 constructor,
220 };
221 let root = init.transform_up(|node| builder.mutate(node)).data()?;
223 Ok((root.data.unwrap(), builder.graph))
225}
226
227pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
229 let mut columns = HashSet::<Column>::new();
230 expr.apply(|expr| {
231 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
232 columns.get_or_insert_owned(column);
233 }
234 Ok(TreeNodeRecursion::Continue)
235 })
236 .expect("no way to return error during recursion");
238 columns
239}
240
241pub fn reassign_expr_columns(
251 expr: Arc<dyn PhysicalExpr>,
252 schema: &Schema,
253) -> Result<Arc<dyn PhysicalExpr>> {
254 expr.transform_down(|expr| {
255 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
256 let index = schema.index_of(column.name())?;
257
258 return Ok(Transformed::yes(Arc::new(Column::new(
259 column.name(),
260 index,
261 ))));
262 }
263 Ok(Transformed::no(expr))
264 })
265 .data()
266}
267
268#[cfg(test)]
269pub(crate) mod tests {
270 use std::any::Any;
271 use std::fmt::{Display, Formatter};
272
273 use super::*;
274 use crate::expressions::{binary, cast, col, in_list, lit, Literal};
275
276 use arrow::array::{ArrayRef, Float32Array, Float64Array};
277 use arrow::datatypes::{DataType, Field, Schema};
278 use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue};
279 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
280 use datafusion_expr::{
281 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
282 };
283
284 use petgraph::visit::Bfs;
285
286 #[derive(Debug, PartialEq, Eq, Hash)]
287 pub struct TestScalarUDF {
288 pub(crate) signature: Signature,
289 }
290
291 impl TestScalarUDF {
292 pub fn new() -> Self {
293 use DataType::*;
294 Self {
295 signature: Signature::uniform(
296 1,
297 vec![Float64, Float32],
298 Volatility::Immutable,
299 ),
300 }
301 }
302 }
303
304 impl ScalarUDFImpl for TestScalarUDF {
305 fn as_any(&self) -> &dyn Any {
306 self
307 }
308 fn name(&self) -> &str {
309 "test-scalar-udf"
310 }
311
312 fn signature(&self) -> &Signature {
313 &self.signature
314 }
315
316 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
317 let arg_type = &arg_types[0];
318
319 match arg_type {
320 DataType::Float32 => Ok(DataType::Float32),
321 _ => Ok(DataType::Float64),
322 }
323 }
324
325 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
326 Ok(input[0].sort_properties)
327 }
328
329 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
330 let args = ColumnarValue::values_to_arrays(&args.args)?;
331
332 let arr: ArrayRef = match args[0].data_type() {
333 DataType::Float64 => Arc::new({
334 let arg = &args[0]
335 .as_any()
336 .downcast_ref::<Float64Array>()
337 .ok_or_else(|| {
338 internal_datafusion_err!(
339 "could not cast {} to {}",
340 self.name(),
341 std::any::type_name::<Float64Array>()
342 )
343 })?;
344
345 arg.iter()
346 .map(|a| a.map(f64::floor))
347 .collect::<Float64Array>()
348 }),
349 DataType::Float32 => Arc::new({
350 let arg = &args[0]
351 .as_any()
352 .downcast_ref::<Float32Array>()
353 .ok_or_else(|| {
354 internal_datafusion_err!(
355 "could not cast {} to {}",
356 self.name(),
357 std::any::type_name::<Float32Array>()
358 )
359 })?;
360
361 arg.iter()
362 .map(|a| a.map(f32::floor))
363 .collect::<Float32Array>()
364 }),
365 other => {
366 return exec_err!(
367 "Unsupported data type {other:?} for function {}",
368 self.name()
369 );
370 }
371 };
372 Ok(ColumnarValue::Array(arr))
373 }
374 }
375
376 #[derive(Clone)]
377 struct DummyProperty {
378 expr_type: String,
379 }
380
381 #[derive(Clone)]
384 struct PhysicalExprDummyNode {
385 pub expr: Arc<dyn PhysicalExpr>,
386 pub property: DummyProperty,
387 }
388
389 impl Display for PhysicalExprDummyNode {
390 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
391 write!(f, "{}", self.expr)
392 }
393 }
394
395 fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
396 let expr = Arc::clone(&node.expr);
397 let dummy_property = if expr.as_any().is::<BinaryExpr>() {
398 "Binary"
399 } else if expr.as_any().is::<Column>() {
400 "Column"
401 } else if expr.as_any().is::<Literal>() {
402 "Literal"
403 } else {
404 "Other"
405 }
406 .to_owned();
407 Ok(PhysicalExprDummyNode {
408 expr,
409 property: DummyProperty {
410 expr_type: dummy_property,
411 },
412 })
413 }
414
415 #[test]
416 fn test_build_dag() -> Result<()> {
417 let schema = Schema::new(vec![
418 Field::new("0", DataType::Int32, true),
419 Field::new("1", DataType::Int32, true),
420 Field::new("2", DataType::Int32, true),
421 ]);
422 let expr = binary(
423 cast(
424 binary(
425 col("0", &schema)?,
426 Operator::Plus,
427 col("1", &schema)?,
428 &schema,
429 )?,
430 &schema,
431 DataType::Int64,
432 )?,
433 Operator::Gt,
434 binary(
435 cast(col("2", &schema)?, &schema, DataType::Int64)?,
436 Operator::Plus,
437 lit(ScalarValue::Int64(Some(10))),
438 &schema,
439 )?,
440 &schema,
441 )?;
442 let mut vector_dummy_props = vec![];
443 let (root, graph) = build_dag(expr, &make_dummy_node)?;
444 let mut bfs = Bfs::new(&graph, root);
445 while let Some(node_index) = bfs.next(&graph) {
446 let node = &graph[node_index];
447 vector_dummy_props.push(node.property.clone());
448 }
449
450 assert_eq!(
451 vector_dummy_props
452 .iter()
453 .filter(|property| property.expr_type == "Binary")
454 .count(),
455 3
456 );
457 assert_eq!(
458 vector_dummy_props
459 .iter()
460 .filter(|property| property.expr_type == "Column")
461 .count(),
462 3
463 );
464 assert_eq!(
465 vector_dummy_props
466 .iter()
467 .filter(|property| property.expr_type == "Literal")
468 .count(),
469 1
470 );
471 assert_eq!(
472 vector_dummy_props
473 .iter()
474 .filter(|property| property.expr_type == "Other")
475 .count(),
476 2
477 );
478 Ok(())
479 }
480
481 #[test]
482 fn test_convert_to_expr() -> Result<()> {
483 let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
484 let sort_expr = vec![PhysicalSortExpr {
485 expr: col("a", &schema)?,
486 options: Default::default(),
487 }];
488 assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
489 Ok(())
490 }
491
492 #[test]
493 fn test_get_indices_of_exprs_strict() {
494 let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
495 Arc::new(Column::new("a", 0)),
496 Arc::new(Column::new("b", 1)),
497 Arc::new(Column::new("c", 2)),
498 Arc::new(Column::new("d", 3)),
499 ];
500 let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
501 Arc::new(Column::new("b", 1)),
502 Arc::new(Column::new("c", 2)),
503 Arc::new(Column::new("a", 0)),
504 ];
505 assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
506 assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
507 }
508
509 #[test]
510 fn test_reassign_expr_columns_in_list() {
511 let int_field = Field::new("should_not_matter", DataType::Int64, true);
512 let dict_field = Field::new(
513 "id",
514 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
515 true,
516 );
517 let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
518 let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
519 let pred = in_list(
520 Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
521 vec![lit(ScalarValue::Dictionary(
522 Box::new(DataType::Int32),
523 Box::new(ScalarValue::from("2")),
524 ))],
525 &false,
526 &schema_big,
527 )
528 .unwrap();
529
530 let actual = reassign_expr_columns(pred, &schema_small).unwrap();
531
532 let expected = in_list(
533 Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
534 vec![lit(ScalarValue::Dictionary(
535 Box::new(DataType::Int32),
536 Box::new(ScalarValue::from("2")),
537 ))],
538 &false,
539 &schema_small,
540 )
541 .unwrap();
542
543 assert_eq!(actual.as_ref(), expected.as_ref());
544 }
545
546 #[test]
547 fn test_collect_columns() -> Result<()> {
548 let expr1 = Arc::new(Column::new("col1", 2)) as _;
549 let mut expected = HashSet::new();
550 expected.insert(Column::new("col1", 2));
551 assert_eq!(collect_columns(&expr1), expected);
552
553 let expr2 = Arc::new(Column::new("col2", 5)) as _;
554 let mut expected = HashSet::new();
555 expected.insert(Column::new("col2", 5));
556 assert_eq!(collect_columns(&expr2), expected);
557
558 let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
559 let mut expected = HashSet::new();
560 expected.insert(Column::new("col1", 2));
561 expected.insert(Column::new("col2", 5));
562 assert_eq!(collect_columns(&expr3), expected);
563 Ok(())
564 }
565}