1use std::fmt;
2use std::hash::{Hash, Hasher};
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use polars_core::error::PolarsResult;
7use polars_core::prelude::DataType;
8use polars_core::scalar::Scalar;
9use polars_io::cloud::CloudOptions;
10use polars_io::utils::file::{DynWriteable, Writeable};
11use polars_io::utils::sync_on_close::SyncOnCloseType;
12use polars_utils::IdxSize;
13use polars_utils::arena::Arena;
14use polars_utils::pl_str::PlSmallStr;
15
16use super::{ExprIR, FileType};
17use crate::dsl::{AExpr, Expr, SpecialEq};
18
19#[derive(Clone, PartialEq, Eq, Debug, Hash)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SinkOptions {
23 pub sync_on_close: SyncOnCloseType,
25
26 pub maintain_order: bool,
28
29 pub mkdir: bool,
31}
32
33impl Default for SinkOptions {
34 fn default() -> Self {
35 Self {
36 sync_on_close: Default::default(),
37 maintain_order: true,
38 mkdir: false,
39 }
40 }
41}
42
43type DynSinkTarget = SpecialEq<Arc<std::sync::Mutex<Option<Box<dyn DynWriteable>>>>>;
44
45#[derive(Clone, PartialEq, Eq)]
46pub enum SinkTarget {
47 Path(Arc<PathBuf>),
48 Dyn(DynSinkTarget),
49}
50
51impl SinkTarget {
52 pub fn open_into_writeable(
53 &self,
54 sink_options: &SinkOptions,
55 cloud_options: Option<&CloudOptions>,
56 ) -> PolarsResult<Writeable> {
57 match self {
58 SinkTarget::Path(path) => {
59 if sink_options.mkdir {
60 polars_io::utils::mkdir::mkdir_recursive(path.as_path())?;
61 }
62
63 let path = path.as_ref().display().to_string();
64 polars_io::utils::file::Writeable::try_new(&path, cloud_options)
65 },
66 SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn(
67 memory_writer.lock().unwrap().take().unwrap(),
68 )),
69 }
70 }
71
72 #[cfg(not(feature = "cloud"))]
73 pub async fn open_into_writeable_async(
74 &self,
75 sink_options: &SinkOptions,
76 cloud_options: Option<&CloudOptions>,
77 ) -> PolarsResult<Writeable> {
78 self.open_into_writeable(sink_options, cloud_options)
79 }
80
81 #[cfg(feature = "cloud")]
82 pub async fn open_into_writeable_async(
83 &self,
84 sink_options: &SinkOptions,
85 cloud_options: Option<&CloudOptions>,
86 ) -> PolarsResult<Writeable> {
87 match self {
88 SinkTarget::Path(path) => {
89 if sink_options.mkdir {
90 polars_io::utils::mkdir::tokio_mkdir_recursive(path.as_path()).await?;
91 }
92
93 let path = path.as_ref().display().to_string();
94 polars_io::utils::file::Writeable::try_new(&path, cloud_options)
95 },
96 SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn(
97 memory_writer.lock().unwrap().take().unwrap(),
98 )),
99 }
100 }
101}
102
103impl fmt::Debug for SinkTarget {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 f.write_str("SinkTarget::")?;
106 match self {
107 Self::Path(p) => write!(f, "Path({p:?})"),
108 Self::Dyn(_) => f.write_str("Dyn"),
109 }
110 }
111}
112
113impl std::hash::Hash for SinkTarget {
114 fn hash<H: Hasher>(&self, state: &mut H) {
115 std::mem::discriminant(self).hash(state);
116 match self {
117 Self::Path(p) => p.hash(state),
118 Self::Dyn(p) => Arc::as_ptr(p).hash(state),
119 }
120 }
121}
122
123#[cfg(feature = "serde")]
124impl serde::Serialize for SinkTarget {
125 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126 where
127 S: serde::Serializer,
128 {
129 match self {
130 Self::Path(p) => p.serialize(serializer),
131 Self::Dyn(_) => Err(serde::ser::Error::custom(
132 "cannot serialize in-memory sink target",
133 )),
134 }
135 }
136}
137
138#[cfg(feature = "serde")]
139impl<'de> serde::Deserialize<'de> for SinkTarget {
140 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141 where
142 D: serde::Deserializer<'de>,
143 {
144 Ok(Self::Path(Arc::new(PathBuf::deserialize(deserializer)?)))
145 }
146}
147
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[derive(Clone, Debug, PartialEq, Eq, Hash)]
150pub struct FileSinkType {
151 pub target: SinkTarget,
152 pub file_type: FileType,
153 pub sink_options: SinkOptions,
154 pub cloud_options: Option<polars_io::cloud::CloudOptions>,
155}
156
157#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
158#[derive(Clone, Debug, PartialEq)]
159pub enum SinkTypeIR {
160 Memory,
161 File(FileSinkType),
162 Partition(PartitionSinkTypeIR),
163}
164
165#[cfg_attr(feature = "python", pyo3::pyclass)]
166#[derive(Clone)]
167pub struct PartitionTargetContextKey {
168 pub name: PlSmallStr,
169 pub raw_value: Scalar,
170}
171
172#[cfg_attr(feature = "python", pyo3::pyclass)]
173pub struct PartitionTargetContext {
174 pub file_idx: usize,
175 pub part_idx: usize,
176 pub in_part_idx: usize,
177 pub keys: Vec<PartitionTargetContextKey>,
178 pub file_path: PathBuf,
179 pub full_path: PathBuf,
180}
181
182#[cfg(feature = "python")]
183#[pyo3::pymethods]
184impl PartitionTargetContext {
185 #[getter]
186 pub fn file_idx(&self) -> usize {
187 self.file_idx
188 }
189 #[getter]
190 pub fn part_idx(&self) -> usize {
191 self.part_idx
192 }
193 #[getter]
194 pub fn in_part_idx(&self) -> usize {
195 self.in_part_idx
196 }
197 #[getter]
198 pub fn keys(&self) -> Vec<PartitionTargetContextKey> {
199 self.keys.clone()
200 }
201 #[getter]
202 pub fn file_path(&self) -> &std::path::Path {
203 self.file_path.as_path()
204 }
205 #[getter]
206 pub fn full_path(&self) -> &std::path::Path {
207 self.full_path.as_path()
208 }
209}
210#[cfg(feature = "python")]
211#[pyo3::pymethods]
212impl PartitionTargetContextKey {
213 #[getter]
214 pub fn name(&self) -> &str {
215 self.name.as_str()
216 }
217 #[getter]
218 pub fn str_value(&self) -> pyo3::PyResult<String> {
219 let value = self
220 .raw_value
221 .clone()
222 .into_series(PlSmallStr::EMPTY)
223 .strict_cast(&DataType::String)
224 .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?;
225 let value = value.str().unwrap();
226 let value = value.get(0).unwrap_or("null").as_bytes();
227 let value = percent_encoding::percent_encode(value, polars_io::utils::URL_ENCODE_CHAR_SET);
228 Ok(value.to_string())
229 }
230 #[getter]
231 pub fn raw_value(&self) -> pyo3::PyObject {
232 let converter = polars_core::chunked_array::object::registry::get_pyobject_converter();
233 *(converter.as_ref())(self.raw_value.as_any_value())
234 .downcast::<pyo3::PyObject>()
235 .unwrap()
236 }
237}
238
239#[derive(Clone, Debug, PartialEq)]
240pub enum PartitionTargetCallback {
241 Rust(SpecialEq<Arc<dyn Fn(PartitionTargetContext) -> PolarsResult<SinkTarget> + Send + Sync>>),
242 #[cfg(feature = "python")]
243 Python(polars_utils::python_function::PythonFunction),
244}
245
246impl PartitionTargetCallback {
247 pub fn call(&self, ctx: PartitionTargetContext) -> PolarsResult<SinkTarget> {
248 match self {
249 Self::Rust(f) => f(ctx),
250 #[cfg(feature = "python")]
251 Self::Python(f) => pyo3::Python::with_gil(|py| {
252 let sink_target = f.call1(py, (ctx,))?;
253 let converter =
254 polars_utils::python_convert_registry::get_python_convert_registry();
255 let sink_target = (converter.from_py.sink_target)(sink_target)?;
256 let sink_target = sink_target.downcast_ref::<SinkTarget>().unwrap().clone();
257 PolarsResult::Ok(sink_target)
258 }),
259 }
260 }
261}
262
263#[cfg(feature = "serde")]
264impl<'de> serde::Deserialize<'de> for PartitionTargetCallback {
265 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
266 where
267 D: serde::Deserializer<'de>,
268 {
269 #[cfg(feature = "python")]
270 {
271 Ok(Self::Python(
272 polars_utils::python_function::PythonFunction::deserialize(_deserializer)?,
273 ))
274 }
275 #[cfg(not(feature = "python"))]
276 {
277 use serde::de::Error;
278 Err(D::Error::custom(
279 "cannot deserialize PartitionOutputCallback",
280 ))
281 }
282 }
283}
284
285#[cfg(feature = "serde")]
286impl serde::Serialize for PartitionTargetCallback {
287 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
288 where
289 S: serde::Serializer,
290 {
291 use serde::ser::Error;
292
293 #[cfg(feature = "python")]
294 if let Self::Python(v) = self {
295 return v.serialize(_serializer);
296 }
297
298 Err(S::Error::custom(format!("cannot serialize {:?}", self)))
299 }
300}
301
302#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
303#[derive(Clone, Debug, PartialEq)]
304pub struct PartitionSinkType {
305 pub base_path: Arc<PathBuf>,
306 pub file_path_cb: Option<PartitionTargetCallback>,
307 pub file_type: FileType,
308 pub sink_options: SinkOptions,
309 pub variant: PartitionVariant,
310 pub cloud_options: Option<polars_io::cloud::CloudOptions>,
311}
312
313#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
314#[derive(Clone, Debug, PartialEq)]
315pub struct PartitionSinkTypeIR {
316 pub base_path: Arc<PathBuf>,
317 pub file_path_cb: Option<PartitionTargetCallback>,
318 pub file_type: FileType,
319 pub sink_options: SinkOptions,
320 pub variant: PartitionVariantIR,
321 pub cloud_options: Option<polars_io::cloud::CloudOptions>,
322}
323
324#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
325#[derive(Clone, Debug, PartialEq)]
326pub enum SinkType {
327 Memory,
328 File(FileSinkType),
329 Partition(PartitionSinkType),
330}
331
332#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
333#[derive(Clone, Debug, PartialEq, Eq, Hash)]
334pub enum PartitionVariant {
335 MaxSize(IdxSize),
336 Parted {
337 key_exprs: Vec<Expr>,
338 include_key: bool,
339 },
340 ByKey {
341 key_exprs: Vec<Expr>,
342 include_key: bool,
343 },
344}
345
346#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
347#[derive(Clone, Debug, PartialEq, Eq)]
348pub enum PartitionVariantIR {
349 MaxSize(IdxSize),
350 Parted {
351 key_exprs: Vec<ExprIR>,
352 include_key: bool,
353 },
354 ByKey {
355 key_exprs: Vec<ExprIR>,
356 include_key: bool,
357 },
358}
359
360impl SinkTypeIR {
361 #[cfg(feature = "cse")]
362 pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
363 std::mem::discriminant(self).hash(state);
364 match self {
365 Self::Memory => {},
366 Self::File(f) => f.hash(state),
367 Self::Partition(f) => {
368 f.file_type.hash(state);
369 f.sink_options.hash(state);
370 f.variant.traverse_and_hash(expr_arena, state);
371 f.cloud_options.hash(state);
372 },
373 }
374 }
375}
376
377impl PartitionVariantIR {
378 #[cfg(feature = "cse")]
379 pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
380 std::mem::discriminant(self).hash(state);
381 match self {
382 Self::MaxSize(size) => size.hash(state),
383 Self::Parted {
384 key_exprs,
385 include_key,
386 }
387 | Self::ByKey {
388 key_exprs,
389 include_key,
390 } => {
391 include_key.hash(state);
392 for key_expr in key_exprs.as_slice() {
393 key_expr.traverse_and_hash(expr_arena, state);
394 }
395 },
396 }
397 }
398}
399
400#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
401#[derive(Clone, Debug)]
402pub struct FileSinkOptions {
403 pub path: Arc<PathBuf>,
404 pub file_type: FileType,
405}