gluesql_core/executor/
execute.rs

1use {
2    super::{
3        alter::{
4            CreateTableOptions, alter_table, create_index, create_table, delete_function,
5            drop_table, insert_function,
6        },
7        delete::delete,
8        fetch::fetch,
9        insert::insert,
10        select::{select, select_with_labels},
11        update::Update,
12        validate::{ColumnValidation, validate_unique},
13    },
14    crate::{
15        ast::{
16            AstLiteral, BinaryOperator, DataType, Dictionary, Expr, Query, SelectItem, SetExpr,
17            Statement, TableAlias, TableFactor, TableWithJoins, Variable,
18        },
19        data::{Key, Row, Schema, Value},
20        result::Result,
21        store::{GStore, GStoreMut},
22    },
23    futures::stream::{StreamExt, TryStreamExt},
24    serde::{Deserialize, Serialize},
25    std::{
26        collections::{BTreeMap, HashMap},
27        env::var,
28        fmt::Debug,
29        sync::Arc,
30    },
31    thiserror::Error as ThisError,
32};
33
34#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
35pub enum ExecuteError {
36    #[error("table not found: {0}")]
37    TableNotFound(String),
38}
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
41pub enum Payload {
42    ShowColumns(Vec<(String, DataType)>),
43    Create,
44    Insert(usize),
45    Select {
46        labels: Vec<String>,
47        rows: Vec<Vec<Value>>,
48    },
49    SelectMap(Vec<BTreeMap<String, Value>>),
50    Delete(usize),
51    Update(usize),
52    DropTable(usize),
53    DropFunction,
54    AlterTable,
55    CreateIndex,
56    DropIndex,
57    StartTransaction,
58    Commit,
59    Rollback,
60    ShowVariable(PayloadVariable),
61}
62
63impl Payload {
64    /// Exports `select` payloads as an [`std::iter::Iterator`].
65    ///
66    /// The items of the Iterator are `HashMap<Column, Value>`, and they are borrowed by default.
67    /// If ownership is required, you need to acquire them directly.
68    ///
69    /// - Some: [`Payload::Select`], [`Payload::SelectMap`]
70    /// - None: otherwise
71    pub fn select(&self) -> Option<impl Iterator<Item = HashMap<&str, &Value>>> {
72        #[derive(iter_enum::Iterator)]
73        enum Iter<I1, I2> {
74            Schema(I1),
75            Schemaless(I2),
76        }
77
78        Some(match self {
79            Payload::Select { labels, rows } => Iter::Schema(rows.iter().map(move |row| {
80                labels
81                    .iter()
82                    .zip(row.iter())
83                    .map(|(label, value)| (label.as_str(), value))
84                    .collect::<HashMap<_, _>>()
85            })),
86            Payload::SelectMap(rows) => Iter::Schemaless(rows.iter().map(|row| {
87                row.iter()
88                    .map(|(k, v)| (k.as_str(), v))
89                    .collect::<HashMap<_, _>>()
90            })),
91            _ => return None,
92        })
93    }
94}
95
96#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
97pub enum PayloadVariable {
98    Tables(Vec<String>),
99    Functions(Vec<String>),
100    Version(String),
101}
102
103pub async fn execute<T: GStore + GStoreMut>(
104    storage: &mut T,
105    statement: &Statement,
106) -> Result<Payload> {
107    if matches!(
108        statement,
109        Statement::StartTransaction | Statement::Rollback | Statement::Commit
110    ) {
111        return execute_inner(storage, statement).await;
112    }
113
114    let autocommit = storage.begin(true).await?;
115    let result = execute_inner(storage, statement).await;
116
117    if !autocommit {
118        return result;
119    }
120
121    match result {
122        Ok(payload) => storage.commit().await.map(|_| payload),
123        Err(error) => {
124            storage.rollback().await?;
125
126            Err(error)
127        }
128    }
129}
130
131async fn execute_inner<T: GStore + GStoreMut>(
132    storage: &mut T,
133    statement: &Statement,
134) -> Result<Payload> {
135    match statement {
136        //- Modification
137        //-- Tables
138        Statement::CreateTable {
139            name,
140            columns,
141            if_not_exists,
142            source,
143            engine,
144            foreign_keys,
145            comment,
146        } => {
147            let options = CreateTableOptions {
148                target_table_name: name,
149                column_defs: columns.as_ref().map(Vec::as_slice),
150                if_not_exists: *if_not_exists,
151                source,
152                engine,
153                foreign_keys,
154                comment,
155            };
156
157            create_table(storage, options)
158                .await
159                .map(|_| Payload::Create)
160        }
161        Statement::DropTable {
162            names,
163            if_exists,
164            cascade,
165            ..
166        } => drop_table(storage, names, *if_exists, *cascade)
167            .await
168            .map(Payload::DropTable),
169        Statement::AlterTable { name, operation } => alter_table(storage, name, operation)
170            .await
171            .map(|_| Payload::AlterTable),
172        Statement::CreateIndex {
173            name,
174            table_name,
175            column,
176        } => create_index(storage, table_name, name, column)
177            .await
178            .map(|_| Payload::CreateIndex),
179        Statement::DropIndex { name, table_name } => storage
180            .drop_index(table_name, name)
181            .await
182            .map(|_| Payload::DropIndex),
183        //- Transaction
184        Statement::StartTransaction => storage
185            .begin(false)
186            .await
187            .map(|_| Payload::StartTransaction),
188        Statement::Commit => storage.commit().await.map(|_| Payload::Commit),
189        Statement::Rollback => storage.rollback().await.map(|_| Payload::Rollback),
190        //-- Rows
191        Statement::Insert {
192            table_name,
193            columns,
194            source,
195        } => insert(storage, table_name, columns, source)
196            .await
197            .map(Payload::Insert),
198        Statement::Update {
199            table_name,
200            selection,
201            assignments,
202        } => {
203            let Schema {
204                column_defs,
205                foreign_keys,
206                ..
207            } = storage
208                .fetch_schema(table_name)
209                .await?
210                .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
211
212            let all_columns = column_defs.as_deref().map(|columns| {
213                columns
214                    .iter()
215                    .map(|col_def| col_def.name.to_owned())
216                    .collect()
217            });
218            let columns_to_update: Vec<String> = assignments
219                .iter()
220                .map(|assignment| assignment.id.to_owned())
221                .collect();
222
223            let update = Update::new(storage, table_name, assignments, column_defs.as_deref())?;
224
225            let foreign_keys = Arc::new(foreign_keys);
226
227            let rows = fetch(storage, table_name, all_columns, selection.as_ref())
228                .await?
229                .and_then(|item| {
230                    let update = &update;
231                    let (key, row) = item;
232
233                    let foreign_keys = Arc::clone(&foreign_keys);
234                    async move {
235                        let row = update.apply(row, foreign_keys.as_ref()).await?;
236
237                        Ok((key, row))
238                    }
239                })
240                .try_collect::<Vec<(Key, Row)>>()
241                .await?;
242
243            if let Some(column_defs) = column_defs {
244                let column_validation =
245                    ColumnValidation::SpecifiedColumns(&column_defs, columns_to_update);
246                let rows = rows.iter().filter_map(|(_, row)| match row {
247                    Row::Vec { values, .. } => Some(values.as_slice()),
248                    Row::Map(_) => None,
249                });
250
251                validate_unique(storage, table_name, column_validation, rows).await?;
252            }
253
254            let num_rows = rows.len();
255            let rows = rows
256                .into_iter()
257                .map(|(key, row)| (key, row.into()))
258                .collect();
259
260            storage
261                .insert_data(table_name, rows)
262                .await
263                .map(|_| Payload::Update(num_rows))
264        }
265        Statement::Delete {
266            table_name,
267            selection,
268        } => delete(storage, table_name, selection).await,
269
270        //- Selection
271        Statement::Query(query) => {
272            let (labels, rows) = select_with_labels(storage, query, None).await?;
273
274            match labels {
275                Some(labels) => rows
276                    .map(|row| row?.try_into_vec())
277                    .try_collect::<Vec<_>>()
278                    .await
279                    .map(|rows| Payload::Select { labels, rows }),
280                None => rows
281                    .map(|row| row?.try_into_map())
282                    .try_collect::<Vec<_>>()
283                    .await
284                    .map(Payload::SelectMap),
285            }
286        }
287        Statement::ShowColumns { table_name } => {
288            let Schema { column_defs, .. } = storage
289                .fetch_schema(table_name)
290                .await?
291                .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
292
293            let output: Vec<(String, DataType)> = column_defs
294                .unwrap_or_default()
295                .into_iter()
296                .map(|key| (key.name, key.data_type))
297                .collect();
298
299            Ok(Payload::ShowColumns(output))
300        }
301        Statement::ShowIndexes(table_name) => {
302            let query = Query {
303                body: SetExpr::Select(Box::new(crate::ast::Select {
304                    distinct: false,
305                    projection: vec![SelectItem::Wildcard],
306                    from: TableWithJoins {
307                        relation: TableFactor::Dictionary {
308                            dict: Dictionary::GlueIndexes,
309                            alias: TableAlias {
310                                name: "GLUE_INDEXES".to_owned(),
311                                columns: Vec::new(),
312                            },
313                        },
314                        joins: Vec::new(),
315                    },
316                    selection: Some(Expr::BinaryOp {
317                        left: Box::new(Expr::Identifier("TABLE_NAME".to_owned())),
318                        op: BinaryOperator::Eq,
319                        right: Box::new(Expr::Literal(AstLiteral::QuotedString(
320                            table_name.to_owned(),
321                        ))),
322                    }),
323                    group_by: Vec::new(),
324                    having: None,
325                })),
326                order_by: Vec::new(),
327                limit: None,
328                offset: None,
329            };
330
331            let (labels, rows) = select_with_labels(storage, &query, None).await?;
332            let labels = labels.unwrap_or_default();
333            let rows = rows
334                .map(|row| row?.try_into_vec())
335                .try_collect::<Vec<_>>()
336                .await?;
337
338            if rows.is_empty() {
339                return Err(ExecuteError::TableNotFound(table_name.to_owned()).into());
340            }
341
342            Ok(Payload::Select { labels, rows })
343        }
344        Statement::ShowVariable(variable) => match variable {
345            Variable::Tables => {
346                let query = Query {
347                    body: SetExpr::Select(Box::new(crate::ast::Select {
348                        distinct: false,
349                        projection: vec![SelectItem::Expr {
350                            expr: Expr::Identifier("TABLE_NAME".to_owned()),
351                            label: "TABLE_NAME".to_owned(),
352                        }],
353                        from: TableWithJoins {
354                            relation: TableFactor::Dictionary {
355                                dict: Dictionary::GlueTables,
356                                alias: TableAlias {
357                                    name: "GLUE_TABLES".to_owned(),
358                                    columns: Vec::new(),
359                                },
360                            },
361                            joins: Vec::new(),
362                        },
363                        selection: None,
364                        group_by: Vec::new(),
365                        having: None,
366                    })),
367                    order_by: Vec::new(),
368                    limit: None,
369                    offset: None,
370                };
371
372                let table_names = select(storage, &query, None)
373                    .await?
374                    .map(|row| row?.try_into_vec())
375                    .try_collect::<Vec<Vec<Value>>>()
376                    .await?
377                    .iter()
378                    .flat_map(|values| values.iter().map(|value| value.into()))
379                    .collect::<Vec<_>>();
380
381                Ok(Payload::ShowVariable(PayloadVariable::Tables(table_names)))
382            }
383            Variable::Functions => {
384                let mut function_desc: Vec<_> = storage
385                    .fetch_all_functions()
386                    .await?
387                    .iter()
388                    .map(|f| f.to_str())
389                    .collect();
390                function_desc.sort();
391                Ok(Payload::ShowVariable(PayloadVariable::Functions(
392                    function_desc,
393                )))
394            }
395            Variable::Version => {
396                let version = var("CARGO_PKG_VERSION")
397                    .unwrap_or_else(|_| env!("CARGO_PKG_VERSION").to_owned());
398                let payload = Payload::ShowVariable(PayloadVariable::Version(version));
399
400                Ok(payload)
401            }
402        },
403        Statement::CreateFunction {
404            or_replace,
405            name,
406            args,
407            return_,
408        } => insert_function(storage, name, args, *or_replace, return_)
409            .await
410            .map(|_| Payload::Create),
411        Statement::DropFunction { if_exists, names } => delete_function(storage, names, *if_exists)
412            .await
413            .map(|_| Payload::DropFunction),
414    }
415}