Skip to main content

sedona_expr/
function_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use crate::{
18    aggregate_udf::{IntoSedonaAccumulatorRefs, SedonaAggregateUDF},
19    scalar_udf::{IntoScalarKernelRefs, SedonaScalarUDF},
20};
21use datafusion_common::error::Result;
22use datafusion_expr::{AggregateUDFImpl, ScalarUDFImpl};
23use std::collections::HashMap;
24
25/// Helper for managing groups of functions
26///
27/// Sedona coordinates the assembly of a large number of spatial functions with potentially
28/// different sets of dependencies (e.g., geography vs. geometry), multiple implementations,
29/// and/or implementations that live in different crates. This structure helps coordinate
30/// these implementations.
31pub struct FunctionSet {
32    scalar_udfs: HashMap<String, SedonaScalarUDF>,
33    aggregate_udfs: HashMap<String, SedonaAggregateUDF>,
34}
35
36impl FunctionSet {
37    /// Create a new, empty FunctionSet
38    pub fn new() -> Self {
39        Self {
40            scalar_udfs: HashMap::new(),
41            aggregate_udfs: HashMap::new(),
42        }
43    }
44
45    /// Iterate over references to all [SedonaScalarUDF]s
46    pub fn scalar_udfs(&self) -> impl Iterator<Item = &SedonaScalarUDF> + '_ {
47        self.scalar_udfs.values()
48    }
49
50    /// Return a reference to the function corresponding to the name
51    pub fn scalar_udf(&self, name: &str) -> Option<&SedonaScalarUDF> {
52        self.scalar_udfs.get(name)
53    }
54
55    /// Return a mutable reference to the function corresponding to the name
56    pub fn scalar_udf_mut(&mut self, name: &str) -> Option<&mut SedonaScalarUDF> {
57        self.scalar_udfs.get_mut(name)
58    }
59
60    /// Insert a new ScalarUDF and return the UDF that had previously been added, if any
61    pub fn insert_scalar_udf(&mut self, udf: SedonaScalarUDF) -> Option<SedonaScalarUDF> {
62        self.scalar_udfs.insert(udf.name().to_string(), udf)
63    }
64
65    /// Iterate over references to all [SedonaAggregateUDF]s
66    pub fn aggregate_udfs(&self) -> impl Iterator<Item = &SedonaAggregateUDF> + '_ {
67        self.aggregate_udfs.values()
68    }
69
70    /// Return a reference to the aggregate function corresponding to the name
71    pub fn aggregate_udf(&self, name: &str) -> Option<&SedonaAggregateUDF> {
72        self.aggregate_udfs.get(name)
73    }
74
75    /// Return a mutable reference to the aggregate function corresponding to the name
76    pub fn aggregate_udf_mut(&mut self, name: &str) -> Option<&mut SedonaAggregateUDF> {
77        self.aggregate_udfs.get_mut(name)
78    }
79
80    /// Insert a new AggregateUDF and return the UDF that had previously been added, if any
81    pub fn insert_aggregate_udf(&mut self, udf: SedonaAggregateUDF) -> Option<SedonaAggregateUDF> {
82        self.aggregate_udfs.insert(udf.name().to_string(), udf)
83    }
84
85    /// Consume another function set and merge its contents into this one
86    pub fn merge(&mut self, other: FunctionSet) {
87        for (k, v) in other.scalar_udfs.into_iter() {
88            self.scalar_udfs.insert(k, v);
89        }
90
91        for (k, v) in other.aggregate_udfs.into_iter() {
92            self.aggregate_udfs.insert(k, v);
93        }
94    }
95
96    /// Add a kernel to a function in this set
97    ///
98    /// This adds a scalar UDF with immutable output if a function of that name does not
99    /// exist in this set. A reference to the matching or inserted function is returned.
100    pub fn add_scalar_udf_impl(
101        &mut self,
102        name: &str,
103        kernels: impl IntoScalarKernelRefs,
104    ) -> Result<&SedonaScalarUDF> {
105        if let Some(function) = self.scalar_udf_mut(name) {
106            function.add_kernels(kernels);
107        } else {
108            let function = SedonaScalarUDF::from_impl(name, kernels);
109            self.insert_scalar_udf(function);
110        }
111
112        Ok(self.scalar_udf(name).unwrap())
113    }
114
115    /// Add an aggregate kernel to a function in this set
116    ///
117    /// This adds an aggregate UDF with immutable output if a function of that name does not
118    /// exist in this set. A reference to the matching or inserted function is returned.
119    pub fn add_aggregate_udf_kernel(
120        &mut self,
121        name: &str,
122        kernel: impl IntoSedonaAccumulatorRefs,
123    ) -> Result<&SedonaAggregateUDF> {
124        if let Some(function) = self.aggregate_udf_mut(name) {
125            function.add_kernel(kernel);
126        } else {
127            let function = SedonaAggregateUDF::from_impl(name, kernel);
128            self.insert_aggregate_udf(function);
129        }
130
131        Ok(self.aggregate_udf(name).unwrap())
132    }
133}
134
135impl Default for FunctionSet {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::{collections::HashSet, sync::Arc};
144
145    use arrow_schema::{DataType, FieldRef};
146    use datafusion_common::{not_impl_err, scalar::ScalarValue};
147
148    use datafusion_expr::{Accumulator, ColumnarValue, Volatility};
149    use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
150
151    use crate::{
152        aggregate_udf::{SedonaAccumulator, SedonaAccumulatorRef},
153        scalar_udf::SimpleSedonaScalarKernel,
154    };
155
156    use super::*;
157
158    #[test]
159    fn function_set() {
160        let mut functions = FunctionSet::new();
161        assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
162        assert!(functions.scalar_udf("simple_udf").is_none());
163        assert!(functions.scalar_udf_mut("simple_udf").is_none());
164
165        let kernel = SimpleSedonaScalarKernel::new_ref(
166            ArgMatcher::new(
167                vec![ArgMatcher::is_arrow(DataType::Boolean)],
168                SedonaType::Arrow(DataType::Boolean),
169            ),
170            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))),
171        );
172
173        let udf = SedonaScalarUDF::new("simple_udf", vec![kernel.clone()], Volatility::Immutable);
174
175        functions.insert_scalar_udf(udf);
176        assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 1);
177        assert!(functions.scalar_udf("simple_udf").is_some());
178        assert!(functions.scalar_udf_mut("simple_udf").is_some());
179        assert_eq!(
180            functions
181                .add_scalar_udf_impl("simple_udf", kernel.clone())
182                .unwrap()
183                .name(),
184            "simple_udf"
185        );
186        let inserted_udf = functions
187            .add_scalar_udf_impl("function that does not yet exist", kernel.clone())
188            .unwrap();
189        assert_eq!(inserted_udf.name(), "function that does not yet exist");
190
191        let kernel2 = SimpleSedonaScalarKernel::new_ref(
192            ArgMatcher::new(
193                vec![ArgMatcher::is_arrow(DataType::Utf8)],
194                SedonaType::Arrow(DataType::Utf8),
195            ),
196            Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
197        );
198
199        let udf2 = SedonaScalarUDF::new("simple_udf2", vec![kernel2], Volatility::Immutable);
200        let mut functions2 = FunctionSet::new();
201        functions2.insert_scalar_udf(udf2);
202        functions.merge(functions2);
203        assert_eq!(
204            functions
205                .scalar_udfs()
206                .map(|s| s.name())
207                .collect::<HashSet<_>>(),
208            vec![
209                "simple_udf",
210                "simple_udf2",
211                "function that does not yet exist"
212            ]
213            .into_iter()
214            .collect::<HashSet<_>>()
215        );
216    }
217
218    #[derive(Debug, Clone)]
219    struct TestAccumulator {}
220
221    impl SedonaAccumulator for TestAccumulator {
222        fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
223            not_impl_err!("")
224        }
225
226        fn accumulator(
227            &self,
228            _args: &[SedonaType],
229            _output_type: &SedonaType,
230        ) -> Result<Box<dyn Accumulator>> {
231            not_impl_err!("")
232        }
233
234        fn state_fields(&self, _args: &[SedonaType]) -> Result<Vec<FieldRef>> {
235            not_impl_err!("")
236        }
237    }
238
239    #[test]
240    fn function_set_with_aggregates() {
241        let mut functions = FunctionSet::new();
242        assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
243        assert!(functions.aggregate_udf("simple_udaf").is_none());
244        assert!(functions.aggregate_udf_mut("simple_udaf").is_none());
245
246        let udaf = SedonaAggregateUDF::new(
247            "simple_udaf",
248            Vec::<SedonaAccumulatorRef>::new(),
249            Volatility::Immutable,
250        );
251        let kernel = TestAccumulator {};
252
253        functions.insert_aggregate_udf(udaf);
254        assert_eq!(functions.aggregate_udfs().collect::<Vec<_>>().len(), 1);
255        assert!(functions.aggregate_udf("simple_udaf").is_some());
256        assert!(functions.aggregate_udf_mut("simple_udaf").is_some());
257        assert_eq!(
258            functions
259                .add_aggregate_udf_kernel("simple_udaf", kernel.clone())
260                .unwrap()
261                .name(),
262            "simple_udaf"
263        );
264        let added_func = functions
265            .add_aggregate_udf_kernel("function that does not exist yet", kernel.clone())
266            .unwrap();
267        assert_eq!(added_func.name(), "function that does not exist yet");
268
269        let udaf2 = SedonaAggregateUDF::new(
270            "simple_udaf2",
271            vec![Arc::new(kernel.clone())],
272            Volatility::Immutable,
273        );
274        let mut functions2 = FunctionSet::new();
275        functions2.insert_aggregate_udf(udaf2);
276        functions.merge(functions2);
277        assert_eq!(
278            functions
279                .aggregate_udfs()
280                .map(|s| s.name())
281                .collect::<HashSet<_>>(),
282            vec![
283                "simple_udaf",
284                "simple_udaf2",
285                "function that does not exist yet"
286            ]
287            .into_iter()
288            .collect::<HashSet<_>>()
289        );
290    }
291}