datafusion_expr/
udf_eq.rs1use crate::{AggregateUDFImpl, HigherOrderUDFImpl, ScalarUDFImpl, WindowUDFImpl};
19use std::any::Any;
20use std::fmt::Debug;
21use std::hash::{DefaultHasher, Hash, Hasher};
22use std::ops::Deref;
23use std::sync::Arc;
24
25#[derive(Clone)]
30#[expect(private_bounds)] pub struct UdfEq<Ptr: UdfPointer>(Ptr);
32
33impl<Ptr> PartialEq for UdfEq<Ptr>
34where
35 Ptr: UdfPointer,
36{
37 fn eq(&self, other: &Self) -> bool {
38 self.0.equals(&other.0)
39 }
40}
41impl<Ptr> Eq for UdfEq<Ptr> where Ptr: UdfPointer {}
42impl<Ptr> Hash for UdfEq<Ptr>
43where
44 Ptr: UdfPointer,
45{
46 fn hash<H: Hasher>(&self, state: &mut H) {
47 self.0.hash_value().hash(state);
48 }
49}
50
51impl<Ptr> From<Ptr> for UdfEq<Ptr>
52where
53 Ptr: UdfPointer,
54{
55 fn from(ptr: Ptr) -> Self {
56 UdfEq(ptr)
57 }
58}
59
60impl<Ptr> Debug for UdfEq<Ptr>
61where
62 Ptr: UdfPointer + Debug,
63{
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 self.0.fmt(f)
66 }
67}
68
69impl<Ptr> Deref for UdfEq<Ptr>
70where
71 Ptr: UdfPointer,
72{
73 type Target = Ptr;
74
75 fn deref(&self) -> &Self::Target {
76 &self.0
77 }
78}
79
80trait UdfPointer: Deref {
81 fn equals(&self, other: &Self::Target) -> bool;
82 fn hash_value(&self) -> u64;
83}
84
85impl UdfPointer for Arc<dyn ScalarUDFImpl + '_> {
86 fn equals(&self, other: &(dyn ScalarUDFImpl + '_)) -> bool {
87 self.as_ref().dyn_eq(other as &dyn Any)
88 }
89
90 fn hash_value(&self) -> u64 {
91 let hasher = &mut DefaultHasher::new();
92 self.as_ref().dyn_hash(hasher);
93 hasher.finish()
94 }
95}
96
97impl UdfPointer for Arc<dyn HigherOrderUDFImpl + '_> {
98 fn equals(&self, other: &Self::Target) -> bool {
99 self.as_ref().dyn_eq(other)
100 }
101
102 fn hash_value(&self) -> u64 {
103 let hasher = &mut DefaultHasher::new();
104 self.as_ref().dyn_hash(hasher);
105 hasher.finish()
106 }
107}
108
109impl UdfPointer for Arc<dyn AggregateUDFImpl + '_> {
110 fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool {
111 self.as_ref().dyn_eq(other)
112 }
113
114 fn hash_value(&self) -> u64 {
115 let hasher = &mut DefaultHasher::new();
116 self.as_ref().dyn_hash(hasher);
117 hasher.finish()
118 }
119}
120
121impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
122 fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
123 self.as_ref().dyn_eq(other as &dyn Any)
124 }
125
126 fn hash_value(&self) -> u64 {
127 let hasher = &mut DefaultHasher::new();
128 self.as_ref().dyn_hash(hasher);
129 hasher.finish()
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::ScalarFunctionArgs;
137 use arrow::datatypes::DataType;
138 use datafusion_expr_common::columnar_value::ColumnarValue;
139 use datafusion_expr_common::signature::{Signature, Volatility};
140 use std::hash::DefaultHasher;
141
142 #[derive(Debug, PartialEq, Eq, Hash)]
143 struct TestScalarUDF {
144 signature: Signature,
145 name: &'static str,
146 }
147 impl ScalarUDFImpl for TestScalarUDF {
148 fn name(&self) -> &str {
149 self.name
150 }
151
152 fn signature(&self) -> &Signature {
153 &self.signature
154 }
155
156 fn return_type(
157 &self,
158 _arg_types: &[DataType],
159 ) -> datafusion_common::Result<DataType> {
160 unimplemented!()
161 }
162
163 fn invoke_with_args(
164 &self,
165 _args: ScalarFunctionArgs,
166 ) -> datafusion_common::Result<ColumnarValue> {
167 unimplemented!()
168 }
169 }
170
171 #[test]
172 pub fn test_eq_eq_wrapper() {
173 let signature = Signature::any(1, Volatility::Immutable);
174
175 let a1: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
176 signature: signature.clone(),
177 name: "a",
178 });
179 let a2: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
180 signature: signature.clone(),
181 name: "a",
182 });
183 let b: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
184 signature: signature.clone(),
185 name: "b",
186 });
187
188 let wrapper = UdfEq(Arc::clone(&a1));
190 assert_eq!(wrapper, wrapper);
191
192 assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1)));
194 assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1))));
195
196 assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2)));
198 assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2))));
199
200 assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b)));
202 }
203
204 fn hash<T: Hash>(value: T) -> u64 {
205 let hasher = &mut DefaultHasher::new();
206 value.hash(hasher);
207 hasher.finish()
208 }
209}