polars_plan/dsl/
plan.rs

1use std::fmt;
2use std::io::{Read, Write};
3use std::sync::{Arc, Mutex};
4
5use polars_utils::arena::Node;
6#[cfg(feature = "serde")]
7use polars_utils::pl_serialize;
8use recursive::recursive;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::*;
13// (Major, Minor)
14// Add a field -> increment minor
15// Remove or modify a field -> increment major and reset minor
16pub static DSL_VERSION: (u16, u16) = (2, 0);
17static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
18
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub enum DslPlan {
21    #[cfg(feature = "python")]
22    PythonScan {
23        options: crate::dsl::python_dsl::PythonOptionsDsl,
24    },
25    /// Filter on a boolean mask
26    Filter {
27        input: Arc<DslPlan>,
28        predicate: Expr,
29    },
30    /// Cache the input at this point in the LP
31    Cache {
32        input: Arc<DslPlan>,
33        id: usize,
34    },
35    Scan {
36        sources: ScanSources,
37        /// Materialized at IR except for AnonymousScan.
38        file_info: Option<FileInfo>,
39        unified_scan_args: Box<UnifiedScanArgs>,
40        scan_type: Box<FileScan>,
41        /// Local use cases often repeatedly collect the same `LazyFrame` (e.g. in interactive notebook use-cases),
42        /// so we cache the IR conversion here, as the path expansion can be quite slow (especially for cloud paths).
43        /// We don't have the arena, as this is always a source node.
44        #[cfg_attr(feature = "serde", serde(skip))]
45        cached_ir: Arc<Mutex<Option<IR>>>,
46    },
47    // we keep track of the projection and selection as it is cheaper to first project and then filter
48    /// In memory DataFrame
49    DataFrameScan {
50        df: Arc<DataFrame>,
51        schema: SchemaRef,
52    },
53    /// Polars' `select` operation, this can mean projection, but also full data access.
54    Select {
55        expr: Vec<Expr>,
56        input: Arc<DslPlan>,
57        options: ProjectionOptions,
58    },
59    /// Groupby aggregation
60    GroupBy {
61        input: Arc<DslPlan>,
62        keys: Vec<Expr>,
63        aggs: Vec<Expr>,
64        maintain_order: bool,
65        options: Arc<GroupbyOptions>,
66        #[cfg_attr(feature = "serde", serde(skip))]
67        apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
68    },
69    /// Join operation
70    Join {
71        input_left: Arc<DslPlan>,
72        input_right: Arc<DslPlan>,
73        // Invariant: left_on and right_on are equal length.
74        left_on: Vec<Expr>,
75        right_on: Vec<Expr>,
76        // Invariant: Either left_on/right_on or predicates is set (non-empty).
77        predicates: Vec<Expr>,
78        options: Arc<JoinOptions>,
79    },
80    /// Adding columns to the table without a Join
81    HStack {
82        input: Arc<DslPlan>,
83        exprs: Vec<Expr>,
84        options: ProjectionOptions,
85    },
86    /// Remove duplicates from the table
87    Distinct {
88        input: Arc<DslPlan>,
89        options: DistinctOptionsDSL,
90    },
91    /// Sort the table
92    Sort {
93        input: Arc<DslPlan>,
94        by_column: Vec<Expr>,
95        slice: Option<(i64, usize)>,
96        sort_options: SortMultipleOptions,
97    },
98    /// Slice the table
99    Slice {
100        input: Arc<DslPlan>,
101        offset: i64,
102        len: IdxSize,
103    },
104    /// A (User Defined) Function
105    MapFunction {
106        input: Arc<DslPlan>,
107        function: DslFunction,
108    },
109    /// Vertical concatenation
110    Union {
111        inputs: Vec<DslPlan>,
112        args: UnionArgs,
113    },
114    /// Horizontal concatenation of multiple plans
115    HConcat {
116        inputs: Vec<DslPlan>,
117        options: HConcatOptions,
118    },
119    /// This allows expressions to access other tables
120    ExtContext {
121        input: Arc<DslPlan>,
122        contexts: Vec<DslPlan>,
123    },
124    Sink {
125        input: Arc<DslPlan>,
126        payload: SinkType,
127    },
128    SinkMultiple {
129        inputs: Vec<DslPlan>,
130    },
131    #[cfg(feature = "merge_sorted")]
132    MergeSorted {
133        input_left: Arc<DslPlan>,
134        input_right: Arc<DslPlan>,
135        key: PlSmallStr,
136    },
137    IR {
138        // Keep the original Dsl around as we need that for serialization.
139        dsl: Arc<DslPlan>,
140        version: u32,
141        #[cfg_attr(feature = "serde", serde(skip))]
142        node: Option<Node>,
143    },
144}
145
146impl Clone for DslPlan {
147    // Autogenerated by rust-analyzer, don't care about it looking nice, it just
148    // calls clone on every member of every enum variant.
149    #[rustfmt::skip]
150    #[allow(clippy::clone_on_copy)]
151    #[recursive]
152    fn clone(&self) -> Self {
153        match self {
154            #[cfg(feature = "python")]
155            Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
156            Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
157            Self::Cache { input, id } => Self::Cache { input: input.clone(), id: id.clone() },
158            Self::Scan { sources, file_info, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), file_info: file_info.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
159            Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(),  },
160            Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
161            Self::GroupBy { input, keys, aggs,  apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
162            Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
163            Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(),  options: options.clone() },
164            Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
165            Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
166            Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
167            Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
168            Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
169            Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
170            Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
171            Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
172            Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
173            #[cfg(feature = "merge_sorted")]
174            Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
175            Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
176        }
177    }
178}
179
180impl Default for DslPlan {
181    fn default() -> Self {
182        let df = DataFrame::empty();
183        let schema = df.schema().clone();
184        DslPlan::DataFrameScan {
185            df: Arc::new(df),
186            schema,
187        }
188    }
189}
190
191impl DslPlan {
192    pub fn describe(&self) -> PolarsResult<String> {
193        Ok(self.clone().to_alp()?.describe())
194    }
195
196    pub fn describe_tree_format(&self) -> PolarsResult<String> {
197        Ok(self.clone().to_alp()?.describe_tree_format())
198    }
199
200    pub fn display(&self) -> PolarsResult<impl fmt::Display> {
201        struct DslPlanDisplay(IRPlan);
202        impl fmt::Display for DslPlanDisplay {
203            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204                fmt::Display::fmt(&self.0.as_ref().display(), f)
205            }
206        }
207        Ok(DslPlanDisplay(self.clone().to_alp()?))
208    }
209
210    pub fn to_alp(self) -> PolarsResult<IRPlan> {
211        let mut lp_arena = Arena::with_capacity(16);
212        let mut expr_arena = Arena::with_capacity(16);
213
214        let node = to_alp(
215            self,
216            &mut expr_arena,
217            &mut lp_arena,
218            &mut OptFlags::default(),
219        )?;
220        let plan = IRPlan::new(node, lp_arena, expr_arena);
221
222        Ok(plan)
223    }
224
225    #[cfg(feature = "serde")]
226    pub fn serialize_versioned<W: Write>(&self, mut writer: W) -> PolarsResult<()> {
227        let le_major = DSL_VERSION.0.to_le_bytes();
228        let le_minor = DSL_VERSION.1.to_le_bytes();
229        writer.write_all(DSL_MAGIC_BYTES)?;
230        writer.write_all(&le_major)?;
231        writer.write_all(&le_minor)?;
232        pl_serialize::SerializeOptions::default().serialize_into_writer::<_, _, true>(writer, self)
233    }
234
235    #[cfg(feature = "serde")]
236    pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
237        const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
238        let mut version_magic = [0u8; MAGIC_LEN + 4];
239        reader.read_exact(&mut version_magic)?;
240
241        if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
242            polars_bail!(ComputeError: "dsl magic bytes not found")
243        }
244
245        // The DSL serialization is forward compatible if fields don't change,
246        // so we don't check equality here, we just use this version
247        // to inform users when the deserialization fails.
248        let major = u16::from_be_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
249        let minor = u16::from_be_bytes(
250            version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
251                .try_into()
252                .unwrap(),
253        );
254
255        pl_serialize::SerializeOptions::default()
256                    .deserialize_from_reader::<_, _, true>(reader).map_err(|e| {
257                    polars_err!(ComputeError: "deserialization failed\n\ngiven DSL_VERSION: {:?} is not compatible with this Polars version which uses DSL_VERSION: {:?}\nerror: {}", (major, minor), DSL_VERSION, e)
258                })
259    }
260}