1use std::fmt;
2use std::io::{Read, Write};
3use std::sync::{Arc, Mutex};
4
5use polars_utils::arena::Node;
6#[cfg(feature = "serde")]
7use polars_utils::pl_serialize;
8use recursive::recursive;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::*;
13pub static DSL_VERSION: (u16, u16) = (2, 0);
17static DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
18
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub enum DslPlan {
21 #[cfg(feature = "python")]
22 PythonScan {
23 options: crate::dsl::python_dsl::PythonOptionsDsl,
24 },
25 Filter {
27 input: Arc<DslPlan>,
28 predicate: Expr,
29 },
30 Cache {
32 input: Arc<DslPlan>,
33 id: usize,
34 },
35 Scan {
36 sources: ScanSources,
37 file_info: Option<FileInfo>,
39 unified_scan_args: Box<UnifiedScanArgs>,
40 scan_type: Box<FileScan>,
41 #[cfg_attr(feature = "serde", serde(skip))]
45 cached_ir: Arc<Mutex<Option<IR>>>,
46 },
47 DataFrameScan {
50 df: Arc<DataFrame>,
51 schema: SchemaRef,
52 },
53 Select {
55 expr: Vec<Expr>,
56 input: Arc<DslPlan>,
57 options: ProjectionOptions,
58 },
59 GroupBy {
61 input: Arc<DslPlan>,
62 keys: Vec<Expr>,
63 aggs: Vec<Expr>,
64 maintain_order: bool,
65 options: Arc<GroupbyOptions>,
66 #[cfg_attr(feature = "serde", serde(skip))]
67 apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
68 },
69 Join {
71 input_left: Arc<DslPlan>,
72 input_right: Arc<DslPlan>,
73 left_on: Vec<Expr>,
75 right_on: Vec<Expr>,
76 predicates: Vec<Expr>,
78 options: Arc<JoinOptions>,
79 },
80 HStack {
82 input: Arc<DslPlan>,
83 exprs: Vec<Expr>,
84 options: ProjectionOptions,
85 },
86 Distinct {
88 input: Arc<DslPlan>,
89 options: DistinctOptionsDSL,
90 },
91 Sort {
93 input: Arc<DslPlan>,
94 by_column: Vec<Expr>,
95 slice: Option<(i64, usize)>,
96 sort_options: SortMultipleOptions,
97 },
98 Slice {
100 input: Arc<DslPlan>,
101 offset: i64,
102 len: IdxSize,
103 },
104 MapFunction {
106 input: Arc<DslPlan>,
107 function: DslFunction,
108 },
109 Union {
111 inputs: Vec<DslPlan>,
112 args: UnionArgs,
113 },
114 HConcat {
116 inputs: Vec<DslPlan>,
117 options: HConcatOptions,
118 },
119 ExtContext {
121 input: Arc<DslPlan>,
122 contexts: Vec<DslPlan>,
123 },
124 Sink {
125 input: Arc<DslPlan>,
126 payload: SinkType,
127 },
128 SinkMultiple {
129 inputs: Vec<DslPlan>,
130 },
131 #[cfg(feature = "merge_sorted")]
132 MergeSorted {
133 input_left: Arc<DslPlan>,
134 input_right: Arc<DslPlan>,
135 key: PlSmallStr,
136 },
137 IR {
138 dsl: Arc<DslPlan>,
140 version: u32,
141 #[cfg_attr(feature = "serde", serde(skip))]
142 node: Option<Node>,
143 },
144}
145
146impl Clone for DslPlan {
147 #[rustfmt::skip]
150 #[allow(clippy::clone_on_copy)]
151 #[recursive]
152 fn clone(&self) -> Self {
153 match self {
154 #[cfg(feature = "python")]
155 Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
156 Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
157 Self::Cache { input, id } => Self::Cache { input: input.clone(), id: id.clone() },
158 Self::Scan { sources, file_info, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), file_info: file_info.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
159 Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), },
160 Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
161 Self::GroupBy { input, keys, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
162 Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
163 Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() },
164 Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
165 Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
166 Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
167 Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
168 Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
169 Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
170 Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
171 Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
172 Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
173 #[cfg(feature = "merge_sorted")]
174 Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
175 Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
176 }
177 }
178}
179
180impl Default for DslPlan {
181 fn default() -> Self {
182 let df = DataFrame::empty();
183 let schema = df.schema().clone();
184 DslPlan::DataFrameScan {
185 df: Arc::new(df),
186 schema,
187 }
188 }
189}
190
191impl DslPlan {
192 pub fn describe(&self) -> PolarsResult<String> {
193 Ok(self.clone().to_alp()?.describe())
194 }
195
196 pub fn describe_tree_format(&self) -> PolarsResult<String> {
197 Ok(self.clone().to_alp()?.describe_tree_format())
198 }
199
200 pub fn display(&self) -> PolarsResult<impl fmt::Display> {
201 struct DslPlanDisplay(IRPlan);
202 impl fmt::Display for DslPlanDisplay {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 fmt::Display::fmt(&self.0.as_ref().display(), f)
205 }
206 }
207 Ok(DslPlanDisplay(self.clone().to_alp()?))
208 }
209
210 pub fn to_alp(self) -> PolarsResult<IRPlan> {
211 let mut lp_arena = Arena::with_capacity(16);
212 let mut expr_arena = Arena::with_capacity(16);
213
214 let node = to_alp(
215 self,
216 &mut expr_arena,
217 &mut lp_arena,
218 &mut OptFlags::default(),
219 )?;
220 let plan = IRPlan::new(node, lp_arena, expr_arena);
221
222 Ok(plan)
223 }
224
225 #[cfg(feature = "serde")]
226 pub fn serialize_versioned<W: Write>(&self, mut writer: W) -> PolarsResult<()> {
227 let le_major = DSL_VERSION.0.to_le_bytes();
228 let le_minor = DSL_VERSION.1.to_le_bytes();
229 writer.write_all(DSL_MAGIC_BYTES)?;
230 writer.write_all(&le_major)?;
231 writer.write_all(&le_minor)?;
232 pl_serialize::SerializeOptions::default().serialize_into_writer::<_, _, true>(writer, self)
233 }
234
235 #[cfg(feature = "serde")]
236 pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
237 const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
238 let mut version_magic = [0u8; MAGIC_LEN + 4];
239 reader.read_exact(&mut version_magic)?;
240
241 if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
242 polars_bail!(ComputeError: "dsl magic bytes not found")
243 }
244
245 let major = u16::from_be_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
249 let minor = u16::from_be_bytes(
250 version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
251 .try_into()
252 .unwrap(),
253 );
254
255 pl_serialize::SerializeOptions::default()
256 .deserialize_from_reader::<_, _, true>(reader).map_err(|e| {
257 polars_err!(ComputeError: "deserialization failed\n\ngiven DSL_VERSION: {:?} is not compatible with this Polars version which uses DSL_VERSION: {:?}\nerror: {}", (major, minor), DSL_VERSION, e)
258 })
259 }
260}