polars_plan/dsl/options/
sink.rs

1use std::fmt;
2use std::hash::{Hash, Hasher};
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use polars_core::error::PolarsResult;
7use polars_core::prelude::DataType;
8use polars_core::scalar::Scalar;
9use polars_io::cloud::CloudOptions;
10use polars_io::utils::file::{DynWriteable, Writeable};
11use polars_io::utils::sync_on_close::SyncOnCloseType;
12use polars_utils::IdxSize;
13use polars_utils::arena::Arena;
14use polars_utils::pl_str::PlSmallStr;
15
16use super::{ExprIR, FileType};
17use crate::dsl::{AExpr, Expr, SpecialEq};
18
19/// Options that apply to all sinks.
20#[derive(Clone, PartialEq, Eq, Debug, Hash)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SinkOptions {
23    /// Call sync when closing the file.
24    pub sync_on_close: SyncOnCloseType,
25
26    /// The output file needs to maintain order of the data that comes in.
27    pub maintain_order: bool,
28
29    /// Recursively create all the directories in the path.
30    pub mkdir: bool,
31}
32
33impl Default for SinkOptions {
34    fn default() -> Self {
35        Self {
36            sync_on_close: Default::default(),
37            maintain_order: true,
38            mkdir: false,
39        }
40    }
41}
42
43type DynSinkTarget = SpecialEq<Arc<std::sync::Mutex<Option<Box<dyn DynWriteable>>>>>;
44
45#[derive(Clone, PartialEq, Eq)]
46pub enum SinkTarget {
47    Path(Arc<PathBuf>),
48    Dyn(DynSinkTarget),
49}
50
51impl SinkTarget {
52    pub fn open_into_writeable(
53        &self,
54        sink_options: &SinkOptions,
55        cloud_options: Option<&CloudOptions>,
56    ) -> PolarsResult<Writeable> {
57        match self {
58            SinkTarget::Path(path) => {
59                if sink_options.mkdir {
60                    polars_io::utils::mkdir::mkdir_recursive(path.as_path())?;
61                }
62
63                let path = path.as_ref().display().to_string();
64                polars_io::utils::file::Writeable::try_new(&path, cloud_options)
65            },
66            SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn(
67                memory_writer.lock().unwrap().take().unwrap(),
68            )),
69        }
70    }
71
72    #[cfg(not(feature = "cloud"))]
73    pub async fn open_into_writeable_async(
74        &self,
75        sink_options: &SinkOptions,
76        cloud_options: Option<&CloudOptions>,
77    ) -> PolarsResult<Writeable> {
78        self.open_into_writeable(sink_options, cloud_options)
79    }
80
81    #[cfg(feature = "cloud")]
82    pub async fn open_into_writeable_async(
83        &self,
84        sink_options: &SinkOptions,
85        cloud_options: Option<&CloudOptions>,
86    ) -> PolarsResult<Writeable> {
87        match self {
88            SinkTarget::Path(path) => {
89                if sink_options.mkdir {
90                    polars_io::utils::mkdir::tokio_mkdir_recursive(path.as_path()).await?;
91                }
92
93                let path = path.as_ref().display().to_string();
94                polars_io::utils::file::Writeable::try_new(&path, cloud_options)
95            },
96            SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn(
97                memory_writer.lock().unwrap().take().unwrap(),
98            )),
99        }
100    }
101}
102
103impl fmt::Debug for SinkTarget {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        f.write_str("SinkTarget::")?;
106        match self {
107            Self::Path(p) => write!(f, "Path({p:?})"),
108            Self::Dyn(_) => f.write_str("Dyn"),
109        }
110    }
111}
112
113impl std::hash::Hash for SinkTarget {
114    fn hash<H: Hasher>(&self, state: &mut H) {
115        std::mem::discriminant(self).hash(state);
116        match self {
117            Self::Path(p) => p.hash(state),
118            Self::Dyn(p) => Arc::as_ptr(p).hash(state),
119        }
120    }
121}
122
123#[cfg(feature = "serde")]
124impl serde::Serialize for SinkTarget {
125    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126    where
127        S: serde::Serializer,
128    {
129        match self {
130            Self::Path(p) => p.serialize(serializer),
131            Self::Dyn(_) => Err(serde::ser::Error::custom(
132                "cannot serialize in-memory sink target",
133            )),
134        }
135    }
136}
137
138#[cfg(feature = "serde")]
139impl<'de> serde::Deserialize<'de> for SinkTarget {
140    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141    where
142        D: serde::Deserializer<'de>,
143    {
144        Ok(Self::Path(Arc::new(PathBuf::deserialize(deserializer)?)))
145    }
146}
147
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[derive(Clone, Debug, PartialEq, Eq, Hash)]
150pub struct FileSinkType {
151    pub target: SinkTarget,
152    pub file_type: FileType,
153    pub sink_options: SinkOptions,
154    pub cloud_options: Option<polars_io::cloud::CloudOptions>,
155}
156
157#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
158#[derive(Clone, Debug, PartialEq)]
159pub enum SinkTypeIR {
160    Memory,
161    File(FileSinkType),
162    Partition(PartitionSinkTypeIR),
163}
164
165#[cfg_attr(feature = "python", pyo3::pyclass)]
166#[derive(Clone)]
167pub struct PartitionTargetContextKey {
168    pub name: PlSmallStr,
169    pub raw_value: Scalar,
170}
171
172#[cfg_attr(feature = "python", pyo3::pyclass)]
173pub struct PartitionTargetContext {
174    pub file_idx: usize,
175    pub part_idx: usize,
176    pub in_part_idx: usize,
177    pub keys: Vec<PartitionTargetContextKey>,
178    pub file_path: PathBuf,
179    pub full_path: PathBuf,
180}
181
182#[cfg(feature = "python")]
183#[pyo3::pymethods]
184impl PartitionTargetContext {
185    #[getter]
186    pub fn file_idx(&self) -> usize {
187        self.file_idx
188    }
189    #[getter]
190    pub fn part_idx(&self) -> usize {
191        self.part_idx
192    }
193    #[getter]
194    pub fn in_part_idx(&self) -> usize {
195        self.in_part_idx
196    }
197    #[getter]
198    pub fn keys(&self) -> Vec<PartitionTargetContextKey> {
199        self.keys.clone()
200    }
201    #[getter]
202    pub fn file_path(&self) -> &std::path::Path {
203        self.file_path.as_path()
204    }
205    #[getter]
206    pub fn full_path(&self) -> &std::path::Path {
207        self.full_path.as_path()
208    }
209}
210#[cfg(feature = "python")]
211#[pyo3::pymethods]
212impl PartitionTargetContextKey {
213    #[getter]
214    pub fn name(&self) -> &str {
215        self.name.as_str()
216    }
217    #[getter]
218    pub fn str_value(&self) -> pyo3::PyResult<String> {
219        let value = self
220            .raw_value
221            .clone()
222            .into_series(PlSmallStr::EMPTY)
223            .strict_cast(&DataType::String)
224            .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?;
225        let value = value.str().unwrap();
226        let value = value.get(0).unwrap_or("null").as_bytes();
227        let value = percent_encoding::percent_encode(value, polars_io::utils::URL_ENCODE_CHAR_SET);
228        Ok(value.to_string())
229    }
230    #[getter]
231    pub fn raw_value(&self) -> pyo3::PyObject {
232        let converter = polars_core::chunked_array::object::registry::get_pyobject_converter();
233        *(converter.as_ref())(self.raw_value.as_any_value())
234            .downcast::<pyo3::PyObject>()
235            .unwrap()
236    }
237}
238
239#[derive(Clone, Debug, PartialEq)]
240pub enum PartitionTargetCallback {
241    Rust(SpecialEq<Arc<dyn Fn(PartitionTargetContext) -> PolarsResult<SinkTarget> + Send + Sync>>),
242    #[cfg(feature = "python")]
243    Python(polars_utils::python_function::PythonFunction),
244}
245
246impl PartitionTargetCallback {
247    pub fn call(&self, ctx: PartitionTargetContext) -> PolarsResult<SinkTarget> {
248        match self {
249            Self::Rust(f) => f(ctx),
250            #[cfg(feature = "python")]
251            Self::Python(f) => pyo3::Python::with_gil(|py| {
252                let sink_target = f.call1(py, (ctx,))?;
253                let converter =
254                    polars_utils::python_convert_registry::get_python_convert_registry();
255                let sink_target = (converter.from_py.sink_target)(sink_target)?;
256                let sink_target = sink_target.downcast_ref::<SinkTarget>().unwrap().clone();
257                PolarsResult::Ok(sink_target)
258            }),
259        }
260    }
261}
262
263#[cfg(feature = "serde")]
264impl<'de> serde::Deserialize<'de> for PartitionTargetCallback {
265    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
266    where
267        D: serde::Deserializer<'de>,
268    {
269        #[cfg(feature = "python")]
270        {
271            Ok(Self::Python(
272                polars_utils::python_function::PythonFunction::deserialize(_deserializer)?,
273            ))
274        }
275        #[cfg(not(feature = "python"))]
276        {
277            use serde::de::Error;
278            Err(D::Error::custom(
279                "cannot deserialize PartitionOutputCallback",
280            ))
281        }
282    }
283}
284
285#[cfg(feature = "serde")]
286impl serde::Serialize for PartitionTargetCallback {
287    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
288    where
289        S: serde::Serializer,
290    {
291        use serde::ser::Error;
292
293        #[cfg(feature = "python")]
294        if let Self::Python(v) = self {
295            return v.serialize(_serializer);
296        }
297
298        Err(S::Error::custom(format!("cannot serialize {:?}", self)))
299    }
300}
301
302#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
303#[derive(Clone, Debug, PartialEq)]
304pub struct PartitionSinkType {
305    pub base_path: Arc<PathBuf>,
306    pub file_path_cb: Option<PartitionTargetCallback>,
307    pub file_type: FileType,
308    pub sink_options: SinkOptions,
309    pub variant: PartitionVariant,
310    pub cloud_options: Option<polars_io::cloud::CloudOptions>,
311}
312
313#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
314#[derive(Clone, Debug, PartialEq)]
315pub struct PartitionSinkTypeIR {
316    pub base_path: Arc<PathBuf>,
317    pub file_path_cb: Option<PartitionTargetCallback>,
318    pub file_type: FileType,
319    pub sink_options: SinkOptions,
320    pub variant: PartitionVariantIR,
321    pub cloud_options: Option<polars_io::cloud::CloudOptions>,
322}
323
324#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
325#[derive(Clone, Debug, PartialEq)]
326pub enum SinkType {
327    Memory,
328    File(FileSinkType),
329    Partition(PartitionSinkType),
330}
331
332#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
333#[derive(Clone, Debug, PartialEq, Eq, Hash)]
334pub enum PartitionVariant {
335    MaxSize(IdxSize),
336    Parted {
337        key_exprs: Vec<Expr>,
338        include_key: bool,
339    },
340    ByKey {
341        key_exprs: Vec<Expr>,
342        include_key: bool,
343    },
344}
345
346#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
347#[derive(Clone, Debug, PartialEq, Eq)]
348pub enum PartitionVariantIR {
349    MaxSize(IdxSize),
350    Parted {
351        key_exprs: Vec<ExprIR>,
352        include_key: bool,
353    },
354    ByKey {
355        key_exprs: Vec<ExprIR>,
356        include_key: bool,
357    },
358}
359
360impl SinkTypeIR {
361    #[cfg(feature = "cse")]
362    pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
363        std::mem::discriminant(self).hash(state);
364        match self {
365            Self::Memory => {},
366            Self::File(f) => f.hash(state),
367            Self::Partition(f) => {
368                f.file_type.hash(state);
369                f.sink_options.hash(state);
370                f.variant.traverse_and_hash(expr_arena, state);
371                f.cloud_options.hash(state);
372            },
373        }
374    }
375}
376
377impl PartitionVariantIR {
378    #[cfg(feature = "cse")]
379    pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
380        std::mem::discriminant(self).hash(state);
381        match self {
382            Self::MaxSize(size) => size.hash(state),
383            Self::Parted {
384                key_exprs,
385                include_key,
386            }
387            | Self::ByKey {
388                key_exprs,
389                include_key,
390            } => {
391                include_key.hash(state);
392                for key_expr in key_exprs.as_slice() {
393                    key_expr.traverse_and_hash(expr_arena, state);
394                }
395            },
396        }
397    }
398}
399
400#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
401#[derive(Clone, Debug)]
402pub struct FileSinkOptions {
403    pub path: Arc<PathBuf>,
404    pub file_type: FileType,
405}