datafusion_expr/
udf_eq.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl};
19use std::fmt::Debug;
20use std::hash::{DefaultHasher, Hash, Hasher};
21use std::ops::Deref;
22use std::sync::Arc;
23
24/// A wrapper around a pointer to UDF that implements `Eq` and `Hash` delegating to
25/// corresponding methods on the UDF trait.
26///
27/// If you want to just compare pointers for equality, use [`super::ptr_eq::PtrEq`].
28#[derive(Clone)]
29#[allow(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse.
30pub struct UdfEq<Ptr: UdfPointer>(Ptr);
31
32impl<Ptr> PartialEq for UdfEq<Ptr>
33where
34    Ptr: UdfPointer,
35{
36    fn eq(&self, other: &Self) -> bool {
37        self.0.equals(&other.0)
38    }
39}
40impl<Ptr> Eq for UdfEq<Ptr> where Ptr: UdfPointer {}
41impl<Ptr> Hash for UdfEq<Ptr>
42where
43    Ptr: UdfPointer,
44{
45    fn hash<H: Hasher>(&self, state: &mut H) {
46        self.0.hash_value().hash(state);
47    }
48}
49
50impl<Ptr> From<Ptr> for UdfEq<Ptr>
51where
52    Ptr: UdfPointer,
53{
54    fn from(ptr: Ptr) -> Self {
55        UdfEq(ptr)
56    }
57}
58
59impl<Ptr> Debug for UdfEq<Ptr>
60where
61    Ptr: UdfPointer + Debug,
62{
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        self.0.fmt(f)
65    }
66}
67
68impl<Ptr> Deref for UdfEq<Ptr>
69where
70    Ptr: UdfPointer,
71{
72    type Target = Ptr;
73
74    fn deref(&self) -> &Self::Target {
75        &self.0
76    }
77}
78
79trait UdfPointer: Deref {
80    fn equals(&self, other: &Self::Target) -> bool;
81    fn hash_value(&self) -> u64;
82}
83
84impl UdfPointer for Arc<dyn ScalarUDFImpl + '_> {
85    fn equals(&self, other: &(dyn ScalarUDFImpl + '_)) -> bool {
86        self.as_ref().dyn_eq(other.as_any())
87    }
88
89    fn hash_value(&self) -> u64 {
90        let hasher = &mut DefaultHasher::new();
91        self.as_ref().dyn_hash(hasher);
92        hasher.finish()
93    }
94}
95
96impl UdfPointer for Arc<dyn AggregateUDFImpl + '_> {
97    fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool {
98        self.as_ref().dyn_eq(other.as_any())
99    }
100
101    fn hash_value(&self) -> u64 {
102        let hasher = &mut DefaultHasher::new();
103        self.as_ref().dyn_hash(hasher);
104        hasher.finish()
105    }
106}
107
108impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
109    fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
110        self.as_ref().dyn_eq(other.as_any())
111    }
112
113    fn hash_value(&self) -> u64 {
114        let hasher = &mut DefaultHasher::new();
115        self.as_ref().dyn_hash(hasher);
116        hasher.finish()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::ScalarFunctionArgs;
124    use arrow::datatypes::DataType;
125    use datafusion_expr_common::columnar_value::ColumnarValue;
126    use datafusion_expr_common::signature::{Signature, Volatility};
127    use std::any::Any;
128    use std::hash::DefaultHasher;
129
130    #[derive(Debug, PartialEq, Eq, Hash)]
131    struct TestScalarUDF {
132        signature: Signature,
133        name: &'static str,
134    }
135    impl ScalarUDFImpl for TestScalarUDF {
136        fn as_any(&self) -> &dyn Any {
137            self
138        }
139
140        fn name(&self) -> &str {
141            self.name
142        }
143
144        fn signature(&self) -> &Signature {
145            &self.signature
146        }
147
148        fn return_type(
149            &self,
150            _arg_types: &[DataType],
151        ) -> datafusion_common::Result<DataType> {
152            unimplemented!()
153        }
154
155        fn invoke_with_args(
156            &self,
157            _args: ScalarFunctionArgs,
158        ) -> datafusion_common::Result<ColumnarValue> {
159            unimplemented!()
160        }
161    }
162
163    #[test]
164    pub fn test_eq_eq_wrapper() {
165        let signature = Signature::any(1, Volatility::Immutable);
166
167        let a1: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
168            signature: signature.clone(),
169            name: "a",
170        });
171        let a2: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
172            signature: signature.clone(),
173            name: "a",
174        });
175        let b: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF {
176            signature: signature.clone(),
177            name: "b",
178        });
179
180        // Reflexivity
181        let wrapper = UdfEq(Arc::clone(&a1));
182        assert_eq!(wrapper, wrapper);
183
184        // Two wrappers around equal pointer
185        assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1)));
186        assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1))));
187
188        // Two wrappers around different pointers but equal in ScalarUDFImpl::equals sense
189        assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2)));
190        assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2))));
191
192        // different functions (not equal)
193        assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b)));
194    }
195
196    fn hash<T: Hash>(value: T) -> u64 {
197        let hasher = &mut DefaultHasher::new();
198        value.hash(hasher);
199        hasher.finish()
200    }
201}