1use std::{fmt, io, string};
4
5use super::{aggregate, function, Column, Error, Expr, Value, Visitor};
6use crate::{
7 data_type::{DataType, DataTyped},
8 display::{self, colors},
9 namer,
10 visitor::Acceptor,
11};
12
13impl From<string::FromUtf8Error> for Error {
14 fn from(err: string::FromUtf8Error) -> Self {
15 Error::Other(err.to_string())
16 }
17}
18
19#[derive(Clone, Debug)]
20pub struct Node<'a, T: Clone + fmt::Display>(&'a Expr, T);
21#[derive(Clone, Debug)]
22pub struct Edge<'a, T: Clone + fmt::Display>(&'a Expr, &'a Expr, T);
23#[derive(Clone, Debug)]
24pub struct VisitedExpr<'a, V>(&'a Expr, V);
25
26#[derive(Clone, Debug)]
27pub struct DotVisitor<'a>(pub &'a DataType);
28
29impl<'a> Visitor<'a, DataType> for DotVisitor<'a> {
30 fn column(&self, column: &'a Column) -> DataType {
31 self.0[column.clone()].clone()
32 }
33
34 fn value(&self, value: &'a Value) -> DataType {
35 value.data_type()
36 }
37
38 fn function(&self, function: &'a function::Function, arguments: Vec<DataType>) -> DataType {
39 function.clone().super_image(&arguments).unwrap()
40 }
41
42 fn aggregate(&self, aggregate: &'a aggregate::Aggregate, argument: DataType) -> DataType {
43 aggregate.clone().super_image(&argument).unwrap()
44 }
45
46 fn structured(&self, fields: Vec<(super::identifier::Identifier, DataType)>) -> DataType {
47 let fields: Vec<(String, DataType)> = fields
48 .into_iter()
49 .map(|(i, t)| (i.split_last().unwrap().0, t))
50 .collect();
51 DataType::structured(fields)
52 }
53}
54
55#[derive(Clone, Debug)]
56pub struct DotValueVisitor<'a>(pub &'a Value);
57
58impl<'a> Visitor<'a, Value> for DotValueVisitor<'a> {
59 fn column(&self, column: &'a Column) -> Value {
60 self.0[column.clone()].clone()
61 }
62
63 fn value(&self, value: &'a Value) -> Value {
64 value.clone()
65 }
66
67 fn function(&self, function: &'a function::Function, arguments: Vec<Value>) -> Value {
68 function.clone().value(&arguments).unwrap()
69 }
70
71 fn aggregate(&self, aggregate: &'a aggregate::Aggregate, argument: Value) -> Value {
72 aggregate.clone().value(&argument).unwrap()
73 }
74
75 fn structured(&self, fields: Vec<(super::identifier::Identifier, Value)>) -> Value {
76 let fields: Vec<(String, Value)> = fields
77 .into_iter()
78 .map(|(i, v)| (i.split_last().unwrap().0, v))
79 .collect();
80 Value::structured(fields)
81 }
82}
83
84impl<'a, T: Clone + fmt::Display, V: Visitor<'a, T>> dot::Labeller<'a, Node<'a, T>, Edge<'a, T>>
85 for VisitedExpr<'a, V>
86{
87 fn graph_id(&'a self) -> dot::Id<'a> {
88 dot::Id::new(namer::name_from_content("graph", self.0)).unwrap()
89 }
90
91 fn node_id(&'a self, node: &Node<'a, T>) -> dot::Id<'a> {
92 dot::Id::new(namer::name_from_content("graph", node.0)).unwrap()
93 }
94
95 fn node_label(&'a self, node: &Node<'a, T>) -> dot::LabelText<'a> {
96 dot::LabelText::html(match &node.0 {
97 Expr::Column(col) => format!(
98 "<b>{}</b><br/>{}",
99 dot::escape_html(&col.to_string()),
100 &node.1
101 ),
102 Expr::Value(val) => {
103 println!("{}", &val.to_string());
104 format!(
105 "<b>{}</b><br/>{}",
106 dot::escape_html(&val.to_string()),
107 &node.1
108 )
109 }
110 Expr::Function(fun) => {
111 format!(
112 "<b>{}</b><br/>{}",
113 dot::escape_html(&fun.function.to_string()),
114 &node.1
115 )
116 }
117 Expr::Aggregate(agg) => format!(
118 "<b>{}</b><br/>{}",
119 dot::escape_html(&agg.aggregate.to_string()),
120 &node.1
121 ),
122 Expr::Struct(s) => format!(
123 "<b>{}</b><br/>{}",
124 dot::escape_html(&s.to_string()),
125 &node.1
126 ),
127 })
128 }
129
130 fn node_color(&'a self, node: &Node<'a, T>) -> Option<dot::LabelText<'a>> {
131 Some(dot::LabelText::label(match &node.0 {
132 Expr::Column(_) => colors::MEDIUM_RED,
133 Expr::Value(_) => colors::LIGHT_RED,
134 Expr::Function(_) => colors::LIGHT_GREEN,
135 Expr::Aggregate(_) => colors::DARK_GREEN,
136 Expr::Struct(_) => colors::LIGHTER_GREEN,
137 }))
138 }
139}
140
141impl<'a, T: Clone + fmt::Display, V: Visitor<'a, T> + Clone>
142 dot::GraphWalk<'a, Node<'a, T>, Edge<'a, T>> for VisitedExpr<'a, V>
143{
144 fn nodes(&'a self) -> dot::Nodes<'a, Node<'a, T>> {
145 self.0
146 .iter_with(self.1.clone())
147 .map(|(expr, t)| Node(expr, t))
148 .collect()
149 }
150
151 fn edges(&'a self) -> dot::Edges<'a, Edge<'a, T>> {
152 self.0
153 .iter_with(self.1.clone())
154 .flat_map(|(expr, t)| match expr {
155 Expr::Column(_) | Expr::Value(_) => vec![],
156 Expr::Function(fun) => fun
157 .arguments
158 .iter()
159 .map(|arg| Edge(expr, arg, t.clone()))
160 .collect(),
161 Expr::Aggregate(agg) => vec![Edge(expr, &agg.argument, t)],
162 Expr::Struct(s) => s
163 .fields
164 .iter()
165 .map(|(_i, e)| Edge(expr, e, t.clone()))
166 .collect(),
167 })
168 .collect()
169 }
170
171 fn source(&'a self, edge: &Edge<'a, T>) -> Node<'a, T> {
172 Node(edge.0, edge.2.clone())
173 }
174
175 fn target(&'a self, edge: &Edge<'a, T>) -> Node<'a, T> {
176 Node(edge.1, edge.2.clone())
177 }
178}
179
180impl Expr {
181 pub fn dot<W: io::Write>(
183 &self,
184 data_type: DataType,
185 w: &mut W,
186 opts: &[&str],
187 ) -> io::Result<()> {
188 display::dot::render(&VisitedExpr(self, DotVisitor(&data_type)), w, opts)
189 }
190
191 pub fn dot_value<W: io::Write>(&self, val: Value, w: &mut W, opts: &[&str]) -> io::Result<()> {
193 display::dot::render(&VisitedExpr(self, DotValueVisitor(&val)), w, opts)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::{
201 builder::{Ready, With},
202 data_type::DataType,
203 display::Dot,
204 relation::{schema::Schema, Relation},
205 WithoutContext as _,
206 };
207 use std::sync::Arc;
208
209 #[test]
210 fn test_dot() {
211 let a = Expr::col("a");
213 let b = Expr::col("b");
214 let x = Expr::col("x");
215 let expr = Expr::exp(Expr::sin(Expr::plus(Expr::multiply(a, x), b)));
216 expr.with(DataType::Any).display_dot().unwrap();
217 }
218
219 #[test]
220 fn test_dot_dsl() {
221 let rel: Arc<Relation> = Arc::new(
222 Relation::table()
223 .schema(
224 Schema::builder()
225 .with(("a", DataType::float_range(1.0..=1.1)))
226 .with(("b", DataType::float_values([0.1, 1.0, 5.0, -1.0, -5.0])))
227 .with(("c", DataType::float_range(0.0..=5.0)))
228 .with(("d", DataType::float_values([0.0, 1.0, 2.0, -1.0])))
229 .with(("x", DataType::float_range(0.0..=2.0)))
230 .with(("y", DataType::float_range(0.0..=5.0)))
231 .with(("z", DataType::float_range(9.0..=11.)))
232 .with(("t", DataType::float_range(0.9..=1.1)))
233 .build(),
234 )
235 .build(),
236 );
237 expr!(exp(a * b) + cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x)))
239 .with(rel.data_type())
240 .display_dot()
241 .unwrap();
242 }
243
244 #[test]
245 fn test_dot_dsl_squared() {
246 let rel: Arc<Relation> = Arc::new(
247 Relation::table()
248 .schema(
249 Schema::builder()
250 .with(("a", DataType::float_range(1.0..=1.1)))
251 .with(("b", DataType::float_values([0.1, 1.0, 5.0, -1.0, -5.0])))
252 .with(("c", DataType::float_range(0.0..=5.0)))
253 .with(("d", DataType::float_values([0.0, 1.0, 2.0, -1.0])))
254 .with(("x", DataType::float_range(0.0..=2.0)))
255 .with(("y", DataType::float_range(0.0..=5.0)))
256 .with(("z", DataType::float_range(9.0..=11.)))
257 .with(("t", DataType::float_range(0.9..=1.1)))
258 .build(),
259 )
260 .build(),
261 );
262 let e = expr!(
264 exp(a * b) + cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x))
265 );
266 let e = Expr::multiply(e.clone(), e);
267 e.with(rel.data_type()).display_dot().unwrap();
268 }
269
270 #[test]
271 fn test_dot_distributivity_dsl() {
272 let val = Value::structured([
273 ("a", Value::float(1.)),
274 ("b", Value::float(2.)),
275 ("c", Value::float(3.)),
276 ("d", Value::integer(4)),
277 ]);
278 let _ = &expr! { a*b+d }.with(val.clone()).display_dot().unwrap();
279 let _ = &expr! { d+a*b }.with(val.clone()).display_dot().unwrap();
280 let _ = &expr! { (a*b+d) }.with(val).display_dot().unwrap();
281 }
282
283 #[test]
284 fn test_dot_plus_minus_dsl() {
285 let val = Value::structured([
286 ("a", Value::float(1.)),
287 ("b", Value::float(2.)),
288 ("c", Value::float(3.)),
289 ("d", Value::integer(4)),
290 ]);
291 expr! { a+b-c+d }.with(val).display_dot().unwrap();
292 }
293
294 #[test]
295 fn test_dot_simple_value_dsl() {
296 let val = Value::structured([
297 ("a", Value::float(0.1)),
298 ("b", Value::float(0.1)),
299 ("z", Value::float(0.1)),
300 ("d", Value::integer(0)),
301 ("t", Value::float(0.1)),
302 ("c", Value::float(0.0)),
303 ("x", Value::float(0.0)),
304 ]);
305 expr! { exp(a*b + cos(2*z)*d - 2*z + t*sin(c+3*x)) }
306 .with(val)
307 .display_dot()
308 .unwrap();
309 }
310
311 #[test]
312 fn test_dot_value_dsl() {
313 let val = Value::structured([
314 ("a", Value::float(0.1)),
315 ("b", Value::float(0.1)),
316 ("c", Value::float(0.1)),
317 ("d", Value::float(0.1)),
318 ("x", Value::float(0.1)),
319 ("y", Value::integer(0)),
320 ("z", Value::float(0.1)),
321 ("t", Value::float(0.0)),
322 ]);
323 expr!(exp(a * b) + cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x)))
325 .with(val)
326 .display_dot()
327 .unwrap();
328 }
329
330 #[test]
331 fn test_dot_aggregate_dsl() {
332 let data_types = DataType::structured([
333 ("a", DataType::list(DataType::Any, 1, 10)),
334 (
335 "b",
336 DataType::list(DataType::integer_interval(2, 18), 1, 10),
337 ),
338 ("c", DataType::list(DataType::float_interval(5., 7.), 1, 10)),
339 ("d", DataType::float_interval(5., 7.)),
340 ]);
341 println!("data_types = {data_types}");
342 let x = expr!((exp(d) + 2 + sum(b) * count(a) + sum(c)) / (1 + count(a)));
343 println!("x = {x}");
344 for (x, t) in x.iter_with(DotVisitor(&data_types)) {
345 println!("({x}, {t})");
346 }
347 println!("END ITER");
348 x.with(data_types).display_dot().unwrap();
350 }
351
352 #[test]
353 fn test_dot_aggregate_any_dsl() {
354 let data_types = DataType::structured([
355 ("a", DataType::Any),
356 (
357 "b",
358 DataType::list(DataType::integer_interval(2, 18), 1, 10),
359 ),
360 ("c", DataType::Any),
361 ("d", DataType::Any),
362 ]);
363 expr!(sum(sum(a) + count(b)) * count(c))
365 .with(data_types)
366 .display_dot()
367 .unwrap();
368 }
369
370 #[test]
371 fn test_dot_escape_html() {
372 let data_types = DataType::structured([("a", DataType::integer_interval(1, 10))]);
373
374 let my_expr = expr!(lt_eq(a, 5));
375 my_expr.with(data_types.clone()).display_dot().unwrap();
376 assert_eq!(my_expr.to_string(), "(a <= 5)".to_string());
377
378 let my_expr = expr!(gt(a, 5));
379 my_expr.with(data_types.clone()).display_dot().unwrap();
380 assert_eq!(my_expr.to_string(), "(a > 5)".to_string());
381
382 let my_expr = expr!(modulo(a, 2));
383 my_expr.with(data_types).display_dot().unwrap();
384 assert_eq!(my_expr.to_string(), "(a % 2)".to_string());
385 }
386
387 #[test]
388 fn test_max() {
389 let data_types = DataType::structured([("a", DataType::float_interval(0., 4.))]);
390
391 let my_expr = expr!((a + 1 + abs(a - 1)) / 2);
392 my_expr.with(data_types.clone()).display_dot().unwrap();
393
394 let my_expr = expr!(1 - gt(a, 1) * (1 - a));
395 my_expr.with(data_types).display_dot().unwrap();
396 }
397
398 #[test]
399 fn test_dot_struct_dsl() {
400 let rel: Arc<Relation> = Arc::new(
401 Relation::table()
402 .schema(
403 Schema::builder()
404 .with(("a", DataType::float_range(1.0..=1.1)))
405 .with(("b", DataType::float_values([0.1, 1.0, 5.0, -1.0, -5.0])))
406 .with(("c", DataType::float_range(0.0..=5.0)))
407 .with(("d", DataType::float_values([0.0, 1.0, 2.0, -1.0])))
408 .with(("x", DataType::float_range(0.0..=2.0)))
409 .with(("y", DataType::float_range(0.0..=5.0)))
410 .with(("z", DataType::float_range(9.0..=11.)))
411 .with(("t", DataType::float_range(0.9..=1.1)))
412 .build(),
413 )
414 .build(),
415 );
416 Expr::structured([
418 ("a", Arc::new(expr!(exp(a * b)))),
419 (
420 "b",
421 Arc::new(expr!(
422 cos(1. * z) * x - 0.2 * (y + 3.) + b + t * sin(c + 4. * (d + 5. + x))
423 )),
424 ),
425 ])
426 .with(rel.data_type())
427 .display_dot()
428 .unwrap();
429 }
430
431 #[test]
432 fn test_dot_case() {
433 let data_types = DataType::structured([(
434 "a",
435 DataType::list(DataType::integer_interval(2, 18), 1, 10),
436 )]);
437 expr!(case(eq(a, 5), 5, a))
439 .with(data_types)
440 .display_dot()
441 .unwrap();
442 }
443}