robin_sparkless_expr/
udf_registry.rs1#[allow(unused_imports)]
5use polars::prelude::{DataType, PolarsError, Series};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub trait RustUdf: Send + Sync {
11 fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError>;
12}
13
14struct RustUdfWrapper<F>
16where
17 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
18{
19 f: F,
20}
21
22impl<F> RustUdf for RustUdfWrapper<F>
23where
24 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
25{
26 fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError> {
27 (self.f)(columns)
28 }
29}
30
31#[derive(Clone)]
33pub struct UdfRegistry {
34 rust_udfs: Arc<std::sync::RwLock<HashMap<String, Arc<dyn RustUdf>>>>,
35}
36
37impl Default for UdfRegistry {
38 fn default() -> Self {
39 Self {
40 rust_udfs: Arc::new(std::sync::RwLock::new(HashMap::new())),
41 }
42 }
43}
44
45impl UdfRegistry {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn register_rust_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
52 where
53 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
54 {
55 let wrapper = Arc::new(RustUdfWrapper { f });
56 self.rust_udfs
57 .write()
58 .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
59 .insert(name.to_string(), wrapper);
60 Ok(())
61 }
62
63 pub fn get_rust_udf(&self, name: &str, case_sensitive: bool) -> Option<Arc<dyn RustUdf>> {
65 let guard = self.rust_udfs.read().ok()?;
66 if case_sensitive {
67 guard.get(name).cloned()
68 } else {
69 let name_lower = name.to_lowercase();
70 guard
71 .iter()
72 .find(|(k, _)| k.to_lowercase() == name_lower)
73 .map(|(_, v)| v.clone())
74 }
75 }
76
77 #[allow(dead_code)] pub fn has_udf(&self, name: &str, case_sensitive: bool) -> bool {
80 if self.get_rust_udf(name, case_sensitive).is_some() {
81 return true;
82 }
83 false
84 }
85
86 pub fn clear(&self) -> Result<(), PolarsError> {
88 self.rust_udfs
89 .write()
90 .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
91 .clear();
92 Ok(())
93 }
94}