polars_io/parquet/write/
key_value_metadata.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::Arc;
4
5use polars_error::PolarsResult;
6use polars_parquet::write::KeyValue;
7#[cfg(feature = "python")]
8use polars_utils::python_function::PythonObject;
9#[cfg(feature = "python")]
10use pyo3::PyObject;
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize, de, ser};
13
14/// Context that can be used to construct custom file-level key value metadata for a Parquet file.
15pub struct ParquetMetadataContext<'a> {
16    pub arrow_schema: &'a str,
17}
18
19/// Key/value pairs that can be attached to a Parquet file as file-level metadtaa.
20#[derive(Clone, Debug, PartialEq, Eq, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22pub enum KeyValueMetadata {
23    /// Static key value metadata.
24    Static(
25        #[cfg_attr(
26            feature = "serde",
27            serde(
28                serialize_with = "serialize_vec_key_value",
29                deserialize_with = "deserialize_vec_key_value"
30            )
31        )]
32        Vec<KeyValue>,
33    ),
34    /// Rust function to dynamically compute key value metadata.
35    DynamicRust(RustKeyValueMetadataFunction),
36    /// Python function to dynamically compute key value metadata.
37    #[cfg(feature = "python")]
38    DynamicPython(python_impl::PythonKeyValueMetadataFunction),
39}
40
41#[cfg(feature = "serde")]
42fn serialize_vec_key_value<S>(kv: &[KeyValue], serializer: S) -> Result<S::Ok, S::Error>
43where
44    S: ser::Serializer,
45{
46    kv.iter()
47        .map(|item| (&item.key, item.value.as_ref()))
48        .collect::<Vec<_>>()
49        .serialize(serializer)
50}
51
52#[cfg(feature = "serde")]
53fn deserialize_vec_key_value<'de, D>(deserializer: D) -> Result<Vec<KeyValue>, D::Error>
54where
55    D: de::Deserializer<'de>,
56{
57    let data = Vec::<(String, Option<String>)>::deserialize(deserializer)?;
58    let result = data
59        .into_iter()
60        .map(|(key, value)| KeyValue { key, value })
61        .collect::<Vec<_>>();
62    Ok(result)
63}
64
65impl KeyValueMetadata {
66    /// Create a key value metadata object from a static key value mapping.
67    pub fn from_static(kv: Vec<(String, String)>) -> Self {
68        Self::Static(
69            kv.into_iter()
70                .map(|(key, value)| KeyValue {
71                    key,
72                    value: Some(value),
73                })
74                .collect(),
75        )
76    }
77
78    /// Create a key value metadata object from a Python function.
79    #[cfg(feature = "python")]
80    pub fn from_py_function(py_object: PyObject) -> Self {
81        Self::DynamicPython(python_impl::PythonKeyValueMetadataFunction(Arc::new(
82            PythonObject(py_object),
83        )))
84    }
85
86    /// Turn the metadata into the key/value pairs to write to the Parquet file.
87    /// The context is used to dynamically construct key/value pairs.
88    pub fn collect(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
89        match self {
90            Self::Static(kv) => Ok(kv.clone()),
91            Self::DynamicRust(func) => Ok(func.0(ctx)),
92            #[cfg(feature = "python")]
93            Self::DynamicPython(py_func) => py_func.call(ctx),
94        }
95    }
96}
97
98#[derive(Clone)]
99pub struct RustKeyValueMetadataFunction(
100    Arc<dyn Fn(ParquetMetadataContext) -> Vec<KeyValue> + Send + Sync>,
101);
102
103impl Debug for RustKeyValueMetadataFunction {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        write!(
106            f,
107            "key value metadata function at 0x{:016x}",
108            self.0.as_ref() as *const _ as *const () as usize
109        )
110    }
111}
112
113impl Eq for RustKeyValueMetadataFunction {}
114
115impl PartialEq for RustKeyValueMetadataFunction {
116    fn eq(&self, other: &Self) -> bool {
117        Arc::ptr_eq(&self.0, &other.0)
118    }
119}
120
121impl Hash for RustKeyValueMetadataFunction {
122    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
123        state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
124    }
125}
126
127#[cfg(feature = "serde")]
128impl Serialize for RustKeyValueMetadataFunction {
129    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
130    where
131        S: serde::Serializer,
132    {
133        use serde::ser::Error;
134        Err(S::Error::custom(format!("cannot serialize {:?}", self)))
135    }
136}
137
138#[cfg(feature = "serde")]
139impl<'de> Deserialize<'de> for RustKeyValueMetadataFunction {
140    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
141    where
142        D: serde::Deserializer<'de>,
143    {
144        use serde::de::Error;
145        Err(D::Error::custom(
146            "cannot deserialize RustKeyValueMetadataFn",
147        ))
148    }
149}
150
151#[cfg(feature = "python")]
152mod python_impl {
153    use std::hash::Hash;
154    use std::sync::Arc;
155
156    use polars_error::{PolarsResult, to_compute_err};
157    use polars_parquet::write::KeyValue;
158    use polars_utils::python_function::PythonObject;
159    use pyo3::types::PyAnyMethods;
160    use pyo3::{PyResult, Python, pyclass};
161    use serde::{Deserialize, Serialize};
162
163    use super::ParquetMetadataContext;
164
165    #[derive(Clone, Debug, PartialEq, Eq)]
166    #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
167    pub struct PythonKeyValueMetadataFunction(
168        #[cfg(feature = "python")]
169        #[cfg_attr(
170            feature = "serde",
171            serde(
172                serialize_with = "PythonObject::serialize_with_pyversion",
173                deserialize_with = "PythonObject::deserialize_with_pyversion"
174            )
175        )]
176        pub Arc<polars_utils::python_function::PythonFunction>,
177    );
178
179    impl PythonKeyValueMetadataFunction {
180        pub fn call(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
181            let ctx = PythonParquetMetadataContext::from_key_value_metadata_context(ctx);
182            Python::with_gil(|py| {
183                let args = (ctx,);
184                let out: Vec<(String, String)> =
185                    self.0.call1(py, args)?.into_bound(py).extract()?;
186                let result = out
187                    .into_iter()
188                    .map(|item| KeyValue {
189                        key: item.0,
190                        value: Some(item.1),
191                    })
192                    .collect::<Vec<_>>();
193                PyResult::Ok(result)
194            })
195            .map_err(to_compute_err)
196        }
197    }
198
199    impl Hash for PythonKeyValueMetadataFunction {
200        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
201            state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
202        }
203    }
204
205    #[pyclass]
206    pub struct PythonParquetMetadataContext {
207        #[pyo3(get)]
208        arrow_schema: String,
209    }
210
211    impl PythonParquetMetadataContext {
212        pub fn from_key_value_metadata_context(ctx: ParquetMetadataContext) -> Self {
213            Self {
214                arrow_schema: ctx.arrow_schema.to_string(),
215            }
216        }
217    }
218}