1use crate::expr_rewriter::FunctionRewrite;
21use crate::higher_order_function::HigherOrderUDF;
22use crate::planner::ExprPlanner;
23use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
24use arrow::datatypes::Field;
25use arrow_schema::DataType;
26use arrow_schema::extension::{
27 Bool8, ExtensionType, FixedShapeTensor, Json, Opaque, TimestampWithOffset, Uuid,
28 VariableShapeTensor,
29};
30use datafusion_common::types::{
31 DFBool8, DFExtensionTypeRef, DFFixedShapeTensor, DFJson, DFOpaque,
32 DFTimestampWithOffset, DFUuid, DFVariableShapeTensor,
33};
34use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err};
35use std::collections::HashSet;
36use std::fmt::{Debug, Formatter};
37use std::sync::{Arc, RwLock};
38
39pub trait FunctionRegistry {
41 fn udfs(&self) -> HashSet<String>;
43
44 fn higher_order_function_names(&self) -> HashSet<String>;
46
47 fn udafs(&self) -> HashSet<String>;
49
50 fn udwfs(&self) -> HashSet<String>;
52
53 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;
56
57 fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>>;
60
61 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
64
65 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
68
69 fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
75 not_impl_err!("Registering ScalarUDF")
76 }
77 fn register_higher_order_function(
83 &mut self,
84 _function: Arc<HigherOrderUDF>,
85 ) -> Result<Option<Arc<HigherOrderUDF>>> {
86 not_impl_err!("Registering HigherOrderUDF")
87 }
88 fn register_udaf(
94 &mut self,
95 _udaf: Arc<AggregateUDF>,
96 ) -> Result<Option<Arc<AggregateUDF>>> {
97 not_impl_err!("Registering AggregateUDF")
98 }
99 fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
105 not_impl_err!("Registering WindowUDF")
106 }
107
108 fn deregister_udf(&mut self, _name: &str) -> Result<Option<Arc<ScalarUDF>>> {
114 not_impl_err!("Deregistering ScalarUDF")
115 }
116
117 fn deregister_higher_order_function(
123 &mut self,
124 _name: &str,
125 ) -> Result<Option<Arc<HigherOrderUDF>>> {
126 not_impl_err!("Deregistering HigherOrderUDF")
127 }
128
129 fn deregister_udaf(&mut self, _name: &str) -> Result<Option<Arc<AggregateUDF>>> {
135 not_impl_err!("Deregistering AggregateUDF")
136 }
137
138 fn deregister_udwf(&mut self, _name: &str) -> Result<Option<Arc<WindowUDF>>> {
144 not_impl_err!("Deregistering WindowUDF")
145 }
146
147 fn register_function_rewrite(
155 &mut self,
156 _rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
157 ) -> Result<()> {
158 not_impl_err!("Registering FunctionRewrite")
159 }
160
161 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>>;
163
164 fn register_expr_planner(
166 &mut self,
167 _expr_planner: Arc<dyn ExprPlanner>,
168 ) -> Result<()> {
169 not_impl_err!("Registering ExprPlanner")
170 }
171}
172
173pub trait SerializerRegistry: Debug + Send + Sync {
175 fn serialize_logical_plan(
178 &self,
179 node: &dyn UserDefinedLogicalNode,
180 ) -> Result<Vec<u8>>;
181
182 fn deserialize_logical_plan(
185 &self,
186 name: &str,
187 bytes: &[u8],
188 ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
189}
190
191#[derive(Default, Debug)]
193pub struct MemoryFunctionRegistry {
194 udfs: HashMap<String, Arc<ScalarUDF>>,
196 udafs: HashMap<String, Arc<AggregateUDF>>,
198 udwfs: HashMap<String, Arc<WindowUDF>>,
200 higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
202}
203
204impl MemoryFunctionRegistry {
205 pub fn new() -> Self {
206 Self::default()
207 }
208}
209
210impl FunctionRegistry for MemoryFunctionRegistry {
211 fn udfs(&self) -> HashSet<String> {
212 self.udfs.keys().cloned().collect()
213 }
214
215 fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
216 self.udfs
217 .get(name)
218 .cloned()
219 .ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
220 }
221
222 fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>> {
223 self.higher_order_functions
224 .get(name)
225 .cloned()
226 .ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found"))
227 }
228
229 fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
230 self.udafs
231 .get(name)
232 .cloned()
233 .ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
234 }
235
236 fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
237 self.udwfs
238 .get(name)
239 .cloned()
240 .ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
241 }
242
243 fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
244 Ok(self.udfs.insert(udf.name().to_string(), udf))
245 }
246 fn register_higher_order_function(
247 &mut self,
248 function: Arc<HigherOrderUDF>,
249 ) -> Result<Option<Arc<HigherOrderUDF>>> {
250 Ok(self
251 .higher_order_functions
252 .insert(function.name().into(), function))
253 }
254 fn register_udaf(
255 &mut self,
256 udaf: Arc<AggregateUDF>,
257 ) -> Result<Option<Arc<AggregateUDF>>> {
258 Ok(self.udafs.insert(udaf.name().into(), udaf))
259 }
260 fn register_udwf(&mut self, udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
261 Ok(self.udwfs.insert(udaf.name().into(), udaf))
262 }
263
264 fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
265 vec![]
266 }
267
268 fn higher_order_function_names(&self) -> HashSet<String> {
269 self.higher_order_functions.keys().cloned().collect()
270 }
271
272 fn udafs(&self) -> HashSet<String> {
273 self.udafs.keys().cloned().collect()
274 }
275
276 fn udwfs(&self) -> HashSet<String> {
277 self.udwfs.keys().cloned().collect()
278 }
279}
280
281pub type ExtensionTypeRegistryRef = Arc<dyn ExtensionTypeRegistry>;
283
284pub trait ExtensionTypeRegistry: Debug + Send + Sync {
290 fn extension_type_registration(
294 &self,
295 name: &str,
296 ) -> Result<ExtensionTypeRegistrationRef>;
297
298 fn create_extension_type_for_field(
303 &self,
304 field: &Field,
305 ) -> Result<Option<DFExtensionTypeRef>> {
306 let Some(extension_type_name) = field.extension_type_name() else {
307 return Ok(None);
308 };
309
310 let registration = self.extension_type_registration(extension_type_name)?;
311 registration
312 .create_df_extension_type(field.data_type(), field.extension_type_metadata())
313 .map(Some)
314 }
315
316 fn extension_type_registrations(&self) -> Vec<ExtensionTypeRegistrationRef>;
318
319 fn add_extension_type_registration(
325 &self,
326 extension_type: ExtensionTypeRegistrationRef,
327 ) -> Result<Option<ExtensionTypeRegistrationRef>>;
328
329 fn extend(&self, extension_types: &[ExtensionTypeRegistrationRef]) -> Result<()> {
334 for extension_type in extension_types.iter().cloned() {
335 self.add_extension_type_registration(extension_type)?;
336 }
337 Ok(())
338 }
339
340 fn remove_extension_type_registration(
346 &self,
347 name: &str,
348 ) -> Result<Option<ExtensionTypeRegistrationRef>>;
349}
350
351pub type ExtensionTypeFactory =
354 dyn Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef> + Send + Sync;
355
356pub type ExtensionTypeRegistrationRef = Arc<ExtensionTypeRegistration>;
358
359pub struct ExtensionTypeRegistration {
377 name: String,
379 factory: Box<ExtensionTypeFactory>,
382}
383
384impl ExtensionTypeRegistration {
385 pub fn new_arc(
388 name: impl Into<String>,
389 factory: impl Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef>
390 + Send
391 + Sync
392 + 'static,
393 ) -> ExtensionTypeRegistrationRef {
394 Arc::new(Self {
395 name: name.into(),
396 factory: Box::new(factory),
397 })
398 }
399}
400
401impl ExtensionTypeRegistration {
402 pub fn type_name(&self) -> &str {
407 &self.name
408 }
409
410 pub fn create_df_extension_type(
413 &self,
414 storage_type: &DataType,
415 metadata: Option<&str>,
416 ) -> Result<DFExtensionTypeRef> {
417 self.factory.as_ref()(storage_type, metadata)
418 }
419}
420
421impl Debug for ExtensionTypeRegistration {
422 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
423 f.debug_struct("DefaultExtensionTypeRegistration")
424 .field("type_name", &self.name)
425 .finish()
426 }
427}
428
429#[derive(Clone, Debug)]
431pub struct MemoryExtensionTypeRegistry {
432 extension_types: Arc<RwLock<HashMap<String, ExtensionTypeRegistrationRef>>>,
434}
435
436impl Default for MemoryExtensionTypeRegistry {
437 fn default() -> Self {
438 Self::new_empty()
439 }
440}
441
442impl MemoryExtensionTypeRegistry {
443 pub fn new_empty() -> Self {
445 Self {
446 extension_types: Arc::new(RwLock::new(HashMap::new())),
447 }
448 }
449
450 pub fn new_with_canonical_extension_types() -> Self {
453 let mapping = [
454 ExtensionTypeRegistration::new_arc(
455 FixedShapeTensor::NAME,
456 |storage_type, metadata| {
457 Ok(Arc::new(DFFixedShapeTensor::try_new(
458 storage_type,
459 FixedShapeTensor::deserialize_metadata(metadata)?,
460 )?))
461 },
462 ),
463 ExtensionTypeRegistration::new_arc(
464 VariableShapeTensor::NAME,
465 |storage_type, metadata| {
466 Ok(Arc::new(DFVariableShapeTensor::try_new(
467 storage_type,
468 VariableShapeTensor::deserialize_metadata(metadata)?,
469 )?))
470 },
471 ),
472 ExtensionTypeRegistration::new_arc(Json::NAME, |storage_type, metadata| {
473 Ok(Arc::new(DFJson::try_new(
474 storage_type,
475 Json::deserialize_metadata(metadata)?,
476 )?))
477 }),
478 ExtensionTypeRegistration::new_arc(Uuid::NAME, |storage_type, metadata| {
479 Ok(Arc::new(DFUuid::try_new(
480 storage_type,
481 Uuid::deserialize_metadata(metadata)?,
482 )?))
483 }),
484 ExtensionTypeRegistration::new_arc(Opaque::NAME, |storage_type, metadata| {
485 Ok(Arc::new(DFOpaque::try_new(
486 storage_type,
487 Opaque::deserialize_metadata(metadata)?,
488 )?))
489 }),
490 ExtensionTypeRegistration::new_arc(Bool8::NAME, |storage_type, metadata| {
491 Ok(Arc::new(DFBool8::try_new(
492 storage_type,
493 Bool8::deserialize_metadata(metadata)?,
494 )?))
495 }),
496 ExtensionTypeRegistration::new_arc(
497 TimestampWithOffset::NAME,
498 |storage_type, metadata| {
499 Ok(Arc::new(DFTimestampWithOffset::try_new(
500 storage_type,
501 TimestampWithOffset::deserialize_metadata(metadata)?,
502 )?))
503 },
504 ),
505 ];
506
507 let mut extension_types = HashMap::new();
508 for registration in mapping.into_iter() {
509 extension_types.insert(registration.type_name().to_owned(), registration);
510 }
511
512 Self {
513 extension_types: Arc::new(RwLock::new(HashMap::from(extension_types))),
514 }
515 }
516
517 pub fn new_with_types(
523 types: impl IntoIterator<Item = ExtensionTypeRegistrationRef>,
524 ) -> Result<Self> {
525 let extension_types = types
526 .into_iter()
527 .map(|t| (t.type_name().to_owned(), t))
528 .collect::<HashMap<_, _>>();
529 Ok(Self {
530 extension_types: Arc::new(RwLock::new(extension_types)),
531 })
532 }
533
534 pub fn all_extension_types(&self) -> Vec<ExtensionTypeRegistrationRef> {
536 self.extension_types
537 .read()
538 .expect("Extension type registry lock poisoned")
539 .values()
540 .cloned()
541 .collect()
542 }
543}
544
545impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry {
546 fn extension_type_registration(
547 &self,
548 name: &str,
549 ) -> Result<ExtensionTypeRegistrationRef> {
550 self.extension_types
551 .write()
552 .expect("Extension type registry lock poisoned")
553 .get(name)
554 .ok_or_else(|| plan_datafusion_err!("Logical type not found."))
555 .cloned()
556 }
557
558 fn extension_type_registrations(&self) -> Vec<ExtensionTypeRegistrationRef> {
559 self.extension_types
560 .read()
561 .expect("Extension type registry lock poisoned")
562 .values()
563 .cloned()
564 .collect()
565 }
566
567 fn add_extension_type_registration(
568 &self,
569 extension_type: ExtensionTypeRegistrationRef,
570 ) -> Result<Option<ExtensionTypeRegistrationRef>> {
571 Ok(self
572 .extension_types
573 .write()
574 .expect("Extension type registry lock poisoned")
575 .insert(extension_type.type_name().to_owned(), extension_type))
576 }
577
578 fn remove_extension_type_registration(
579 &self,
580 name: &str,
581 ) -> Result<Option<ExtensionTypeRegistrationRef>> {
582 Ok(self
583 .extension_types
584 .write()
585 .expect("Extension type registry lock poisoned")
586 .remove(name))
587 }
588}
589
590impl From<HashMap<String, ExtensionTypeRegistrationRef>> for MemoryExtensionTypeRegistry {
591 fn from(value: HashMap<String, ExtensionTypeRegistrationRef>) -> Self {
592 Self {
593 extension_types: Arc::new(RwLock::new(value)),
594 }
595 }
596}