1use {
2 super::expr::PlanExpr,
3 crate::{
4 ast::{
5 Expr, Join, JoinConstraint, JoinOperator, Query, Select, SelectItem, SetExpr,
6 Statement, TableFactor, TableWithJoins,
7 },
8 data::Schema,
9 result::Result,
10 store::Store,
11 },
12 async_recursion::async_recursion,
13 futures::stream::{self, StreamExt, TryStreamExt},
14 std::collections::HashMap,
15};
16
17pub async fn fetch_schema_map<T: Store>(
18 storage: &T,
19 statement: &Statement,
20) -> Result<HashMap<String, Schema>> {
21 match statement {
22 Statement::Query(query) => scan_query(storage, query).await,
23 Statement::Insert {
24 table_name, source, ..
25 } => {
26 let table_schema = storage
27 .fetch_schema(table_name)
28 .await?
29 .map(|schema| HashMap::from([(table_name.to_owned(), schema)]))
30 .unwrap_or_else(HashMap::new);
31 let source_schema_list = scan_query(storage, source).await?;
32 let schema_list = table_schema.into_iter().chain(source_schema_list).collect();
33
34 Ok(schema_list)
35 }
36 Statement::CreateTable { name, source, .. } => {
37 let table_schema = storage
38 .fetch_schema(name)
39 .await?
40 .map(|schema| HashMap::from([(name.to_owned(), schema)]))
41 .unwrap_or_else(HashMap::new);
42 let source_schema_list = match source {
43 Some(source) => scan_query(storage, source).await?,
44 None => HashMap::new(),
45 };
46 let schema_list = table_schema.into_iter().chain(source_schema_list).collect();
47
48 Ok(schema_list)
49 }
50 Statement::DropTable { names, .. } => {
51 stream::iter(names)
52 .filter_map(|table_name| async {
53 storage
54 .fetch_schema(table_name)
55 .await
56 .map(|schema| Some((table_name.clone(), schema?)))
57 .transpose()
58 })
59 .try_collect()
60 .await
61 }
62 _ => Ok(HashMap::new()),
63 }
64}
65
66async fn scan_query<T: Store>(storage: &T, query: &Query) -> Result<HashMap<String, Schema>> {
67 let Query {
68 body,
69 limit,
70 offset,
71 ..
72 } = query;
73
74 let schema_list = match body {
75 SetExpr::Select(select) => scan_select(storage, select).await?,
76 SetExpr::Values(_) => HashMap::new(),
77 };
78
79 let schema_list = match (limit, offset) {
80 (Some(limit), Some(offset)) => schema_list
81 .into_iter()
82 .chain(scan_expr(storage, limit).await?)
83 .chain(scan_expr(storage, offset).await?)
84 .collect(),
85 (Some(expr), None) | (None, Some(expr)) => schema_list
86 .into_iter()
87 .chain(scan_expr(storage, expr).await?)
88 .collect(),
89 (None, None) => schema_list,
90 };
91
92 Ok(schema_list)
93}
94
95async fn scan_select<T: Store>(storage: &T, select: &Select) -> Result<HashMap<String, Schema>> {
96 let Select {
97 distinct: _,
98 projection,
99 from,
100 selection,
101 group_by,
102 having,
103 } = select;
104
105 let projection = stream::iter(projection)
106 .then(|select_item| async move {
107 match select_item {
108 SelectItem::Expr { expr, .. } => scan_expr(storage, expr).await,
109 SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => Ok(HashMap::new()),
110 }
111 })
112 .try_collect::<Vec<HashMap<String, Schema>>>()
113 .await?
114 .into_iter()
115 .flatten();
116
117 let from = scan_table_with_joins(storage, from).await?;
118
119 let exprs = selection.iter().chain(group_by.iter()).chain(having.iter());
120
121 Ok(stream::iter(exprs)
122 .then(|expr| scan_expr(storage, expr))
123 .try_collect::<Vec<HashMap<String, Schema>>>()
124 .await?
125 .into_iter()
126 .flatten()
127 .chain(projection)
128 .chain(from)
129 .collect())
130}
131
132async fn scan_table_with_joins<T: Store>(
133 storage: &T,
134 table_with_joins: &TableWithJoins,
135) -> Result<HashMap<String, Schema>> {
136 let TableWithJoins { relation, joins } = table_with_joins;
137 let schema_list = scan_table_factor(storage, relation).await?;
138
139 Ok(stream::iter(joins)
140 .then(|join| scan_join(storage, join))
141 .try_collect::<Vec<HashMap<String, Schema>>>()
142 .await?
143 .into_iter()
144 .flatten()
145 .chain(schema_list)
146 .collect())
147}
148
149async fn scan_join<T: Store>(storage: &T, join: &Join) -> Result<HashMap<String, Schema>> {
150 let Join {
151 relation,
152 join_operator,
153 ..
154 } = join;
155
156 let schema_list = scan_table_factor(storage, relation).await?;
157 let schema_list = match join_operator {
158 JoinOperator::Inner(JoinConstraint::On(expr))
159 | JoinOperator::LeftOuter(JoinConstraint::On(expr)) => scan_expr(storage, expr)
160 .await?
161 .into_iter()
162 .chain(schema_list)
163 .collect(),
164 JoinOperator::Inner(JoinConstraint::None)
165 | JoinOperator::LeftOuter(JoinConstraint::None) => schema_list,
166 };
167
168 Ok(schema_list)
169}
170
171#[async_recursion]
172async fn scan_table_factor<T>(
173 storage: &T,
174 table_factor: &TableFactor,
175) -> Result<HashMap<String, Schema>>
176where
177 T: Store,
178{
179 match table_factor {
180 TableFactor::Table { name, .. } => {
181 let schema = storage.fetch_schema(name).await?;
182 let schema_list: HashMap<String, Schema> = schema.map_or_else(HashMap::new, |schema| {
183 HashMap::from([(name.to_owned(), schema)])
184 });
185
186 Ok(schema_list)
187 }
188 TableFactor::Derived { subquery, .. } => scan_query(storage, subquery).await,
189 TableFactor::Series { .. } | TableFactor::Dictionary { .. } => Ok(HashMap::new()),
190 }
191}
192
193#[async_recursion]
194async fn scan_expr<T>(storage: &T, expr: &Expr) -> Result<HashMap<String, Schema>>
195where
196 T: Store,
197{
198 let schema_list = match expr.into() {
199 PlanExpr::None | PlanExpr::Identifier(_) | PlanExpr::CompoundIdentifier { .. } => {
200 HashMap::new()
201 }
202 PlanExpr::Expr(expr) => scan_expr(storage, expr).await?,
203 PlanExpr::TwoExprs(expr, expr2) => scan_expr(storage, expr)
204 .await?
205 .into_iter()
206 .chain(scan_expr(storage, expr2).await?)
207 .collect(),
208 PlanExpr::ThreeExprs(expr, expr2, expr3) => scan_expr(storage, expr)
209 .await?
210 .into_iter()
211 .chain(scan_expr(storage, expr2).await?)
212 .chain(scan_expr(storage, expr3).await?)
213 .collect(),
214 PlanExpr::MultiExprs(exprs) => stream::iter(exprs)
215 .then(|expr| scan_expr(storage, expr))
216 .try_collect::<Vec<HashMap<String, Schema>>>()
217 .await?
218 .into_iter()
219 .flatten()
220 .collect(),
221 PlanExpr::Query(query) => scan_query(storage, query).await?,
222 PlanExpr::QueryAndExpr { query, expr } => scan_query(storage, query)
223 .await?
224 .into_iter()
225 .chain(scan_expr(storage, expr).await?)
226 .collect(),
227 };
228
229 Ok(schema_list)
230}
231
232#[cfg(test)]
233mod tests {
234 use {
235 super::fetch_schema_map,
236 crate::{
237 mock::{MockStorage, run},
238 parse_sql::parse,
239 result::Result,
240 translate::translate,
241 },
242 futures::executor::block_on,
243 utils::Vector,
244 };
245
246 fn plan(storage: &MockStorage, sql: &str) -> Result<Vec<String>> {
247 let parsed = parse(sql).expect(sql).into_iter().next().unwrap();
248 let statement = translate(&parsed).unwrap();
249 let schema_map = block_on(fetch_schema_map(storage, &statement));
250
251 Ok(schema_map?
252 .into_keys()
253 .collect::<Vector<String>>()
254 .sort()
255 .into())
256 }
257
258 fn run_test(storage: &MockStorage, sql: &str, expected: &[&str]) {
259 let actual = plan(storage, sql).unwrap();
260 let actual = actual.as_slice();
261
262 assert_eq!(actual, expected, "{sql}");
263 }
264
265 #[test]
266 fn basic() {
267 let storage = run("
268 CREATE TABLE Foo (id INTEGER);
269 CREATE TABLE Bar (name TEXT);
270 ");
271
272 let test = |sql, expected| run_test(&storage, sql, expected);
273
274 test("SELECT * FROM Foo", &["Foo"]);
275 test("INSERT INTO Foo VALUES (1), (2), (3);", &["Foo"]);
276 test("DROP TABLE Foo, Bar;", &["Bar", "Foo"]);
277
278 test("DELETE FROM Foo;", &[]);
280 }
281
282 #[test]
283 fn expr() {
284 let storage = run("
285 CREATE TABLE Foo (id INTEGER);
286 CREATE TABLE Bar (name TEXT);
287 ");
288 let test = |sql, expected| run_test(&storage, sql, expected);
289
290 test(
292 r#"SELECT Foo.*, * FROM Foo WHERE id = DATE "2021-01-01";"#,
293 &["Foo"],
294 );
295
296 test(
298 "
299 SELECT * FROM Foo
300 WHERE
301 Foo.id IS NULL
302 AND id IS NOT NULL
303 OR (id IS NULL)
304 ",
305 &["Foo"],
306 );
307
308 test("SELECT * FROM Foo WHERE id = 1", &["Foo"]);
310
311 test("SELECT * FROM Foo WHERE id BETWEEN 1 AND 20", &["Foo"]);
313
314 test("SELECT * FROM Foo WHERE id IN (1, 2, 3)", &["Foo"]);
316
317 test(
319 "
320 SELECT * FROM Bar
321 WHERE
322 EXISTS(SELECT id FROM Foo)
323 AND Bar.id = (SELECT id FROM Bar LIMIT 1);
324 ",
325 &["Bar", "Foo"],
326 );
327
328 test(
330 "SELECT * FROM Foo WHERE Foo.id IN (SELECT 1 FROM Bar);",
331 &["Bar", "Foo"],
332 );
333 }
334
335 #[test]
336 fn select() {
337 let storage = run("
338 CREATE TABLE Foo (id INTEGER);
339 CREATE TABLE Bar (
340 id INTEGER,
341 foo_id INTEGER
342 );
343 CREATE TABLE Baz (flag BOOLEAN);
344 ");
345
346 let test = |sql, expected| run_test(&storage, sql, expected);
347
348 test(
349 "
350 SELECT foo_id, COUNT(*)
351 FROM Bar
352 WHERE id IS NOT NULL
353 GROUP BY foo_id
354 HAVING foo_id > 10;
355 ",
356 &["Bar"],
357 );
358 test(
359 "SELECT * FROM Foo JOIN Bar ORDER BY Foo.id",
360 &["Bar", "Foo"],
361 );
362 test("SELECT * FROM Foo LEFT OUTER JOIN Bar", &["Bar", "Foo"]);
363 test(
364 "SELECT * FROM Foo LEFT JOIN Bar ON Bar.foo_id = Foo.id",
365 &["Bar", "Foo"],
366 );
367 test(
368 "
369 SELECT * FROM Foo
370 INNER JOIN Bar ON Bar.id = Foo.bar_id
371 LEFT JOIN Baz ON False;
372 ",
373 &["Bar", "Baz", "Foo"],
374 );
375 test(
376 "
377 SELECT Bar.*, id, *
378 FROM Foo
379 JOIN Bar ON True
380 LEFT JOIN Baz ON True
381 WHERE Foo.id = 1
382 LIMIT 1 OFFSET 1
383 ",
384 &["Bar", "Baz", "Foo"],
385 );
386
387 test("SELECT * FROM Railway", &[]);
389 test("SELECT * FROM Foo WHERE Foo.id = Lab.foo_id", &["Foo"]);
390 }
391
392 #[test]
393 fn storage_err() {
394 let storage = run("
395 CREATE TABLE Foo (id INTEGER);
396 CREATE TABLE Bar (id INTEGER);
397 CREATE TABLE Baz (flag BOOLEAN);
398 ");
399
400 let test = |sql| assert!(plan(&storage, sql).is_err(), "{sql}");
401
402 test("SELECT * FROM __Err__");
403 test("INSERT INTO __Err__ VALUES (1), (2)");
404 test("DROP TABLE __Err__");
405
406 test("SELECT * FROM Foo WHERE id = (SELECT foo_id FROM __Err__ LIMIT 1)");
407 test("SELECT * FROM Foo WHERE (SELECT foo_id FROM __Err__ LIMIT 1) = id");
408 test("SELECT * FROM Foo WHERE id BETWEEN (SELECT foo_id FROM __Err__ LIMIT 1) AND 100");
409 test("SELECT * FROM Foo WHERE (SELECT id FROM __Err__ LIMIT 1) BETWEEN 20 AND 50");
410 test("SELECT * FROM Foo WHERE id IN (1, 2, (SELECT foo_id FROM __Err__ LIMIT 1), 5)");
411 test("SELECT * FROM Foo WHERE id IN (SELECT * FROM __Err__)");
412 test("SELECT * FROM Foo LEFT JOIN Bar ON Bar.id = (SELECT id FROM __Err__ LIMIT 1)");
413 test("SELECT id, (SELECT id FROM __Err__ LIMIT 1) AS cc FROM Foo;");
414 }
415}