robin_sparkless_polars/
udf_registry.rs1#[allow(unused_imports)]
5use polars::prelude::{DataType, PolarsError, Series};
6use std::collections::HashMap;
7use std::sync::Arc;
8
9pub trait RustUdf: Send + Sync {
11 fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError>;
12}
13
14struct RustUdfWrapper<F>
16where
17 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
18{
19 f: F,
20}
21
22impl<F> RustUdf for RustUdfWrapper<F>
23where
24 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync,
25{
26 fn apply(&self, columns: &[Series]) -> Result<Series, PolarsError> {
27 (self.f)(columns)
28 }
29}
30
31#[derive(Clone)]
33pub struct UdfRegistry {
34 rust_udfs: Arc<std::sync::RwLock<HashMap<String, Arc<dyn RustUdf>>>>,
35}
36
37impl Default for UdfRegistry {
38 fn default() -> Self {
39 Self {
40 rust_udfs: Arc::new(std::sync::RwLock::new(HashMap::new())),
41 }
42 }
43}
44
45impl UdfRegistry {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn register_rust_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
52 where
53 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
54 {
55 let wrapper = Arc::new(RustUdfWrapper { f });
56 self.rust_udfs
57 .write()
58 .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
59 .insert(name.to_string(), wrapper);
60 Ok(())
61 }
62
63 pub fn get_rust_udf(
66 &self,
67 name: &str,
68 case_sensitive: bool,
69 ) -> Result<Option<Arc<dyn RustUdf>>, PolarsError> {
70 let guard = self
71 .rust_udfs
72 .read()
73 .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?;
74 Ok(if case_sensitive {
75 guard.get(name).cloned()
76 } else {
77 let name_lower = name.to_lowercase();
78 guard
79 .iter()
80 .find(|(k, _)| k.to_lowercase() == name_lower)
81 .map(|(_, v)| v.clone())
82 })
83 }
84
85 #[allow(dead_code)] pub fn has_udf(&self, name: &str, case_sensitive: bool) -> Result<bool, PolarsError> {
88 self.get_rust_udf(name, case_sensitive).map(|o| o.is_some())
89 }
90
91 pub fn clear(&self) -> Result<(), PolarsError> {
93 self.rust_udfs
94 .write()
95 .map_err(|_| PolarsError::ComputeError("udf registry lock poisoned".into()))?
96 .clear();
97 Ok(())
98 }
99}