1use std::fmt::{self, Display};
2
3use serde::{Deserialize, Serialize};
4use sqlparser::{ast, dialect::GenericDialect, parser::Parser};
5use utoipa::{IntoParams, ToSchema};
6
7use crate::{
8 aggregation::{Aggregation, KoronFunction},
9 destructured_query::DestructuredQuery,
10 error::ParseError,
11 filter::{Filter, FilterExtractor},
12 support::case_fold_identifier,
13 table::{TabIdent, TableIdentWithAlias},
14 unsupported,
15};
16
17#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default, ToSchema, IntoParams)]
19pub struct QueryMetadata {
20 pub aggregation: Aggregation,
22 pub table: TabIdent,
24 pub filter: Option<Filter>,
26 pub data_extraction_query: String,
28 pub data_aggregation_query: Option<String>,
30}
31
32impl QueryMetadata {
33 pub fn parse(
35 sql_query: &str,
36 quote_style: Option<char>, ) -> Result<Self, ParseError> {
38 let statements = Parser::parse_sql(&GenericDialect {}, sql_query)?;
40 let statement = Self::extract_select_query(&statements)?;
42 let DestructuredQuery {
44 projection,
45 from,
46 selection,
47 } = DestructuredQuery::destructure(statement)?;
48 let TableIdentWithAlias(table_name, table_alias) = TableIdentWithAlias::extract(from)?;
50 let from_clause_identifier = table_alias.as_deref().map_or_else(
52 || FromClauseIdentifier::Base(&table_name),
53 |x| FromClauseIdentifier::Alias { alias: x },
54 );
55
56 let aggregation = Aggregation::extract(from_clause_identifier, projection)?;
58
59 let filter = selection
60 .map(|selection| FilterExtractor::new(from_clause_identifier).extract(selection))
61 .transpose()?;
62
63 let data_extraction_query =
64 Self::create_data_extraction_query(&aggregation, &table_name, &filter, quote_style);
65 let data_aggregation_query = match aggregation.function {
66 KoronFunction::Median => None,
67 _ => Some(Self::create_data_aggregation_query(
68 projection, from, selection,
69 )?),
70 };
71 Ok(Self {
72 aggregation,
73 table: table_name,
74 filter,
75 data_extraction_query,
76 data_aggregation_query,
77 })
78 }
79
80 fn extract_select_query(statements: &[ast::Statement]) -> Result<&ast::Query, ParseError> {
81 if let [ast::Statement::Query(query)] = statements {
82 Ok(query)
83 } else {
84 Err(unsupported!(
85 "statements different from single SELECT statement.".to_string()
86 ))
87 }
88 }
89
90 #[must_use]
91 pub fn create_data_extraction_query(
92 aggregation: &Aggregation,
93 table: &TabIdent,
94 filter: &Option<Filter>,
95 quote_style: Option<char>, ) -> String {
97 let mut projection = Vec::default();
98 let aggregation_column_ident =
99 ast::SelectItem::UnnamedExpr(ast::Expr::Identifier(ast::Ident {
100 value: aggregation.column.clone(),
101 quote_style,
102 }));
103 projection.push(aggregation_column_ident);
104 if let Some(filter) = &filter {
105 if filter.column != aggregation.column {
106 let filter_column_ident =
107 ast::SelectItem::UnnamedExpr(ast::Expr::Identifier(ast::Ident {
108 value: filter.column.clone(),
109 quote_style,
110 }));
111 projection.push(filter_column_ident);
112 }
113 }
114 let from = vec![ast::TableWithJoins {
115 relation: ast::TableFactor::Table {
116 name: table.into_object_name(quote_style),
117 alias: None,
118 args: None,
119 with_hints: Vec::default(),
120 version: None,
121 partitions: Vec::default(),
122 },
123 joins: Vec::default(),
124 }];
125 let select_expr = ast::Select {
126 distinct: None,
127 top: None,
128 projection,
129 into: None,
130 from,
131 lateral_views: Vec::default(),
132 selection: None,
133 group_by: ast::GroupByExpr::Expressions(Vec::default()),
134 cluster_by: Vec::default(),
135 distribute_by: Vec::default(),
136 sort_by: Vec::default(),
137 having: None,
138 qualify: None,
139 named_window: Vec::default(),
140 };
141 let query_body = ast::SetExpr::Select(Box::new(select_expr));
142 let query = ast::Query {
143 with: None,
144 body: Box::new(query_body),
145 order_by: Vec::default(),
146 limit: None,
147 offset: None,
148 fetch: None,
149 locks: Vec::default(),
150 limit_by: Vec::default(),
151 for_clause: None,
152 };
153 let select_statement = ast::Statement::Query(Box::new(query));
154 select_statement.to_string()
155 }
156
157 fn create_data_aggregation_query(
158 projection: &[ast::SelectItem],
159 from: &[ast::TableWithJoins],
160 selection: Option<&ast::Expr>,
161 ) -> Result<String, ParseError> {
162 let projection = match projection {
163 [ast::SelectItem::UnnamedExpr(expr)] => {
164 vec![ast::SelectItem::UnnamedExpr(ast::Expr::Cast {
165 expr: Box::new(expr.clone()),
166 data_type: ast::DataType::Text,
167 format: None,
168 })]
169 }
170 [ast::SelectItem::ExprWithAlias { expr, alias }] => {
171 vec![ast::SelectItem::ExprWithAlias {
172 expr: ast::Expr::Cast {
173 expr: Box::new(expr.clone()),
174 data_type: ast::DataType::Text,
175 format: None,
176 },
177 alias: alias.clone(),
178 }]
179 }
180 _ => {
181 return Err(unsupported!("the SELECT clause must contain exactly one aggregation / analytic function. Nothing else is accepted.".to_string()));
182 }
183 };
184 let select_expr = ast::Select {
185 distinct: None,
186 top: None,
187 projection,
188 into: None,
189 from: from.to_vec(),
190 lateral_views: Vec::default(),
191 selection: selection.cloned(),
192 group_by: ast::GroupByExpr::Expressions(Vec::default()),
193 cluster_by: Vec::default(),
194 distribute_by: Vec::default(),
195 sort_by: Vec::default(),
196 having: None,
197 qualify: None,
198 named_window: Vec::default(),
199 };
200 let query_body = ast::SetExpr::Select(Box::new(select_expr));
201 let query = ast::Query {
202 with: None,
203 body: Box::new(query_body),
204 order_by: Vec::default(),
205 limit: None,
206 offset: None,
207 fetch: None,
208 locks: Vec::default(),
209 limit_by: Vec::default(),
210 for_clause: None,
211 };
212 let select_statement = ast::Statement::Query(Box::new(query));
213 Ok(select_statement.to_string())
214 }
215}
216
217#[derive(Clone, Copy)]
218pub(crate) enum FromClauseIdentifier<'a> {
219 Base(&'a TabIdent),
220 Alias { alias: &'a str },
221}
222
223impl FromClauseIdentifier<'_> {
224 pub fn matches(
225 self,
226 db: Option<&ast::Ident>,
227 schema: Option<&ast::Ident>,
228 table: &ast::Ident,
229 ) -> bool {
230 match self {
231 FromClauseIdentifier::Base(expected) => {
232 let db_matches = if expected.db.is_none() {
233 true
234 } else {
235 db.map_or(true, |db| {
236 expected
237 .db
238 .as_ref()
239 .map_or(true, |expected_db| &case_fold_identifier(db) == expected_db)
240 })
241 };
242 let schema_matches = if expected.schema.is_none() {
243 true
244 } else {
245 schema.map_or(true, |schema| {
246 expected.schema.as_ref().map_or(true, |expected_schema| {
247 &case_fold_identifier(schema) == expected_schema
248 })
249 })
250 };
251 let table_matches = case_fold_identifier(table) == expected.table;
252 db_matches && schema_matches && table_matches
253 }
254 FromClauseIdentifier::Alias { alias, .. } => {
255 schema.is_none() && case_fold_identifier(table) == alias
258 }
259 }
260 }
261}
262
263impl Display for FromClauseIdentifier<'_> {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 match self {
266 FromClauseIdentifier::Base(table_info) => write!(f, "{table_info}"),
267 FromClauseIdentifier::Alias { alias } => {
268 write!(f, "{alias}")
269 }
270 }
271 }
272}