polars_plan/dsl/options/
sink.rs

1use std::fmt;
2use std::hash::{Hash, Hasher};
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use polars_core::error::PolarsResult;
7use polars_core::prelude::DataType;
8use polars_core::scalar::Scalar;
9use polars_io::cloud::CloudOptions;
10use polars_io::utils::file::{DynWriteable, Writeable};
11use polars_io::utils::sync_on_close::SyncOnCloseType;
12use polars_utils::IdxSize;
13use polars_utils::arena::Arena;
14use polars_utils::pl_str::PlSmallStr;
15
16use super::{ExprIR, FileType};
17use crate::dsl::{AExpr, Expr, SpecialEq};
18
19/// Options that apply to all sinks.
20#[derive(Clone, PartialEq, Eq, Debug, Hash)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SinkOptions {
23    /// Call sync when closing the file.
24    pub sync_on_close: SyncOnCloseType,
25
26    /// The output file needs to maintain order of the data that comes in.
27    pub maintain_order: bool,
28
29    /// Recursively create all the directories in the path.
30    pub mkdir: bool,
31}
32
33impl Default for SinkOptions {
34    fn default() -> Self {
35        Self {
36            sync_on_close: Default::default(),
37            maintain_order: true,
38            mkdir: false,
39        }
40    }
41}
42
43type DynSinkTarget = SpecialEq<Arc<std::sync::Mutex<Option<Box<dyn DynWriteable>>>>>;
44
45#[derive(Clone, PartialEq, Eq)]
46pub enum SinkTarget {
47    Path(Arc<PathBuf>),
48    Dyn(DynSinkTarget),
49}
50
51impl SinkTarget {
52    pub fn open_into_writeable(
53        &self,
54        sink_options: &SinkOptions,
55        cloud_options: Option<&CloudOptions>,
56    ) -> PolarsResult<Writeable> {
57        match self {
58            SinkTarget::Path(path) => {
59                if sink_options.mkdir {
60                    polars_io::utils::mkdir::mkdir_recursive(path.as_path())?;
61                }
62
63                let path = path.as_ref().display().to_string();
64                polars_io::utils::file::Writeable::try_new(&path, cloud_options)
65            },
66            SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn(
67                memory_writer.lock().unwrap().take().unwrap(),
68            )),
69        }
70    }
71
72    #[cfg(feature = "cloud")]
73    pub async fn open_into_writeable_async(
74        &self,
75        sink_options: &SinkOptions,
76        cloud_options: Option<&CloudOptions>,
77    ) -> PolarsResult<Writeable> {
78        match self {
79            SinkTarget::Path(path) => {
80                if sink_options.mkdir {
81                    polars_io::utils::mkdir::tokio_mkdir_recursive(path.as_path()).await?;
82                }
83
84                let path = path.as_ref().display().to_string();
85                polars_io::utils::file::Writeable::try_new(&path, cloud_options)
86            },
87            SinkTarget::Dyn(memory_writer) => Ok(Writeable::Dyn(
88                memory_writer.lock().unwrap().take().unwrap(),
89            )),
90        }
91    }
92}
93
94impl fmt::Debug for SinkTarget {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        f.write_str("SinkTarget::")?;
97        match self {
98            Self::Path(p) => write!(f, "Path({p:?})"),
99            Self::Dyn(_) => f.write_str("Dyn"),
100        }
101    }
102}
103
104impl std::hash::Hash for SinkTarget {
105    fn hash<H: Hasher>(&self, state: &mut H) {
106        std::mem::discriminant(self).hash(state);
107        match self {
108            Self::Path(p) => p.hash(state),
109            Self::Dyn(p) => Arc::as_ptr(p).hash(state),
110        }
111    }
112}
113
114#[cfg(feature = "serde")]
115impl serde::Serialize for SinkTarget {
116    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117    where
118        S: serde::Serializer,
119    {
120        match self {
121            Self::Path(p) => p.serialize(serializer),
122            Self::Dyn(_) => Err(serde::ser::Error::custom(
123                "cannot serialize in-memory sink target",
124            )),
125        }
126    }
127}
128
129#[cfg(feature = "serde")]
130impl<'de> serde::Deserialize<'de> for SinkTarget {
131    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
132    where
133        D: serde::Deserializer<'de>,
134    {
135        Ok(Self::Path(Arc::new(PathBuf::deserialize(deserializer)?)))
136    }
137}
138
139#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
140#[derive(Clone, Debug, PartialEq, Eq, Hash)]
141pub struct FileSinkType {
142    pub target: SinkTarget,
143    pub file_type: FileType,
144    pub sink_options: SinkOptions,
145    pub cloud_options: Option<polars_io::cloud::CloudOptions>,
146}
147
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[derive(Clone, Debug, PartialEq)]
150pub enum SinkTypeIR {
151    Memory,
152    File(FileSinkType),
153    Partition(PartitionSinkTypeIR),
154}
155
156#[cfg_attr(feature = "python", pyo3::pyclass)]
157#[derive(Clone)]
158pub struct PartitionTargetContextKey {
159    pub name: PlSmallStr,
160    pub raw_value: Scalar,
161}
162
163#[cfg_attr(feature = "python", pyo3::pyclass)]
164pub struct PartitionTargetContext {
165    pub file_idx: usize,
166    pub part_idx: usize,
167    pub in_part_idx: usize,
168    pub keys: Vec<PartitionTargetContextKey>,
169    pub file_path: PathBuf,
170    pub full_path: PathBuf,
171}
172
173#[cfg(feature = "python")]
174#[pyo3::pymethods]
175impl PartitionTargetContext {
176    #[getter]
177    pub fn file_idx(&self) -> usize {
178        self.file_idx
179    }
180    #[getter]
181    pub fn part_idx(&self) -> usize {
182        self.part_idx
183    }
184    #[getter]
185    pub fn in_part_idx(&self) -> usize {
186        self.in_part_idx
187    }
188    #[getter]
189    pub fn keys(&self) -> Vec<PartitionTargetContextKey> {
190        self.keys.clone()
191    }
192    #[getter]
193    pub fn file_path(&self) -> &std::path::Path {
194        self.file_path.as_path()
195    }
196    #[getter]
197    pub fn full_path(&self) -> &std::path::Path {
198        self.full_path.as_path()
199    }
200}
201#[cfg(feature = "python")]
202#[pyo3::pymethods]
203impl PartitionTargetContextKey {
204    #[getter]
205    pub fn name(&self) -> &str {
206        self.name.as_str()
207    }
208    #[getter]
209    pub fn str_value(&self) -> pyo3::PyResult<String> {
210        let value = self
211            .raw_value
212            .clone()
213            .into_series(PlSmallStr::EMPTY)
214            .strict_cast(&DataType::String)
215            .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?;
216        let value = value.str().unwrap();
217        let value = value.get(0).unwrap_or("null").as_bytes();
218        let value = percent_encoding::percent_encode(value, polars_io::utils::URL_ENCODE_CHAR_SET);
219        Ok(value.to_string())
220    }
221    #[getter]
222    pub fn raw_value(&self) -> pyo3::PyObject {
223        let converter = polars_core::chunked_array::object::registry::get_pyobject_converter();
224        *(converter.as_ref())(self.raw_value.as_any_value())
225            .downcast::<pyo3::PyObject>()
226            .unwrap()
227    }
228}
229
230#[derive(Clone, Debug, PartialEq)]
231pub enum PartitionTargetCallback {
232    Rust(SpecialEq<Arc<dyn Fn(PartitionTargetContext) -> PolarsResult<SinkTarget> + Send + Sync>>),
233    #[cfg(feature = "python")]
234    Python(polars_utils::python_function::PythonFunction),
235}
236
237impl PartitionTargetCallback {
238    pub fn call(&self, ctx: PartitionTargetContext) -> PolarsResult<SinkTarget> {
239        match self {
240            Self::Rust(f) => f(ctx),
241            #[cfg(feature = "python")]
242            Self::Python(f) => pyo3::Python::with_gil(|py| {
243                let sink_target = f.call1(py, (ctx,))?;
244                let converter =
245                    polars_utils::python_convert_registry::get_python_convert_registry();
246                let sink_target = (converter.from_py.sink_target)(sink_target)?;
247                let sink_target = sink_target.downcast_ref::<SinkTarget>().unwrap().clone();
248                PolarsResult::Ok(sink_target)
249            }),
250        }
251    }
252}
253
254#[cfg(feature = "serde")]
255impl<'de> serde::Deserialize<'de> for PartitionTargetCallback {
256    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
257    where
258        D: serde::Deserializer<'de>,
259    {
260        #[cfg(feature = "python")]
261        {
262            Ok(Self::Python(
263                polars_utils::python_function::PythonFunction::deserialize(_deserializer)?,
264            ))
265        }
266        #[cfg(not(feature = "python"))]
267        {
268            use serde::de::Error;
269            Err(D::Error::custom(
270                "cannot deserialize PartitionOutputCallback",
271            ))
272        }
273    }
274}
275
276#[cfg(feature = "serde")]
277impl serde::Serialize for PartitionTargetCallback {
278    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
279    where
280        S: serde::Serializer,
281    {
282        use serde::ser::Error;
283
284        #[cfg(feature = "python")]
285        if let Self::Python(v) = self {
286            return v.serialize(_serializer);
287        }
288
289        Err(S::Error::custom(format!("cannot serialize {:?}", self)))
290    }
291}
292
293#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
294#[derive(Clone, Debug, PartialEq)]
295pub struct PartitionSinkType {
296    pub base_path: Arc<PathBuf>,
297    pub file_path_cb: Option<PartitionTargetCallback>,
298    pub file_type: FileType,
299    pub sink_options: SinkOptions,
300    pub variant: PartitionVariant,
301    pub cloud_options: Option<polars_io::cloud::CloudOptions>,
302}
303
304#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
305#[derive(Clone, Debug, PartialEq)]
306pub struct PartitionSinkTypeIR {
307    pub base_path: Arc<PathBuf>,
308    pub file_path_cb: Option<PartitionTargetCallback>,
309    pub file_type: FileType,
310    pub sink_options: SinkOptions,
311    pub variant: PartitionVariantIR,
312    pub cloud_options: Option<polars_io::cloud::CloudOptions>,
313}
314
315#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
316#[derive(Clone, Debug, PartialEq)]
317pub enum SinkType {
318    Memory,
319    File(FileSinkType),
320    Partition(PartitionSinkType),
321}
322
323#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
324#[derive(Clone, Debug, PartialEq, Eq, Hash)]
325pub enum PartitionVariant {
326    MaxSize(IdxSize),
327    Parted {
328        key_exprs: Vec<Expr>,
329        include_key: bool,
330    },
331    ByKey {
332        key_exprs: Vec<Expr>,
333        include_key: bool,
334    },
335}
336
337#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
338#[derive(Clone, Debug, PartialEq, Eq)]
339pub enum PartitionVariantIR {
340    MaxSize(IdxSize),
341    Parted {
342        key_exprs: Vec<ExprIR>,
343        include_key: bool,
344    },
345    ByKey {
346        key_exprs: Vec<ExprIR>,
347        include_key: bool,
348    },
349}
350
351impl SinkTypeIR {
352    #[cfg(feature = "cse")]
353    pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
354        std::mem::discriminant(self).hash(state);
355        match self {
356            Self::Memory => {},
357            Self::File(f) => f.hash(state),
358            Self::Partition(f) => {
359                f.file_type.hash(state);
360                f.sink_options.hash(state);
361                f.variant.traverse_and_hash(expr_arena, state);
362                f.cloud_options.hash(state);
363            },
364        }
365    }
366}
367
368impl PartitionVariantIR {
369    #[cfg(feature = "cse")]
370    pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
371        std::mem::discriminant(self).hash(state);
372        match self {
373            Self::MaxSize(size) => size.hash(state),
374            Self::Parted {
375                key_exprs,
376                include_key,
377            }
378            | Self::ByKey {
379                key_exprs,
380                include_key,
381            } => {
382                include_key.hash(state);
383                for key_expr in key_exprs.as_slice() {
384                    key_expr.traverse_and_hash(expr_arena, state);
385                }
386            },
387        }
388    }
389}
390
391#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
392#[derive(Clone, Debug)]
393pub struct FileSinkOptions {
394    pub path: Arc<PathBuf>,
395    pub file_type: FileType,
396}