Skip to main content

robin_sparkless_expr/
udf_registry.rs

1//! UDF registry: session-scoped storage for Rust UDFs.
2//! PySpark parity: register_udf; call_udf resolves by name.
3
4#[allow(unused_imports)]
5use polars::prelude::{DataType, PolarsError, Series};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Rust UDF: takes columns as Series, returns one Series. Used via Expr::map / map_many.
10pub trait RustUdf: Send + Sync {
11    fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError>;
12}
13
14/// Type-erased wrapper for Rust UDF closures.
15struct RustUdfWrapper<F>
16where
17    F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
18{
19    f: F,
20}
21
22impl<F> RustUdf for RustUdfWrapper<F>
23where
24    F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
25{
26    fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError> {
27        (self.f)(columns)
28    }
29}
30
31/// Session-scoped UDF registry. Rust UDFs run lazily via Polars Expr::map.
32#[derive(Clone)]
33pub struct UdfRegistry {
34    rust_udfs: Arc<std::sync::RwLock<HashMap<String, Arc<dyn RustUdf>>>>,
35}
36
37impl Default for UdfRegistry {
38    fn default() -> Self {
39        Self {
40            rust_udfs: Arc::new(std::sync::RwLock::new(HashMap::new())),
41        }
42    }
43}
44
45impl UdfRegistry {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Register a Rust UDF. Runs lazily when used in DataFrame operations.
51    pub fn register_rust_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
52    where
53        F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
54    {
55        let wrapper = Arc::new(RustUdfWrapper { f });
56        self.rust_udfs
57            .write()
58            .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
59            .insert(name.to_string(), wrapper);
60        Ok(())
61    }
62
63    /// Look up a Rust UDF by name. Case sensitivity follows session config.
64    pub fn get_rust_udf(&self, name: &str, case_sensitive: bool) -> Option<Arc<dyn RustUdf>> {
65        let guard = self.rust_udfs.read().ok()?;
66        if case_sensitive {
67            guard.get(name).cloned()
68        } else {
69            let name_lower = name.to_lowercase();
70            guard
71                .iter()
72                .find(|(k, _)| k.to_lowercase() == name_lower)
73                .map(|(_, v)| v.clone())
74        }
75    }
76
77    /// Check if a Rust UDF exists.
78    #[allow(dead_code)] // used by SQL translator
79    pub fn has_udf(&self, name: &str, case_sensitive: bool) -> bool {
80        if self.get_rust_udf(name, case_sensitive).is_some() {
81            return true;
82        }
83        false
84    }
85
86    /// Clear all registered UDFs (used by SparkSession.stop()).
87    pub fn clear(&self) -> Result<(), PolarsError> {
88        self.rust_udfs
89            .write()
90            .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
91            .clear();
92        Ok(())
93    }
94}