datafusion_comet_spark_expr/string_funcs/
prediction.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#![allow(deprecated)]
19
20use arrow::datatypes::{DataType, Schema};
21use arrow::{
22    compute::{
23        contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn,
24        like_utf8_scalar_dyn, starts_with_dyn, starts_with_utf8_scalar_dyn,
25    },
26    record_batch::RecordBatch,
27};
28use datafusion::common::{DataFusionError, ScalarValue::Utf8};
29use datafusion::logical_expr::ColumnarValue;
30use datafusion::physical_expr::PhysicalExpr;
31use std::{
32    any::Any,
33    fmt::{Display, Formatter},
34    hash::Hash,
35    sync::Arc,
36};
37
38macro_rules! make_predicate_function {
39    ($name: ident, $kernel: ident, $str_scalar_kernel: ident) => {
40        #[derive(Debug, Eq)]
41        pub struct $name {
42            left: Arc<dyn PhysicalExpr>,
43            right: Arc<dyn PhysicalExpr>,
44        }
45
46        impl $name {
47            pub fn new(left: Arc<dyn PhysicalExpr>, right: Arc<dyn PhysicalExpr>) -> Self {
48                Self { left, right }
49            }
50        }
51
52        impl Display for $name {
53            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54                write!(f, "$name [left: {}, right: {}]", self.left, self.right)
55            }
56        }
57
58        impl Hash for $name {
59            fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60                self.left.hash(state);
61                self.right.hash(state);
62            }
63        }
64
65        impl PartialEq for $name {
66            fn eq(&self, other: &Self) -> bool {
67                self.left.eq(&other.left) && self.right.eq(&other.right)
68            }
69        }
70
71        impl PhysicalExpr for $name {
72            fn as_any(&self) -> &dyn Any {
73                self
74            }
75
76            fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
77                unimplemented!()
78            }
79
80            fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
81                Ok(DataType::Boolean)
82            }
83
84            fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
85                Ok(true)
86            }
87
88            fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
89                let left_arg = self.left.evaluate(batch)?;
90                let right_arg = self.right.evaluate(batch)?;
91
92                let array = match (left_arg, right_arg) {
93                    // array (op) scalar
94                    (ColumnarValue::Array(array), ColumnarValue::Scalar(Utf8(Some(string)))) => {
95                        $str_scalar_kernel(&array, string.as_str())
96                    }
97                    (ColumnarValue::Array(_), ColumnarValue::Scalar(other)) => {
98                        return Err(DataFusionError::Execution(format!(
99                            "Should be String but got: {:?}",
100                            other
101                        )))
102                    }
103                    // array (op) array
104                    (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => {
105                        $kernel(&array1, &array2)
106                    }
107                    // scalar (op) scalar should be folded at Spark optimizer
108                    _ => {
109                        return Err(DataFusionError::Execution(
110                            "Predicate on two literals should be folded at Spark".to_string(),
111                        ))
112                    }
113                }?;
114
115                Ok(ColumnarValue::Array(Arc::new(array)))
116            }
117
118            fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119                vec![&self.left, &self.right]
120            }
121
122            fn with_new_children(
123                self: Arc<Self>,
124                children: Vec<Arc<dyn PhysicalExpr>>,
125            ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
126                Ok(Arc::new($name::new(
127                    children[0].clone(),
128                    children[1].clone(),
129                )))
130            }
131        }
132    };
133}
134
135make_predicate_function!(Like, like_dyn, like_utf8_scalar_dyn);
136
137make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dyn);
138
139make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn);
140
141make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn);