datafusion_expr/
udf_eq.rs1use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl};
19use std::fmt::Debug;
20use std::hash::{DefaultHasher, Hash, Hasher};
21use std::ops::Deref;
22use std::sync::Arc;
23
24#[derive(Clone)]
29#[allow(private_bounds)] pub struct UdfEq<Ptr: UdfPointer>(Ptr);
31
32impl<Ptr> PartialEq for UdfEq<Ptr>
33where
34 Ptr: UdfPointer,
35{
36 fn eq(&self, other: &Self) -> bool {
37 self.0.equals(&other.0)
38 }
39}
40impl<Ptr> Eq for UdfEq<Ptr> where Ptr: UdfPointer {}
41impl<Ptr> Hash for UdfEq<Ptr>
42where
43 Ptr: UdfPointer,
44{
45 fn hash<H: Hasher>(&self, state: &mut H) {
46 self.0.hash_value().hash(state);
47 }
48}
49
50impl<Ptr> From<Ptr> for UdfEq<Ptr>
51where
52 Ptr: UdfPointer,
53{
54 fn from(ptr: Ptr) -> Self {
55 UdfEq(ptr)
56 }
57}
58
59impl<Ptr> Debug for UdfEq<Ptr>
60where
61 Ptr: UdfPointer + Debug,
62{
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 self.0.fmt(f)
65 }
66}
67
68impl<Ptr> Deref for UdfEq<Ptr>
69where
70 Ptr: UdfPointer,
71{
72 type Target = Ptr;
73
74 fn deref(&self) -> &Self::Target {
75 &self.0
76 }
77}
78
79trait UdfPointer: Deref {
80 fn equals(&self, other: &Self::Target) -> bool;
81 fn hash_value(&self) -> u64;
82}
83
84impl UdfPointer for Arc<dyn ScalarUDFImpl + '_> {
85 fn equals(&self, other: &(dyn ScalarUDFImpl + '_)) -> bool {
86 self.as_ref().dyn_eq(other.as_any())
87 }
88
89 fn hash_value(&self) -> u64 {
90 let hasher = &mut DefaultHasher::new();
91 self.as_ref().dyn_hash(hasher);
92 hasher.finish()
93 }
94}
95
96impl UdfPointer for Arc<dyn AggregateUDFImpl + '_> {
97 fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool {
98 self.as_ref().dyn_eq(other.as_any())
99 }
100
101 fn hash_value(&self) -> u64 {
102 let hasher = &mut DefaultHasher::new();
103 self.as_ref().dyn_hash(hasher);
104 hasher.finish()
105 }
106}
107
108impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
109 fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
110 self.as_ref().dyn_eq(other.as_any())
111 }
112
113 fn hash_value(&self) -> u64 {
114 let hasher = &mut DefaultHasher::new();
115 self.as_ref().dyn_hash(hasher);
116 hasher.finish()
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::ScalarFunctionArgs;
124 use arrow::datatypes::DataType;
125 use datafusion_expr_common::columnar_value::ColumnarValue;
126 use datafusion_expr_common::signature::{Signature, Volatility};
127 use std::any::Any;
128 use std::hash::DefaultHasher;
129
130 #[derive(Debug, PartialEq, Eq, Hash)]
131 struct TestScalarUDF {
132 signature: Signature,
133 name: &'static str,
134 }
135 impl ScalarUDFImpl for TestScalarUDF {
136 fn as_any(&self) -> &dyn Any {
137 self
138 }
139
140 fn name(&self) -> &str {
141 self.name
142 }
143
144 fn signature(&self) -> &Signature {
145 &self.signature
146 }
147
148 fn return_type(
149 &self,
150 _arg_types: &[DataType],
151 ) -> datafusion_common::Result<DataType> {
152 unimplemented!()
153 }
154
155 fn invoke_with_args(
156 &self,
157 _args: ScalarFunctionArgs,
158 ) -> datafusion_common::Result<ColumnarValue> {
159 unimplemented!()
160 }
161 }
162
163 #[test]
164 pub fn test_eq_eq_wrapper() {
165 let signature = Signature::any(1, Volatility::Immutable);
166
167 let a1: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
168 signature: signature.clone(),
169 name: "a",
170 });
171 let a2: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
172 signature: signature.clone(),
173 name: "a",
174 });
175 let b: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
176 signature: signature.clone(),
177 name: "b",
178 });
179
180 let wrapper = UdfEq(Arc::clone(&a1));
182 assert_eq!(wrapper, wrapper);
183
184 assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1)));
186 assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1))));
187
188 assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2)));
190 assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2))));
191
192 assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b)));
194 }
195
196 fn hash<T: Hash>(value: T) -> u64 {
197 let hasher = &mut DefaultHasher::new();
198 value.hash(hasher);
199 hasher.finish()
200 }
201}