Skip to main content

polars_python/
extension.rs

1use std::any::Any;
2use std::borrow::Cow;
3use std::hash::{BuildHasher, Hash, Hasher};
4use std::sync::Arc;
5
6use polars::prelude::PlFixedStateQuality;
7use polars::prelude::extension::{register_extension_type, unregister_extension_type};
8use polars_core::datatypes::DataType;
9use polars_core::datatypes::extension::{ExtensionTypeFactory, ExtensionTypeImpl};
10use pyo3::prelude::*;
11
12use crate::prelude::Wrap;
13use crate::utils::to_py_err;
14
15struct PyExtensionTypeFactory {
16    cls: Arc<Py<PyAny>>,
17}
18
19#[derive(Clone)]
20struct PyExtensionTypeImpl {
21    name: String,
22    display: String,
23    metadata: Option<String>,
24}
25
26impl ExtensionTypeFactory for PyExtensionTypeFactory {
27    fn create_type_instance(
28        &self,
29        name: &str,
30        storage: &DataType,
31        metadata: Option<&str>,
32    ) -> Box<dyn ExtensionTypeImpl> {
33        Python::attach(|py| {
34            let typ_obj = self
35                .cls
36                .bind(py)
37                .call_method1("ext_from_params", (name, &Wrap(storage.clone()), metadata))
38                .unwrap();
39
40            let display = typ_obj
41                .call_method0("_string_repr")
42                .unwrap()
43                .extract()
44                .unwrap();
45            let metadata = typ_obj
46                .call_method0("ext_metadata")
47                .unwrap()
48                .extract()
49                .unwrap();
50
51            Box::new(PyExtensionTypeImpl {
52                name: name.to_string(),
53                display,
54                metadata,
55            })
56        })
57    }
58}
59
60impl ExtensionTypeImpl for PyExtensionTypeImpl {
61    fn name(&self) -> Cow<'_, str> {
62        Cow::Borrowed(&self.name)
63    }
64
65    fn serialize_metadata(&self) -> Option<Cow<'_, str>> {
66        self.metadata.as_deref().map(Cow::Borrowed)
67    }
68
69    fn dyn_clone(&self) -> Box<dyn ExtensionTypeImpl> {
70        Box::new(self.clone())
71    }
72
73    fn dyn_eq(&self, other: &dyn ExtensionTypeImpl) -> bool {
74        let Some(other) = (other as &dyn Any).downcast_ref::<PyExtensionTypeImpl>() else {
75            return false;
76        };
77
78        self.name == other.name && self.metadata == other.metadata
79    }
80
81    fn dyn_hash(&self) -> u64 {
82        let mut hasher = PlFixedStateQuality::default().build_hasher();
83        self.name.hash(&mut hasher);
84        self.metadata.hash(&mut hasher);
85        hasher.finish()
86    }
87
88    fn dyn_display(&self) -> Cow<'_, str> {
89        Cow::Borrowed(&self.display)
90    }
91
92    fn dyn_debug(&self) -> Cow<'_, str> {
93        if let Some(md) = &self.metadata {
94            Cow::Owned(format!(
95                "PyExtensionType(name='{}', metadata='{}')",
96                self.name, md
97            ))
98        } else {
99            Cow::Owned(format!("PyExtensionType(name='{}')", self.name))
100        }
101    }
102}
103
104#[pyfunction]
105pub fn _register_extension_type(name: &str, cls: Option<&Bound<PyAny>>) -> PyResult<()> {
106    register_extension_type(
107        name,
108        cls.map(|c| {
109            Arc::new(PyExtensionTypeFactory {
110                cls: Arc::new(c.clone().unbind()),
111            }) as _
112        }),
113    )
114    .map_err(to_py_err)
115}
116
117#[pyfunction]
118pub fn _unregister_extension_type(name: &str) -> PyResult<()> {
119    unregister_extension_type(name).map(drop).map_err(to_py_err)
120}