koron_query_parser/
query_metadata.rs

1use std::fmt::{self, Display};
2
3use serde::{Deserialize, Serialize};
4use sqlparser::{ast, dialect::GenericDialect, parser::Parser};
5use utoipa::{IntoParams, ToSchema};
6
7use crate::{
8    aggregation::{Aggregation, KoronFunction},
9    destructured_query::DestructuredQuery,
10    error::ParseError,
11    filter::{Filter, FilterExtractor},
12    support::case_fold_identifier,
13    table::{TabIdent, TableIdentWithAlias},
14    unsupported,
15};
16
17/// QueryMetadata extracted from the query.
18#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema, IntoParams)]
19pub struct QueryMetadata {
20    /// Aggregation performed.
21    pub aggregation: Aggregation,
22    /// Table subject to query.
23    pub table: TabIdent,
24    /// Filter applied.
25    pub filter: Option<Filter>,
26    /// Data Extraction Query in SQL
27    pub data_extraction_query: String,
28    /// Data Aggregation Query in SQL
29    pub data_aggregation_query: Option<String>,
30}
31
32impl QueryMetadata {
33    /// Generates `QueryMetadata` from a SQL query using [`crate::config::Config`].
34    pub fn parse(
35        sql_query: &str,
36        quote_style: Option<char>, /* e.g. "'" for PostgreSQL, "`" for MySQL */
37    ) -> Result<Self, ParseError> {
38        //extract all the statement from the sql query.
39        let statements = Parser::parse_sql(&GenericDialect {}, sql_query)?;
40        //check if the sql query is: single, and is a select.
41        let statement = Self::extract_select_query(&statements)?;
42        //check and extract query clauses from statement
43        let DestructuredQuery {
44            projection,
45            from,
46            selection,
47        } = DestructuredQuery::destructure(statement)?;
48        //check and extract table informations from FROM clause
49        let TableIdentWithAlias(table_name, table_alias) = TableIdentWithAlias::extract(from)?;
50        //extract table name to be used in the SELECT clause
51        let from_clause_identifier = table_alias.as_deref().map_or_else(
52            || FromClauseIdentifier::Base(&table_name),
53            |x| FromClauseIdentifier::Alias { alias: x },
54        );
55
56        //extract analytic functions
57        let aggregation = Aggregation::extract(from_clause_identifier, projection)?;
58
59        let filter = selection
60            .map(|selection| FilterExtractor::new(from_clause_identifier).extract(selection))
61            .transpose()?;
62
63        let data_extraction_query =
64            Self::create_data_extraction_query(&aggregation, &table_name, &filter, quote_style);
65        let data_aggregation_query = match aggregation.function {
66            KoronFunction::Median => None,
67            _ => Some(Self::create_data_aggregation_query(
68                projection, from, selection,
69            )?),
70        };
71        Ok(Self {
72            aggregation,
73            table: table_name,
74            filter,
75            data_extraction_query,
76            data_aggregation_query,
77        })
78    }
79
80    fn extract_select_query(statements: &[ast::Statement]) -> Result<&ast::Query, ParseError> {
81        if let [ast::Statement::Query(query)] = statements {
82            Ok(query)
83        } else {
84            Err(unsupported!(
85                "statements different from single SELECT statement.".to_string()
86            ))
87        }
88    }
89
90    #[must_use]
91    pub fn create_data_extraction_query(
92        aggregation: &Aggregation,
93        table: &TabIdent,
94        filter: &Option<Filter>,
95        quote_style: Option<char>, // e.g. "'" for PostgreSQL, "`" for MySQL
96    ) -> String {
97        let mut projection = Vec::default();
98        let aggregation_column_ident =
99            ast::SelectItem::UnnamedExpr(ast::Expr::Identifier(ast::Ident {
100                value: aggregation.column.clone(),
101                quote_style,
102            }));
103        projection.push(aggregation_column_ident);
104        if let Some(filter) = &filter {
105            if filter.column != aggregation.column {
106                let filter_column_ident =
107                    ast::SelectItem::UnnamedExpr(ast::Expr::Identifier(ast::Ident {
108                        value: filter.column.clone(),
109                        quote_style,
110                    }));
111                projection.push(filter_column_ident);
112            }
113        }
114        let from = vec![ast::TableWithJoins {
115            relation: ast::TableFactor::Table {
116                name: table.into_object_name(quote_style),
117                alias: None,
118                args: None,
119                with_hints: Vec::default(),
120                version: None,
121                partitions: Vec::default(),
122            },
123            joins: Vec::default(),
124        }];
125        let select_expr = ast::Select {
126            distinct: None,
127            top: None,
128            projection,
129            into: None,
130            from,
131            lateral_views: Vec::default(),
132            selection: None,
133            group_by: ast::GroupByExpr::Expressions(Vec::default()),
134            cluster_by: Vec::default(),
135            distribute_by: Vec::default(),
136            sort_by: Vec::default(),
137            having: None,
138            qualify: None,
139            named_window: Vec::default(),
140        };
141        let query_body = ast::SetExpr::Select(Box::new(select_expr));
142        let query = ast::Query {
143            with: None,
144            body: Box::new(query_body),
145            order_by: Vec::default(),
146            limit: None,
147            offset: None,
148            fetch: None,
149            locks: Vec::default(),
150            limit_by: Vec::default(),
151            for_clause: None,
152        };
153        let select_statement = ast::Statement::Query(Box::new(query));
154        select_statement.to_string()
155    }
156
157    fn create_data_aggregation_query(
158        projection: &[ast::SelectItem],
159        from: &[ast::TableWithJoins],
160        selection: Option<&ast::Expr>,
161    ) -> Result<String, ParseError> {
162        let projection = match projection {
163            [ast::SelectItem::UnnamedExpr(expr)] => {
164                vec![ast::SelectItem::UnnamedExpr(ast::Expr::Cast {
165                    expr: Box::new(expr.clone()),
166                    data_type: ast::DataType::Text,
167                    format: None,
168                })]
169            }
170            [ast::SelectItem::ExprWithAlias { expr, alias }] => {
171                vec![ast::SelectItem::ExprWithAlias {
172                    expr: ast::Expr::Cast {
173                        expr: Box::new(expr.clone()),
174                        data_type: ast::DataType::Text,
175                        format: None,
176                    },
177                    alias: alias.clone(),
178                }]
179            }
180            _ => {
181                return Err(unsupported!("the SELECT clause must contain exactly one aggregation / analytic function. Nothing else is accepted.".to_string()));
182            }
183        };
184        let select_expr = ast::Select {
185            distinct: None,
186            top: None,
187            projection,
188            into: None,
189            from: from.to_vec(),
190            lateral_views: Vec::default(),
191            selection: selection.cloned(),
192            group_by: ast::GroupByExpr::Expressions(Vec::default()),
193            cluster_by: Vec::default(),
194            distribute_by: Vec::default(),
195            sort_by: Vec::default(),
196            having: None,
197            qualify: None,
198            named_window: Vec::default(),
199        };
200        let query_body = ast::SetExpr::Select(Box::new(select_expr));
201        let query = ast::Query {
202            with: None,
203            body: Box::new(query_body),
204            order_by: Vec::default(),
205            limit: None,
206            offset: None,
207            fetch: None,
208            locks: Vec::default(),
209            limit_by: Vec::default(),
210            for_clause: None,
211        };
212        let select_statement = ast::Statement::Query(Box::new(query));
213        Ok(select_statement.to_string())
214    }
215}
216
217#[derive(Clone, Copy)]
218pub(crate) enum FromClauseIdentifier<'a> {
219    Base(&'a TabIdent),
220    Alias { alias: &'a str },
221}
222
223impl FromClauseIdentifier<'_> {
224    pub fn matches(
225        self,
226        db: Option<&ast::Ident>,
227        schema: Option<&ast::Ident>,
228        table: &ast::Ident,
229    ) -> bool {
230        match self {
231            FromClauseIdentifier::Base(expected) => {
232                let db_matches = if expected.db.is_none() {
233                    true
234                } else {
235                    db.map_or(true, |db| {
236                        expected
237                            .db
238                            .as_ref()
239                            .map_or(true, |expected_db| &case_fold_identifier(db) == expected_db)
240                    })
241                };
242                let schema_matches = if expected.schema.is_none() {
243                    true
244                } else {
245                    schema.map_or(true, |schema| {
246                        expected.schema.as_ref().map_or(true, |expected_schema| {
247                            &case_fold_identifier(schema) == expected_schema
248                        })
249                    })
250                };
251                let table_matches = case_fold_identifier(table) == expected.table;
252                db_matches && schema_matches && table_matches
253            }
254            FromClauseIdentifier::Alias { alias, .. } => {
255                // An alias name is always unqualified, so it can never match a schema-qualified
256                // table name.
257                schema.is_none() && case_fold_identifier(table) == alias
258            }
259        }
260    }
261}
262
263impl Display for FromClauseIdentifier<'_> {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        match self {
266            FromClauseIdentifier::Base(table_info) => write!(f, "{table_info}"),
267            FromClauseIdentifier::Alias { alias } => {
268                write!(f, "{alias}")
269            }
270        }
271    }
272}