Skip to main content

datafusion_expr/
registry.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! FunctionRegistry trait
19
20use crate::expr_rewriter::FunctionRewrite;
21use crate::higher_order_function::HigherOrderUDF;
22use crate::planner::ExprPlanner;
23use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
24use arrow::datatypes::Field;
25use arrow_schema::DataType;
26use arrow_schema::extension::{
27    Bool8, ExtensionType, FixedShapeTensor, Json, Opaque, TimestampWithOffset, Uuid,
28    VariableShapeTensor,
29};
30use datafusion_common::types::{
31    DFBool8, DFExtensionTypeRef, DFFixedShapeTensor, DFJson, DFOpaque,
32    DFTimestampWithOffset, DFUuid, DFVariableShapeTensor,
33};
34use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err};
35use std::collections::HashSet;
36use std::fmt::{Debug, Formatter};
37use std::sync::{Arc, RwLock};
38
39/// A registry knows how to build logical expressions out of user-defined function' names
40pub trait FunctionRegistry {
41    /// Returns names of all available scalar user defined functions.
42    fn udfs(&self) -> HashSet<String>;
43
44    /// Returns names of all available higher order user defined functions.
45    fn higher_order_function_names(&self) -> HashSet<String>;
46
47    /// Returns names of all available aggregate user defined functions.
48    fn udafs(&self) -> HashSet<String>;
49
50    /// Returns names of all available window user defined functions.
51    fn udwfs(&self) -> HashSet<String>;
52
53    /// Returns a reference to the user defined scalar function (udf) named
54    /// `name`.
55    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;
56
57    /// Returns a reference to the user defined higher order function named
58    /// `name`.
59    fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>>;
60
61    /// Returns a reference to the user defined aggregate function (udaf) named
62    /// `name`.
63    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
64
65    /// Returns a reference to the user defined window function (udwf) named
66    /// `name`.
67    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
68
69    /// Registers a new [`ScalarUDF`], returning any previously registered
70    /// implementation.
71    ///
72    /// Returns an error (the default) if the function can not be registered,
73    /// for example if the registry is read only.
74    fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
75        not_impl_err!("Registering ScalarUDF")
76    }
77    /// Registers a new [`HigherOrderUDF`], returning any previously registered
78    /// implementation.
79    ///
80    /// Returns an error (the default) if the function can not be registered,
81    /// for example if the registry is read only.
82    fn register_higher_order_function(
83        &mut self,
84        _function: Arc<HigherOrderUDF>,
85    ) -> Result<Option<Arc<HigherOrderUDF>>> {
86        not_impl_err!("Registering HigherOrderUDF")
87    }
88    /// Registers a new [`AggregateUDF`], returning any previously registered
89    /// implementation.
90    ///
91    /// Returns an error (the default) if the function can not be registered,
92    /// for example if the registry is read only.
93    fn register_udaf(
94        &mut self,
95        _udaf: Arc<AggregateUDF>,
96    ) -> Result<Option<Arc<AggregateUDF>>> {
97        not_impl_err!("Registering AggregateUDF")
98    }
99    /// Registers a new [`WindowUDF`], returning any previously registered
100    /// implementation.
101    ///
102    /// Returns an error (the default) if the function can not be registered,
103    /// for example if the registry is read only.
104    fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
105        not_impl_err!("Registering WindowUDF")
106    }
107
108    /// Deregisters a [`ScalarUDF`], returning the implementation that was
109    /// deregistered.
110    ///
111    /// Returns an error (the default) if the function can not be deregistered,
112    /// for example if the registry is read only.
113    fn deregister_udf(&mut self, _name: &str) -> Result<Option<Arc<ScalarUDF>>> {
114        not_impl_err!("Deregistering ScalarUDF")
115    }
116
117    /// Deregisters a [`HigherOrderUDF`], returning the implementation that was
118    /// deregistered.
119    ///
120    /// Returns an error (the default) if the function can not be deregistered,
121    /// for example if the registry is read only.
122    fn deregister_higher_order_function(
123        &mut self,
124        _name: &str,
125    ) -> Result<Option<Arc<HigherOrderUDF>>> {
126        not_impl_err!("Deregistering HigherOrderUDF")
127    }
128
129    /// Deregisters a [`AggregateUDF`], returning the implementation that was
130    /// deregistered.
131    ///
132    /// Returns an error (the default) if the function can not be deregistered,
133    /// for example if the registry is read only.
134    fn deregister_udaf(&mut self, _name: &str) -> Result<Option<Arc<AggregateUDF>>> {
135        not_impl_err!("Deregistering AggregateUDF")
136    }
137
138    /// Deregisters a [`WindowUDF`], returning the implementation that was
139    /// deregistered.
140    ///
141    /// Returns an error (the default) if the function can not be deregistered,
142    /// for example if the registry is read only.
143    fn deregister_udwf(&mut self, _name: &str) -> Result<Option<Arc<WindowUDF>>> {
144        not_impl_err!("Deregistering WindowUDF")
145    }
146
147    /// Registers a new [`FunctionRewrite`] with the registry.
148    ///
149    /// `FunctionRewrite` rules are used to rewrite certain / operators in the
150    /// logical plan to function calls.  For example `a || b` might be written to
151    /// `array_concat(a, b)`.
152    ///
153    /// This allows the behavior of operators to be customized by the user.
154    fn register_function_rewrite(
155        &mut self,
156        _rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
157    ) -> Result<()> {
158        not_impl_err!("Registering FunctionRewrite")
159    }
160
161    /// Set of all registered [`ExprPlanner`]s
162    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>>;
163
164    /// Registers a new [`ExprPlanner`] with the registry.
165    fn register_expr_planner(
166        &mut self,
167        _expr_planner: Arc<dyn ExprPlanner>,
168    ) -> Result<()> {
169        not_impl_err!("Registering ExprPlanner")
170    }
171}
172
173/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
174pub trait SerializerRegistry: Debug + Send + Sync {
175    /// Serialize this node to a byte array. This serialization should not include
176    /// input plans.
177    fn serialize_logical_plan(
178        &self,
179        node: &dyn UserDefinedLogicalNode,
180    ) -> Result<Vec<u8>>;
181
182    /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
183    /// bytes.
184    fn deserialize_logical_plan(
185        &self,
186        name: &str,
187        bytes: &[u8],
188    ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
189}
190
191/// A  [`FunctionRegistry`] that uses in memory [`HashMap`]s
192#[derive(Default, Debug)]
193pub struct MemoryFunctionRegistry {
194    /// Scalar Functions
195    udfs: HashMap<String, Arc<ScalarUDF>>,
196    /// Aggregate Functions
197    udafs: HashMap<String, Arc<AggregateUDF>>,
198    /// Window Functions
199    udwfs: HashMap<String, Arc<WindowUDF>>,
200    /// Higher Order Functions
201    higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
202}
203
204impl MemoryFunctionRegistry {
205    pub fn new() -> Self {
206        Self::default()
207    }
208}
209
210impl FunctionRegistry for MemoryFunctionRegistry {
211    fn udfs(&self) -> HashSet<String> {
212        self.udfs.keys().cloned().collect()
213    }
214
215    fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
216        self.udfs
217            .get(name)
218            .cloned()
219            .ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
220    }
221
222    fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>> {
223        self.higher_order_functions
224            .get(name)
225            .cloned()
226            .ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found"))
227    }
228
229    fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
230        self.udafs
231            .get(name)
232            .cloned()
233            .ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
234    }
235
236    fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
237        self.udwfs
238            .get(name)
239            .cloned()
240            .ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
241    }
242
243    fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
244        Ok(self.udfs.insert(udf.name().to_string(), udf))
245    }
246    fn register_higher_order_function(
247        &mut self,
248        function: Arc<HigherOrderUDF>,
249    ) -> Result<Option<Arc<HigherOrderUDF>>> {
250        Ok(self
251            .higher_order_functions
252            .insert(function.name().into(), function))
253    }
254    fn register_udaf(
255        &mut self,
256        udaf: Arc<AggregateUDF>,
257    ) -> Result<Option<Arc<AggregateUDF>>> {
258        Ok(self.udafs.insert(udaf.name().into(), udaf))
259    }
260    fn register_udwf(&mut self, udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
261        Ok(self.udwfs.insert(udaf.name().into(), udaf))
262    }
263
264    fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
265        vec![]
266    }
267
268    fn higher_order_function_names(&self) -> HashSet<String> {
269        self.higher_order_functions.keys().cloned().collect()
270    }
271
272    fn udafs(&self) -> HashSet<String> {
273        self.udafs.keys().cloned().collect()
274    }
275
276    fn udwfs(&self) -> HashSet<String> {
277        self.udwfs.keys().cloned().collect()
278    }
279}
280
281/// A cheaply cloneable pointer to an [ExtensionTypeRegistry].
282pub type ExtensionTypeRegistryRef = Arc<dyn ExtensionTypeRegistry>;
283
284/// Manages [`ExtensionTypeRegistration`]s, which allow users to register custom behavior for
285/// extension types.
286///
287/// Each registration is connected to the extension type name, which can also be looked up to get
288/// the registration.
289pub trait ExtensionTypeRegistry: Debug + Send + Sync {
290    /// Returns a reference to registration of an extension type named `name`.
291    ///
292    /// Returns an error if there is no extension type with that name.
293    fn extension_type_registration(
294        &self,
295        name: &str,
296    ) -> Result<ExtensionTypeRegistrationRef>;
297
298    /// Creates a [`DFExtensionTypeRef`] from the type information in the `field`.
299    ///
300    /// The result `Ok(None)` indicates that there is no extension type metadata. Returns an error
301    /// if the extension type in the metadata is not found.
302    fn create_extension_type_for_field(
303        &self,
304        field: &Field,
305    ) -> Result<Option<DFExtensionTypeRef>> {
306        let Some(extension_type_name) = field.extension_type_name() else {
307            return Ok(None);
308        };
309
310        let registration = self.extension_type_registration(extension_type_name)?;
311        registration
312            .create_df_extension_type(field.data_type(), field.extension_type_metadata())
313            .map(Some)
314    }
315
316    /// Returns all registered [ExtensionTypeRegistration].
317    fn extension_type_registrations(&self) -> Vec<ExtensionTypeRegistrationRef>;
318
319    /// Registers a new [ExtensionTypeRegistrationRef], returning any previously registered
320    /// implementation.
321    ///
322    /// Returns an error if the type cannot be registered, for example, if the registry is
323    /// read-only.
324    fn add_extension_type_registration(
325        &self,
326        extension_type: ExtensionTypeRegistrationRef,
327    ) -> Result<Option<ExtensionTypeRegistrationRef>>;
328
329    /// Extends the registry with the provided extension types.
330    ///
331    /// Returns an error if the type cannot be registered, for example, if the registry is
332    /// read-only.
333    fn extend(&self, extension_types: &[ExtensionTypeRegistrationRef]) -> Result<()> {
334        for extension_type in extension_types.iter().cloned() {
335            self.add_extension_type_registration(extension_type)?;
336        }
337        Ok(())
338    }
339
340    /// Deregisters an extension type registration with the name `name`, returning the
341    /// implementation that was deregistered.
342    ///
343    /// Returns an error if the type cannot be deregistered, for example, if the registry is
344    /// read-only.
345    fn remove_extension_type_registration(
346        &self,
347        name: &str,
348    ) -> Result<Option<ExtensionTypeRegistrationRef>>;
349}
350
351/// A factory that creates instances of extension types from a storage [`DataType`] and the
352/// metadata.
353pub type ExtensionTypeFactory =
354    dyn Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef> + Send + Sync;
355
356/// A cheaply cloneable pointer to an [ExtensionTypeRegistration].
357pub type ExtensionTypeRegistrationRef = Arc<ExtensionTypeRegistration>;
358
359/// The registration of an extension type. Implementations of this trait are responsible for
360/// *creating* instances of [`DFExtensionType`] that represent the entire semantics of an extension
361/// type.
362///
363/// # Why do we need a Registration?
364///
365/// A good question is why this trait is even necessary. Why not directly register the
366/// [`DFExtensionType`] in a registry?
367///
368/// While this works for extension types requiring no additional metadata (e.g., `arrow.uuid`), it
369/// does not work for more complex extension types with metadata. For example, consider an extension
370/// type `custom.shortened(n)` that aims to short the pretty-printing string to `n` characters.
371/// Here, `n` is a parameter of the extension type and should be a field in the struct that
372/// implements the [`DFExtensionType`]. The job of the registration is to read the metadata from the
373/// field and create the corresponding [`DFExtensionType`] instance with the correct `n` set.
374///
375/// [`DFExtensionType`]: datafusion_common::types::DFExtensionType
376pub struct ExtensionTypeRegistration {
377    /// The name of the extension type.
378    name: String,
379    /// A function that creates an instance of [`DFExtensionTypeRef`] from the storage type and the
380    /// metadata.
381    factory: Box<ExtensionTypeFactory>,
382}
383
384impl ExtensionTypeRegistration {
385    /// Creates a new registration for an extension type. The factory is required to validate that
386    /// the storage [`DataType`] is compatible with the extension type.
387    pub fn new_arc(
388        name: impl Into<String>,
389        factory: impl Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef>
390        + Send
391        + Sync
392        + 'static,
393    ) -> ExtensionTypeRegistrationRef {
394        Arc::new(Self {
395            name: name.into(),
396            factory: Box::new(factory),
397        })
398    }
399}
400
401impl ExtensionTypeRegistration {
402    /// The name of the extension type.
403    ///
404    /// This name will be used to find the correct [ExtensionTypeRegistration] when an extension
405    /// type is encountered.
406    pub fn type_name(&self) -> &str {
407        &self.name
408    }
409
410    /// Creates an extension type instance from the optional metadata. The name of the extension
411    /// type is not a parameter as it's already defined by the registration itself.
412    pub fn create_df_extension_type(
413        &self,
414        storage_type: &DataType,
415        metadata: Option<&str>,
416    ) -> Result<DFExtensionTypeRef> {
417        self.factory.as_ref()(storage_type, metadata)
418    }
419}
420
421impl Debug for ExtensionTypeRegistration {
422    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
423        f.debug_struct("DefaultExtensionTypeRegistration")
424            .field("type_name", &self.name)
425            .finish()
426    }
427}
428
429/// An [`ExtensionTypeRegistry`] that uses in memory [`HashMap`]s.
430#[derive(Clone, Debug)]
431pub struct MemoryExtensionTypeRegistry {
432    /// Holds a mapping between the name of an extension type and its logical type.
433    extension_types: Arc<RwLock<HashMap<String, ExtensionTypeRegistrationRef>>>,
434}
435
436impl Default for MemoryExtensionTypeRegistry {
437    fn default() -> Self {
438        Self::new_empty()
439    }
440}
441
442impl MemoryExtensionTypeRegistry {
443    /// Creates an empty [MemoryExtensionTypeRegistry].
444    pub fn new_empty() -> Self {
445        Self {
446            extension_types: Arc::new(RwLock::new(HashMap::new())),
447        }
448    }
449
450    /// Pre-registers the [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html)
451    /// in the extension type registry.
452    pub fn new_with_canonical_extension_types() -> Self {
453        let mapping = [
454            ExtensionTypeRegistration::new_arc(
455                FixedShapeTensor::NAME,
456                |storage_type, metadata| {
457                    Ok(Arc::new(DFFixedShapeTensor::try_new(
458                        storage_type,
459                        FixedShapeTensor::deserialize_metadata(metadata)?,
460                    )?))
461                },
462            ),
463            ExtensionTypeRegistration::new_arc(
464                VariableShapeTensor::NAME,
465                |storage_type, metadata| {
466                    Ok(Arc::new(DFVariableShapeTensor::try_new(
467                        storage_type,
468                        VariableShapeTensor::deserialize_metadata(metadata)?,
469                    )?))
470                },
471            ),
472            ExtensionTypeRegistration::new_arc(Json::NAME, |storage_type, metadata| {
473                Ok(Arc::new(DFJson::try_new(
474                    storage_type,
475                    Json::deserialize_metadata(metadata)?,
476                )?))
477            }),
478            ExtensionTypeRegistration::new_arc(Uuid::NAME, |storage_type, metadata| {
479                Ok(Arc::new(DFUuid::try_new(
480                    storage_type,
481                    Uuid::deserialize_metadata(metadata)?,
482                )?))
483            }),
484            ExtensionTypeRegistration::new_arc(Opaque::NAME, |storage_type, metadata| {
485                Ok(Arc::new(DFOpaque::try_new(
486                    storage_type,
487                    Opaque::deserialize_metadata(metadata)?,
488                )?))
489            }),
490            ExtensionTypeRegistration::new_arc(Bool8::NAME, |storage_type, metadata| {
491                Ok(Arc::new(DFBool8::try_new(
492                    storage_type,
493                    Bool8::deserialize_metadata(metadata)?,
494                )?))
495            }),
496            ExtensionTypeRegistration::new_arc(
497                TimestampWithOffset::NAME,
498                |storage_type, metadata| {
499                    Ok(Arc::new(DFTimestampWithOffset::try_new(
500                        storage_type,
501                        TimestampWithOffset::deserialize_metadata(metadata)?,
502                    )?))
503                },
504            ),
505        ];
506
507        let mut extension_types = HashMap::new();
508        for registration in mapping.into_iter() {
509            extension_types.insert(registration.type_name().to_owned(), registration);
510        }
511
512        Self {
513            extension_types: Arc::new(RwLock::new(HashMap::from(extension_types))),
514        }
515    }
516
517    /// Creates a new [MemoryExtensionTypeRegistry] with the provided `types`.
518    ///
519    /// # Errors
520    ///
521    /// Returns an error if one of the `types` is a native type.
522    pub fn new_with_types(
523        types: impl IntoIterator<Item = ExtensionTypeRegistrationRef>,
524    ) -> Result<Self> {
525        let extension_types = types
526            .into_iter()
527            .map(|t| (t.type_name().to_owned(), t))
528            .collect::<HashMap<_, _>>();
529        Ok(Self {
530            extension_types: Arc::new(RwLock::new(extension_types)),
531        })
532    }
533
534    /// Returns a list of all registered types.
535    pub fn all_extension_types(&self) -> Vec<ExtensionTypeRegistrationRef> {
536        self.extension_types
537            .read()
538            .expect("Extension type registry lock poisoned")
539            .values()
540            .cloned()
541            .collect()
542    }
543}
544
545impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry {
546    fn extension_type_registration(
547        &self,
548        name: &str,
549    ) -> Result<ExtensionTypeRegistrationRef> {
550        self.extension_types
551            .write()
552            .expect("Extension type registry lock poisoned")
553            .get(name)
554            .ok_or_else(|| plan_datafusion_err!("Logical type not found."))
555            .cloned()
556    }
557
558    fn extension_type_registrations(&self) -> Vec<ExtensionTypeRegistrationRef> {
559        self.extension_types
560            .read()
561            .expect("Extension type registry lock poisoned")
562            .values()
563            .cloned()
564            .collect()
565    }
566
567    fn add_extension_type_registration(
568        &self,
569        extension_type: ExtensionTypeRegistrationRef,
570    ) -> Result<Option<ExtensionTypeRegistrationRef>> {
571        Ok(self
572            .extension_types
573            .write()
574            .expect("Extension type registry lock poisoned")
575            .insert(extension_type.type_name().to_owned(), extension_type))
576    }
577
578    fn remove_extension_type_registration(
579        &self,
580        name: &str,
581    ) -> Result<Option<ExtensionTypeRegistrationRef>> {
582        Ok(self
583            .extension_types
584            .write()
585            .expect("Extension type registry lock poisoned")
586            .remove(name))
587    }
588}
589
590impl From<HashMap<String, ExtensionTypeRegistrationRef>> for MemoryExtensionTypeRegistry {
591    fn from(value: HashMap<String, ExtensionTypeRegistrationRef>) -> Self {
592        Self {
593            extension_types: Arc::new(RwLock::new(value)),
594        }
595    }
596}