use crate::expr_rewriter::FunctionRewrite;
use crate::higher_order_function::HigherOrderUDF;
use crate::planner::ExprPlanner;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use arrow::datatypes::Field;
use arrow_schema::DataType;
use arrow_schema::extension::{
Bool8, ExtensionType, FixedShapeTensor, Json, Opaque, TimestampWithOffset, Uuid,
VariableShapeTensor,
};
use datafusion_common::types::{
DFBool8, DFExtensionTypeRef, DFFixedShapeTensor, DFJson, DFOpaque,
DFTimestampWithOffset, DFUuid, DFVariableShapeTensor,
};
use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err};
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::{Arc, RwLock};
pub trait FunctionRegistry {
fn udfs(&self) -> HashSet<String>;
fn higher_order_function_names(&self) -> HashSet<String>;
fn udafs(&self) -> HashSet<String>;
fn udwfs(&self) -> HashSet<String>;
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;
fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>>;
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Registering ScalarUDF")
}
fn register_higher_order_function(
&mut self,
_function: Arc<HigherOrderUDF>,
) -> Result<Option<Arc<HigherOrderUDF>>> {
not_impl_err!("Registering HigherOrderUDF")
}
fn register_udaf(
&mut self,
_udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Registering AggregateUDF")
}
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Registering WindowUDF")
}
fn deregister_udf(&mut self, _name: &str) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Deregistering ScalarUDF")
}
fn deregister_higher_order_function(
&mut self,
_name: &str,
) -> Result<Option<Arc<HigherOrderUDF>>> {
not_impl_err!("Deregistering HigherOrderUDF")
}
fn deregister_udaf(&mut self, _name: &str) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Deregistering AggregateUDF")
}
fn deregister_udwf(&mut self, _name: &str) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Deregistering WindowUDF")
}
fn register_function_rewrite(
&mut self,
_rewrite: Arc<dyn FunctionRewrite + Send + Sync>,
) -> Result<()> {
not_impl_err!("Registering FunctionRewrite")
}
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>>;
fn register_expr_planner(
&mut self,
_expr_planner: Arc<dyn ExprPlanner>,
) -> Result<()> {
not_impl_err!("Registering ExprPlanner")
}
}
pub trait SerializerRegistry: Debug + Send + Sync {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
}
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
udfs: HashMap<String, Arc<ScalarUDF>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
udwfs: HashMap<String, Arc<WindowUDF>>,
higher_order_functions: HashMap<String, Arc<HigherOrderUDF>>,
}
impl MemoryFunctionRegistry {
pub fn new() -> Self {
Self::default()
}
}
impl FunctionRegistry for MemoryFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.udfs.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.udfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
}
fn higher_order_function(&self, name: &str) -> Result<Arc<HigherOrderUDF>> {
self.higher_order_functions
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found"))
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.udafs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.udwfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.udfs.insert(udf.name().to_string(), udf))
}
fn register_higher_order_function(
&mut self,
function: Arc<HigherOrderUDF>,
) -> Result<Option<Arc<HigherOrderUDF>>> {
Ok(self
.higher_order_functions
.insert(function.name().into(), function))
}
fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
Ok(self.udafs.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.udwfs.insert(udaf.name().into(), udaf))
}
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
vec![]
}
fn higher_order_function_names(&self) -> HashSet<String> {
self.higher_order_functions.keys().cloned().collect()
}
fn udafs(&self) -> HashSet<String> {
self.udafs.keys().cloned().collect()
}
fn udwfs(&self) -> HashSet<String> {
self.udwfs.keys().cloned().collect()
}
}
pub type ExtensionTypeRegistryRef = Arc<dyn ExtensionTypeRegistry>;
pub trait ExtensionTypeRegistry: Debug + Send + Sync {
fn extension_type_registration(
&self,
name: &str,
) -> Result<ExtensionTypeRegistrationRef>;
fn create_extension_type_for_field(
&self,
field: &Field,
) -> Result<Option<DFExtensionTypeRef>> {
let Some(extension_type_name) = field.extension_type_name() else {
return Ok(None);
};
let registration = self.extension_type_registration(extension_type_name)?;
registration
.create_df_extension_type(field.data_type(), field.extension_type_metadata())
.map(Some)
}
fn extension_type_registrations(&self) -> Vec<ExtensionTypeRegistrationRef>;
fn add_extension_type_registration(
&self,
extension_type: ExtensionTypeRegistrationRef,
) -> Result<Option<ExtensionTypeRegistrationRef>>;
fn extend(&self, extension_types: &[ExtensionTypeRegistrationRef]) -> Result<()> {
for extension_type in extension_types.iter().cloned() {
self.add_extension_type_registration(extension_type)?;
}
Ok(())
}
fn remove_extension_type_registration(
&self,
name: &str,
) -> Result<Option<ExtensionTypeRegistrationRef>>;
}
pub type ExtensionTypeFactory =
dyn Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef> + Send + Sync;
pub type ExtensionTypeRegistrationRef = Arc<ExtensionTypeRegistration>;
pub struct ExtensionTypeRegistration {
name: String,
factory: Box<ExtensionTypeFactory>,
}
impl ExtensionTypeRegistration {
pub fn new_arc(
name: impl Into<String>,
factory: impl Fn(&DataType, Option<&str>) -> Result<DFExtensionTypeRef>
+ Send
+ Sync
+ 'static,
) -> ExtensionTypeRegistrationRef {
Arc::new(Self {
name: name.into(),
factory: Box::new(factory),
})
}
}
impl ExtensionTypeRegistration {
pub fn type_name(&self) -> &str {
&self.name
}
pub fn create_df_extension_type(
&self,
storage_type: &DataType,
metadata: Option<&str>,
) -> Result<DFExtensionTypeRef> {
self.factory.as_ref()(storage_type, metadata)
}
}
impl Debug for ExtensionTypeRegistration {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DefaultExtensionTypeRegistration")
.field("type_name", &self.name)
.finish()
}
}
#[derive(Clone, Debug)]
pub struct MemoryExtensionTypeRegistry {
extension_types: Arc<RwLock<HashMap<String, ExtensionTypeRegistrationRef>>>,
}
impl Default for MemoryExtensionTypeRegistry {
fn default() -> Self {
Self::new_empty()
}
}
impl MemoryExtensionTypeRegistry {
pub fn new_empty() -> Self {
Self {
extension_types: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn new_with_canonical_extension_types() -> Self {
let mapping = [
ExtensionTypeRegistration::new_arc(
FixedShapeTensor::NAME,
|storage_type, metadata| {
Ok(Arc::new(DFFixedShapeTensor::try_new(
storage_type,
FixedShapeTensor::deserialize_metadata(metadata)?,
)?))
},
),
ExtensionTypeRegistration::new_arc(
VariableShapeTensor::NAME,
|storage_type, metadata| {
Ok(Arc::new(DFVariableShapeTensor::try_new(
storage_type,
VariableShapeTensor::deserialize_metadata(metadata)?,
)?))
},
),
ExtensionTypeRegistration::new_arc(Json::NAME, |storage_type, metadata| {
Ok(Arc::new(DFJson::try_new(
storage_type,
Json::deserialize_metadata(metadata)?,
)?))
}),
ExtensionTypeRegistration::new_arc(Uuid::NAME, |storage_type, metadata| {
Ok(Arc::new(DFUuid::try_new(
storage_type,
Uuid::deserialize_metadata(metadata)?,
)?))
}),
ExtensionTypeRegistration::new_arc(Opaque::NAME, |storage_type, metadata| {
Ok(Arc::new(DFOpaque::try_new(
storage_type,
Opaque::deserialize_metadata(metadata)?,
)?))
}),
ExtensionTypeRegistration::new_arc(Bool8::NAME, |storage_type, metadata| {
Ok(Arc::new(DFBool8::try_new(
storage_type,
Bool8::deserialize_metadata(metadata)?,
)?))
}),
ExtensionTypeRegistration::new_arc(
TimestampWithOffset::NAME,
|storage_type, metadata| {
Ok(Arc::new(DFTimestampWithOffset::try_new(
storage_type,
TimestampWithOffset::deserialize_metadata(metadata)?,
)?))
},
),
];
let mut extension_types = HashMap::new();
for registration in mapping.into_iter() {
extension_types.insert(registration.type_name().to_owned(), registration);
}
Self {
extension_types: Arc::new(RwLock::new(HashMap::from(extension_types))),
}
}
pub fn new_with_types(
types: impl IntoIterator<Item = ExtensionTypeRegistrationRef>,
) -> Result<Self> {
let extension_types = types
.into_iter()
.map(|t| (t.type_name().to_owned(), t))
.collect::<HashMap<_, _>>();
Ok(Self {
extension_types: Arc::new(RwLock::new(extension_types)),
})
}
pub fn all_extension_types(&self) -> Vec<ExtensionTypeRegistrationRef> {
self.extension_types
.read()
.expect("Extension type registry lock poisoned")
.values()
.cloned()
.collect()
}
}
impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry {
fn extension_type_registration(
&self,
name: &str,
) -> Result<ExtensionTypeRegistrationRef> {
self.extension_types
.write()
.expect("Extension type registry lock poisoned")
.get(name)
.ok_or_else(|| plan_datafusion_err!("Logical type not found."))
.cloned()
}
fn extension_type_registrations(&self) -> Vec<ExtensionTypeRegistrationRef> {
self.extension_types
.read()
.expect("Extension type registry lock poisoned")
.values()
.cloned()
.collect()
}
fn add_extension_type_registration(
&self,
extension_type: ExtensionTypeRegistrationRef,
) -> Result<Option<ExtensionTypeRegistrationRef>> {
Ok(self
.extension_types
.write()
.expect("Extension type registry lock poisoned")
.insert(extension_type.type_name().to_owned(), extension_type))
}
fn remove_extension_type_registration(
&self,
name: &str,
) -> Result<Option<ExtensionTypeRegistrationRef>> {
Ok(self
.extension_types
.write()
.expect("Extension type registry lock poisoned")
.remove(name))
}
}
impl From<HashMap<String, ExtensionTypeRegistrationRef>> for MemoryExtensionTypeRegistry {
fn from(value: HashMap<String, ExtensionTypeRegistrationRef>) -> Self {
Self {
extension_types: Arc::new(RwLock::new(value)),
}
}
}