proof_of_sql_parser/
select_statement.rs

1use super::intermediate_ast::{OrderBy, SetExpression, Slice, TableExpression};
2use crate::{sql::SelectStatementParser, Identifier, ParseError, ParseResult, ResourceId};
3use alloc::{boxed::Box, string::ToString, vec::Vec};
4use core::{fmt, str::FromStr};
5use serde::{Deserialize, Serialize};
6
7/// Representation of a select statement, that is, the only type of queries allowed.
8#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)]
9pub struct SelectStatement {
10    /// the query expression
11    pub expr: Box<SetExpression>,
12
13    /// if non-empty, an sort-order that is applied to the rows returned as result
14    pub order_by: Vec<OrderBy>,
15
16    /// an optional slice clause, which can restrict the rows returned to a window within the
17    /// set of rows as generated by `expr` and `order_by`.
18    pub slice: Option<Slice>,
19}
20
21impl fmt::Debug for SelectStatement {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(
24            f,
25            "SelectStatement \n[{:#?},\n{:#?},\n{:#?}\n]",
26            self.expr, self.order_by, self.slice
27        )
28    }
29}
30
31impl SelectStatement {
32    /// This function returns the referenced tables in the provided `intermediate_ast`
33    ///
34    /// Note that we provide a `default_schema` in case the table expression
35    /// does not have any associated schema. This `default_schema` is
36    /// used to construct the `resource_id`, as we cannot have this field empty.
37    /// In case the table expression already has an associated schema,
38    /// then it's used instead of `default_schema`. Although the DQL endpoint
39    /// would require both to be equal, we have chosen to not fail here
40    /// as this would imply the caller to always know beforehand the referenced
41    /// schemas.
42    ///
43    /// Return:
44    /// - The vector with all tables referenced by the intermediate ast, encoded as resource ids.
45    #[must_use]
46    pub fn get_table_references(&self, default_schema: Identifier) -> Vec<ResourceId> {
47        let set_expression: &SetExpression = &(self.expr);
48
49        match set_expression {
50            SetExpression::Query {
51                result_exprs: _,
52                from,
53                where_expr: _,
54                group_by: _,
55            } => convert_table_expr_to_resource_id_vector(&from[..], default_schema),
56        }
57    }
58}
59
60impl FromStr for SelectStatement {
61    type Err = crate::ParseError;
62
63    fn from_str(query: &str) -> ParseResult<Self> {
64        SelectStatementParser::new()
65            .parse(query)
66            .map_err(|e| ParseError::QueryParseError {
67                error: e.to_string(),
68            })
69    }
70}
71
72/// # Panics
73///
74/// This function will panic in the following cases:
75/// - If `ResourceId::try_new` fails to create a valid `ResourceId`,
76///   the `.unwrap()` call will cause a panic.
77fn convert_table_expr_to_resource_id_vector(
78    table_expressions: &[Box<TableExpression>],
79    default_schema: Identifier,
80) -> Vec<ResourceId> {
81    let mut tables = Vec::new();
82
83    for table_expression in table_expressions {
84        let table_ref: &TableExpression = table_expression;
85
86        match table_ref {
87            TableExpression::Named { table, schema } => {
88                let schema = schema.as_ref().map_or_else(
89                    || default_schema.name(),
90                    super::identifier::Identifier::as_str,
91                );
92
93                tables.push(ResourceId::try_new(schema, table.as_str()).unwrap());
94            }
95        }
96    }
97
98    tables
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::sql::SelectStatementParser;
105
106    #[test]
107    fn we_can_get_the_correct_table_references_using_a_default_schema() {
108        let parsed_query_ast = SelectStatementParser::new()
109            .parse("SELECT A FROM TAB WHERE C = 3")
110            .unwrap();
111        let default_schema = Identifier::try_new("ETH").unwrap();
112        let ref_tables = parsed_query_ast.get_table_references(default_schema);
113
114        // note: the parsed table is always lower case
115        assert_eq!(ref_tables, [ResourceId::try_new("eth", "tab").unwrap()]);
116    }
117
118    #[test]
119    fn we_can_get_the_correct_table_references_in_case_the_default_schema_equals_the_original_schema(
120    ) {
121        let parsed_query_ast = SelectStatementParser::new()
122            .parse("SELECT A FROM SCHEMA.TAB WHERE C = 3")
123            .unwrap();
124        let default_schema = Identifier::try_new("SCHEMA").unwrap();
125        let ref_tables = parsed_query_ast.get_table_references(default_schema);
126
127        assert_eq!(ref_tables, [ResourceId::try_new("schema", "tab").unwrap()]);
128    }
129
130    #[test]
131    fn we_can_get_the_correct_table_references_in_case_the_default_schema_differs_from_the_original_schema(
132    ) {
133        let parsed_query_ast = SelectStatementParser::new()
134            .parse("SELECT A FROM SCHEMA.TAB WHERE C = 3")
135            .unwrap();
136        let default_schema = Identifier::try_new("  ETH  ").unwrap();
137        let ref_tables = parsed_query_ast.get_table_references(default_schema);
138
139        assert_eq!(ref_tables, [ResourceId::try_new("schema", "tab").unwrap()]);
140    }
141}