gluesql_core/plan/
schema.rs

1use {
2    super::expr::PlanExpr,
3    crate::{
4        ast::{
5            Expr, Join, JoinConstraint, JoinOperator, Query, Select, SelectItem, SetExpr,
6            Statement, TableFactor, TableWithJoins,
7        },
8        data::Schema,
9        result::Result,
10        store::Store,
11    },
12    async_recursion::async_recursion,
13    futures::stream::{self, StreamExt, TryStreamExt},
14    std::collections::HashMap,
15};
16
17pub async fn fetch_schema_map<T: Store>(
18    storage: &T,
19    statement: &Statement,
20) -> Result<HashMap<String, Schema>> {
21    match statement {
22        Statement::Query(query) => scan_query(storage, query).await,
23        Statement::Insert {
24            table_name, source, ..
25        } => {
26            let table_schema = storage
27                .fetch_schema(table_name)
28                .await?
29                .map(|schema| HashMap::from([(table_name.to_owned(), schema)]))
30                .unwrap_or_else(HashMap::new);
31            let source_schema_list = scan_query(storage, source).await?;
32            let schema_list = table_schema.into_iter().chain(source_schema_list).collect();
33
34            Ok(schema_list)
35        }
36        Statement::CreateTable { name, source, .. } => {
37            let table_schema = storage
38                .fetch_schema(name)
39                .await?
40                .map(|schema| HashMap::from([(name.to_owned(), schema)]))
41                .unwrap_or_else(HashMap::new);
42            let source_schema_list = match source {
43                Some(source) => scan_query(storage, source).await?,
44                None => HashMap::new(),
45            };
46            let schema_list = table_schema.into_iter().chain(source_schema_list).collect();
47
48            Ok(schema_list)
49        }
50        Statement::DropTable { names, .. } => {
51            stream::iter(names)
52                .filter_map(|table_name| async {
53                    storage
54                        .fetch_schema(table_name)
55                        .await
56                        .map(|schema| Some((table_name.clone(), schema?)))
57                        .transpose()
58                })
59                .try_collect()
60                .await
61        }
62        _ => Ok(HashMap::new()),
63    }
64}
65
66async fn scan_query<T: Store>(storage: &T, query: &Query) -> Result<HashMap<String, Schema>> {
67    let Query {
68        body,
69        limit,
70        offset,
71        ..
72    } = query;
73
74    let schema_list = match body {
75        SetExpr::Select(select) => scan_select(storage, select).await?,
76        SetExpr::Values(_) => HashMap::new(),
77    };
78
79    let schema_list = match (limit, offset) {
80        (Some(limit), Some(offset)) => schema_list
81            .into_iter()
82            .chain(scan_expr(storage, limit).await?)
83            .chain(scan_expr(storage, offset).await?)
84            .collect(),
85        (Some(expr), None) | (None, Some(expr)) => schema_list
86            .into_iter()
87            .chain(scan_expr(storage, expr).await?)
88            .collect(),
89        (None, None) => schema_list,
90    };
91
92    Ok(schema_list)
93}
94
95async fn scan_select<T: Store>(storage: &T, select: &Select) -> Result<HashMap<String, Schema>> {
96    let Select {
97        distinct: _,
98        projection,
99        from,
100        selection,
101        group_by,
102        having,
103    } = select;
104
105    let projection = stream::iter(projection)
106        .then(|select_item| async move {
107            match select_item {
108                SelectItem::Expr { expr, .. } => scan_expr(storage, expr).await,
109                SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => Ok(HashMap::new()),
110            }
111        })
112        .try_collect::<Vec<HashMap<String, Schema>>>()
113        .await?
114        .into_iter()
115        .flatten();
116
117    let from = scan_table_with_joins(storage, from).await?;
118
119    let exprs = selection.iter().chain(group_by.iter()).chain(having.iter());
120
121    Ok(stream::iter(exprs)
122        .then(|expr| scan_expr(storage, expr))
123        .try_collect::<Vec<HashMap<String, Schema>>>()
124        .await?
125        .into_iter()
126        .flatten()
127        .chain(projection)
128        .chain(from)
129        .collect())
130}
131
132async fn scan_table_with_joins<T: Store>(
133    storage: &T,
134    table_with_joins: &TableWithJoins,
135) -> Result<HashMap<String, Schema>> {
136    let TableWithJoins { relation, joins } = table_with_joins;
137    let schema_list = scan_table_factor(storage, relation).await?;
138
139    Ok(stream::iter(joins)
140        .then(|join| scan_join(storage, join))
141        .try_collect::<Vec<HashMap<String, Schema>>>()
142        .await?
143        .into_iter()
144        .flatten()
145        .chain(schema_list)
146        .collect())
147}
148
149async fn scan_join<T: Store>(storage: &T, join: &Join) -> Result<HashMap<String, Schema>> {
150    let Join {
151        relation,
152        join_operator,
153        ..
154    } = join;
155
156    let schema_list = scan_table_factor(storage, relation).await?;
157    let schema_list = match join_operator {
158        JoinOperator::Inner(JoinConstraint::On(expr))
159        | JoinOperator::LeftOuter(JoinConstraint::On(expr)) => scan_expr(storage, expr)
160            .await?
161            .into_iter()
162            .chain(schema_list)
163            .collect(),
164        JoinOperator::Inner(JoinConstraint::None)
165        | JoinOperator::LeftOuter(JoinConstraint::None) => schema_list,
166    };
167
168    Ok(schema_list)
169}
170
171#[async_recursion]
172async fn scan_table_factor<T>(
173    storage: &T,
174    table_factor: &TableFactor,
175) -> Result<HashMap<String, Schema>>
176where
177    T: Store,
178{
179    match table_factor {
180        TableFactor::Table { name, .. } => {
181            let schema = storage.fetch_schema(name).await?;
182            let schema_list: HashMap<String, Schema> = schema.map_or_else(HashMap::new, |schema| {
183                HashMap::from([(name.to_owned(), schema)])
184            });
185
186            Ok(schema_list)
187        }
188        TableFactor::Derived { subquery, .. } => scan_query(storage, subquery).await,
189        TableFactor::Series { .. } | TableFactor::Dictionary { .. } => Ok(HashMap::new()),
190    }
191}
192
193#[async_recursion]
194async fn scan_expr<T>(storage: &T, expr: &Expr) -> Result<HashMap<String, Schema>>
195where
196    T: Store,
197{
198    let schema_list = match expr.into() {
199        PlanExpr::None | PlanExpr::Identifier(_) | PlanExpr::CompoundIdentifier { .. } => {
200            HashMap::new()
201        }
202        PlanExpr::Expr(expr) => scan_expr(storage, expr).await?,
203        PlanExpr::TwoExprs(expr, expr2) => scan_expr(storage, expr)
204            .await?
205            .into_iter()
206            .chain(scan_expr(storage, expr2).await?)
207            .collect(),
208        PlanExpr::ThreeExprs(expr, expr2, expr3) => scan_expr(storage, expr)
209            .await?
210            .into_iter()
211            .chain(scan_expr(storage, expr2).await?)
212            .chain(scan_expr(storage, expr3).await?)
213            .collect(),
214        PlanExpr::MultiExprs(exprs) => stream::iter(exprs)
215            .then(|expr| scan_expr(storage, expr))
216            .try_collect::<Vec<HashMap<String, Schema>>>()
217            .await?
218            .into_iter()
219            .flatten()
220            .collect(),
221        PlanExpr::Query(query) => scan_query(storage, query).await?,
222        PlanExpr::QueryAndExpr { query, expr } => scan_query(storage, query)
223            .await?
224            .into_iter()
225            .chain(scan_expr(storage, expr).await?)
226            .collect(),
227    };
228
229    Ok(schema_list)
230}
231
232#[cfg(test)]
233mod tests {
234    use {
235        super::fetch_schema_map,
236        crate::{
237            mock::{MockStorage, run},
238            parse_sql::parse,
239            result::Result,
240            translate::translate,
241        },
242        futures::executor::block_on,
243        utils::Vector,
244    };
245
246    fn plan(storage: &MockStorage, sql: &str) -> Result<Vec<String>> {
247        let parsed = parse(sql).expect(sql).into_iter().next().unwrap();
248        let statement = translate(&parsed).unwrap();
249        let schema_map = block_on(fetch_schema_map(storage, &statement));
250
251        Ok(schema_map?
252            .into_keys()
253            .collect::<Vector<String>>()
254            .sort()
255            .into())
256    }
257
258    fn run_test(storage: &MockStorage, sql: &str, expected: &[&str]) {
259        let actual = plan(storage, sql).unwrap();
260        let actual = actual.as_slice();
261
262        assert_eq!(actual, expected, "{sql}");
263    }
264
265    #[test]
266    fn basic() {
267        let storage = run("
268            CREATE TABLE Foo (id INTEGER);
269            CREATE TABLE Bar (name TEXT);
270        ");
271
272        let test = |sql, expected| run_test(&storage, sql, expected);
273
274        test("SELECT * FROM Foo", &["Foo"]);
275        test("INSERT INTO Foo VALUES (1), (2), (3);", &["Foo"]);
276        test("DROP TABLE Foo, Bar;", &["Bar", "Foo"]);
277
278        // Unimplemented
279        test("DELETE FROM Foo;", &[]);
280    }
281
282    #[test]
283    fn expr() {
284        let storage = run("
285            CREATE TABLE Foo (id INTEGER);
286            CREATE TABLE Bar (name TEXT);
287        ");
288        let test = |sql, expected| run_test(&storage, sql, expected);
289
290        // PlanExpr::None
291        test(
292            r#"SELECT Foo.*, * FROM Foo WHERE id = DATE "2021-01-01";"#,
293            &["Foo"],
294        );
295
296        // PlanExpr::Expr
297        test(
298            "
299            SELECT * FROM Foo
300            WHERE
301                Foo.id IS NULL
302                AND id IS NOT NULL
303                OR (id IS NULL)
304        ",
305            &["Foo"],
306        );
307
308        // PlanExpr::TwoExprs
309        test("SELECT * FROM Foo WHERE id = 1", &["Foo"]);
310
311        // PlanExpr::ThreeExprs
312        test("SELECT * FROM Foo WHERE id BETWEEN 1 AND 20", &["Foo"]);
313
314        // PlanExpr::MultiExprs
315        test("SELECT * FROM Foo WHERE id IN (1, 2, 3)", &["Foo"]);
316
317        // PlanExpr::Query
318        test(
319            "
320            SELECT * FROM Bar
321            WHERE
322                EXISTS(SELECT id FROM Foo)
323                AND Bar.id = (SELECT id FROM Bar LIMIT 1);
324        ",
325            &["Bar", "Foo"],
326        );
327
328        // PlanExpr::QueryAndExpr
329        test(
330            "SELECT * FROM Foo WHERE Foo.id IN (SELECT 1 FROM Bar);",
331            &["Bar", "Foo"],
332        );
333    }
334
335    #[test]
336    fn select() {
337        let storage = run("
338            CREATE TABLE Foo (id INTEGER);
339            CREATE TABLE Bar (
340                id INTEGER,
341                foo_id INTEGER
342            );
343            CREATE TABLE Baz (flag BOOLEAN);
344        ");
345
346        let test = |sql, expected| run_test(&storage, sql, expected);
347
348        test(
349            "
350            SELECT foo_id, COUNT(*)
351            FROM Bar
352            WHERE id IS NOT NULL
353            GROUP BY foo_id
354            HAVING foo_id > 10;
355            ",
356            &["Bar"],
357        );
358        test(
359            "SELECT * FROM Foo JOIN Bar ORDER BY Foo.id",
360            &["Bar", "Foo"],
361        );
362        test("SELECT * FROM Foo LEFT OUTER JOIN Bar", &["Bar", "Foo"]);
363        test(
364            "SELECT * FROM Foo LEFT JOIN Bar ON Bar.foo_id = Foo.id",
365            &["Bar", "Foo"],
366        );
367        test(
368            "
369            SELECT * FROM Foo
370            INNER JOIN Bar ON Bar.id = Foo.bar_id
371            LEFT JOIN Baz ON False;
372        ",
373            &["Bar", "Baz", "Foo"],
374        );
375        test(
376            "
377            SELECT Bar.*, id, *
378            FROM Foo
379            JOIN Bar ON True
380            LEFT JOIN Baz ON True
381            WHERE Foo.id = 1
382            LIMIT 1 OFFSET 1
383            ",
384            &["Bar", "Baz", "Foo"],
385        );
386
387        // ignore rather than returning error
388        test("SELECT * FROM Railway", &[]);
389        test("SELECT * FROM Foo WHERE Foo.id = Lab.foo_id", &["Foo"]);
390    }
391
392    #[test]
393    fn storage_err() {
394        let storage = run("
395            CREATE TABLE Foo (id INTEGER);
396            CREATE TABLE Bar (id INTEGER);
397            CREATE TABLE Baz (flag BOOLEAN);
398        ");
399
400        let test = |sql| assert!(plan(&storage, sql).is_err(), "{sql}");
401
402        test("SELECT * FROM __Err__");
403        test("INSERT INTO __Err__ VALUES (1), (2)");
404        test("DROP TABLE __Err__");
405
406        test("SELECT * FROM Foo WHERE id = (SELECT foo_id FROM __Err__ LIMIT 1)");
407        test("SELECT * FROM Foo WHERE (SELECT foo_id FROM __Err__ LIMIT 1) = id");
408        test("SELECT * FROM Foo WHERE id BETWEEN (SELECT foo_id FROM __Err__ LIMIT 1) AND 100");
409        test("SELECT * FROM Foo WHERE (SELECT id FROM __Err__ LIMIT 1) BETWEEN 20 AND 50");
410        test("SELECT * FROM Foo WHERE id IN (1, 2, (SELECT foo_id FROM __Err__ LIMIT 1), 5)");
411        test("SELECT * FROM Foo WHERE id IN (SELECT * FROM __Err__)");
412        test("SELECT * FROM Foo LEFT JOIN Bar ON Bar.id = (SELECT id FROM __Err__ LIMIT 1)");
413        test("SELECT id, (SELECT id FROM __Err__ LIMIT 1) AS cc FROM Foo;");
414    }
415}