polars_plan/dsl/
plan.rs

1use std::fmt;
2use std::io::{Read, Write};
3use std::sync::{Arc, Mutex};
4
5use polars_utils::arena::Node;
6#[cfg(feature = "serde")]
7use polars_utils::pl_serialize;
8use recursive::recursive;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::*;
13
14// DSL version in a form of (Major, Minor).
15//
16// Serialized DSL is compatible with a deserializer, if:
17// - the serialized Major version and the deserializer Major version are equal, and
18// - there are no unknown fields in the serialized DSL.
19//
20// The following sections describe when to increment the version. If unsure, ask.
21//
22// # Minor version
23//
24// Increment Minor if you're extending the DSL without breaking backward compatibility.
25// - DSL serialized with this Polars version is NOT fully compatible with the previous version,
26// - DSL serialized with the previous Polars version is still fully compatible with this version.
27//
28// You need to be sure that every possible DSL serialized with the previous Polars version is still
29// valid and has the same meaning in this Polars version.
30//
31// Allowed changes:
32// - adding a new enum variant,
33// - adding a field with a default value, where the default value matches the behavior of the
34//   previous Polars version that didn't have this field,
35// - adding new flags to bitflags; again, the default value has to preserve the previous behavior,
36// - allowing field values that were previously rejected, e.g. a value that would cause an error or
37//   panic if it was greater than 10 can be allowed to go up to 20 in the new version).
38//
39// # Major version
40//
41// Increment Major and reset Minor to zero if you're breaking backward compatibility:
42// - DSL serialized with the previous Polars version is NOT compatible with this Polars version.
43//
44// Examples:
45// - adding a field that doesn't have a default (or the default doesn't match the behavior
46//   of the previous version),
47// - removing a field or an enum variant
48// - changing a name, type, or meaning of a field or an enum variant
49// - changing a default value of a field or a default enum variant
50// - restricting the range of allowed values a field can have
51pub static DSL_VERSION: (u16, u16) = (4, 1);
52static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
53
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub enum DslPlan {
56    #[cfg(feature = "python")]
57    PythonScan {
58        options: crate::dsl::python_dsl::PythonOptionsDsl,
59    },
60    /// Filter on a boolean mask
61    Filter {
62        input: Arc<DslPlan>,
63        predicate: Expr,
64    },
65    /// Cache the input at this point in the LP
66    Cache {
67        input: Arc<DslPlan>,
68        id: usize,
69    },
70    Scan {
71        sources: ScanSources,
72        /// Materialized at IR except for AnonymousScan.
73        file_info: Option<FileInfo>,
74        unified_scan_args: Box<UnifiedScanArgs>,
75        scan_type: Box<FileScan>,
76        /// Local use cases often repeatedly collect the same `LazyFrame` (e.g. in interactive notebook use-cases),
77        /// so we cache the IR conversion here, as the path expansion can be quite slow (especially for cloud paths).
78        /// We don't have the arena, as this is always a source node.
79        #[cfg_attr(feature = "serde", serde(skip))]
80        cached_ir: Arc<Mutex<Option<IR>>>,
81    },
82    // we keep track of the projection and selection as it is cheaper to first project and then filter
83    /// In memory DataFrame
84    DataFrameScan {
85        df: Arc<DataFrame>,
86        schema: SchemaRef,
87    },
88    /// Polars' `select` operation, this can mean projection, but also full data access.
89    Select {
90        expr: Vec<Expr>,
91        input: Arc<DslPlan>,
92        options: ProjectionOptions,
93    },
94    /// Groupby aggregation
95    GroupBy {
96        input: Arc<DslPlan>,
97        keys: Vec<Expr>,
98        aggs: Vec<Expr>,
99        maintain_order: bool,
100        options: Arc<GroupbyOptions>,
101        #[cfg_attr(feature = "serde", serde(skip))]
102        apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
103    },
104    /// Join operation
105    Join {
106        input_left: Arc<DslPlan>,
107        input_right: Arc<DslPlan>,
108        // Invariant: left_on and right_on are equal length.
109        left_on: Vec<Expr>,
110        right_on: Vec<Expr>,
111        // Invariant: Either left_on/right_on or predicates is set (non-empty).
112        predicates: Vec<Expr>,
113        options: Arc<JoinOptions>,
114    },
115    /// Adding columns to the table without a Join
116    HStack {
117        input: Arc<DslPlan>,
118        exprs: Vec<Expr>,
119        options: ProjectionOptions,
120    },
121    /// Match / Evolve into a schema
122    MatchToSchema {
123        input: Arc<DslPlan>,
124        /// The schema to match to.
125        ///
126        /// This is also always the output schema.
127        match_schema: SchemaRef,
128
129        per_column: Arc<[MatchToSchemaPerColumn]>,
130
131        extra_columns: ExtraColumnsPolicy,
132    },
133    /// Remove duplicates from the table
134    Distinct {
135        input: Arc<DslPlan>,
136        options: DistinctOptionsDSL,
137    },
138    /// Sort the table
139    Sort {
140        input: Arc<DslPlan>,
141        by_column: Vec<Expr>,
142        slice: Option<(i64, usize)>,
143        sort_options: SortMultipleOptions,
144    },
145    /// Slice the table
146    Slice {
147        input: Arc<DslPlan>,
148        offset: i64,
149        len: IdxSize,
150    },
151    /// A (User Defined) Function
152    MapFunction {
153        input: Arc<DslPlan>,
154        function: DslFunction,
155    },
156    /// Vertical concatenation
157    Union {
158        inputs: Vec<DslPlan>,
159        args: UnionArgs,
160    },
161    /// Horizontal concatenation of multiple plans
162    HConcat {
163        inputs: Vec<DslPlan>,
164        options: HConcatOptions,
165    },
166    /// This allows expressions to access other tables
167    ExtContext {
168        input: Arc<DslPlan>,
169        contexts: Vec<DslPlan>,
170    },
171    Sink {
172        input: Arc<DslPlan>,
173        payload: SinkType,
174    },
175    SinkMultiple {
176        inputs: Vec<DslPlan>,
177    },
178    #[cfg(feature = "merge_sorted")]
179    MergeSorted {
180        input_left: Arc<DslPlan>,
181        input_right: Arc<DslPlan>,
182        key: PlSmallStr,
183    },
184    IR {
185        // Keep the original Dsl around as we need that for serialization.
186        dsl: Arc<DslPlan>,
187        version: u32,
188        #[cfg_attr(feature = "serde", serde(skip))]
189        node: Option<Node>,
190    },
191}
192
193impl Clone for DslPlan {
194    // Autogenerated by rust-analyzer, don't care about it looking nice, it just
195    // calls clone on every member of every enum variant.
196    #[rustfmt::skip]
197    #[allow(clippy::clone_on_copy)]
198    #[recursive]
199    fn clone(&self) -> Self {
200        match self {
201            #[cfg(feature = "python")]
202            Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
203            Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
204            Self::Cache { input, id } => Self::Cache { input: input.clone(), id: id.clone() },
205            Self::Scan { sources, file_info, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), file_info: file_info.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
206            Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(),  },
207            Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
208            Self::GroupBy { input, keys, aggs,  apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
209            Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
210            Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(),  options: options.clone() },
211            Self::MatchToSchema { input, match_schema, per_column, extra_columns } => Self::MatchToSchema { input: input.clone(), match_schema: match_schema.clone(), per_column: per_column.clone(), extra_columns: *extra_columns },
212            Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
213            Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
214            Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
215            Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
216            Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
217            Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
218            Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
219            Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
220            Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
221            #[cfg(feature = "merge_sorted")]
222            Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
223            Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
224        }
225    }
226}
227
228impl Default for DslPlan {
229    fn default() -> Self {
230        let df = DataFrame::empty();
231        let schema = df.schema().clone();
232        DslPlan::DataFrameScan {
233            df: Arc::new(df),
234            schema,
235        }
236    }
237}
238
239impl DslPlan {
240    pub fn describe(&self) -> PolarsResult<String> {
241        Ok(self.clone().to_alp()?.describe())
242    }
243
244    pub fn describe_tree_format(&self) -> PolarsResult<String> {
245        Ok(self.clone().to_alp()?.describe_tree_format())
246    }
247
248    pub fn display(&self) -> PolarsResult<impl fmt::Display> {
249        struct DslPlanDisplay(IRPlan);
250        impl fmt::Display for DslPlanDisplay {
251            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252                fmt::Display::fmt(&self.0.as_ref().display(), f)
253            }
254        }
255        Ok(DslPlanDisplay(self.clone().to_alp()?))
256    }
257
258    pub fn to_alp(self) -> PolarsResult<IRPlan> {
259        let mut lp_arena = Arena::with_capacity(16);
260        let mut expr_arena = Arena::with_capacity(16);
261
262        let node = to_alp(
263            self,
264            &mut expr_arena,
265            &mut lp_arena,
266            &mut OptFlags::default(),
267        )?;
268        let plan = IRPlan::new(node, lp_arena, expr_arena);
269
270        Ok(plan)
271    }
272
273    #[cfg(feature = "serde")]
274    pub fn serialize_versioned<W: Write>(&self, mut writer: W) -> PolarsResult<()> {
275        let le_major = DSL_VERSION.0.to_le_bytes();
276        let le_minor = DSL_VERSION.1.to_le_bytes();
277        writer.write_all(DSL_MAGIC_BYTES)?;
278        writer.write_all(&le_major)?;
279        writer.write_all(&le_minor)?;
280        pl_serialize::SerializeOptions::default().serialize_into_writer::<_, _, true>(writer, self)
281    }
282
283    #[cfg(feature = "serde")]
284    pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
285        const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
286        let mut version_magic = [0u8; MAGIC_LEN + 4];
287        reader
288            .read_exact(&mut version_magic)
289            .map_err(|e| polars_err!(ComputeError: "failed to read incoming DSL_VERSION: {e}"))?;
290
291        if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
292            polars_bail!(ComputeError: "dsl magic bytes not found")
293        }
294
295        let major = u16::from_le_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
296        let minor = u16::from_le_bytes(
297            version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
298                .try_into()
299                .unwrap(),
300        );
301
302        const MAJOR: u16 = DSL_VERSION.0;
303        const MINOR: u16 = DSL_VERSION.1;
304
305        if polars_core::config::verbose() {
306            eprintln!(
307                "incoming DSL_VERSION: {major}.{minor}, deserializer DSL_VERSION: {MAJOR}.{MINOR}"
308            );
309        }
310
311        if major != MAJOR {
312            polars_bail!(ComputeError:
313                "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
314                "error: can't deserialize DSL with a different major version"
315            );
316        }
317
318        let (dsl, unknown_fields) = pl_serialize::SerializeOptions::default().deserialize_from_reader_with_unknown_fields(reader).map_err(|e| {
319            // The DSL serialization is forward compatible if there are no unknown fields
320            if minor > MINOR {
321                // Convey that the failure might also be due to broken forward compatibility
322                polars_err!(ComputeError:
323                    "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is higher than this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}\nerror: {e}",
324                    "either the input is malformed, or the plan requires functionality not supported in this Polars version"
325                )
326            } else {
327                polars_err!(ComputeError:
328                    "deserialization failed\n\nerror: {e}",
329                )
330            }
331        })?;
332
333        if !unknown_fields.is_empty() {
334            if minor > MINOR {
335                polars_bail!(ComputeError:
336                    "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is higher than this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}\nencountered unknown fields: {:?}",
337                    "the plan requires functionality not supported in this Polars version",
338                    unknown_fields,
339                )
340            } else {
341                polars_bail!(ComputeError:
342                    "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} should be supported in this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\nencountered unknown fields: {:?}",
343                    unknown_fields,
344                )
345            }
346        }
347
348        Ok(dsl)
349    }
350}