Skip to main content

datafusion_expr/
udf_eq.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{AggregateUDFImpl, HigherOrderUDFImpl, ScalarUDFImpl, WindowUDFImpl};
19use std::any::Any;
20use std::fmt::Debug;
21use std::hash::{DefaultHasher, Hash, Hasher};
22use std::ops::Deref;
23use std::sync::Arc;
24
25/// A wrapper around a pointer to UDF that implements `Eq` and `Hash` delegating to
26/// corresponding methods on the UDF trait.
27///
28/// If you want to just compare pointers for equality, use [`super::ptr_eq::PtrEq`].
29#[derive(Clone)]
30#[expect(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse.
31pub struct UdfEq<Ptr: UdfPointer>(Ptr);
32
33impl<Ptr> PartialEq for UdfEq<Ptr>
34where
35    Ptr: UdfPointer,
36{
37    fn eq(&self, other: &Self) -> bool {
38        self.0.equals(&other.0)
39    }
40}
41impl<Ptr> Eq for UdfEq<Ptr> where Ptr: UdfPointer {}
42impl<Ptr> Hash for UdfEq<Ptr>
43where
44    Ptr: UdfPointer,
45{
46    fn hash<H: Hasher>(&self, state: &mut H) {
47        self.0.hash_value().hash(state);
48    }
49}
50
51impl<Ptr> From<Ptr> for UdfEq<Ptr>
52where
53    Ptr: UdfPointer,
54{
55    fn from(ptr: Ptr) -> Self {
56        UdfEq(ptr)
57    }
58}
59
60impl<Ptr> Debug for UdfEq<Ptr>
61where
62    Ptr: UdfPointer + Debug,
63{
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        self.0.fmt(f)
66    }
67}
68
69impl<Ptr> Deref for UdfEq<Ptr>
70where
71    Ptr: UdfPointer,
72{
73    type Target = Ptr;
74
75    fn deref(&self) -> &Self::Target {
76        &self.0
77    }
78}
79
80trait UdfPointer: Deref {
81    fn equals(&self, other: &Self::Target) -> bool;
82    fn hash_value(&self) -> u64;
83}
84
85impl UdfPointer for Arc<dyn ScalarUDFImpl + '_> {
86    fn equals(&self, other: &(dyn ScalarUDFImpl + '_)) -> bool {
87        self.as_ref().dyn_eq(other as &dyn Any)
88    }
89
90    fn hash_value(&self) -> u64 {
91        let hasher = &mut DefaultHasher::new();
92        self.as_ref().dyn_hash(hasher);
93        hasher.finish()
94    }
95}
96
97impl UdfPointer for Arc<dyn HigherOrderUDFImpl + '_> {
98    fn equals(&self, other: &Self::Target) -> bool {
99        self.as_ref().dyn_eq(other)
100    }
101
102    fn hash_value(&self) -> u64 {
103        let hasher = &mut DefaultHasher::new();
104        self.as_ref().dyn_hash(hasher);
105        hasher.finish()
106    }
107}
108
109impl UdfPointer for Arc<dyn AggregateUDFImpl + '_> {
110    fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool {
111        self.as_ref().dyn_eq(other)
112    }
113
114    fn hash_value(&self) -> u64 {
115        let hasher = &mut DefaultHasher::new();
116        self.as_ref().dyn_hash(hasher);
117        hasher.finish()
118    }
119}
120
121impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
122    fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
123        self.as_ref().dyn_eq(other as &dyn Any)
124    }
125
126    fn hash_value(&self) -> u64 {
127        let hasher = &mut DefaultHasher::new();
128        self.as_ref().dyn_hash(hasher);
129        hasher.finish()
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::ScalarFunctionArgs;
137    use arrow::datatypes::DataType;
138    use datafusion_expr_common::columnar_value::ColumnarValue;
139    use datafusion_expr_common::signature::{Signature, Volatility};
140    use std::hash::DefaultHasher;
141
142    #[derive(Debug, PartialEq, Eq, Hash)]
143    struct TestScalarUDF {
144        signature: Signature,
145        name: &'static str,
146    }
147    impl ScalarUDFImpl for TestScalarUDF {
148        fn name(&self) -> &str {
149            self.name
150        }
151
152        fn signature(&self) -> &Signature {
153            &self.signature
154        }
155
156        fn return_type(
157            &self,
158            _arg_types: &[DataType],
159        ) -> datafusion_common::Result<DataType> {
160            unimplemented!()
161        }
162
163        fn invoke_with_args(
164            &self,
165            _args: ScalarFunctionArgs,
166        ) -> datafusion_common::Result<ColumnarValue> {
167            unimplemented!()
168        }
169    }
170
171    #[test]
172    pub fn test_eq_eq_wrapper() {
173        let signature = Signature::any(1, Volatility::Immutable);
174
175        let a1: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
176            signature: signature.clone(),
177            name: "a",
178        });
179        let a2: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
180            signature: signature.clone(),
181            name: "a",
182        });
183        let b: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
184            signature: signature.clone(),
185            name: "b",
186        });
187
188        // Reflexivity
189        let wrapper = UdfEq(Arc::clone(&a1));
190        assert_eq!(wrapper, wrapper);
191
192        // Two wrappers around equal pointer
193        assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1)));
194        assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1))));
195
196        // Two wrappers around different pointers but equal in ScalarUDFImpl::equals sense
197        assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2)));
198        assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2))));
199
200        // different functions (not equal)
201        assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b)));
202    }
203
204    fn hash<T: Hash>(value: T) -> u64 {
205        let hasher = &mut DefaultHasher::new();
206        value.hash(hasher);
207        hasher.finish()
208    }
209}