1use std::sync::{Arc, Mutex};
2
3use polars::prelude::sync_on_close::SyncOnCloseType;
4use polars::prelude::{
5 PartitionTargetCallbackResult, PartitionVariant, PlPath, SinkFinishCallback, SinkOptions,
6 SortColumn, SpecialEq,
7};
8use polars_utils::IdxSize;
9use polars_utils::plpath::PlPathRef;
10use polars_utils::python_function::{PythonFunction, PythonObject};
11use pyo3::exceptions::PyValueError;
12use pyo3::pybacked::PyBackedStr;
13use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
14use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python, pyclass, pymethods};
15
16use crate::expr::PyExpr;
17use crate::prelude::Wrap;
18
19#[derive(Clone)]
20pub enum SinkTarget {
21 File(polars_plan::dsl::SinkTarget),
22 Partition(PyPartitioning),
23}
24
25#[pyclass]
26#[derive(Clone)]
27pub struct PyPartitioning {
28 #[pyo3(get)]
29 pub base_path: Wrap<PlPath>,
30 pub file_path_cb: Option<PythonFunction>,
31 pub variant: PartitionVariant,
32 pub per_partition_sort_by: Option<Vec<SortColumn>>,
33 pub finish_callback: Option<SinkFinishCallback>,
34}
35
36fn parse_per_partition_sort_by(sort_by: Option<Vec<PyExpr>>) -> Option<Vec<SortColumn>> {
37 sort_by.map(|exprs| {
38 exprs
39 .into_iter()
40 .map(|e| SortColumn {
41 expr: e.inner,
42 descending: false,
43 nulls_last: false,
44 })
45 .collect()
46 })
47}
48
49#[cfg(feature = "pymethods")]
50#[pymethods]
51impl PyPartitioning {
52 #[staticmethod]
53 #[pyo3(signature = (base_path, file_path_cb, max_size, per_partition_sort_by, finish_callback))]
54 pub fn new_max_size(
55 base_path: Wrap<PlPath>,
56 file_path_cb: Option<PyObject>,
57 max_size: IdxSize,
58 per_partition_sort_by: Option<Vec<PyExpr>>,
59 finish_callback: Option<PyObject>,
60 ) -> PyPartitioning {
61 let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
62 let finish_callback =
63 finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
64
65 PyPartitioning {
66 base_path,
67 file_path_cb,
68 variant: PartitionVariant::MaxSize(max_size),
69 per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
70 finish_callback,
71 }
72 }
73
74 #[staticmethod]
75 #[pyo3(signature = (base_path, file_path_cb, by, include_key, per_partition_sort_by, finish_callback))]
76 pub fn new_by_key(
77 base_path: Wrap<PlPath>,
78 file_path_cb: Option<PyObject>,
79 by: Vec<PyExpr>,
80 include_key: bool,
81 per_partition_sort_by: Option<Vec<PyExpr>>,
82 finish_callback: Option<PyObject>,
83 ) -> PyPartitioning {
84 let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
85 let finish_callback =
86 finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
87
88 PyPartitioning {
89 base_path,
90 file_path_cb,
91 variant: PartitionVariant::ByKey {
92 key_exprs: by.into_iter().map(|e| e.inner).collect(),
93 include_key,
94 },
95 per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
96 finish_callback,
97 }
98 }
99
100 #[staticmethod]
101 #[pyo3(signature = (base_path, file_path_cb, by, include_key, per_partition_sort_by, finish_callback))]
102 pub fn new_parted(
103 base_path: Wrap<PlPath>,
104 file_path_cb: Option<PyObject>,
105 by: Vec<PyExpr>,
106 include_key: bool,
107 per_partition_sort_by: Option<Vec<PyExpr>>,
108 finish_callback: Option<PyObject>,
109 ) -> PyPartitioning {
110 let file_path_cb = file_path_cb.map(|f| PythonObject(f.into_any()));
111 let finish_callback =
112 finish_callback.map(|f| SinkFinishCallback::Python(PythonObject(f.into_any())));
113
114 PyPartitioning {
115 base_path,
116 file_path_cb,
117 variant: PartitionVariant::Parted {
118 key_exprs: by.into_iter().map(|e| e.inner).collect(),
119 include_key,
120 },
121 per_partition_sort_by: parse_per_partition_sort_by(per_partition_sort_by),
122 finish_callback,
123 }
124 }
125}
126
127impl<'py> FromPyObject<'py> for Wrap<polars_plan::dsl::SinkTarget> {
128 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
129 if let Ok(v) = ob.extract::<PyBackedStr>() {
130 Ok(Wrap(polars::prelude::SinkTarget::Path(PlPath::new(&v))))
131 } else {
132 let writer = Python::with_gil(|py| {
133 let py_f = ob.clone();
134 PyResult::Ok(
135 crate::file::try_get_pyfile(py, py_f, true)?
136 .0
137 .into_writeable(),
138 )
139 })?;
140
141 Ok(Wrap(polars_plan::prelude::SinkTarget::Dyn(SpecialEq::new(
142 Arc::new(Mutex::new(Some(writer))),
143 ))))
144 }
145 }
146}
147
148impl<'py> FromPyObject<'py> for Wrap<PartitionTargetCallbackResult> {
149 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
150 if let Ok(v) = ob.extract::<PyBackedStr>() {
151 Ok(Wrap(polars::prelude::PartitionTargetCallbackResult::Str(
152 v.to_string(),
153 )))
154 } else if let Ok(v) = ob.extract::<std::path::PathBuf>() {
155 Ok(Wrap(polars::prelude::PartitionTargetCallbackResult::Str(
156 v.to_str().unwrap().to_string(),
157 )))
158 } else {
159 let writer = Python::with_gil(|py| {
160 let py_f = ob.clone();
161 PyResult::Ok(
162 crate::file::try_get_pyfile(py, py_f, true)?
163 .0
164 .into_writeable(),
165 )
166 })?;
167
168 Ok(Wrap(
169 polars_plan::prelude::PartitionTargetCallbackResult::Dyn(SpecialEq::new(Arc::new(
170 Mutex::new(Some(writer)),
171 ))),
172 ))
173 }
174 }
175}
176
177impl<'py> FromPyObject<'py> for SinkTarget {
178 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
179 if let Ok(v) = ob.extract::<PyPartitioning>() {
180 Ok(Self::Partition(v))
181 } else {
182 Ok(Self::File(
183 <Wrap<polars_plan::dsl::SinkTarget>>::extract_bound(ob)?.0,
184 ))
185 }
186 }
187}
188
189impl SinkTarget {
190 pub fn base_path(&self) -> Option<PlPathRef<'_>> {
191 match self {
192 Self::File(t) => match t {
193 polars::prelude::SinkTarget::Path(p) => Some(p.as_ref()),
194 polars::prelude::SinkTarget::Dyn(_) => None,
195 },
196 Self::Partition(p) => Some(p.base_path.0.as_ref()),
197 }
198 }
199}
200
201impl<'py> FromPyObject<'py> for Wrap<SyncOnCloseType> {
202 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
203 let parsed = match &*ob.extract::<PyBackedStr>()? {
204 "none" => SyncOnCloseType::None,
205 "data" => SyncOnCloseType::Data,
206 "all" => SyncOnCloseType::All,
207 v => {
208 return Err(PyValueError::new_err(format!(
209 "`sync_on_close` must be one of {{'none', 'data', 'all'}}, got {v}",
210 )));
211 },
212 };
213 Ok(Wrap(parsed))
214 }
215}
216
217impl<'py> FromPyObject<'py> for Wrap<SinkOptions> {
218 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
219 let parsed = ob.extract::<pyo3::Bound<'_, PyDict>>()?;
220
221 if parsed.len() != 3 {
222 return Err(PyValueError::new_err(
223 "`sink_options` must be a dictionary with the exactly 3 field.",
224 ));
225 }
226
227 let sync_on_close =
228 PyDictMethods::get_item(&parsed, "sync_on_close")?.ok_or_else(|| {
229 PyValueError::new_err("`sink_options` must contain `sync_on_close` field")
230 })?;
231 let sync_on_close = sync_on_close.extract::<Wrap<SyncOnCloseType>>()?.0;
232
233 let maintain_order =
234 PyDictMethods::get_item(&parsed, "maintain_order")?.ok_or_else(|| {
235 PyValueError::new_err("`sink_options` must contain `maintain_order` field")
236 })?;
237 let maintain_order = maintain_order.extract::<bool>()?;
238
239 let mkdir = PyDictMethods::get_item(&parsed, "mkdir")?
240 .ok_or_else(|| PyValueError::new_err("`sink_options` must contain `mkdir` field"))?;
241 let mkdir = mkdir.extract::<bool>()?;
242
243 Ok(Wrap(SinkOptions {
244 sync_on_close,
245 maintain_order,
246 mkdir,
247 }))
248 }
249}