datafusion_comet_spark_expr/string_funcs/
prediction.rs1#![allow(deprecated)]
19
20use arrow::datatypes::{DataType, Schema};
21use arrow::{
22 compute::{
23 contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn,
24 like_utf8_scalar_dyn, starts_with_dyn, starts_with_utf8_scalar_dyn,
25 },
26 record_batch::RecordBatch,
27};
28use datafusion::common::{DataFusionError, ScalarValue::Utf8};
29use datafusion::logical_expr::ColumnarValue;
30use datafusion::physical_expr::PhysicalExpr;
31use std::{
32 any::Any,
33 fmt::{Display, Formatter},
34 hash::Hash,
35 sync::Arc,
36};
37
38macro_rules! make_predicate_function {
39 ($name: ident, $kernel: ident, $str_scalar_kernel: ident) => {
40 #[derive(Debug, Eq)]
41 pub struct $name {
42 left: Arc<dyn PhysicalExpr>,
43 right: Arc<dyn PhysicalExpr>,
44 }
45
46 impl $name {
47 pub fn new(left: Arc<dyn PhysicalExpr>, right: Arc<dyn PhysicalExpr>) -> Self {
48 Self { left, right }
49 }
50 }
51
52 impl Display for $name {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(f, "$name [left: {}, right: {}]", self.left, self.right)
55 }
56 }
57
58 impl Hash for $name {
59 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60 self.left.hash(state);
61 self.right.hash(state);
62 }
63 }
64
65 impl PartialEq for $name {
66 fn eq(&self, other: &Self) -> bool {
67 self.left.eq(&other.left) && self.right.eq(&other.right)
68 }
69 }
70
71 impl PhysicalExpr for $name {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
77 unimplemented!()
78 }
79
80 fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
81 Ok(DataType::Boolean)
82 }
83
84 fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
85 Ok(true)
86 }
87
88 fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
89 let left_arg = self.left.evaluate(batch)?;
90 let right_arg = self.right.evaluate(batch)?;
91
92 let array = match (left_arg, right_arg) {
93 (ColumnarValue::Array(array), ColumnarValue::Scalar(Utf8(Some(string)))) => {
95 $str_scalar_kernel(&array, string.as_str())
96 }
97 (ColumnarValue::Array(_), ColumnarValue::Scalar(other)) => {
98 return Err(DataFusionError::Execution(format!(
99 "Should be String but got: {:?}",
100 other
101 )))
102 }
103 (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => {
105 $kernel(&array1, &array2)
106 }
107 _ => {
109 return Err(DataFusionError::Execution(
110 "Predicate on two literals should be folded at Spark".to_string(),
111 ))
112 }
113 }?;
114
115 Ok(ColumnarValue::Array(Arc::new(array)))
116 }
117
118 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119 vec![&self.left, &self.right]
120 }
121
122 fn with_new_children(
123 self: Arc<Self>,
124 children: Vec<Arc<dyn PhysicalExpr>>,
125 ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
126 Ok(Arc::new($name::new(
127 children[0].clone(),
128 children[1].clone(),
129 )))
130 }
131 }
132 };
133}
134
135make_predicate_function!(Like, like_dyn, like_utf8_scalar_dyn);
136
137make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dyn);
138
139make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn);
140
141make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn);