qrlew/relation/
sql.rs

1//! Methods to convert Relations to ast::Query
2
3use super::{
4    Join, Map, OrderBy, Reduce, Relation, Set, SetOperator, SetQuantifier, Table, Values,
5    Variant as _, Visitor,
6};
7use crate::{
8    ast,
9    dialect_translation::{postgresql::PostgreSqlTranslator, RelationToQueryTranslator},
10    expr::{identifier::Identifier, Expr},
11    visitor::Acceptor,
12};
13use std::{collections::HashSet, iter::Iterator, ops::Deref};
14
15/// A simple Relation -> ast::Query conversion Visitor using CTE
16#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
17pub struct FromRelationVisitor<T: RelationToQueryTranslator> {
18    translator: T,
19}
20
21impl<T: RelationToQueryTranslator> FromRelationVisitor<T> {
22    pub fn new(translator: T) -> Self {
23        FromRelationVisitor { translator }
24    }
25}
26
27impl From<Identifier> for ast::ObjectName {
28    fn from(value: Identifier) -> Self {
29        ast::ObjectName(value.into_iter().map(|s| ast::Ident::new(s)).collect())
30    }
31}
32
33impl From<SetOperator> for ast::SetOperator {
34    fn from(value: SetOperator) -> Self {
35        match value {
36            SetOperator::Union => ast::SetOperator::Union,
37            SetOperator::Except => ast::SetOperator::Except,
38            SetOperator::Intersect => ast::SetOperator::Intersect,
39        }
40    }
41}
42
43impl From<SetQuantifier> for ast::SetQuantifier {
44    fn from(value: SetQuantifier) -> Self {
45        match value {
46            SetQuantifier::All => ast::SetQuantifier::All,
47            SetQuantifier::Distinct => ast::SetQuantifier::Distinct,
48            SetQuantifier::None => ast::SetQuantifier::None,
49            SetQuantifier::ByName => ast::SetQuantifier::ByName,
50            SetQuantifier::AllByName => ast::SetQuantifier::AllByName,
51            SetQuantifier::DistinctByName => ast::SetQuantifier::DistinctByName,
52        }
53    }
54}
55
56fn values_query(rows: Vec<Vec<ast::Expr>>) -> ast::Query {
57    ast::Query {
58        with: None,
59        body: Box::new(ast::SetExpr::Values(ast::Values {
60            explicit_row: false,
61            rows,
62        })),
63        order_by: vec![],
64        limit: None,
65        offset: None,
66        fetch: None,
67        locks: vec![],
68        limit_by: vec![],
69        for_clause: None,
70    }
71}
72
73fn table_with_joins(relation: ast::TableFactor, joins: Vec<ast::Join>) -> ast::TableWithJoins {
74    ast::TableWithJoins { relation, joins }
75}
76
77fn ctes_from_query(query: ast::Query) -> Vec<ast::Cte> {
78    query.with.map(|with| with.cte_tables).unwrap_or_default()
79}
80
81fn all() -> Vec<ast::SelectItem> {
82    vec![ast::SelectItem::Wildcard(
83        ast::WildcardAdditionalOptions::default(),
84    )]
85}
86
87fn select_from_query(query: ast::Query) -> ast::Select {
88    match query.body.as_ref() {
89        ast::SetExpr::Select(select) => select.as_ref().clone(),
90        _ => panic!("Non select query"), // It is okay to panic as this should not happen in our context and is a private function
91    }
92}
93
94impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationVisitor<T> {
95    fn table(&self, table: &'a Table) -> ast::Query {
96        self.translator.query(
97            vec![],
98            vec![ast::SelectItem::Wildcard(
99                ast::WildcardAdditionalOptions::default(),
100            )],
101            table_with_joins(
102                self.translator.table_factor(&table.clone().into(), None),
103                vec![],
104            ),
105            None,
106            ast::GroupByExpr::Expressions(vec![]),
107            vec![],
108            None,
109            None,
110        )
111    }
112
113    fn map(&self, map: &'a Map, input: ast::Query) -> ast::Query {
114        // Pull the existing CTEs
115        let mut input_ctes = ctes_from_query(input);
116        // Add input query to CTEs
117        input_ctes.push(
118            self.translator.cte(
119                self.translator.identifier(&(map.name().into()))[0].clone(),
120                map.schema()
121                    .iter()
122                    .map(|field| self.translator.identifier(&(field.name().into()))[0].clone())
123                    .collect(),
124                self.translator.query(
125                    vec![],
126                    map.projection
127                        .clone()
128                        .into_iter()
129                        .zip(map.schema.clone())
130                        .map(|(expr, field)| ast::SelectItem::ExprWithAlias {
131                            expr: self.translator.expr(&expr),
132                            alias: self.translator.identifier(&(field.name().into()))[0].clone(),
133                        })
134                        .collect(),
135                    table_with_joins(
136                        self.translator
137                            .table_factor(map.input.as_ref().into(), None),
138                        vec![],
139                    ),
140                    map.filter.as_ref().map(|expr| self.translator.expr(expr)),
141                    ast::GroupByExpr::Expressions(vec![]),
142                    map.order_by
143                        .iter()
144                        .map(|OrderBy { expr, asc }| ast::OrderByExpr {
145                            expr: self.translator.expr(expr),
146                            asc: Some(*asc),
147                            nulls_first: None,
148                        })
149                        .collect(),
150                    map.limit.map(|limit| {
151                        ast::Expr::Value(ast::Value::Number(limit.to_string(), false))
152                    }),
153                    map.offset.map(|offset| ast::Offset {
154                        value: ast::Expr::Value(ast::Value::Number(offset.to_string(), false)),
155                        rows: ast::OffsetRows::None,
156                    }),
157                ),
158            ),
159        );
160        self.translator.query(
161            input_ctes,
162            all(),
163            table_with_joins(
164                self.translator.table_factor(&map.clone().into(), None),
165                vec![],
166            ),
167            None,
168            ast::GroupByExpr::Expressions(vec![]),
169            vec![],
170            map.limit
171                .map(|limit| ast::Expr::Value(ast::Value::Number(limit.to_string(), false))),
172            None,
173        )
174    }
175
176    fn reduce(&self, reduce: &'a Reduce, input: ast::Query) -> ast::Query {
177        // Pull the existing CTEs
178        let mut input_ctes = ctes_from_query(input);
179        // Add input query to CTEs
180        input_ctes.push(
181            self.translator.cte(
182                self.translator.identifier(&(reduce.name().into()))[0].clone(),
183                reduce
184                    .schema()
185                    .iter()
186                    .map(|field| self.translator.identifier(&(field.name().into()))[0].clone())
187                    .collect(),
188                self.translator.query(
189                    vec![],
190                    reduce
191                        .aggregate
192                        .clone()
193                        .into_iter()
194                        .zip(reduce.schema.clone())
195                        .map(|(aggregate, field)| ast::SelectItem::ExprWithAlias {
196                            expr: self.translator.expr(aggregate.deref()),
197                            alias: self.translator.identifier(&(field.name().into()))[0].clone(),
198                        })
199                        .collect(),
200                    table_with_joins(
201                        self.translator
202                            .table_factor(reduce.input.as_ref().into(), None),
203                        vec![],
204                    ),
205                    None,
206                    ast::GroupByExpr::Expressions(
207                        reduce
208                            .group_by
209                            .iter()
210                            .map(|col| self.translator.expr(&Expr::Column(col.clone())))
211                            .collect(),
212                    ),
213                    vec![],
214                    None,
215                    None,
216                ),
217            ),
218        );
219        self.translator.query(
220            input_ctes,
221            all(),
222            table_with_joins(
223                self.translator.table_factor(&reduce.clone().into(), None),
224                vec![],
225            ),
226            None,
227            ast::GroupByExpr::Expressions(vec![]),
228            vec![],
229            None,
230            None,
231        )
232    }
233
234    fn join(&self, join: &'a Join, left: ast::Query, right: ast::Query) -> ast::Query {
235        // Pull the existing CTEs
236        let mut exist: HashSet<ast::Cte> = HashSet::new();
237        let mut input_ctes: Vec<ast::Cte> = vec![];
238        ctes_from_query(left).into_iter().for_each(|cte| {
239            if exist.insert(cte.clone()) {
240                input_ctes.push(cte)
241            }
242        });
243        ctes_from_query(right).into_iter().for_each(|cte| {
244            if exist.insert(cte.clone()) {
245                input_ctes.push(cte)
246            }
247        });
248
249        // Add input query to CTEs
250        input_ctes.push(
251            self.translator.cte(
252                self.translator.identifier(&(join.name().into()))[0].clone(),
253                join.schema()
254                    .iter()
255                    .map(|field| self.translator.identifier(&(field.name().into()))[0].clone())
256                    .collect(),
257                self.translator.query(
258                    vec![],
259                    self.translator.join_projection(join), //self.translator.join_projection(),
260                    table_with_joins(
261                        self.translator
262                            .table_factor(join.left.as_ref().into(), Some(Join::left_name())),
263                        vec![ast::Join {
264                            relation: self
265                                .translator
266                                .table_factor(join.right.as_ref().into(), Some(Join::right_name())),
267                            join_operator: self.translator.join_operator(&join.operator),
268                        }],
269                    ),
270                    None,
271                    ast::GroupByExpr::Expressions(vec![]),
272                    vec![],
273                    None,
274                    None,
275                ),
276            ),
277        );
278        self.translator.query(
279            input_ctes,
280            all(),
281            table_with_joins(
282                self.translator.table_factor(&join.clone().into(), None),
283                vec![],
284            ),
285            None,
286            ast::GroupByExpr::Expressions(vec![]),
287            vec![],
288            None,
289            None,
290        )
291    }
292
293    fn set(&self, set: &'a Set, left: ast::Query, right: ast::Query) -> ast::Query {
294        // Pull the existing CTEs
295        let mut exist: HashSet<ast::Cte> = HashSet::new();
296        let mut input_ctes: Vec<ast::Cte> = vec![];
297        ctes_from_query(left.clone()).into_iter().for_each(|cte| {
298            if exist.insert(cte.clone()) {
299                input_ctes.push(cte)
300            }
301        });
302        ctes_from_query(right.clone()).into_iter().for_each(|cte| {
303            if exist.insert(cte.clone()) {
304                input_ctes.push(cte)
305            }
306        });
307        // Add input query to CTEs
308        input_ctes.push(
309            self.translator.cte(
310                set.name().into(),
311                set.schema()
312                    .iter()
313                    .map(|field| self.translator.identifier(&(field.name().into()))[0].clone())
314                    .collect(),
315                self.translator.set_operation(
316                    vec![],
317                    set.operator.clone().into(),
318                    set.quantifier.clone().into(),
319                    select_from_query(left),
320                    select_from_query(right),
321                ),
322            ),
323        );
324        self.translator.query(
325            input_ctes,
326            all(),
327            table_with_joins(
328                self.translator.table_factor(&set.clone().into(), None),
329                vec![],
330            ),
331            None,
332            ast::GroupByExpr::Expressions(vec![]),
333            vec![],
334            None,
335            None,
336        )
337    }
338
339    fn values(&self, values: &'a Values) -> ast::Query {
340        let rows = values
341            .values
342            .iter()
343            .cloned()
344            .map(|v| vec![ast::Expr::from(&Expr::Value(v))])
345            .collect();
346
347        let value_name = self.translator.identifier(&(values.name.as_str().into()))[0].clone();
348        let from = ast::TableWithJoins {
349            relation: ast::TableFactor::Derived {
350                lateral: false,
351                subquery: Box::new(values_query(rows)),
352                alias: Some(ast::TableAlias {
353                    name: value_name.clone(),
354                    columns: vec![value_name],
355                }),
356            },
357            joins: vec![],
358        };
359        let cte_query = self.translator.query(
360            vec![],
361            all(),
362            from,
363            None,
364            ast::GroupByExpr::Expressions(vec![]),
365            vec![],
366            None,
367            None,
368        );
369        let value_name = self.translator.identifier(&(values.name().into()))[0].clone();
370        let input_ctes = vec![self
371            .translator
372            .cte(value_name.clone(), vec![value_name], cte_query)];
373        self.translator.query(
374            input_ctes,
375            all(),
376            table_with_joins(
377                self.translator.table_factor(&values.clone().into(), None),
378                vec![],
379            ),
380            None,
381            ast::GroupByExpr::Expressions(vec![]),
382            vec![],
383            None,
384            None,
385        )
386    }
387}
388
389/// Based on the FromRelationVisitor implement the From trait
390impl From<&Relation> for ast::Query {
391    fn from(value: &Relation) -> Self {
392        let dialect_translator = PostgreSqlTranslator;
393        value.accept(FromRelationVisitor::new(dialect_translator))
394    }
395}
396
397impl Table {
398    /// Build the CREATE TABLE statement
399    pub fn create<T: RelationToQueryTranslator>(&self, translator: T) -> ast::Statement {
400        translator.create(self)
401    }
402
403    pub fn insert<T: RelationToQueryTranslator>(
404        &self,
405        prefix: &str,
406        translator: T,
407    ) -> ast::Statement {
408        translator.insert(prefix, self)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::{
416        builder::{Ready, With},
417        data_type::{DataType, Value},
418        display::Dot,
419        namer,
420        relation::schema::Schema,
421    };
422    use std::sync::Arc;
423
424    fn build_complex_relation() -> Arc<Relation> {
425        namer::reset();
426        let schema: Schema = vec![
427            ("a", DataType::float()),
428            ("b", DataType::float_interval(-2., 2.)),
429            ("c", DataType::float()),
430            ("d", DataType::float_interval(0., 1.)),
431        ]
432        .into_iter()
433        .collect();
434        let table: Arc<Relation> = Arc::new(
435            Relation::table()
436                .name("table")
437                .schema(schema.clone())
438                .size(100)
439                .build(),
440        );
441        let map: Arc<Relation> = Arc::new(
442            Relation::map()
443                .name("map_1")
444                .with(Expr::exp(Expr::col("a")))
445                .input(table.clone())
446                .with(Expr::col("b") + Expr::col("d"))
447                .build(),
448        );
449        let join: Arc<Relation> = Arc::new(
450            Relation::join()
451                .name("join")
452                .cross()
453                .left(table.clone())
454                .right(map.clone())
455                .build(),
456        );
457        let map_2: Arc<Relation> = Arc::new(
458            Relation::map()
459                .name("map_2")
460                .with(Expr::exp(Expr::col(join[4].name())))
461                .input(join.clone())
462                .with(Expr::col(join[0].name()) + Expr::col(join[1].name()))
463                .limit(100)
464                .offset(20)
465                .build(),
466        );
467        let join_2: Arc<Relation> = Arc::new(
468            Relation::join()
469                .name("join_2")
470                .cross()
471                .left(join.clone())
472                .right(map_2.clone())
473                .build(),
474        );
475        join_2
476    }
477
478    #[test]
479    fn test_from_table_relation() {
480        // let relation = build_complex_relation();
481        let schema: Schema = vec![
482            ("a", DataType::float()),
483            ("b", DataType::float_interval(-2., 2.)),
484            ("c", DataType::float()),
485            ("d", DataType::float_interval(0., 1.)),
486        ]
487        .into_iter()
488        .collect();
489        let table: Relation = Relation::table()
490            .name("Name")
491            .schema(schema.clone())
492            .build();
493        let query = ast::Query::from(&table);
494        println!("query = {query}");
495    }
496
497    #[test]
498    fn test_from_complex_relation() {
499        let relation = build_complex_relation();
500        let relation = relation.as_ref();
501        relation.display_dot().unwrap();
502        let query = ast::Query::from(relation);
503        println!("query = {query}");
504    }
505
506    #[test]
507    fn test_display_join() {
508        namer::reset();
509        let schema: Schema = vec![("b", DataType::float_interval(-2., 2.))]
510            .into_iter()
511            .collect();
512        let left: Relation = Relation::table()
513            .name("left")
514            .schema(schema.clone())
515            .size(1000)
516            .build();
517        let right: Relation = Relation::table()
518            .name("right")
519            .schema(schema.clone())
520            .size(1000)
521            .build();
522
523        let join: Relation = Relation::join()
524            .name("join")
525            .left_outer(Expr::val(true))
526            .on_eq("b", "b")
527            .left(left)
528            .right(right)
529            .build();
530
531        let query = ast::Query::from(&join);
532        println!("query = {}", query.to_string());
533    }
534
535    #[ignore] // Too fragile
536    #[test]
537    fn test_display_values() {
538        namer::reset();
539        let values: Relation = Relation::values()
540            .name("my_values")
541            .values([Value::from(3.), Value::from(4)])
542            .build();
543
544        let query = ast::Query::from(&values);
545        assert_eq!(
546            query.to_string(),
547            r#"WITH "my_values" ("my_values") AS (SELECT * FROM (VALUES (3), (4)) AS "my_values" ("my_values")) SELECT * FROM "my_values""#.to_string()
548        );
549
550        let schema: Schema = vec![("b", DataType::float_interval(-2., 2.))]
551            .into_iter()
552            .collect();
553        let table: Relation = Relation::table()
554            .name("table")
555            .schema(schema.clone())
556            .size(1000)
557            .build();
558
559        let join: Relation = Relation::join().left(values).right(table).cross().build();
560        let query = ast::Query::from(&join);
561        assert_eq!(
562            query.to_string(),
563            r#"WITH "my_values" ("my_values") AS (SELECT * FROM (VALUES (3), (4)) AS "my_values" ("my_values")), "join_zs1x" ("field_gu2a", "field_b8x4") AS (SELECT * FROM "my_values" AS "_LEFT_" CROSS JOIN "table" AS "_RIGHT_") SELECT * FROM "join_zs1x""#.to_string()
564        );
565    }
566}