gluesql_core/executor/
execute.rs

1use {
2    super::{
3        alter::{
4            alter_table, create_index, create_table, delete_function, drop_table, insert_function,
5            CreateTableOptions,
6        },
7        delete::delete,
8        fetch::fetch,
9        insert::insert,
10        select::{select, select_with_labels},
11        update::Update,
12        validate::{validate_unique, ColumnValidation},
13    },
14    crate::{
15        ast::{
16            AstLiteral, BinaryOperator, DataType, Dictionary, Expr, Query, SelectItem, SetExpr,
17            Statement, TableAlias, TableFactor, TableWithJoins, Variable,
18        },
19        data::{Key, Row, Schema, Value},
20        result::Result,
21        store::{GStore, GStoreMut},
22    },
23    futures::stream::{StreamExt, TryStreamExt},
24    serde::{Deserialize, Serialize},
25    std::{collections::HashMap, env::var, fmt::Debug, rc::Rc},
26    thiserror::Error as ThisError,
27};
28
29#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
30pub enum ExecuteError {
31    #[error("table not found: {0}")]
32    TableNotFound(String),
33}
34
35#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
36pub enum Payload {
37    ShowColumns(Vec<(String, DataType)>),
38    Create,
39    Insert(usize),
40    Select {
41        labels: Vec<String>,
42        rows: Vec<Vec<Value>>,
43    },
44    SelectMap(Vec<HashMap<String, Value>>),
45    Delete(usize),
46    Update(usize),
47    DropTable(usize),
48    DropFunction,
49    AlterTable,
50    CreateIndex,
51    DropIndex,
52    StartTransaction,
53    Commit,
54    Rollback,
55    ShowVariable(PayloadVariable),
56}
57
58impl Payload {
59    /// Exports `select` payloads as an [`std::iter::Iterator`].
60    ///
61    /// The items of the Iterator are `HashMap<Column, Value>`, and they are borrowed by default.
62    /// If ownership is required, you need to acquire them directly.
63    ///
64    /// - Some: [`Payload::Select`], [`Payload::SelectMap`]
65    /// - None: otherwise
66    pub fn select(&self) -> Option<impl Iterator<Item = HashMap<&str, &Value>>> {
67        #[derive(iter_enum::Iterator)]
68        enum Iter<I1, I2> {
69            Schema(I1),
70            Schemaless(I2),
71        }
72
73        Some(match self {
74            Payload::Select { labels, rows } => Iter::Schema(rows.iter().map(move |row| {
75                labels
76                    .iter()
77                    .zip(row.iter())
78                    .map(|(label, value)| (label.as_str(), value))
79                    .collect::<HashMap<_, _>>()
80            })),
81            Payload::SelectMap(rows) => Iter::Schemaless(rows.iter().map(|row| {
82                row.iter()
83                    .map(|(k, v)| (k.as_str(), v))
84                    .collect::<HashMap<_, _>>()
85            })),
86            _ => return None,
87        })
88    }
89}
90
91#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
92pub enum PayloadVariable {
93    Tables(Vec<String>),
94    Functions(Vec<String>),
95    Version(String),
96}
97
98pub async fn execute<T: GStore + GStoreMut>(
99    storage: &mut T,
100    statement: &Statement,
101) -> Result<Payload> {
102    if matches!(
103        statement,
104        Statement::StartTransaction | Statement::Rollback | Statement::Commit
105    ) {
106        return execute_inner(storage, statement).await;
107    }
108
109    let autocommit = storage.begin(true).await?;
110    let result = execute_inner(storage, statement).await;
111
112    if !autocommit {
113        return result;
114    }
115
116    match result {
117        Ok(payload) => storage.commit().await.map(|_| payload),
118        Err(error) => {
119            storage.rollback().await?;
120
121            Err(error)
122        }
123    }
124}
125
126async fn execute_inner<T: GStore + GStoreMut>(
127    storage: &mut T,
128    statement: &Statement,
129) -> Result<Payload> {
130    match statement {
131        //- Modification
132        //-- Tables
133        Statement::CreateTable {
134            name,
135            columns,
136            if_not_exists,
137            source,
138            engine,
139            foreign_keys,
140            comment,
141        } => {
142            let options = CreateTableOptions {
143                target_table_name: name,
144                column_defs: columns.as_ref().map(Vec::as_slice),
145                if_not_exists: *if_not_exists,
146                source,
147                engine,
148                foreign_keys,
149                comment,
150            };
151
152            create_table(storage, options)
153                .await
154                .map(|_| Payload::Create)
155        }
156        Statement::DropTable {
157            names,
158            if_exists,
159            cascade,
160            ..
161        } => drop_table(storage, names, *if_exists, *cascade)
162            .await
163            .map(Payload::DropTable),
164        Statement::AlterTable { name, operation } => alter_table(storage, name, operation)
165            .await
166            .map(|_| Payload::AlterTable),
167        Statement::CreateIndex {
168            name,
169            table_name,
170            column,
171        } => create_index(storage, table_name, name, column)
172            .await
173            .map(|_| Payload::CreateIndex),
174        Statement::DropIndex { name, table_name } => storage
175            .drop_index(table_name, name)
176            .await
177            .map(|_| Payload::DropIndex),
178        //- Transaction
179        Statement::StartTransaction => storage
180            .begin(false)
181            .await
182            .map(|_| Payload::StartTransaction),
183        Statement::Commit => storage.commit().await.map(|_| Payload::Commit),
184        Statement::Rollback => storage.rollback().await.map(|_| Payload::Rollback),
185        //-- Rows
186        Statement::Insert {
187            table_name,
188            columns,
189            source,
190        } => insert(storage, table_name, columns, source)
191            .await
192            .map(Payload::Insert),
193        Statement::Update {
194            table_name,
195            selection,
196            assignments,
197        } => {
198            let Schema {
199                column_defs,
200                foreign_keys,
201                ..
202            } = storage
203                .fetch_schema(table_name)
204                .await?
205                .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
206
207            let all_columns = column_defs.as_deref().map(|columns| {
208                columns
209                    .iter()
210                    .map(|col_def| col_def.name.to_owned())
211                    .collect()
212            });
213            let columns_to_update: Vec<String> = assignments
214                .iter()
215                .map(|assignment| assignment.id.to_owned())
216                .collect();
217
218            let update = Update::new(storage, table_name, assignments, column_defs.as_deref())?;
219
220            let foreign_keys = Rc::new(foreign_keys);
221
222            let rows = fetch(storage, table_name, all_columns, selection.as_ref())
223                .await?
224                .and_then(|item| {
225                    let update = &update;
226                    let (key, row) = item;
227
228                    let foreign_keys = Rc::clone(&foreign_keys);
229                    async move {
230                        let row = update.apply(row, foreign_keys.as_ref()).await?;
231
232                        Ok((key, row))
233                    }
234                })
235                .try_collect::<Vec<(Key, Row)>>()
236                .await?;
237
238            if let Some(column_defs) = column_defs {
239                let column_validation =
240                    ColumnValidation::SpecifiedColumns(&column_defs, columns_to_update);
241                let rows = rows.iter().filter_map(|(_, row)| match row {
242                    Row::Vec { values, .. } => Some(values.as_slice()),
243                    Row::Map(_) => None,
244                });
245
246                validate_unique(storage, table_name, column_validation, rows).await?;
247            }
248
249            let num_rows = rows.len();
250            let rows = rows
251                .into_iter()
252                .map(|(key, row)| (key, row.into()))
253                .collect();
254
255            storage
256                .insert_data(table_name, rows)
257                .await
258                .map(|_| Payload::Update(num_rows))
259        }
260        Statement::Delete {
261            table_name,
262            selection,
263        } => delete(storage, table_name, selection).await,
264
265        //- Selection
266        Statement::Query(query) => {
267            let (labels, rows) = select_with_labels(storage, query, None).await?;
268
269            match labels {
270                Some(labels) => rows
271                    .map(|row| row?.try_into_vec())
272                    .try_collect::<Vec<_>>()
273                    .await
274                    .map(|rows| Payload::Select { labels, rows }),
275                None => rows
276                    .map(|row| row?.try_into_map())
277                    .try_collect::<Vec<_>>()
278                    .await
279                    .map(Payload::SelectMap),
280            }
281        }
282        Statement::ShowColumns { table_name } => {
283            let Schema { column_defs, .. } = storage
284                .fetch_schema(table_name)
285                .await?
286                .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
287
288            let output: Vec<(String, DataType)> = column_defs
289                .unwrap_or_default()
290                .into_iter()
291                .map(|key| (key.name, key.data_type))
292                .collect();
293
294            Ok(Payload::ShowColumns(output))
295        }
296        Statement::ShowIndexes(table_name) => {
297            let query = Query {
298                body: SetExpr::Select(Box::new(crate::ast::Select {
299                    projection: vec![SelectItem::Wildcard],
300                    from: TableWithJoins {
301                        relation: TableFactor::Dictionary {
302                            dict: Dictionary::GlueIndexes,
303                            alias: TableAlias {
304                                name: "GLUE_INDEXES".to_owned(),
305                                columns: Vec::new(),
306                            },
307                        },
308                        joins: Vec::new(),
309                    },
310                    selection: Some(Expr::BinaryOp {
311                        left: Box::new(Expr::Identifier("TABLE_NAME".to_owned())),
312                        op: BinaryOperator::Eq,
313                        right: Box::new(Expr::Literal(AstLiteral::QuotedString(
314                            table_name.to_owned(),
315                        ))),
316                    }),
317                    group_by: Vec::new(),
318                    having: None,
319                })),
320                order_by: Vec::new(),
321                limit: None,
322                offset: None,
323            };
324
325            let (labels, rows) = select_with_labels(storage, &query, None).await?;
326            let labels = labels.unwrap_or_default();
327            let rows = rows
328                .map(|row| row?.try_into_vec())
329                .try_collect::<Vec<_>>()
330                .await?;
331
332            if rows.is_empty() {
333                return Err(ExecuteError::TableNotFound(table_name.to_owned()).into());
334            }
335
336            Ok(Payload::Select { labels, rows })
337        }
338        Statement::ShowVariable(variable) => match variable {
339            Variable::Tables => {
340                let query = Query {
341                    body: SetExpr::Select(Box::new(crate::ast::Select {
342                        projection: vec![SelectItem::Expr {
343                            expr: Expr::Identifier("TABLE_NAME".to_owned()),
344                            label: "TABLE_NAME".to_owned(),
345                        }],
346                        from: TableWithJoins {
347                            relation: TableFactor::Dictionary {
348                                dict: Dictionary::GlueTables,
349                                alias: TableAlias {
350                                    name: "GLUE_TABLES".to_owned(),
351                                    columns: Vec::new(),
352                                },
353                            },
354                            joins: Vec::new(),
355                        },
356                        selection: None,
357                        group_by: Vec::new(),
358                        having: None,
359                    })),
360                    order_by: Vec::new(),
361                    limit: None,
362                    offset: None,
363                };
364
365                let table_names = select(storage, &query, None)
366                    .await?
367                    .map(|row| row?.try_into_vec())
368                    .try_collect::<Vec<Vec<Value>>>()
369                    .await?
370                    .iter()
371                    .flat_map(|values| values.iter().map(|value| value.into()))
372                    .collect::<Vec<_>>();
373
374                Ok(Payload::ShowVariable(PayloadVariable::Tables(table_names)))
375            }
376            Variable::Functions => {
377                let mut function_desc: Vec<_> = storage
378                    .fetch_all_functions()
379                    .await?
380                    .iter()
381                    .map(|f| f.to_str())
382                    .collect();
383                function_desc.sort();
384                Ok(Payload::ShowVariable(PayloadVariable::Functions(
385                    function_desc,
386                )))
387            }
388            Variable::Version => {
389                let version = var("CARGO_PKG_VERSION")
390                    .unwrap_or_else(|_| env!("CARGO_PKG_VERSION").to_owned());
391                let payload = Payload::ShowVariable(PayloadVariable::Version(version));
392
393                Ok(payload)
394            }
395        },
396        Statement::CreateFunction {
397            or_replace,
398            name,
399            args,
400            return_,
401        } => insert_function(storage, name, args, *or_replace, return_)
402            .await
403            .map(|_| Payload::Create),
404        Statement::DropFunction { if_exists, names } => delete_function(storage, names, *if_exists)
405            .await
406            .map(|_| Payload::DropFunction),
407    }
408}