datafusion_comet_spark_expr/string_funcs/
prediction.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#![allow(deprecated)]
19
20use arrow::{
21    compute::{
22        contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn,
23        like_utf8_scalar_dyn, starts_with_dyn, starts_with_utf8_scalar_dyn,
24    },
25    record_batch::RecordBatch,
26};
27use arrow_schema::{DataType, Schema};
28use datafusion::logical_expr::ColumnarValue;
29use datafusion_common::{DataFusionError, ScalarValue::Utf8};
30use datafusion_physical_expr::PhysicalExpr;
31use std::{
32    any::Any,
33    fmt::{Display, Formatter},
34    hash::Hash,
35    sync::Arc,
36};
37
38macro_rules! make_predicate_function {
39    ($name: ident, $kernel: ident, $str_scalar_kernel: ident) => {
40        #[derive(Debug, Eq)]
41        pub struct $name {
42            left: Arc<dyn PhysicalExpr>,
43            right: Arc<dyn PhysicalExpr>,
44        }
45
46        impl $name {
47            pub fn new(left: Arc<dyn PhysicalExpr>, right: Arc<dyn PhysicalExpr>) -> Self {
48                Self { left, right }
49            }
50        }
51
52        impl Display for $name {
53            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54                write!(f, "$name [left: {}, right: {}]", self.left, self.right)
55            }
56        }
57
58        impl Hash for $name {
59            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60                self.left.hash(state);
61                self.right.hash(state);
62            }
63        }
64
65        impl PartialEq for $name {
66            fn eq(&self, other: &Self) -> bool {
67                self.left.eq(&other.left) && self.right.eq(&other.right)
68            }
69        }
70
71        impl PhysicalExpr for $name {
72            fn as_any(&self) -> &dyn Any {
73                self
74            }
75
76            fn data_type(&self, _: &Schema) -> datafusion_common::Result<DataType> {
77                Ok(DataType::Boolean)
78            }
79
80            fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
81                Ok(true)
82            }
83
84            fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
85                let left_arg = self.left.evaluate(batch)?;
86                let right_arg = self.right.evaluate(batch)?;
87
88                let array = match (left_arg, right_arg) {
89                    // array (op) scalar
90                    (ColumnarValue::Array(array), ColumnarValue::Scalar(Utf8(Some(string)))) => {
91                        $str_scalar_kernel(&array, string.as_str())
92                    }
93                    (ColumnarValue::Array(_), ColumnarValue::Scalar(other)) => {
94                        return Err(DataFusionError::Execution(format!(
95                            "Should be String but got: {:?}",
96                            other
97                        )))
98                    }
99                    // array (op) array
100                    (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => {
101                        $kernel(&array1, &array2)
102                    }
103                    // scalar (op) scalar should be folded at Spark optimizer
104                    _ => {
105                        return Err(DataFusionError::Execution(
106                            "Predicate on two literals should be folded at Spark".to_string(),
107                        ))
108                    }
109                }?;
110
111                Ok(ColumnarValue::Array(Arc::new(array)))
112            }
113
114            fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
115                vec![&self.left, &self.right]
116            }
117
118            fn with_new_children(
119                self: Arc<Self>,
120                children: Vec<Arc<dyn PhysicalExpr>>,
121            ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
122                Ok(Arc::new($name::new(
123                    children[0].clone(),
124                    children[1].clone(),
125                )))
126            }
127        }
128    };
129}
130
131make_predicate_function!(Like, like_dyn, like_utf8_scalar_dyn);
132
133make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dyn);
134
135make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn);
136
137make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn);