sedona_expr/
function_set.rs1use crate::{
18 aggregate_udf::{IntoSedonaAccumulatorRefs, SedonaAggregateUDF},
19 scalar_udf::{IntoScalarKernelRefs, SedonaScalarUDF},
20};
21use datafusion_common::error::Result;
22use datafusion_expr::{AggregateUDFImpl, ScalarUDFImpl};
23use std::collections::HashMap;
24
25pub struct FunctionSet {
32 scalar_udfs: HashMap<String, SedonaScalarUDF>,
33 aggregate_udfs: HashMap<String, SedonaAggregateUDF>,
34}
35
36impl FunctionSet {
37 pub fn new() -> Self {
39 Self {
40 scalar_udfs: HashMap::new(),
41 aggregate_udfs: HashMap::new(),
42 }
43 }
44
45 pub fn scalar_udfs(&self) -> impl Iterator<Item = &SedonaScalarUDF> + '_ {
47 self.scalar_udfs.values()
48 }
49
50 pub fn scalar_udf(&self, name: &str) -> Option<&SedonaScalarUDF> {
52 self.scalar_udfs.get(name)
53 }
54
55 pub fn scalar_udf_mut(&mut self, name: &str) -> Option<&mut SedonaScalarUDF> {
57 self.scalar_udfs.get_mut(name)
58 }
59
60 pub fn insert_scalar_udf(&mut self, udf: SedonaScalarUDF) -> Option<SedonaScalarUDF> {
62 self.scalar_udfs.insert(udf.name().to_string(), udf)
63 }
64
65 pub fn aggregate_udfs(&self) -> impl Iterator<Item = &SedonaAggregateUDF> + '_ {
67 self.aggregate_udfs.values()
68 }
69
70 pub fn aggregate_udf(&self, name: &str) -> Option<&SedonaAggregateUDF> {
72 self.aggregate_udfs.get(name)
73 }
74
75 pub fn aggregate_udf_mut(&mut self, name: &str) -> Option<&mut SedonaAggregateUDF> {
77 self.aggregate_udfs.get_mut(name)
78 }
79
80 pub fn insert_aggregate_udf(&mut self, udf: SedonaAggregateUDF) -> Option<SedonaAggregateUDF> {
82 self.aggregate_udfs.insert(udf.name().to_string(), udf)
83 }
84
85 pub fn merge(&mut self, other: FunctionSet) {
87 for (k, v) in other.scalar_udfs.into_iter() {
88 self.scalar_udfs.insert(k, v);
89 }
90
91 for (k, v) in other.aggregate_udfs.into_iter() {
92 self.aggregate_udfs.insert(k, v);
93 }
94 }
95
96 pub fn add_scalar_udf_impl(
101 &mut self,
102 name: &str,
103 kernels: impl IntoScalarKernelRefs,
104 ) -> Result<&SedonaScalarUDF> {
105 if let Some(function) = self.scalar_udf_mut(name) {
106 function.add_kernels(kernels);
107 } else {
108 let function = SedonaScalarUDF::from_impl(name, kernels);
109 self.insert_scalar_udf(function);
110 }
111
112 Ok(self.scalar_udf(name).unwrap())
113 }
114
115 pub fn add_aggregate_udf_kernel(
120 &mut self,
121 name: &str,
122 kernel: impl IntoSedonaAccumulatorRefs,
123 ) -> Result<&SedonaAggregateUDF> {
124 if let Some(function) = self.aggregate_udf_mut(name) {
125 function.add_kernel(kernel);
126 } else {
127 let function = SedonaAggregateUDF::from_impl(name, kernel);
128 self.insert_aggregate_udf(function);
129 }
130
131 Ok(self.aggregate_udf(name).unwrap())
132 }
133}
134
135impl Default for FunctionSet {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::{collections::HashSet, sync::Arc};
144
145 use arrow_schema::{DataType, FieldRef};
146 use datafusion_common::{not_impl_err, scalar::ScalarValue};
147
148 use datafusion_expr::{Accumulator, ColumnarValue, Volatility};
149 use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
150
151 use crate::{
152 aggregate_udf::{SedonaAccumulator, SedonaAccumulatorRef},
153 scalar_udf::SimpleSedonaScalarKernel,
154 };
155
156 use super::*;
157
158 #[test]
159 fn function_set() {
160 let mut functions = FunctionSet::new();
161 assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
162 assert!(functions.scalar_udf("simple_udf").is_none());
163 assert!(functions.scalar_udf_mut("simple_udf").is_none());
164
165 let kernel = SimpleSedonaScalarKernel::new_ref(
166 ArgMatcher::new(
167 vec![ArgMatcher::is_arrow(DataType::Boolean)],
168 SedonaType::Arrow(DataType::Boolean),
169 ),
170 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))),
171 );
172
173 let udf = SedonaScalarUDF::new("simple_udf", vec![kernel.clone()], Volatility::Immutable);
174
175 functions.insert_scalar_udf(udf);
176 assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 1);
177 assert!(functions.scalar_udf("simple_udf").is_some());
178 assert!(functions.scalar_udf_mut("simple_udf").is_some());
179 assert_eq!(
180 functions
181 .add_scalar_udf_impl("simple_udf", kernel.clone())
182 .unwrap()
183 .name(),
184 "simple_udf"
185 );
186 let inserted_udf = functions
187 .add_scalar_udf_impl("function that does not yet exist", kernel.clone())
188 .unwrap();
189 assert_eq!(inserted_udf.name(), "function that does not yet exist");
190
191 let kernel2 = SimpleSedonaScalarKernel::new_ref(
192 ArgMatcher::new(
193 vec![ArgMatcher::is_arrow(DataType::Utf8)],
194 SedonaType::Arrow(DataType::Utf8),
195 ),
196 Arc::new(|_, _| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
197 );
198
199 let udf2 = SedonaScalarUDF::new("simple_udf2", vec![kernel2], Volatility::Immutable);
200 let mut functions2 = FunctionSet::new();
201 functions2.insert_scalar_udf(udf2);
202 functions.merge(functions2);
203 assert_eq!(
204 functions
205 .scalar_udfs()
206 .map(|s| s.name())
207 .collect::<HashSet<_>>(),
208 vec![
209 "simple_udf",
210 "simple_udf2",
211 "function that does not yet exist"
212 ]
213 .into_iter()
214 .collect::<HashSet<_>>()
215 );
216 }
217
218 #[derive(Debug, Clone)]
219 struct TestAccumulator {}
220
221 impl SedonaAccumulator for TestAccumulator {
222 fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
223 not_impl_err!("")
224 }
225
226 fn accumulator(
227 &self,
228 _args: &[SedonaType],
229 _output_type: &SedonaType,
230 ) -> Result<Box<dyn Accumulator>> {
231 not_impl_err!("")
232 }
233
234 fn state_fields(&self, _args: &[SedonaType]) -> Result<Vec<FieldRef>> {
235 not_impl_err!("")
236 }
237 }
238
239 #[test]
240 fn function_set_with_aggregates() {
241 let mut functions = FunctionSet::new();
242 assert_eq!(functions.scalar_udfs().collect::<Vec<_>>().len(), 0);
243 assert!(functions.aggregate_udf("simple_udaf").is_none());
244 assert!(functions.aggregate_udf_mut("simple_udaf").is_none());
245
246 let udaf = SedonaAggregateUDF::new(
247 "simple_udaf",
248 Vec::<SedonaAccumulatorRef>::new(),
249 Volatility::Immutable,
250 );
251 let kernel = TestAccumulator {};
252
253 functions.insert_aggregate_udf(udaf);
254 assert_eq!(functions.aggregate_udfs().collect::<Vec<_>>().len(), 1);
255 assert!(functions.aggregate_udf("simple_udaf").is_some());
256 assert!(functions.aggregate_udf_mut("simple_udaf").is_some());
257 assert_eq!(
258 functions
259 .add_aggregate_udf_kernel("simple_udaf", kernel.clone())
260 .unwrap()
261 .name(),
262 "simple_udaf"
263 );
264 let added_func = functions
265 .add_aggregate_udf_kernel("function that does not exist yet", kernel.clone())
266 .unwrap();
267 assert_eq!(added_func.name(), "function that does not exist yet");
268
269 let udaf2 = SedonaAggregateUDF::new(
270 "simple_udaf2",
271 vec![Arc::new(kernel.clone())],
272 Volatility::Immutable,
273 );
274 let mut functions2 = FunctionSet::new();
275 functions2.insert_aggregate_udf(udaf2);
276 functions.merge(functions2);
277 assert_eq!(
278 functions
279 .aggregate_udfs()
280 .map(|s| s.name())
281 .collect::<HashSet<_>>(),
282 vec![
283 "simple_udaf",
284 "simple_udaf2",
285 "function that does not exist yet"
286 ]
287 .into_iter()
288 .collect::<HashSet<_>>()
289 );
290 }
291}