1use std::fmt;
2use std::io::{Read, Write};
3use std::sync::{Arc, Mutex};
4
5use polars_utils::arena::Node;
6#[cfg(feature = "serde")]
7use polars_utils::pl_serialize;
8use recursive::recursive;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::*;
13
14pub static DSL_VERSION: (u16, u16) = (4, 1);
52static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
53
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub enum DslPlan {
56 #[cfg(feature = "python")]
57 PythonScan {
58 options: crate::dsl::python_dsl::PythonOptionsDsl,
59 },
60 Filter {
62 input: Arc<DslPlan>,
63 predicate: Expr,
64 },
65 Cache {
67 input: Arc<DslPlan>,
68 id: usize,
69 },
70 Scan {
71 sources: ScanSources,
72 file_info: Option<FileInfo>,
74 unified_scan_args: Box<UnifiedScanArgs>,
75 scan_type: Box<FileScan>,
76 #[cfg_attr(feature = "serde", serde(skip))]
80 cached_ir: Arc<Mutex<Option<IR>>>,
81 },
82 DataFrameScan {
85 df: Arc<DataFrame>,
86 schema: SchemaRef,
87 },
88 Select {
90 expr: Vec<Expr>,
91 input: Arc<DslPlan>,
92 options: ProjectionOptions,
93 },
94 GroupBy {
96 input: Arc<DslPlan>,
97 keys: Vec<Expr>,
98 aggs: Vec<Expr>,
99 maintain_order: bool,
100 options: Arc<GroupbyOptions>,
101 #[cfg_attr(feature = "serde", serde(skip))]
102 apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
103 },
104 Join {
106 input_left: Arc<DslPlan>,
107 input_right: Arc<DslPlan>,
108 left_on: Vec<Expr>,
110 right_on: Vec<Expr>,
111 predicates: Vec<Expr>,
113 options: Arc<JoinOptions>,
114 },
115 HStack {
117 input: Arc<DslPlan>,
118 exprs: Vec<Expr>,
119 options: ProjectionOptions,
120 },
121 MatchToSchema {
123 input: Arc<DslPlan>,
124 match_schema: SchemaRef,
128
129 per_column: Arc<[MatchToSchemaPerColumn]>,
130
131 extra_columns: ExtraColumnsPolicy,
132 },
133 Distinct {
135 input: Arc<DslPlan>,
136 options: DistinctOptionsDSL,
137 },
138 Sort {
140 input: Arc<DslPlan>,
141 by_column: Vec<Expr>,
142 slice: Option<(i64, usize)>,
143 sort_options: SortMultipleOptions,
144 },
145 Slice {
147 input: Arc<DslPlan>,
148 offset: i64,
149 len: IdxSize,
150 },
151 MapFunction {
153 input: Arc<DslPlan>,
154 function: DslFunction,
155 },
156 Union {
158 inputs: Vec<DslPlan>,
159 args: UnionArgs,
160 },
161 HConcat {
163 inputs: Vec<DslPlan>,
164 options: HConcatOptions,
165 },
166 ExtContext {
168 input: Arc<DslPlan>,
169 contexts: Vec<DslPlan>,
170 },
171 Sink {
172 input: Arc<DslPlan>,
173 payload: SinkType,
174 },
175 SinkMultiple {
176 inputs: Vec<DslPlan>,
177 },
178 #[cfg(feature = "merge_sorted")]
179 MergeSorted {
180 input_left: Arc<DslPlan>,
181 input_right: Arc<DslPlan>,
182 key: PlSmallStr,
183 },
184 IR {
185 dsl: Arc<DslPlan>,
187 version: u32,
188 #[cfg_attr(feature = "serde", serde(skip))]
189 node: Option<Node>,
190 },
191}
192
193impl Clone for DslPlan {
194 #[rustfmt::skip]
197 #[allow(clippy::clone_on_copy)]
198 #[recursive]
199 fn clone(&self) -> Self {
200 match self {
201 #[cfg(feature = "python")]
202 Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
203 Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
204 Self::Cache { input, id } => Self::Cache { input: input.clone(), id: id.clone() },
205 Self::Scan { sources, file_info, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), file_info: file_info.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
206 Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), },
207 Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
208 Self::GroupBy { input, keys, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
209 Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
210 Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() },
211 Self::MatchToSchema { input, match_schema, per_column, extra_columns } => Self::MatchToSchema { input: input.clone(), match_schema: match_schema.clone(), per_column: per_column.clone(), extra_columns: *extra_columns },
212 Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
213 Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
214 Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
215 Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
216 Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
217 Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
218 Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
219 Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
220 Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
221 #[cfg(feature = "merge_sorted")]
222 Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
223 Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
224 }
225 }
226}
227
228impl Default for DslPlan {
229 fn default() -> Self {
230 let df = DataFrame::empty();
231 let schema = df.schema().clone();
232 DslPlan::DataFrameScan {
233 df: Arc::new(df),
234 schema,
235 }
236 }
237}
238
239impl DslPlan {
240 pub fn describe(&self) -> PolarsResult<String> {
241 Ok(self.clone().to_alp()?.describe())
242 }
243
244 pub fn describe_tree_format(&self) -> PolarsResult<String> {
245 Ok(self.clone().to_alp()?.describe_tree_format())
246 }
247
248 pub fn display(&self) -> PolarsResult<impl fmt::Display> {
249 struct DslPlanDisplay(IRPlan);
250 impl fmt::Display for DslPlanDisplay {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 fmt::Display::fmt(&self.0.as_ref().display(), f)
253 }
254 }
255 Ok(DslPlanDisplay(self.clone().to_alp()?))
256 }
257
258 pub fn to_alp(self) -> PolarsResult<IRPlan> {
259 let mut lp_arena = Arena::with_capacity(16);
260 let mut expr_arena = Arena::with_capacity(16);
261
262 let node = to_alp(
263 self,
264 &mut expr_arena,
265 &mut lp_arena,
266 &mut OptFlags::default(),
267 )?;
268 let plan = IRPlan::new(node, lp_arena, expr_arena);
269
270 Ok(plan)
271 }
272
273 #[cfg(feature = "serde")]
274 pub fn serialize_versioned<W: Write>(&self, mut writer: W) -> PolarsResult<()> {
275 let le_major = DSL_VERSION.0.to_le_bytes();
276 let le_minor = DSL_VERSION.1.to_le_bytes();
277 writer.write_all(DSL_MAGIC_BYTES)?;
278 writer.write_all(&le_major)?;
279 writer.write_all(&le_minor)?;
280 pl_serialize::SerializeOptions::default().serialize_into_writer::<_, _, true>(writer, self)
281 }
282
283 #[cfg(feature = "serde")]
284 pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
285 const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
286 let mut version_magic = [0u8; MAGIC_LEN + 4];
287 reader
288 .read_exact(&mut version_magic)
289 .map_err(|e| polars_err!(ComputeError: "failed to read incoming DSL_VERSION: {e}"))?;
290
291 if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
292 polars_bail!(ComputeError: "dsl magic bytes not found")
293 }
294
295 let major = u16::from_le_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
296 let minor = u16::from_le_bytes(
297 version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
298 .try_into()
299 .unwrap(),
300 );
301
302 const MAJOR: u16 = DSL_VERSION.0;
303 const MINOR: u16 = DSL_VERSION.1;
304
305 if polars_core::config::verbose() {
306 eprintln!(
307 "incoming DSL_VERSION: {major}.{minor}, deserializer DSL_VERSION: {MAJOR}.{MINOR}"
308 );
309 }
310
311 if major != MAJOR {
312 polars_bail!(ComputeError:
313 "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
314 "error: can't deserialize DSL with a different major version"
315 );
316 }
317
318 let (dsl, unknown_fields) = pl_serialize::SerializeOptions::default().deserialize_from_reader_with_unknown_fields(reader).map_err(|e| {
319 if minor > MINOR {
321 polars_err!(ComputeError:
323 "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is higher than this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}\nerror: {e}",
324 "either the input is malformed, or the plan requires functionality not supported in this Polars version"
325 )
326 } else {
327 polars_err!(ComputeError:
328 "deserialization failed\n\nerror: {e}",
329 )
330 }
331 })?;
332
333 if !unknown_fields.is_empty() {
334 if minor > MINOR {
335 polars_bail!(ComputeError:
336 "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is higher than this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}\nencountered unknown fields: {:?}",
337 "the plan requires functionality not supported in this Polars version",
338 unknown_fields,
339 )
340 } else {
341 polars_bail!(ComputeError:
342 "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} should be supported in this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\nencountered unknown fields: {:?}",
343 unknown_fields,
344 )
345 }
346 }
347
348 Ok(dsl)
349 }
350}