#[allow(unused_imports)]
use polars::prelude::{DataType, PolarsError, Series};
use std::collections::HashMap;
use std::sync::Arc;
pub trait RustUdf: Send + Sync {
fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError>;
}
struct RustUdfWrapper<F>
where
F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
{
f: F,
}
impl<F> RustUdf for RustUdfWrapper<F>
where
F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
{
fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError> {
(self.f)(columns)
}
}
#[cfg(feature = "pyo3")]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PythonUdfKind {
Scalar,
Vectorized,
GroupedVectorizedAgg,
}
#[cfg(feature = "pyo3")]
pub struct PythonUdfEntry {
pub callable: pyo3::Py<pyo3::PyAny>,
pub return_type: DataType,
pub kind: PythonUdfKind,
}
#[derive(Clone)]
pub struct UdfRegistry {
rust_udfs: Arc<std::sync::RwLock<HashMap<String, Arc<dyn RustUdf>>>>,
#[cfg(feature = "pyo3")]
python_udfs: Arc<std::sync::RwLock<HashMap<String, Arc<PythonUdfEntry>>>>,
}
impl Default for UdfRegistry {
fn default() -> Self {
Self {
rust_udfs: Arc::new(std::sync::RwLock::new(HashMap::new())),
#[cfg(feature = "pyo3")]
python_udfs: Arc::new(std::sync::RwLock::new(HashMap::new())),
}
}
}
impl UdfRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_rust_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
where
F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
{
let wrapper = Arc::new(RustUdfWrapper { f });
self.rust_udfs
.write()
.map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
.insert(name.to_string(), wrapper);
Ok(())
}
pub fn get_rust_udf(&self, name: &str, case_sensitive: bool) -> Option<Arc<dyn RustUdf>> {
let guard = self.rust_udfs.read().ok()?;
if case_sensitive {
guard.get(name).cloned()
} else {
let name_lower = name.to_lowercase();
guard
.iter()
.find(|(k, _)| k.to_lowercase() == name_lower)
.map(|(_, v)| v.clone())
}
}
#[allow(dead_code)] pub fn has_udf(&self, name: &str, case_sensitive: bool) -> bool {
if self.get_rust_udf(name, case_sensitive).is_some() {
return true;
}
#[cfg(feature = "pyo3")]
{
self.get_python_udf(name, case_sensitive).is_some()
}
#[cfg(not(feature = "pyo3"))]
false
}
#[cfg(feature = "pyo3")]
pub fn register_python_udf(
&self,
name: &str,
callable: pyo3::Py<pyo3::PyAny>,
return_type: DataType,
) -> Result<(), PolarsError> {
self.register_python_udf_with_kind(name, callable, return_type, PythonUdfKind::Scalar)
}
#[cfg(feature = "pyo3")]
pub fn register_vectorized_python_udf(
&self,
name: &str,
callable: pyo3::Py<pyo3::PyAny>,
return_type: DataType,
) -> Result<(), PolarsError> {
self.register_python_udf_with_kind(name, callable, return_type, PythonUdfKind::Vectorized)
}
#[cfg(feature = "pyo3")]
pub fn register_grouped_vectorized_python_udf(
&self,
name: &str,
callable: pyo3::Py<pyo3::PyAny>,
return_type: DataType,
) -> Result<(), PolarsError> {
self.register_python_udf_with_kind(
name,
callable,
return_type,
PythonUdfKind::GroupedVectorizedAgg,
)
}
#[cfg(feature = "pyo3")]
fn register_python_udf_with_kind(
&self,
name: &str,
callable: pyo3::Py<pyo3::PyAny>,
return_type: DataType,
kind: PythonUdfKind,
) -> Result<(), PolarsError> {
let entry = PythonUdfEntry {
callable,
return_type,
kind,
};
self.python_udfs
.write()
.map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
.insert(name.to_string(), Arc::new(entry));
Ok(())
}
#[cfg(feature = "pyo3")]
pub fn get_python_udf(&self, name: &str, case_sensitive: bool) -> Option<Arc<PythonUdfEntry>> {
let guard = self.python_udfs.read().ok()?;
if case_sensitive {
guard.get(name).cloned()
} else {
let name_lower = name.to_lowercase();
guard
.iter()
.find(|(k, _)| k.to_lowercase() == name_lower)
.map(|(_, v)| v.clone())
}
}
}