polars_plan/plans/ir/
dot.rs

1use std::fmt;
2use std::path::PathBuf;
3
4use polars_core::schema::Schema;
5use polars_utils::pl_str::PlSmallStr;
6
7use super::format::ExprIRSliceDisplay;
8use crate::constants::UNLIMITED_CACHE;
9use crate::prelude::ir::format::ColumnsDisplay;
10use crate::prelude::*;
11
12pub struct IRDotDisplay<'a> {
13    is_streaming: bool,
14    lp: IRPlanRef<'a>,
15}
16
17const INDENT: &str = "  ";
18
19#[derive(Clone, Copy)]
20enum DotNode {
21    Plain(usize),
22    Cache(usize),
23}
24
25impl fmt::Display for DotNode {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        match self {
28            DotNode::Plain(n) => write!(f, "p{n}"),
29            DotNode::Cache(n) => write!(f, "c{n}"),
30        }
31    }
32}
33
34#[inline(always)]
35fn write_label<'a, 'b>(
36    f: &'a mut fmt::Formatter<'b>,
37    id: DotNode,
38    mut w: impl FnMut(&mut EscapeLabel<'a>) -> fmt::Result,
39) -> fmt::Result {
40    write!(f, "{INDENT}{id}[label=\"")?;
41
42    let mut escaped = EscapeLabel(f);
43    w(&mut escaped)?;
44    let EscapeLabel(f) = escaped;
45
46    writeln!(f, "\"]")?;
47
48    Ok(())
49}
50
51impl<'a> IRDotDisplay<'a> {
52    pub fn new(lp: IRPlanRef<'a>) -> Self {
53        if let Some(streaming_lp) = lp.extract_streaming_plan() {
54            return Self::new_streaming(streaming_lp);
55        }
56
57        Self {
58            is_streaming: false,
59            lp,
60        }
61    }
62
63    fn new_streaming(lp: IRPlanRef<'a>) -> Self {
64        Self {
65            is_streaming: true,
66            lp,
67        }
68    }
69
70    fn with_root(&self, root: Node) -> Self {
71        Self {
72            is_streaming: false,
73            lp: self.lp.with_root(root),
74        }
75    }
76
77    fn display_expr(&self, expr: &'a ExprIR) -> ExprIRDisplay<'a> {
78        expr.display(self.lp.expr_arena)
79    }
80
81    fn display_exprs(&self, exprs: &'a [ExprIR]) -> ExprIRSliceDisplay<'a, ExprIR> {
82        ExprIRSliceDisplay {
83            exprs,
84            expr_arena: self.lp.expr_arena,
85        }
86    }
87
88    fn _format(
89        &self,
90        f: &mut fmt::Formatter<'_>,
91        parent: Option<DotNode>,
92        last: &mut usize,
93    ) -> std::fmt::Result {
94        use fmt::Write;
95
96        let root = self.lp.root();
97
98        let mut parent = parent;
99        if self.is_streaming {
100            *last += 1;
101            let streaming_node = DotNode::Plain(*last);
102
103            if let Some(parent) = parent {
104                writeln!(f, "{INDENT}{parent} -- {streaming_node}")?;
105                write_label(f, streaming_node, |f| f.write_str("STREAMING"))?;
106            }
107
108            parent = Some(streaming_node);
109        }
110        let parent = parent;
111
112        let id = if let IR::Cache { id, .. } = root {
113            DotNode::Cache(*id)
114        } else {
115            *last += 1;
116            DotNode::Plain(*last)
117        };
118
119        if let Some(parent) = parent {
120            writeln!(f, "{INDENT}{parent} -- {id}")?;
121        }
122
123        use IR::*;
124        match root {
125            Union { inputs, .. } => {
126                for input in inputs {
127                    self.with_root(*input)._format(f, Some(id), last)?;
128                }
129
130                write_label(f, id, |f| f.write_str("UNION"))?;
131            },
132            HConcat { inputs, .. } => {
133                for input in inputs {
134                    self.with_root(*input)._format(f, Some(id), last)?;
135                }
136
137                write_label(f, id, |f| f.write_str("HCONCAT"))?;
138            },
139            Cache {
140                input, cache_hits, ..
141            } => {
142                self.with_root(*input)._format(f, Some(id), last)?;
143
144                if *cache_hits == UNLIMITED_CACHE {
145                    write_label(f, id, |f| f.write_str("CACHE"))?;
146                } else {
147                    write_label(f, id, |f| write!(f, "CACHE: {cache_hits} times"))?;
148                };
149            },
150            Filter { predicate, input } => {
151                self.with_root(*input)._format(f, Some(id), last)?;
152
153                let pred = self.display_expr(predicate);
154                write_label(f, id, |f| write!(f, "FILTER BY {pred}"))?;
155            },
156            #[cfg(feature = "python")]
157            PythonScan { options } => {
158                let predicate = match &options.predicate {
159                    PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)),
160                    PythonPredicate::PyArrow(s) => s.clone(),
161                    PythonPredicate::None => "none".to_string(),
162                };
163                let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref()));
164                let total_columns = options.schema.len();
165
166                write_label(f, id, |f| {
167                    write!(
168                        f,
169                        "PYTHON SCAN\nπ {with_columns}/{total_columns};\nσ {predicate}"
170                    )
171                })?
172            },
173            Select {
174                expr,
175                input,
176                schema,
177                ..
178            } => {
179                self.with_root(*input)._format(f, Some(id), last)?;
180                write_label(f, id, |f| write!(f, "π {}/{}", expr.len(), schema.len()))?;
181            },
182            Sort {
183                input, by_column, ..
184            } => {
185                let by_column = self.display_exprs(by_column);
186                self.with_root(*input)._format(f, Some(id), last)?;
187                write_label(f, id, |f| write!(f, "SORT BY {by_column}"))?;
188            },
189            GroupBy {
190                input, keys, aggs, ..
191            } => {
192                let keys = self.display_exprs(keys);
193                let aggs = self.display_exprs(aggs);
194                self.with_root(*input)._format(f, Some(id), last)?;
195                write_label(f, id, |f| write!(f, "AGG {aggs}\nBY\n{keys}"))?;
196            },
197            HStack { input, exprs, .. } => {
198                let exprs = self.display_exprs(exprs);
199                self.with_root(*input)._format(f, Some(id), last)?;
200                write_label(f, id, |f| write!(f, "WITH COLUMNS {exprs}"))?;
201            },
202            Slice { input, offset, len } => {
203                self.with_root(*input)._format(f, Some(id), last)?;
204                write_label(f, id, |f| write!(f, "SLICE offset: {offset}; len: {len}"))?;
205            },
206            Distinct { input, options, .. } => {
207                self.with_root(*input)._format(f, Some(id), last)?;
208                write_label(f, id, |f| {
209                    f.write_str("DISTINCT")?;
210
211                    if let Some(subset) = &options.subset {
212                        f.write_str(" BY ")?;
213
214                        let mut subset = subset.iter();
215
216                        if let Some(fst) = subset.next() {
217                            f.write_str(fst)?;
218                            for name in subset {
219                                write!(f, ", \"{name}\"")?;
220                            }
221                        } else {
222                            f.write_str("None")?;
223                        }
224                    }
225
226                    Ok(())
227                })?;
228            },
229            DataFrameScan {
230                schema,
231                output_schema,
232                ..
233            } => {
234                let num_columns = NumColumnsSchema(output_schema.as_ref().map(|p| p.as_ref()));
235                let total_columns = schema.len();
236
237                write_label(f, id, |f| {
238                    write!(f, "TABLE\nπ {num_columns}/{total_columns}")
239                })?;
240            },
241            Scan {
242                sources,
243                file_info,
244                hive_parts: _,
245                predicate,
246                scan_type,
247                file_options: options,
248                output_schema: _,
249            } => {
250                let name: &str = scan_type.into();
251                let path = ScanSourcesDisplay(sources);
252                let with_columns = options.with_columns.as_ref().map(|cols| cols.as_ref());
253                let with_columns = NumColumns(with_columns);
254                let total_columns =
255                    file_info.schema.len() - usize::from(options.row_index.is_some());
256
257                write_label(f, id, |f| {
258                    write!(f, "{name} SCAN {path}\nπ {with_columns}/{total_columns};",)?;
259
260                    if let Some(predicate) = predicate.as_ref() {
261                        write!(f, "\nσ {}", self.display_expr(predicate))?;
262                    }
263
264                    if let Some(row_index) = options.row_index.as_ref() {
265                        write!(f, "\nrow index: {} (+{})", row_index.name, row_index.offset)?;
266                    }
267
268                    Ok(())
269                })?;
270            },
271            Join {
272                input_left,
273                input_right,
274                left_on,
275                right_on,
276                options,
277                ..
278            } => {
279                self.with_root(*input_left)._format(f, Some(id), last)?;
280                self.with_root(*input_right)._format(f, Some(id), last)?;
281
282                let left_on = self.display_exprs(left_on);
283                let right_on = self.display_exprs(right_on);
284
285                write_label(f, id, |f| {
286                    write!(
287                        f,
288                        "JOIN {}\nleft: {left_on};\nright: {right_on}",
289                        options.args.how
290                    )
291                })?;
292            },
293            MapFunction {
294                input, function, ..
295            } => {
296                if let Some(streaming_lp) = function.to_streaming_lp() {
297                    Self::new_streaming(streaming_lp)._format(f, Some(id), last)?;
298                } else {
299                    self.with_root(*input)._format(f, Some(id), last)?;
300                    write_label(f, id, |f| write!(f, "{function}"))?;
301                }
302            },
303            ExtContext { input, .. } => {
304                self.with_root(*input)._format(f, Some(id), last)?;
305                write_label(f, id, |f| f.write_str("EXTERNAL_CONTEXT"))?;
306            },
307            Sink { input, payload, .. } => {
308                self.with_root(*input)._format(f, Some(id), last)?;
309
310                write_label(f, id, |f| {
311                    f.write_str(match payload {
312                        SinkType::Memory => "SINK (MEMORY)",
313                        SinkType::File { .. } => "SINK (FILE)",
314                    })
315                })?;
316            },
317            SimpleProjection { input, columns } => {
318                let num_columns = columns.as_ref().len();
319                let total_columns = self.lp.lp_arena.get(*input).schema(self.lp.lp_arena).len();
320
321                let columns = ColumnsDisplay(columns.as_ref());
322                self.with_root(*input)._format(f, Some(id), last)?;
323                write_label(f, id, |f| {
324                    write!(f, "simple π {num_columns}/{total_columns}\n[{columns}]")
325                })?;
326            },
327            Invalid => write_label(f, id, |f| f.write_str("INVALID"))?,
328        }
329
330        Ok(())
331    }
332}
333
334// A few utility structures for formatting
335pub struct PathsDisplay<'a>(pub &'a [PathBuf]);
336pub struct ScanSourcesDisplay<'a>(pub &'a ScanSources);
337struct NumColumns<'a>(Option<&'a [PlSmallStr]>);
338struct NumColumnsSchema<'a>(Option<&'a Schema>);
339
340impl fmt::Display for ScanSourceRef<'_> {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        match self {
343            ScanSourceRef::Path(path) => path.display().fmt(f),
344            ScanSourceRef::File(_) => f.write_str("open-file"),
345            ScanSourceRef::Buffer(buff) => write!(f, "{} in-mem bytes", buff.len()),
346        }
347    }
348}
349
350impl fmt::Display for ScanSourcesDisplay<'_> {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        match self.0.len() {
353            0 => write!(f, "[]"),
354            1 => write!(f, "[{}]", self.0.at(0)),
355            2 => write!(f, "[{}, {}]", self.0.at(0), self.0.at(1)),
356            _ => write!(
357                f,
358                "[{}, ... {} other sources]",
359                self.0.at(0),
360                self.0.len() - 1,
361            ),
362        }
363    }
364}
365
366impl fmt::Display for PathsDisplay<'_> {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        match self.0.len() {
369            0 => write!(f, "[]"),
370            1 => write!(f, "[{}]", self.0[0].display()),
371            2 => write!(f, "[{}, {}]", self.0[0].display(), self.0[1].display()),
372            _ => write!(
373                f,
374                "[{}, ... {} other files]",
375                self.0[0].display(),
376                self.0.len() - 1,
377            ),
378        }
379    }
380}
381
382impl fmt::Display for NumColumns<'_> {
383    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384        match self.0 {
385            None => f.write_str("*"),
386            Some(columns) => columns.len().fmt(f),
387        }
388    }
389}
390
391impl fmt::Display for NumColumnsSchema<'_> {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        match self.0 {
394            None => f.write_str("*"),
395            Some(columns) => columns.len().fmt(f),
396        }
397    }
398}
399
400/// Utility structure to write to a [`fmt::Formatter`] whilst escaping the output as a label name
401pub struct EscapeLabel<'a>(pub &'a mut dyn fmt::Write);
402
403impl fmt::Write for EscapeLabel<'_> {
404    fn write_str(&mut self, mut s: &str) -> fmt::Result {
405        loop {
406            let mut char_indices = s.char_indices();
407
408            // This escapes quotes and new lines
409            // @NOTE: I am aware this does not work for \" and such. I am ignoring that fact as we
410            // are not really using such strings.
411            let f = char_indices.find_map(|(i, c)| match c {
412                '"' => Some((i, r#"\""#)),
413                '\n' => Some((i, r#"\n"#)),
414                _ => None,
415            });
416
417            let Some((at, to_write)) = f else {
418                break;
419            };
420
421            self.0.write_str(&s[..at])?;
422            self.0.write_str(to_write)?;
423            s = &s[at + 1..];
424        }
425
426        self.0.write_str(s)?;
427
428        Ok(())
429    }
430}
431
432impl fmt::Display for IRDotDisplay<'_> {
433    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
434        writeln!(f, "graph  polars_query {{")?;
435
436        let mut last = 0;
437        self._format(f, None, &mut last)?;
438
439        writeln!(f, "}}")?;
440
441        Ok(())
442    }
443}