polars_python/lazyframe/
sink.rs

1use std::sync::{Arc, Mutex};
2
3use polars::prelude::sync_on_close::SyncOnCloseType;
4use polars::prelude::{
5    PartitionTargetCallbackResult, PartitionVariant, PlPath, SinkFinishCallback, SinkOptions,
6    SortColumn, SpecialEq,
7};
8use polars_utils::IdxSize;
9use polars_utils::plpath::PlPathRef;
10use polars_utils::python_function::{PythonFunction, PythonObject};
11use pyo3::exceptions::PyValueError;
12use pyo3::pybacked::PyBackedStr;
13use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
14use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python, pyclass, pymethods};
15
16use crate::expr::PyExpr;
17use crate::prelude::Wrap;
18
19#[derive(Clone)]
20pub enum SinkTarget {
21    File(polars_plan::dsl::SinkTarget),
22    Partition(PyPartitioning),
23}
24
25#[pyclass]
26#[derive(Clone)]
27pub struct PyPartitioning {
28    #[pyo3(get)]
29    pub base_path: Wrap<PlPath>,
30    pub file_path_cb: Option<PythonFunction>,
31    pub variant: PartitionVariant,
32    pub per_partition_sort_by: Option<Vec<SortColumn>>,
33    pub finish_callback: Option<SinkFinishCallback>,
34}
35
36fn parse_per_partition_sort_by(sort_by: Option<Vec<PyExpr>>) -> Option<Vec<SortColumn>> {
37    sort_by.map(|exprs| {
38        exprs
39            .into_iter()
40            .map(|e| SortColumn {
41                expr: e.inner,
42                descending: false,
43                nulls_last: false,
44            })
45            .collect()
46    })
47}
48
49#[cfg(feature = "pymethods")]
50#[pymethods]
51impl PyPartitioning {
52    #[staticmethod]
53    #[pyo3(signature = (base_path, file_path_cb, max_size, per_partition_sort_by, finish_callback))]
54    pub fn new_max_size(
55        base_path: Wrap<PlPath>,
56        file_path_cb: Option<PyObject>,
57        max_size: IdxSize,
58        per_partition_sort_by: Option<Vec<PyExpr>>,
59        finish_callback: Option<PyObject>,
60    ) -> PyPartitioning {
61        let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
62        let finish_callback =
63            finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
64
65        PyPartitioning {
66            base_path,
67            file_path_cb,
68            variant: PartitionVariant::MaxSize(max_size),
69            per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
70            finish_callback,
71        }
72    }
73
74    #[staticmethod]
75    #[pyo3(signature = (base_path, file_path_cb, by, include_key, per_partition_sort_by, finish_callback))]
76    pub fn new_by_key(
77        base_path: Wrap<PlPath>,
78        file_path_cb: Option<PyObject>,
79        by: Vec<PyExpr>,
80        include_key: bool,
81        per_partition_sort_by: Option<Vec<PyExpr>>,
82        finish_callback: Option<PyObject>,
83    ) -> PyPartitioning {
84        let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
85        let finish_callback =
86            finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
87
88        PyPartitioning {
89            base_path,
90            file_path_cb,
91            variant: PartitionVariant::ByKey {
92                key_exprs: by.into_iter().map(|e| e.inner).collect(),
93                include_key,
94            },
95            per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
96            finish_callback,
97        }
98    }
99
100    #[staticmethod]
101    #[pyo3(signature = (base_path, file_path_cb, by, include_key, per_partition_sort_by, finish_callback))]
102    pub fn new_parted(
103        base_path: Wrap<PlPath>,
104        file_path_cb: Option<PyObject>,
105        by: Vec<PyExpr>,
106        include_key: bool,
107        per_partition_sort_by: Option<Vec<PyExpr>>,
108        finish_callback: Option<PyObject>,
109    ) -> PyPartitioning {
110        let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
111        let finish_callback =
112            finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
113
114        PyPartitioning {
115            base_path,
116            file_path_cb,
117            variant: PartitionVariant::Parted {
118                key_exprs: by.into_iter().map(|e| e.inner).collect(),
119                include_key,
120            },
121            per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
122            finish_callback,
123        }
124    }
125}
126
127impl<'py> FromPyObject<'py> for Wrap<polars_plan::dsl::SinkTarget> {
128    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
129        if let Ok(v) = ob.extract::<PyBackedStr>() {
130            Ok(Wrap(polars::prelude::SinkTarget::Path(PlPath::new(&v))))
131        } else {
132            let writer = Python::with_gil(|py| {
133                let py_f = ob.clone();
134                PyResult::Ok(
135                    crate::file::try_get_pyfile(py, py_f, true)?
136                        .0
137                        .into_writeable(),
138                )
139            })?;
140
141            Ok(Wrap(polars_plan::prelude::SinkTarget::Dyn(SpecialEq::new(
142                Arc::new(Mutex::new(Some(writer))),
143            ))))
144        }
145    }
146}
147
148impl<'py> FromPyObject<'py> for Wrap<PartitionTargetCallbackResult> {
149    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
150        if let Ok(v) = ob.extract::<PyBackedStr>() {
151            Ok(Wrap(polars::prelude::PartitionTargetCallbackResult::Str(
152                v.to_string(),
153            )))
154        } else if let Ok(v) = ob.extract::<std::path::PathBuf>() {
155            Ok(Wrap(polars::prelude::PartitionTargetCallbackResult::Str(
156                v.to_str().unwrap().to_string(),
157            )))
158        } else {
159            let writer = Python::with_gil(|py| {
160                let py_f = ob.clone();
161                PyResult::Ok(
162                    crate::file::try_get_pyfile(py, py_f, true)?
163                        .0
164                        .into_writeable(),
165                )
166            })?;
167
168            Ok(Wrap(
169                polars_plan::prelude::PartitionTargetCallbackResult::Dyn(SpecialEq::new(Arc::new(
170                    Mutex::new(Some(writer)),
171                ))),
172            ))
173        }
174    }
175}
176
177impl<'py> FromPyObject<'py> for SinkTarget {
178    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
179        if let Ok(v) = ob.extract::<PyPartitioning>() {
180            Ok(Self::Partition(v))
181        } else {
182            Ok(Self::File(
183                <Wrap<polars_plan::dsl::SinkTarget>>::extract_bound(ob)?.0,
184            ))
185        }
186    }
187}
188
189impl SinkTarget {
190    pub fn base_path(&self) -> Option<PlPathRef<'_>> {
191        match self {
192            Self::File(t) => match t {
193                polars::prelude::SinkTarget::Path(p) => Some(p.as_ref()),
194                polars::prelude::SinkTarget::Dyn(_) => None,
195            },
196            Self::Partition(p) => Some(p.base_path.0.as_ref()),
197        }
198    }
199}
200
201impl<'py> FromPyObject<'py> for Wrap<SyncOnCloseType> {
202    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
203        let parsed = match &*ob.extract::<PyBackedStr>()? {
204            "none" => SyncOnCloseType::None,
205            "data" => SyncOnCloseType::Data,
206            "all" => SyncOnCloseType::All,
207            v => {
208                return Err(PyValueError::new_err(format!(
209                    "`sync_on_close` must be one of {{'none', 'data', 'all'}}, got {v}",
210                )));
211            },
212        };
213        Ok(Wrap(parsed))
214    }
215}
216
217impl<'py> FromPyObject<'py> for Wrap<SinkOptions> {
218    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
219        let parsed = ob.extract::<pyo3::Bound<'_, PyDict>>()?;
220
221        if parsed.len() != 3 {
222            return Err(PyValueError::new_err(
223                "`sink_options` must be a dictionary with the exactly 3 field.",
224            ));
225        }
226
227        let sync_on_close =
228            PyDictMethods::get_item(&parsed, "sync_on_close")?.ok_or_else(|| {
229                PyValueError::new_err("`sink_options` must contain `sync_on_close` field")
230            })?;
231        let sync_on_close = sync_on_close.extract::<Wrap<SyncOnCloseType>>()?.0;
232
233        let maintain_order =
234            PyDictMethods::get_item(&parsed, "maintain_order")?.ok_or_else(|| {
235                PyValueError::new_err("`sink_options` must contain `maintain_order` field")
236            })?;
237        let maintain_order = maintain_order.extract::<bool>()?;
238
239        let mkdir = PyDictMethods::get_item(&parsed, "mkdir")?
240            .ok_or_else(|| PyValueError::new_err("`sink_options` must contain `mkdir` field"))?;
241        let mkdir = mkdir.extract::<bool>()?;
242
243        Ok(Wrap(SinkOptions {
244            sync_on_close,
245            maintain_order,
246            mkdir,
247        }))
248    }
249}