datafusion_comet_spark_expr/string_funcs/
prediction.rs1#![allow(deprecated)]
19
20use arrow::{
21 compute::{
22 contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn,
23 like_utf8_scalar_dyn, starts_with_dyn, starts_with_utf8_scalar_dyn,
24 },
25 record_batch::RecordBatch,
26};
27use arrow_schema::{DataType, Schema};
28use datafusion::logical_expr::ColumnarValue;
29use datafusion_common::{DataFusionError, ScalarValue::Utf8};
30use datafusion_physical_expr::PhysicalExpr;
31use std::{
32 any::Any,
33 fmt::{Display, Formatter},
34 hash::Hash,
35 sync::Arc,
36};
37
38macro_rules! make_predicate_function {
39 ($name: ident, $kernel: ident, $str_scalar_kernel: ident) => {
40 #[derive(Debug, Eq)]
41 pub struct $name {
42 left: Arc<dyn PhysicalExpr>,
43 right: Arc<dyn PhysicalExpr>,
44 }
45
46 impl $name {
47 pub fn new(left: Arc<dyn PhysicalExpr>, right: Arc<dyn PhysicalExpr>) -> Self {
48 Self { left, right }
49 }
50 }
51
52 impl Display for $name {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(f, "$name [left: {}, right: {}]", self.left, self.right)
55 }
56 }
57
58 impl Hash for $name {
59 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60 self.left.hash(state);
61 self.right.hash(state);
62 }
63 }
64
65 impl PartialEq for $name {
66 fn eq(&self, other: &Self) -> bool {
67 self.left.eq(&other.left) && self.right.eq(&other.right)
68 }
69 }
70
71 impl PhysicalExpr for $name {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn data_type(&self, _: &Schema) -> datafusion_common::Result<DataType> {
77 Ok(DataType::Boolean)
78 }
79
80 fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
81 Ok(true)
82 }
83
84 fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
85 let left_arg = self.left.evaluate(batch)?;
86 let right_arg = self.right.evaluate(batch)?;
87
88 let array = match (left_arg, right_arg) {
89 (ColumnarValue::Array(array), ColumnarValue::Scalar(Utf8(Some(string)))) => {
91 $str_scalar_kernel(&array, string.as_str())
92 }
93 (ColumnarValue::Array(_), ColumnarValue::Scalar(other)) => {
94 return Err(DataFusionError::Execution(format!(
95 "Should be String but got: {:?}",
96 other
97 )))
98 }
99 (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => {
101 $kernel(&array1, &array2)
102 }
103 _ => {
105 return Err(DataFusionError::Execution(
106 "Predicate on two literals should be folded at Spark".to_string(),
107 ))
108 }
109 }?;
110
111 Ok(ColumnarValue::Array(Arc::new(array)))
112 }
113
114 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
115 vec![&self.left, &self.right]
116 }
117
118 fn with_new_children(
119 self: Arc<Self>,
120 children: Vec<Arc<dyn PhysicalExpr>>,
121 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
122 Ok(Arc::new($name::new(
123 children[0].clone(),
124 children[1].clone(),
125 )))
126 }
127 }
128 };
129}
130
131make_predicate_function!(Like, like_dyn, like_utf8_scalar_dyn);
132
133make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dyn);
134
135make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn);
136
137make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn);