polars_io/parquet/write/
key_value_metadata.rs1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::Arc;
4
5use polars_error::PolarsResult;
6use polars_parquet::write::KeyValue;
7#[cfg(feature = "python")]
8use polars_utils::python_function::PythonObject;
9#[cfg(feature = "python")]
10use pyo3::PyObject;
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize, de, ser};
13
14pub struct ParquetMetadataContext<'a> {
16 pub arrow_schema: &'a str,
17}
18
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22pub enum KeyValueMetadata {
23 Static(
25 #[cfg_attr(
26 feature = "serde",
27 serde(
28 serialize_with = "serialize_vec_key_value",
29 deserialize_with = "deserialize_vec_key_value"
30 )
31 )]
32 Vec<KeyValue>,
33 ),
34 DynamicRust(RustKeyValueMetadataFunction),
36 #[cfg(feature = "python")]
38 DynamicPython(python_impl::PythonKeyValueMetadataFunction),
39}
40
41#[cfg(feature = "serde")]
42fn serialize_vec_key_value<S>(kv: &[KeyValue], serializer: S) -> Result<S::Ok, S::Error>
43where
44 S: ser::Serializer,
45{
46 kv.iter()
47 .map(|item| (&item.key, item.value.as_ref()))
48 .collect::<Vec<_>>()
49 .serialize(serializer)
50}
51
52#[cfg(feature = "serde")]
53fn deserialize_vec_key_value<'de, D>(deserializer: D) -> Result<Vec<KeyValue>, D::Error>
54where
55 D: de::Deserializer<'de>,
56{
57 let data = Vec::<(String, Option<String>)>::deserialize(deserializer)?;
58 let result = data
59 .into_iter()
60 .map(|(key, value)| KeyValue { key, value })
61 .collect::<Vec<_>>();
62 Ok(result)
63}
64
65impl KeyValueMetadata {
66 pub fn from_static(kv: Vec<(String, String)>) -> Self {
68 Self::Static(
69 kv.into_iter()
70 .map(|(key, value)| KeyValue {
71 key,
72 value: Some(value),
73 })
74 .collect(),
75 )
76 }
77
78 #[cfg(feature = "python")]
80 pub fn from_py_function(py_object: PyObject) -> Self {
81 Self::DynamicPython(python_impl::PythonKeyValueMetadataFunction(Arc::new(
82 PythonObject(py_object),
83 )))
84 }
85
86 pub fn collect(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
89 match self {
90 Self::Static(kv) => Ok(kv.clone()),
91 Self::DynamicRust(func) => Ok(func.0(ctx)),
92 #[cfg(feature = "python")]
93 Self::DynamicPython(py_func) => py_func.call(ctx),
94 }
95 }
96}
97
98#[derive(Clone)]
99pub struct RustKeyValueMetadataFunction(
100 Arc<dyn Fn(ParquetMetadataContext) -> Vec<KeyValue> + Send + Sync>,
101);
102
103impl Debug for RustKeyValueMetadataFunction {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(
106 f,
107 "key value metadata function at 0x{:016x}",
108 self.0.as_ref() as *const _ as *const () as usize
109 )
110 }
111}
112
113impl Eq for RustKeyValueMetadataFunction {}
114
115impl PartialEq for RustKeyValueMetadataFunction {
116 fn eq(&self, other: &Self) -> bool {
117 Arc::ptr_eq(&self.0, &other.0)
118 }
119}
120
121impl Hash for RustKeyValueMetadataFunction {
122 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123 state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
124 }
125}
126
127#[cfg(feature = "serde")]
128impl Serialize for RustKeyValueMetadataFunction {
129 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
130 where
131 S: serde::Serializer,
132 {
133 use serde::ser::Error;
134 Err(S::Error::custom(format!("cannot serialize {:?}", self)))
135 }
136}
137
138#[cfg(feature = "serde")]
139impl<'de> Deserialize<'de> for RustKeyValueMetadataFunction {
140 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
141 where
142 D: serde::Deserializer<'de>,
143 {
144 use serde::de::Error;
145 Err(D::Error::custom(
146 "cannot deserialize RustKeyValueMetadataFn",
147 ))
148 }
149}
150
151#[cfg(feature = "python")]
152mod python_impl {
153 use std::hash::Hash;
154 use std::sync::Arc;
155
156 use polars_error::{PolarsResult, to_compute_err};
157 use polars_parquet::write::KeyValue;
158 use polars_utils::python_function::PythonObject;
159 use pyo3::types::PyAnyMethods;
160 use pyo3::{PyResult, Python, pyclass};
161 use serde::{Deserialize, Serialize};
162
163 use super::ParquetMetadataContext;
164
165 #[derive(Clone, Debug, PartialEq, Eq)]
166 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
167 pub struct PythonKeyValueMetadataFunction(
168 #[cfg(feature = "python")]
169 #[cfg_attr(
170 feature = "serde",
171 serde(
172 serialize_with = "PythonObject::serialize_with_pyversion",
173 deserialize_with = "PythonObject::deserialize_with_pyversion"
174 )
175 )]
176 pub Arc<polars_utils::python_function::PythonFunction>,
177 );
178
179 impl PythonKeyValueMetadataFunction {
180 pub fn call(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
181 let ctx = PythonParquetMetadataContext::from_key_value_metadata_context(ctx);
182 Python::with_gil(|py| {
183 let args = (ctx,);
184 let out: Vec<(String, String)> =
185 self.0.call1(py, args)?.into_bound(py).extract()?;
186 let result = out
187 .into_iter()
188 .map(|item| KeyValue {
189 key: item.0,
190 value: Some(item.1),
191 })
192 .collect::<Vec<_>>();
193 PyResult::Ok(result)
194 })
195 .map_err(to_compute_err)
196 }
197 }
198
199 impl Hash for PythonKeyValueMetadataFunction {
200 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
201 state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
202 }
203 }
204
205 #[pyclass]
206 pub struct PythonParquetMetadataContext {
207 #[pyo3(get)]
208 arrow_schema: String,
209 }
210
211 impl PythonParquetMetadataContext {
212 pub fn from_key_value_metadata_context(ctx: ParquetMetadataContext) -> Self {
213 Self {
214 arrow_schema: ctx.arrow_schema.to_string(),
215 }
216 }
217 }
218}