Skip to main content

robin_sparkless_polars/
udf_registry.rs

1//! UDF registry: session-scoped storage for Rust UDFs.
2//! PySpark parity: register_udf; call_udf resolves by name.
3
4#[allow(unused_imports)]
5use polars::prelude::{DataType, PolarsError, Series};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Rust UDF: takes columns as Series, returns one Series. Used via Expr::map / map_many.
10pub trait RustUdf: Send + Sync {
11    fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError>;
12}
13
14/// Type-erased wrapper for Rust UDF closures.
15struct RustUdfWrapper<F>
16where
17    F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
18{
19    f: F,
20}
21
22impl<F> RustUdf for RustUdfWrapper<F>
23where
24    F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
25{
26    fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError> {
27        (self.f)(columns)
28    }
29}
30
31/// Session-scoped UDF registry. Rust UDFs run lazily via Polars Expr::map.
32#[derive(Clone)]
33pub struct UdfRegistry {
34    rust_udfs: Arc<std::sync::RwLock<HashMap<String, Arc<dyn RustUdf>>>>,
35}
36
37impl Default for UdfRegistry {
38    fn default() -> Self {
39        Self {
40            rust_udfs: Arc::new(std::sync::RwLock::new(HashMap::new())),
41        }
42    }
43}
44
45impl UdfRegistry {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Register a Rust UDF. Runs lazily when used in DataFrame operations.
51    pub fn register_rust_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
52    where
53        F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
54    {
55        let wrapper = Arc::new(RustUdfWrapper { f });
56        self.rust_udfs
57            .write()
58            .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
59            .insert(name.to_string(), wrapper);
60        Ok(())
61    }
62
63    /// Look up a Rust UDF by name. Case sensitivity follows session config.
64    /// Returns `Err` if the registry lock is poisoned (e.g. a thread panicked while holding it).
65    pub fn get_rust_udf(
66        &self,
67        name: &str,
68        case_sensitive: bool,
69    ) -> Result<Option<Arc<dyn RustUdf>>, PolarsError> {
70        let guard = self
71            .rust_udfs
72            .read()
73            .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?;
74        Ok(if case_sensitive {
75            guard.get(name).cloned()
76        } else {
77            let name_lower = name.to_lowercase();
78            guard
79                .iter()
80                .find(|(k, _)| k.to_lowercase() == name_lower)
81                .map(|(_, v)| v.clone())
82        })
83    }
84
85    /// Check if a Rust UDF exists. Returns `Err` if the registry lock is poisoned.
86    #[allow(dead_code)] // used by SQL translator
87    pub fn has_udf(&self, name: &str, case_sensitive: bool) -> Result<bool, PolarsError> {
88        self.get_rust_udf(name, case_sensitive).map(|o| o.is_some())
89    }
90
91    /// Clear all registered UDFs (used by SparkSession.stop()).
92    pub fn clear(&self) -> Result<(), PolarsError> {
93        self.rust_udfs
94            .write()
95            .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
96            .clear();
97        Ok(())
98    }
99}