qrlew/expr/
dot.rs

1//! Plot the dot graph of an expression to debug
2
3use std::{fmt, io, string};
4
5use super::{aggregate, function, Column, Error, Expr, Value, Visitor};
6use crate::{
7    data_type::{DataType, DataTyped},
8    display::{self, colors},
9    namer,
10    visitor::Acceptor,
11};
12
13impl From<string::FromUtf8Error> for Error {
14    fn from(err: string::FromUtf8Error) -> Self {
15        Error::Other(err.to_string())
16    }
17}
18
19#[derive(Clone, Debug)]
20pub struct Node<'a, T: Clone + fmt::Display>(&'a Expr, T);
21#[derive(Clone, Debug)]
22pub struct Edge<'a, T: Clone + fmt::Display>(&'a Expr, &'a Expr, T);
23#[derive(Clone, Debug)]
24pub struct VisitedExpr<'a, V>(&'a Expr, V);
25
26#[derive(Clone, Debug)]
27pub struct DotVisitor<'a>(pub &'a DataType);
28
29impl<'a> Visitor<'a, DataType> for DotVisitor<'a> {
30    fn column(&self, column: &'a Column) -> DataType {
31        self.0[column.clone()].clone()
32    }
33
34    fn value(&self, value: &'a Value) -> DataType {
35        value.data_type()
36    }
37
38    fn function(&self, function: &'a function::Function, arguments: Vec<DataType>) -> DataType {
39        function.clone().super_image(&arguments).unwrap()
40    }
41
42    fn aggregate(&self, aggregate: &'a aggregate::Aggregate, argument: DataType) -> DataType {
43        aggregate.clone().super_image(&argument).unwrap()
44    }
45
46    fn structured(&self, fields: Vec<(super::identifier::Identifier, DataType)>) -> DataType {
47        let fields: Vec<(String, DataType)> = fields
48            .into_iter()
49            .map(|(i, t)| (i.split_last().unwrap().0, t))
50            .collect();
51        DataType::structured(fields)
52    }
53}
54
55#[derive(Clone, Debug)]
56pub struct DotValueVisitor<'a>(pub &'a Value);
57
58impl<'a> Visitor<'a, Value> for DotValueVisitor<'a> {
59    fn column(&self, column: &'a Column) -> Value {
60        self.0[column.clone()].clone()
61    }
62
63    fn value(&self, value: &'a Value) -> Value {
64        value.clone()
65    }
66
67    fn function(&self, function: &'a function::Function, arguments: Vec<Value>) -> Value {
68        function.clone().value(&arguments).unwrap()
69    }
70
71    fn aggregate(&self, aggregate: &'a aggregate::Aggregate, argument: Value) -> Value {
72        aggregate.clone().value(&argument).unwrap()
73    }
74
75    fn structured(&self, fields: Vec<(super::identifier::Identifier, Value)>) -> Value {
76        let fields: Vec<(String, Value)> = fields
77            .into_iter()
78            .map(|(i, v)| (i.split_last().unwrap().0, v))
79            .collect();
80        Value::structured(fields)
81    }
82}
83
84impl<'a, T: Clone + fmt::Display, V: Visitor<'a, T>> dot::Labeller<'a, Node<'a, T>, Edge<'a, T>>
85    for VisitedExpr<'a, V>
86{
87    fn graph_id(&'a self) -> dot::Id<'a> {
88        dot::Id::new(namer::name_from_content("graph", self.0)).unwrap()
89    }
90
91    fn node_id(&'a self, node: &Node<'a, T>) -> dot::Id<'a> {
92        dot::Id::new(namer::name_from_content("graph", node.0)).unwrap()
93    }
94
95    fn node_label(&'a self, node: &Node<'a, T>) -> dot::LabelText<'a> {
96        dot::LabelText::html(match &node.0 {
97            Expr::Column(col) => format!(
98                "<b>{}</b><br/>{}",
99                dot::escape_html(&col.to_string()),
100                &node.1
101            ),
102            Expr::Value(val) => {
103                println!("{}", &val.to_string());
104                format!(
105                    "<b>{}</b><br/>{}",
106                    dot::escape_html(&val.to_string()),
107                    &node.1
108                )
109            }
110            Expr::Function(fun) => {
111                format!(
112                    "<b>{}</b><br/>{}",
113                    dot::escape_html(&fun.function.to_string()),
114                    &node.1
115                )
116            }
117            Expr::Aggregate(agg) => format!(
118                "<b>{}</b><br/>{}",
119                dot::escape_html(&agg.aggregate.to_string()),
120                &node.1
121            ),
122            Expr::Struct(s) => format!(
123                "<b>{}</b><br/>{}",
124                dot::escape_html(&s.to_string()),
125                &node.1
126            ),
127        })
128    }
129
130    fn node_color(&'a self, node: &Node<'a, T>) -> Option<dot::LabelText<'a>> {
131        Some(dot::LabelText::label(match &node.0 {
132            Expr::Column(_) => colors::MEDIUM_RED,
133            Expr::Value(_) => colors::LIGHT_RED,
134            Expr::Function(_) => colors::LIGHT_GREEN,
135            Expr::Aggregate(_) => colors::DARK_GREEN,
136            Expr::Struct(_) => colors::LIGHTER_GREEN,
137        }))
138    }
139}
140
141impl<'a, T: Clone + fmt::Display, V: Visitor<'a, T> + Clone>
142    dot::GraphWalk<'a, Node<'a, T>, Edge<'a, T>> for VisitedExpr<'a, V>
143{
144    fn nodes(&'a self) -> dot::Nodes<'a, Node<'a, T>> {
145        self.0
146            .iter_with(self.1.clone())
147            .map(|(expr, t)| Node(expr, t))
148            .collect()
149    }
150
151    fn edges(&'a self) -> dot::Edges<'a, Edge<'a, T>> {
152        self.0
153            .iter_with(self.1.clone())
154            .flat_map(|(expr, t)| match expr {
155                Expr::Column(_) | Expr::Value(_) => vec![],
156                Expr::Function(fun) => fun
157                    .arguments
158                    .iter()
159                    .map(|arg| Edge(expr, arg, t.clone()))
160                    .collect(),
161                Expr::Aggregate(agg) => vec![Edge(expr, &agg.argument, t)],
162                Expr::Struct(s) => s
163                    .fields
164                    .iter()
165                    .map(|(_i, e)| Edge(expr, e, t.clone()))
166                    .collect(),
167            })
168            .collect()
169    }
170
171    fn source(&'a self, edge: &Edge<'a, T>) -> Node<'a, T> {
172        Node(edge.0, edge.2.clone())
173    }
174
175    fn target(&'a self, edge: &Edge<'a, T>) -> Node<'a, T> {
176        Node(edge.1, edge.2.clone())
177    }
178}
179
180impl Expr {
181    /// Render the Expr to dot
182    pub fn dot<W: io::Write>(
183        &self,
184        data_type: DataType,
185        w: &mut W,
186        opts: &[&str],
187    ) -> io::Result<()> {
188        display::dot::render(&VisitedExpr(self, DotVisitor(&data_type)), w, opts)
189    }
190
191    /// Render the Expr to dot
192    pub fn dot_value<W: io::Write>(&self, val: Value, w: &mut W, opts: &[&str]) -> io::Result<()> {
193        display::dot::render(&VisitedExpr(self, DotValueVisitor(&val)), w, opts)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::{
201        builder::{Ready, With},
202        data_type::DataType,
203        display::Dot,
204        relation::{schema::Schema, Relation},
205        WithoutContext as _,
206    };
207    use std::sync::Arc;
208
209    #[test]
210    fn test_dot() {
211        // Create an expr
212        let a = Expr::col("a");
213        let b = Expr::col("b");
214        let x = Expr::col("x");
215        let expr = Expr::exp(Expr::sin(Expr::plus(Expr::multiply(a, x), b)));
216        expr.with(DataType::Any).display_dot().unwrap();
217    }
218
219    #[test]
220    fn test_dot_dsl() {
221        let rel: Arc<Relation> = Arc::new(
222            Relation::table()
223                .schema(
224                    Schema::builder()
225                        .with(("a", DataType::float_range(1.0..=1.1)))
226                        .with(("b", DataType::float_values([0.1, 1.0, 5.0, -1.0, -5.0])))
227                        .with(("c", DataType::float_range(0.0..=5.0)))
228                        .with(("d", DataType::float_values([0.0, 1.0, 2.0, -1.0])))
229                        .with(("x", DataType::float_range(0.0..=2.0)))
230                        .with(("y", DataType::float_range(0.0..=5.0)))
231                        .with(("z", DataType::float_range(9.0..=11.)))
232                        .with(("t", DataType::float_range(0.9..=1.1)))
233                        .build(),
234                )
235                .build(),
236        );
237        // Create an expr
238        expr!(exp(a * b) + cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x)))
239            .with(rel.data_type())
240            .display_dot()
241            .unwrap();
242    }
243
244    #[test]
245    fn test_dot_dsl_squared() {
246        let rel: Arc<Relation> = Arc::new(
247            Relation::table()
248                .schema(
249                    Schema::builder()
250                        .with(("a", DataType::float_range(1.0..=1.1)))
251                        .with(("b", DataType::float_values([0.1, 1.0, 5.0, -1.0, -5.0])))
252                        .with(("c", DataType::float_range(0.0..=5.0)))
253                        .with(("d", DataType::float_values([0.0, 1.0, 2.0, -1.0])))
254                        .with(("x", DataType::float_range(0.0..=2.0)))
255                        .with(("y", DataType::float_range(0.0..=5.0)))
256                        .with(("z", DataType::float_range(9.0..=11.)))
257                        .with(("t", DataType::float_range(0.9..=1.1)))
258                        .build(),
259                )
260                .build(),
261        );
262        // Create an expr
263        let e = expr!(
264            exp(a * b) + cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x))
265        );
266        let e = Expr::multiply(e.clone(), e);
267        e.with(rel.data_type()).display_dot().unwrap();
268    }
269
270    #[test]
271    fn test_dot_distributivity_dsl() {
272        let val = Value::structured([
273            ("a", Value::float(1.)),
274            ("b", Value::float(2.)),
275            ("c", Value::float(3.)),
276            ("d", Value::integer(4)),
277        ]);
278        let _ = &expr! { a*b+d }.with(val.clone()).display_dot().unwrap();
279        let _ = &expr! { d+a*b }.with(val.clone()).display_dot().unwrap();
280        let _ = &expr! { (a*b+d) }.with(val).display_dot().unwrap();
281    }
282
283    #[test]
284    fn test_dot_plus_minus_dsl() {
285        let val = Value::structured([
286            ("a", Value::float(1.)),
287            ("b", Value::float(2.)),
288            ("c", Value::float(3.)),
289            ("d", Value::integer(4)),
290        ]);
291        expr! { a+b-c+d }.with(val).display_dot().unwrap();
292    }
293
294    #[test]
295    fn test_dot_simple_value_dsl() {
296        let val = Value::structured([
297            ("a", Value::float(0.1)),
298            ("b", Value::float(0.1)),
299            ("z", Value::float(0.1)),
300            ("d", Value::integer(0)),
301            ("t", Value::float(0.1)),
302            ("c", Value::float(0.0)),
303            ("x", Value::float(0.0)),
304        ]);
305        expr! { exp(a*b + cos(2*z)*d - 2*z + t*sin(c+3*x)) }
306            .with(val)
307            .display_dot()
308            .unwrap();
309    }
310
311    #[test]
312    fn test_dot_value_dsl() {
313        let val = Value::structured([
314            ("a", Value::float(0.1)),
315            ("b", Value::float(0.1)),
316            ("c", Value::float(0.1)),
317            ("d", Value::float(0.1)),
318            ("x", Value::float(0.1)),
319            ("y", Value::integer(0)),
320            ("z", Value::float(0.1)),
321            ("t", Value::float(0.0)),
322        ]);
323        // Create an expr
324        expr!(exp(a * b) + cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x)))
325            .with(val)
326            .display_dot()
327            .unwrap();
328    }
329
330    #[test]
331    fn test_dot_aggregate_dsl() {
332        let data_types = DataType::structured([
333            ("a", DataType::list(DataType::Any, 1, 10)),
334            (
335                "b",
336                DataType::list(DataType::integer_interval(2, 18), 1, 10),
337            ),
338            ("c", DataType::list(DataType::float_interval(5., 7.), 1, 10)),
339            ("d", DataType::float_interval(5., 7.)),
340        ]);
341        println!("data_types = {data_types}");
342        let x = expr!((exp(d) + 2 + sum(b) * count(a) + sum(c)) / (1 + count(a)));
343        println!("x = {x}");
344        for (x, t) in x.iter_with(DotVisitor(&data_types)) {
345            println!("({x}, {t})");
346        }
347        println!("END ITER");
348        // Create an expr
349        x.with(data_types).display_dot().unwrap();
350    }
351
352    #[test]
353    fn test_dot_aggregate_any_dsl() {
354        let data_types = DataType::structured([
355            ("a", DataType::Any),
356            (
357                "b",
358                DataType::list(DataType::integer_interval(2, 18), 1, 10),
359            ),
360            ("c", DataType::Any),
361            ("d", DataType::Any),
362        ]);
363        // Create an expr
364        expr!(sum(sum(a) + count(b)) * count(c))
365            .with(data_types)
366            .display_dot()
367            .unwrap();
368    }
369
370    #[test]
371    fn test_dot_escape_html() {
372        let data_types = DataType::structured([("a", DataType::integer_interval(1, 10))]);
373
374        let my_expr = expr!(lt_eq(a, 5));
375        my_expr.with(data_types.clone()).display_dot().unwrap();
376        assert_eq!(my_expr.to_string(), "(a <= 5)".to_string());
377
378        let my_expr = expr!(gt(a, 5));
379        my_expr.with(data_types.clone()).display_dot().unwrap();
380        assert_eq!(my_expr.to_string(), "(a > 5)".to_string());
381
382        let my_expr = expr!(modulo(a, 2));
383        my_expr.with(data_types).display_dot().unwrap();
384        assert_eq!(my_expr.to_string(), "(a % 2)".to_string());
385    }
386
387    #[test]
388    fn test_max() {
389        let data_types = DataType::structured([("a", DataType::float_interval(0., 4.))]);
390
391        let my_expr = expr!((a + 1 + abs(a - 1)) / 2);
392        my_expr.with(data_types.clone()).display_dot().unwrap();
393
394        let my_expr = expr!(1 - gt(a, 1) * (1 - a));
395        my_expr.with(data_types).display_dot().unwrap();
396    }
397
398    #[test]
399    fn test_dot_struct_dsl() {
400        let rel: Arc<Relation> = Arc::new(
401            Relation::table()
402                .schema(
403                    Schema::builder()
404                        .with(("a", DataType::float_range(1.0..=1.1)))
405                        .with(("b", DataType::float_values([0.1, 1.0, 5.0, -1.0, -5.0])))
406                        .with(("c", DataType::float_range(0.0..=5.0)))
407                        .with(("d", DataType::float_values([0.0, 1.0, 2.0, -1.0])))
408                        .with(("x", DataType::float_range(0.0..=2.0)))
409                        .with(("y", DataType::float_range(0.0..=5.0)))
410                        .with(("z", DataType::float_range(9.0..=11.)))
411                        .with(("t", DataType::float_range(0.9..=1.1)))
412                        .build(),
413                )
414                .build(),
415        );
416        // Create an expr
417        Expr::structured([
418            ("a", Arc::new(expr!(exp(a * b)))),
419            (
420                "b",
421                Arc::new(expr!(
422                    cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x))
423                )),
424            ),
425        ])
426        .with(rel.data_type())
427        .display_dot()
428        .unwrap();
429    }
430
431    #[test]
432    fn test_dot_case() {
433        let data_types = DataType::structured([(
434            "a",
435            DataType::list(DataType::integer_interval(2, 18), 1, 10),
436        )]);
437        // Create an expr
438        expr!(case(eq(a, 5), 5, a))
439            .with(data_types)
440            .display_dot()
441            .unwrap();
442    }
443}