gluesql_core/executor/
execute.rs

1use {
2    super::{
3        alter::{
4            CreateTableOptions, alter_table, create_index, create_table, delete_function,
5            drop_table, insert_function,
6        },
7        delete::delete,
8        fetch::fetch,
9        insert::insert,
10        select::{select, select_with_labels},
11        update::Update,
12        validate::{ColumnValidation, validate_unique},
13    },
14    crate::{
15        ast::{
16            BinaryOperator, DataType, Dictionary, Expr, Literal, Query, SelectItem, SetExpr,
17            Statement, TableAlias, TableFactor, TableWithJoins, Variable,
18        },
19        data::{Key, Row, Schema, Value},
20        result::Result,
21        store::{GStore, GStoreMut},
22    },
23    futures::stream::{StreamExt, TryStreamExt},
24    serde::{Deserialize, Serialize},
25    std::{
26        collections::{BTreeMap, HashMap},
27        env::var,
28        fmt::Debug,
29        sync::Arc,
30    },
31    thiserror::Error as ThisError,
32};
33
34#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
35pub enum ExecuteError {
36    #[error("table not found: {0}")]
37    TableNotFound(String),
38}
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
41pub enum Payload {
42    ShowColumns(Vec<(String, DataType)>),
43    Create,
44    Insert(usize),
45    Select {
46        labels: Vec<String>,
47        rows: Vec<Vec<Value>>,
48    },
49    SelectMap(Vec<BTreeMap<String, Value>>),
50    Delete(usize),
51    Update(usize),
52    DropTable(usize),
53    DropFunction,
54    AlterTable,
55    CreateIndex,
56    DropIndex,
57    StartTransaction,
58    Commit,
59    Rollback,
60    ShowVariable(PayloadVariable),
61}
62
63impl Payload {
64    /// Exports `select` payloads as an [`std::iter::Iterator`].
65    ///
66    /// The items of the Iterator are `HashMap<Column, Value>`, and they are borrowed by default.
67    /// If ownership is required, you need to acquire them directly.
68    ///
69    /// - Some: [`Payload::Select`], [`Payload::SelectMap`]
70    /// - None: otherwise
71    pub fn select(&self) -> Option<impl Iterator<Item = HashMap<&str, &Value>>> {
72        #[derive(iter_enum::Iterator)]
73        enum Iter<I1, I2> {
74            Schema(I1),
75            Schemaless(I2),
76        }
77
78        Some(match self {
79            Payload::Select { labels, rows } => Iter::Schema(rows.iter().map(move |row| {
80                labels
81                    .iter()
82                    .zip(row.iter())
83                    .map(|(label, value)| (label.as_str(), value))
84                    .collect::<HashMap<_, _>>()
85            })),
86            Payload::SelectMap(rows) => Iter::Schemaless(rows.iter().map(|row| {
87                row.iter()
88                    .map(|(k, v)| (k.as_str(), v))
89                    .collect::<HashMap<_, _>>()
90            })),
91            _ => return None,
92        })
93    }
94}
95
96#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
97pub enum PayloadVariable {
98    Tables(Vec<String>),
99    Functions(Vec<String>),
100    Version(String),
101}
102
103pub async fn execute<T: GStore + GStoreMut>(
104    storage: &mut T,
105    statement: &Statement,
106) -> Result<Payload> {
107    if matches!(
108        statement,
109        Statement::StartTransaction | Statement::Rollback | Statement::Commit
110    ) {
111        return execute_inner(storage, statement).await;
112    }
113
114    let autocommit = storage.begin(true).await?;
115    let result = execute_inner(storage, statement).await;
116
117    if !autocommit {
118        return result;
119    }
120
121    match result {
122        Ok(payload) => storage.commit().await.map(|()| payload),
123        Err(error) => {
124            storage.rollback().await?;
125
126            Err(error)
127        }
128    }
129}
130
131async fn execute_inner<T: GStore + GStoreMut>(
132    storage: &mut T,
133    statement: &Statement,
134) -> Result<Payload> {
135    match statement {
136        //- Modification
137        //-- Tables
138        Statement::CreateTable {
139            name,
140            columns,
141            if_not_exists,
142            source,
143            engine,
144            foreign_keys,
145            comment,
146        } => {
147            let options = CreateTableOptions {
148                target_table_name: name,
149                column_defs: columns.as_ref().map(Vec::as_slice),
150                if_not_exists: *if_not_exists,
151                source,
152                engine,
153                foreign_keys,
154                comment,
155            };
156
157            create_table(storage, options)
158                .await
159                .map(|()| Payload::Create)
160        }
161        Statement::DropTable {
162            names,
163            if_exists,
164            cascade,
165            ..
166        } => drop_table(storage, names, *if_exists, *cascade)
167            .await
168            .map(Payload::DropTable),
169        Statement::AlterTable { name, operation } => alter_table(storage, name, operation)
170            .await
171            .map(|()| Payload::AlterTable),
172        Statement::CreateIndex {
173            name,
174            table_name,
175            column,
176        } => create_index(storage, table_name, name, column)
177            .await
178            .map(|()| Payload::CreateIndex),
179        Statement::DropIndex { name, table_name } => storage
180            .drop_index(table_name, name)
181            .await
182            .map(|()| Payload::DropIndex),
183        //- Transaction
184        Statement::StartTransaction => storage
185            .begin(false)
186            .await
187            .map(|_| Payload::StartTransaction),
188        Statement::Commit => storage.commit().await.map(|()| Payload::Commit),
189        Statement::Rollback => storage.rollback().await.map(|()| Payload::Rollback),
190        //-- Rows
191        Statement::Insert {
192            table_name,
193            columns,
194            source,
195        } => insert(storage, table_name, columns, source)
196            .await
197            .map(Payload::Insert),
198        Statement::Update {
199            table_name,
200            selection,
201            assignments,
202        } => {
203            let Schema {
204                column_defs,
205                foreign_keys,
206                ..
207            } = storage
208                .fetch_schema(table_name)
209                .await?
210                .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
211
212            let all_columns = column_defs
213                .as_deref()
214                .map(|columns| columns.iter().map(|col_def| col_def.name.clone()).collect());
215            let columns_to_update: Vec<String> = assignments
216                .iter()
217                .map(|assignment| assignment.id.clone())
218                .collect();
219
220            let update = Update::new(storage, table_name, assignments, column_defs.as_deref())?;
221
222            let foreign_keys = Arc::new(foreign_keys);
223
224            let rows = fetch(storage, table_name, all_columns, selection.as_ref())
225                .await?
226                .and_then(|item| {
227                    let update = &update;
228                    let (key, row) = item;
229
230                    let foreign_keys = Arc::clone(&foreign_keys);
231                    async move {
232                        let row = update.apply(row, foreign_keys.as_ref()).await?;
233
234                        Ok((key, row))
235                    }
236                })
237                .try_collect::<Vec<(Key, Row)>>()
238                .await?;
239
240            if let Some(column_defs) = column_defs {
241                let column_validation =
242                    ColumnValidation::SpecifiedColumns(&column_defs, columns_to_update);
243                let rows = rows.iter().filter_map(|(_, row)| match row {
244                    Row::Vec { values, .. } => Some(values.as_slice()),
245                    Row::Map(_) => None,
246                });
247
248                validate_unique(storage, table_name, column_validation, rows).await?;
249            }
250
251            let num_rows = rows.len();
252            let rows = rows
253                .into_iter()
254                .map(|(key, row)| (key, row.into()))
255                .collect();
256
257            storage
258                .insert_data(table_name, rows)
259                .await
260                .map(|()| Payload::Update(num_rows))
261        }
262        Statement::Delete {
263            table_name,
264            selection,
265        } => delete(storage, table_name, selection.as_ref()).await,
266
267        //- Selection
268        Statement::Query(query) => {
269            let (labels, rows) = select_with_labels(storage, query, None).await?;
270
271            match labels {
272                Some(labels) => rows
273                    .map(|row| row?.try_into_vec())
274                    .try_collect::<Vec<_>>()
275                    .await
276                    .map(|rows| Payload::Select { labels, rows }),
277                None => rows
278                    .map(|row| row?.try_into_map())
279                    .try_collect::<Vec<_>>()
280                    .await
281                    .map(Payload::SelectMap),
282            }
283        }
284        Statement::ShowColumns { table_name } => {
285            let Schema { column_defs, .. } = storage
286                .fetch_schema(table_name)
287                .await?
288                .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
289
290            let output: Vec<(String, DataType)> = column_defs
291                .unwrap_or_default()
292                .into_iter()
293                .map(|key| (key.name, key.data_type))
294                .collect();
295
296            Ok(Payload::ShowColumns(output))
297        }
298        Statement::ShowIndexes(table_name) => {
299            let query = Query {
300                body: SetExpr::Select(Box::new(crate::ast::Select {
301                    distinct: false,
302                    projection: vec![SelectItem::Wildcard],
303                    from: TableWithJoins {
304                        relation: TableFactor::Dictionary {
305                            dict: Dictionary::GlueIndexes,
306                            alias: TableAlias {
307                                name: "GLUE_INDEXES".to_owned(),
308                                columns: Vec::new(),
309                            },
310                        },
311                        joins: Vec::new(),
312                    },
313                    selection: Some(Expr::BinaryOp {
314                        left: Box::new(Expr::Identifier("TABLE_NAME".to_owned())),
315                        op: BinaryOperator::Eq,
316                        right: Box::new(Expr::Literal(Literal::QuotedString(
317                            table_name.to_owned(),
318                        ))),
319                    }),
320                    group_by: Vec::new(),
321                    having: None,
322                })),
323                order_by: Vec::new(),
324                limit: None,
325                offset: None,
326            };
327
328            let (labels, rows) = select_with_labels(storage, &query, None).await?;
329            let labels = labels.unwrap_or_default();
330            let rows = rows
331                .map(|row| row?.try_into_vec())
332                .try_collect::<Vec<_>>()
333                .await?;
334
335            if rows.is_empty() {
336                return Err(ExecuteError::TableNotFound(table_name.to_owned()).into());
337            }
338
339            Ok(Payload::Select { labels, rows })
340        }
341        Statement::ShowVariable(variable) => match variable {
342            Variable::Tables => {
343                let query = Query {
344                    body: SetExpr::Select(Box::new(crate::ast::Select {
345                        distinct: false,
346                        projection: vec![SelectItem::Expr {
347                            expr: Expr::Identifier("TABLE_NAME".to_owned()),
348                            label: "TABLE_NAME".to_owned(),
349                        }],
350                        from: TableWithJoins {
351                            relation: TableFactor::Dictionary {
352                                dict: Dictionary::GlueTables,
353                                alias: TableAlias {
354                                    name: "GLUE_TABLES".to_owned(),
355                                    columns: Vec::new(),
356                                },
357                            },
358                            joins: Vec::new(),
359                        },
360                        selection: None,
361                        group_by: Vec::new(),
362                        having: None,
363                    })),
364                    order_by: Vec::new(),
365                    limit: None,
366                    offset: None,
367                };
368
369                let table_names = select(storage, &query, None)
370                    .await?
371                    .map(|row| row?.try_into_vec())
372                    .try_collect::<Vec<Vec<Value>>>()
373                    .await?
374                    .iter()
375                    .flat_map(|values| values.iter().map(Into::into))
376                    .collect::<Vec<_>>();
377
378                Ok(Payload::ShowVariable(PayloadVariable::Tables(table_names)))
379            }
380            Variable::Functions => {
381                let mut function_desc: Vec<_> = storage
382                    .fetch_all_functions()
383                    .await?
384                    .iter()
385                    .map(|f| f.to_str())
386                    .collect();
387                function_desc.sort();
388                Ok(Payload::ShowVariable(PayloadVariable::Functions(
389                    function_desc,
390                )))
391            }
392            Variable::Version => {
393                let version = var("CARGO_PKG_VERSION")
394                    .unwrap_or_else(|_| env!("CARGO_PKG_VERSION").to_owned());
395                let payload = Payload::ShowVariable(PayloadVariable::Version(version));
396
397                Ok(payload)
398            }
399        },
400        Statement::CreateFunction {
401            or_replace,
402            name,
403            args,
404            return_,
405        } => insert_function(storage, name, args, *or_replace, return_)
406            .await
407            .map(|()| Payload::Create),
408        Statement::DropFunction { if_exists, names } => delete_function(storage, names, *if_exists)
409            .await
410            .map(|()| Payload::DropFunction),
411    }
412}