1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use super::intermediate_ast::{OrderBy, SetExpression, Slice, TableExpression};
use crate::{sql::SelectStatementParser, Identifier, ParseError, ParseResult, ResourceId};
use serde::{Deserialize, Serialize};
use std::{fmt, ops::Deref, str::FromStr};

/// Representation of a select statement, that is, the only type of queries allowed.
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct SelectStatement {
    /// the query expression
    pub expr: Box<SetExpression>,

    /// if non-empty, an sort-order that is applied to the rows returned as result
    pub order_by: Vec<OrderBy>,

    /// an optional slice clause, which can restrict the rows returned to a window within the
    /// set of rows as generated by `expr` and `order_by`.
    pub slice: Option<Slice>,
}

impl fmt::Debug for SelectStatement {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "SelectStatement \n[{:#?},\n{:#?},\n{:#?}\n]",
            self.expr, self.order_by, self.slice
        )
    }
}

impl SelectStatement {
    /// This function returns the referenced tables in the provided intermediate_ast
    ///
    /// Note that we provide a `default_schema` in case the table expression
    /// does not have any associated schema. This `default_schema` is
    /// used to construct the resource_id, as we cannot have this field empty.
    /// In case the table expression already has an associated schema,
    /// then it's used instead of `default_schema`. Although the DQL endpoint
    /// would require both to be equal, we have chosen to not fail here
    /// as this would imply the caller to always know beforehand the referenced
    /// schemas.
    ///
    /// Return:
    /// - The vector with all tables referenced by the intermediate ast, encoded as resource ids.
    pub fn get_table_references(&self, default_schema: Identifier) -> Vec<ResourceId> {
        let set_expression: &SetExpression = &(self.expr);

        match set_expression {
            SetExpression::Query {
                result_exprs: _,
                from,
                where_expr: _,
                group_by: _,
            } => convert_table_expr_to_resource_id_vector(&from[..], default_schema),
        }
    }
}

impl FromStr for SelectStatement {
    type Err = crate::ParseError;

    fn from_str(query: &str) -> ParseResult<Self> {
        SelectStatementParser::new()
            .parse(query)
            .map_err(|e| ParseError::QueryParseError(e.to_string()))
    }
}

fn convert_table_expr_to_resource_id_vector(
    table_expressions: &[Box<TableExpression>],
    default_schema: Identifier,
) -> Vec<ResourceId> {
    let mut tables = Vec::new();

    for table_expression in table_expressions.iter() {
        let table_ref: &TableExpression = table_expression.deref();

        match table_ref {
            TableExpression::Named { table, schema } => {
                let schema = schema
                    .as_ref()
                    .map(|schema| schema.as_str())
                    .unwrap_or_else(|| default_schema.name());

                tables.push(ResourceId::try_new(schema, table.as_str()).unwrap());
            }
        }
    }

    tables
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::sql::SelectStatementParser;

    #[test]
    fn we_can_get_the_correct_table_references_using_a_default_schema() {
        let parsed_query_ast = SelectStatementParser::new()
            .parse("SELECT A FROM TAB WHERE C = 3")
            .unwrap();
        let default_schema = Identifier::try_new("ETH").unwrap();
        let ref_tables = parsed_query_ast.get_table_references(default_schema);

        // note: the parsed table is always lower case
        assert_eq!(ref_tables, [ResourceId::try_new("eth", "tab").unwrap()]);
    }

    #[test]
    fn we_can_get_the_correct_table_references_in_case_the_default_schema_equals_the_original_schema(
    ) {
        let parsed_query_ast = SelectStatementParser::new()
            .parse("SELECT A FROM SCHEMA.TAB WHERE C = 3")
            .unwrap();
        let default_schema = Identifier::try_new("SCHEMA").unwrap();
        let ref_tables = parsed_query_ast.get_table_references(default_schema);

        assert_eq!(ref_tables, [ResourceId::try_new("schema", "tab").unwrap()]);
    }

    #[test]
    fn we_can_get_the_correct_table_references_in_case_the_default_schema_differs_from_the_original_schema(
    ) {
        let parsed_query_ast = SelectStatementParser::new()
            .parse("SELECT A FROM SCHEMA.TAB WHERE C = 3")
            .unwrap();
        let default_schema = Identifier::try_new("  ETH  ").unwrap();
        let ref_tables = parsed_query_ast.get_table_references(default_schema);

        assert_eq!(ref_tables, [ResourceId::try_new("schema", "tab").unwrap()]);
    }
}