polars_python/
extension.rs1use std::any::Any;
2use std::borrow::Cow;
3use std::hash::{BuildHasher, Hash, Hasher};
4use std::sync::Arc;
5
6use polars::prelude::PlFixedStateQuality;
7use polars::prelude::extension::{register_extension_type, unregister_extension_type};
8use polars_core::datatypes::DataType;
9use polars_core::datatypes::extension::{ExtensionTypeFactory, ExtensionTypeImpl};
10use pyo3::prelude::*;
11
12use crate::prelude::Wrap;
13use crate::utils::to_py_err;
14
15struct PyExtensionTypeFactory {
16 cls: Arc<Py<PyAny>>,
17}
18
19#[derive(Clone)]
20struct PyExtensionTypeImpl {
21 name: String,
22 display: String,
23 metadata: Option<String>,
24}
25
26impl ExtensionTypeFactory for PyExtensionTypeFactory {
27 fn create_type_instance(
28 &self,
29 name: &str,
30 storage: &DataType,
31 metadata: Option<&str>,
32 ) -> Box<dyn ExtensionTypeImpl> {
33 Python::attach(|py| {
34 let typ_obj = self
35 .cls
36 .bind(py)
37 .call_method1("ext_from_params", (name, &Wrap(storage.clone()), metadata))
38 .unwrap();
39
40 let display = typ_obj
41 .call_method0("_string_repr")
42 .unwrap()
43 .extract()
44 .unwrap();
45 let metadata = typ_obj
46 .call_method0("ext_metadata")
47 .unwrap()
48 .extract()
49 .unwrap();
50
51 Box::new(PyExtensionTypeImpl {
52 name: name.to_string(),
53 display,
54 metadata,
55 })
56 })
57 }
58}
59
60impl ExtensionTypeImpl for PyExtensionTypeImpl {
61 fn name(&self) -> Cow<'_, str> {
62 Cow::Borrowed(&self.name)
63 }
64
65 fn serialize_metadata(&self) -> Option<Cow<'_, str>> {
66 self.metadata.as_deref().map(Cow::Borrowed)
67 }
68
69 fn dyn_clone(&self) -> Box<dyn ExtensionTypeImpl> {
70 Box::new(self.clone())
71 }
72
73 fn dyn_eq(&self, other: &dyn ExtensionTypeImpl) -> bool {
74 let Some(other) = (other as &dyn Any).downcast_ref::<PyExtensionTypeImpl>() else {
75 return false;
76 };
77
78 self.name == other.name && self.metadata == other.metadata
79 }
80
81 fn dyn_hash(&self) -> u64 {
82 let mut hasher = PlFixedStateQuality::default().build_hasher();
83 self.name.hash(&mut hasher);
84 self.metadata.hash(&mut hasher);
85 hasher.finish()
86 }
87
88 fn dyn_display(&self) -> Cow<'_, str> {
89 Cow::Borrowed(&self.display)
90 }
91
92 fn dyn_debug(&self) -> Cow<'_, str> {
93 if let Some(md) = &self.metadata {
94 Cow::Owned(format!(
95 "PyExtensionType(name='{}', metadata='{}')",
96 self.name, md
97 ))
98 } else {
99 Cow::Owned(format!("PyExtensionType(name='{}')", self.name))
100 }
101 }
102}
103
104#[pyfunction]
105pub fn _register_extension_type(name: &str, cls: Option<&Bound<PyAny>>) -> PyResult<()> {
106 register_extension_type(
107 name,
108 cls.map(|c| {
109 Arc::new(PyExtensionTypeFactory {
110 cls: Arc::new(c.clone().unbind()),
111 }) as _
112 }),
113 )
114 .map_err(to_py_err)
115}
116
117#[pyfunction]
118pub fn _unregister_extension_type(name: &str) -> PyResult<()> {
119 unregister_extension_type(name).map(drop).map_err(to_py_err)
120}