1use std::path::{Path, PathBuf};
2use std::sync::{Arc, Mutex};
3
4use polars::prelude::sync_on_close::SyncOnCloseType;
5use polars::prelude::{PartitionVariant, SinkFinishCallback, SinkOptions, SortColumn, SpecialEq};
6use polars_utils::IdxSize;
7use polars_utils::python_function::{PythonFunction, PythonObject};
8use pyo3::exceptions::PyValueError;
9use pyo3::pybacked::PyBackedStr;
10use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
11use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python, pyclass, pymethods};
12
13use crate::expr::PyExpr;
14use crate::prelude::Wrap;
15
16#[derive(Clone)]
17pub enum SinkTarget {
18 File(polars_plan::dsl::SinkTarget),
19 Partition(PyPartitioning),
20}
21
22#[pyclass]
23#[derive(Clone)]
24pub struct PyPartitioning {
25 #[pyo3(get)]
26 pub base_path: PathBuf,
27 pub file_path_cb: Option<PythonFunction>,
28 pub variant: PartitionVariant,
29 pub per_partition_sort_by: Option<Vec<SortColumn>>,
30 pub finish_callback: Option<SinkFinishCallback>,
31}
32
33fn parse_per_partition_sort_by(sort_by: Option<Vec<PyExpr>>) -> Option<Vec<SortColumn>> {
34 sort_by.map(|exprs| {
35 exprs
36 .into_iter()
37 .map(|e| SortColumn {
38 expr: e.inner,
39 descending: false,
40 nulls_last: false,
41 })
42 .collect()
43 })
44}
45
46#[cfg(feature = "pymethods")]
47#[pymethods]
48impl PyPartitioning {
49 #[staticmethod]
50 #[pyo3(signature = (base_path, file_path_cb, max_size, per_partition_sort_by, finish_callback))]
51 pub fn new_max_size(
52 base_path: PathBuf,
53 file_path_cb: Option<PyObject>,
54 max_size: IdxSize,
55 per_partition_sort_by: Option<Vec<PyExpr>>,
56 finish_callback: Option<PyObject>,
57 ) -> PyPartitioning {
58 let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
59 let finish_callback =
60 finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
61
62 PyPartitioning {
63 base_path,
64 file_path_cb,
65 variant: PartitionVariant::MaxSize(max_size),
66 per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
67 finish_callback,
68 }
69 }
70
71 #[staticmethod]
72 #[pyo3(signature = (base_path, file_path_cb, by, include_key, per_partition_sort_by, finish_callback))]
73 pub fn new_by_key(
74 base_path: PathBuf,
75 file_path_cb: Option<PyObject>,
76 by: Vec<PyExpr>,
77 include_key: bool,
78 per_partition_sort_by: Option<Vec<PyExpr>>,
79 finish_callback: Option<PyObject>,
80 ) -> PyPartitioning {
81 let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
82 let finish_callback =
83 finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
84
85 PyPartitioning {
86 base_path,
87 file_path_cb,
88 variant: PartitionVariant::ByKey {
89 key_exprs: by.into_iter().map(|e| e.inner).collect(),
90 include_key,
91 },
92 per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
93 finish_callback,
94 }
95 }
96
97 #[staticmethod]
98 #[pyo3(signature = (base_path, file_path_cb, by, include_key, per_partition_sort_by, finish_callback))]
99 pub fn new_parted(
100 base_path: PathBuf,
101 file_path_cb: Option<PyObject>,
102 by: Vec<PyExpr>,
103 include_key: bool,
104 per_partition_sort_by: Option<Vec<PyExpr>>,
105 finish_callback: Option<PyObject>,
106 ) -> PyPartitioning {
107 let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
108 let finish_callback =
109 finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
110
111 PyPartitioning {
112 base_path,
113 file_path_cb,
114 variant: PartitionVariant::Parted {
115 key_exprs: by.into_iter().map(|e| e.inner).collect(),
116 include_key,
117 },
118 per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
119 finish_callback,
120 }
121 }
122}
123
124impl<'py> FromPyObject<'py> for Wrap<polars_plan::dsl::SinkTarget> {
125 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
126 if let Ok(v) = ob.extract::<PathBuf>() {
127 Ok(Wrap(polars::prelude::SinkTarget::Path(Arc::new(v))))
128 } else {
129 let writer = Python::with_gil(|py| {
130 let py_f = ob.clone();
131 PyResult::Ok(
132 crate::file::try_get_pyfile(py, py_f, true)?
133 .0
134 .into_writeable(),
135 )
136 })?;
137
138 Ok(Wrap(polars_plan::prelude::SinkTarget::Dyn(SpecialEq::new(
139 Arc::new(Mutex::new(Some(writer))),
140 ))))
141 }
142 }
143}
144
145impl<'py> FromPyObject<'py> for SinkTarget {
146 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
147 if let Ok(v) = ob.extract::<PyPartitioning>() {
148 Ok(Self::Partition(v))
149 } else {
150 Ok(Self::File(
151 <Wrap<polars_plan::dsl::SinkTarget>>::extract_bound(ob)?.0,
152 ))
153 }
154 }
155}
156
157impl SinkTarget {
158 pub fn base_path(&self) -> Option<&Path> {
159 match self {
160 Self::File(t) => match t {
161 polars::prelude::SinkTarget::Path(p) => Some(p.as_path()),
162 polars::prelude::SinkTarget::Dyn(_) => None,
163 },
164 Self::Partition(p) => Some(&p.base_path),
165 }
166 }
167}
168
169impl<'py> FromPyObject<'py> for Wrap<SyncOnCloseType> {
170 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
171 let parsed = match &*ob.extract::<PyBackedStr>()? {
172 "none" => SyncOnCloseType::None,
173 "data" => SyncOnCloseType::Data,
174 "all" => SyncOnCloseType::All,
175 v => {
176 return Err(PyValueError::new_err(format!(
177 "`sync_on_close` must be one of {{'none', 'data', 'all'}}, got {v}",
178 )));
179 },
180 };
181 Ok(Wrap(parsed))
182 }
183}
184
185impl<'py> FromPyObject<'py> for Wrap<SinkOptions> {
186 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
187 let parsed = ob.extract::<pyo3::Bound<'_, PyDict>>()?;
188
189 if parsed.len() != 3 {
190 return Err(PyValueError::new_err(
191 "`sink_options` must be a dictionary with the exactly 3 field.",
192 ));
193 }
194
195 let sync_on_close =
196 PyDictMethods::get_item(&parsed, "sync_on_close")?.ok_or_else(|| {
197 PyValueError::new_err("`sink_options` must contain `sync_on_close` field")
198 })?;
199 let sync_on_close = sync_on_close.extract::<Wrap<SyncOnCloseType>>()?.0;
200
201 let maintain_order =
202 PyDictMethods::get_item(&parsed, "maintain_order")?.ok_or_else(|| {
203 PyValueError::new_err("`sink_options` must contain `maintain_order` field")
204 })?;
205 let maintain_order = maintain_order.extract::<bool>()?;
206
207 let mkdir = PyDictMethods::get_item(&parsed, "mkdir")?
208 .ok_or_else(|| PyValueError::new_err("`sink_options` must contain `mkdir` field"))?;
209 let mkdir = mkdir.extract::<bool>()?;
210
211 Ok(Wrap(SinkOptions {
212 sync_on_close,
213 maintain_order,
214 mkdir,
215 }))
216 }
217}