1use {
2 super::{
3 alter::{
4 CreateTableOptions, alter_table, create_index, create_table, delete_function,
5 drop_table, insert_function,
6 },
7 delete::delete,
8 fetch::fetch,
9 insert::insert,
10 select::{select, select_with_labels},
11 update::Update,
12 validate::{ColumnValidation, validate_unique},
13 },
14 crate::{
15 ast::{
16 AstLiteral, BinaryOperator, DataType, Dictionary, Expr, Query, SelectItem, SetExpr,
17 Statement, TableAlias, TableFactor, TableWithJoins, Variable,
18 },
19 data::{Key, Row, Schema, Value},
20 result::Result,
21 store::{GStore, GStoreMut},
22 },
23 futures::stream::{StreamExt, TryStreamExt},
24 serde::{Deserialize, Serialize},
25 std::{
26 collections::{BTreeMap, HashMap},
27 env::var,
28 fmt::Debug,
29 sync::Arc,
30 },
31 thiserror::Error as ThisError,
32};
33
34#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
35pub enum ExecuteError {
36 #[error("table not found: {0}")]
37 TableNotFound(String),
38}
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
41pub enum Payload {
42 ShowColumns(Vec<(String, DataType)>),
43 Create,
44 Insert(usize),
45 Select {
46 labels: Vec<String>,
47 rows: Vec<Vec<Value>>,
48 },
49 SelectMap(Vec<BTreeMap<String, Value>>),
50 Delete(usize),
51 Update(usize),
52 DropTable(usize),
53 DropFunction,
54 AlterTable,
55 CreateIndex,
56 DropIndex,
57 StartTransaction,
58 Commit,
59 Rollback,
60 ShowVariable(PayloadVariable),
61}
62
63impl Payload {
64 pub fn select(&self) -> Option<impl Iterator<Item = HashMap<&str, &Value>>> {
72 #[derive(iter_enum::Iterator)]
73 enum Iter<I1, I2> {
74 Schema(I1),
75 Schemaless(I2),
76 }
77
78 Some(match self {
79 Payload::Select { labels, rows } => Iter::Schema(rows.iter().map(move |row| {
80 labels
81 .iter()
82 .zip(row.iter())
83 .map(|(label, value)| (label.as_str(), value))
84 .collect::<HashMap<_, _>>()
85 })),
86 Payload::SelectMap(rows) => Iter::Schemaless(rows.iter().map(|row| {
87 row.iter()
88 .map(|(k, v)| (k.as_str(), v))
89 .collect::<HashMap<_, _>>()
90 })),
91 _ => return None,
92 })
93 }
94}
95
96#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
97pub enum PayloadVariable {
98 Tables(Vec<String>),
99 Functions(Vec<String>),
100 Version(String),
101}
102
103pub async fn execute<T: GStore + GStoreMut>(
104 storage: &mut T,
105 statement: &Statement,
106) -> Result<Payload> {
107 if matches!(
108 statement,
109 Statement::StartTransaction | Statement::Rollback | Statement::Commit
110 ) {
111 return execute_inner(storage, statement).await;
112 }
113
114 let autocommit = storage.begin(true).await?;
115 let result = execute_inner(storage, statement).await;
116
117 if !autocommit {
118 return result;
119 }
120
121 match result {
122 Ok(payload) => storage.commit().await.map(|_| payload),
123 Err(error) => {
124 storage.rollback().await?;
125
126 Err(error)
127 }
128 }
129}
130
131async fn execute_inner<T: GStore + GStoreMut>(
132 storage: &mut T,
133 statement: &Statement,
134) -> Result<Payload> {
135 match statement {
136 Statement::CreateTable {
139 name,
140 columns,
141 if_not_exists,
142 source,
143 engine,
144 foreign_keys,
145 comment,
146 } => {
147 let options = CreateTableOptions {
148 target_table_name: name,
149 column_defs: columns.as_ref().map(Vec::as_slice),
150 if_not_exists: *if_not_exists,
151 source,
152 engine,
153 foreign_keys,
154 comment,
155 };
156
157 create_table(storage, options)
158 .await
159 .map(|_| Payload::Create)
160 }
161 Statement::DropTable {
162 names,
163 if_exists,
164 cascade,
165 ..
166 } => drop_table(storage, names, *if_exists, *cascade)
167 .await
168 .map(Payload::DropTable),
169 Statement::AlterTable { name, operation } => alter_table(storage, name, operation)
170 .await
171 .map(|_| Payload::AlterTable),
172 Statement::CreateIndex {
173 name,
174 table_name,
175 column,
176 } => create_index(storage, table_name, name, column)
177 .await
178 .map(|_| Payload::CreateIndex),
179 Statement::DropIndex { name, table_name } => storage
180 .drop_index(table_name, name)
181 .await
182 .map(|_| Payload::DropIndex),
183 Statement::StartTransaction => storage
185 .begin(false)
186 .await
187 .map(|_| Payload::StartTransaction),
188 Statement::Commit => storage.commit().await.map(|_| Payload::Commit),
189 Statement::Rollback => storage.rollback().await.map(|_| Payload::Rollback),
190 Statement::Insert {
192 table_name,
193 columns,
194 source,
195 } => insert(storage, table_name, columns, source)
196 .await
197 .map(Payload::Insert),
198 Statement::Update {
199 table_name,
200 selection,
201 assignments,
202 } => {
203 let Schema {
204 column_defs,
205 foreign_keys,
206 ..
207 } = storage
208 .fetch_schema(table_name)
209 .await?
210 .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
211
212 let all_columns = column_defs.as_deref().map(|columns| {
213 columns
214 .iter()
215 .map(|col_def| col_def.name.to_owned())
216 .collect()
217 });
218 let columns_to_update: Vec<String> = assignments
219 .iter()
220 .map(|assignment| assignment.id.to_owned())
221 .collect();
222
223 let update = Update::new(storage, table_name, assignments, column_defs.as_deref())?;
224
225 let foreign_keys = Arc::new(foreign_keys);
226
227 let rows = fetch(storage, table_name, all_columns, selection.as_ref())
228 .await?
229 .and_then(|item| {
230 let update = &update;
231 let (key, row) = item;
232
233 let foreign_keys = Arc::clone(&foreign_keys);
234 async move {
235 let row = update.apply(row, foreign_keys.as_ref()).await?;
236
237 Ok((key, row))
238 }
239 })
240 .try_collect::<Vec<(Key, Row)>>()
241 .await?;
242
243 if let Some(column_defs) = column_defs {
244 let column_validation =
245 ColumnValidation::SpecifiedColumns(&column_defs, columns_to_update);
246 let rows = rows.iter().filter_map(|(_, row)| match row {
247 Row::Vec { values, .. } => Some(values.as_slice()),
248 Row::Map(_) => None,
249 });
250
251 validate_unique(storage, table_name, column_validation, rows).await?;
252 }
253
254 let num_rows = rows.len();
255 let rows = rows
256 .into_iter()
257 .map(|(key, row)| (key, row.into()))
258 .collect();
259
260 storage
261 .insert_data(table_name, rows)
262 .await
263 .map(|_| Payload::Update(num_rows))
264 }
265 Statement::Delete {
266 table_name,
267 selection,
268 } => delete(storage, table_name, selection).await,
269
270 Statement::Query(query) => {
272 let (labels, rows) = select_with_labels(storage, query, None).await?;
273
274 match labels {
275 Some(labels) => rows
276 .map(|row| row?.try_into_vec())
277 .try_collect::<Vec<_>>()
278 .await
279 .map(|rows| Payload::Select { labels, rows }),
280 None => rows
281 .map(|row| row?.try_into_map())
282 .try_collect::<Vec<_>>()
283 .await
284 .map(Payload::SelectMap),
285 }
286 }
287 Statement::ShowColumns { table_name } => {
288 let Schema { column_defs, .. } = storage
289 .fetch_schema(table_name)
290 .await?
291 .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
292
293 let output: Vec<(String, DataType)> = column_defs
294 .unwrap_or_default()
295 .into_iter()
296 .map(|key| (key.name, key.data_type))
297 .collect();
298
299 Ok(Payload::ShowColumns(output))
300 }
301 Statement::ShowIndexes(table_name) => {
302 let query = Query {
303 body: SetExpr::Select(Box::new(crate::ast::Select {
304 distinct: false,
305 projection: vec![SelectItem::Wildcard],
306 from: TableWithJoins {
307 relation: TableFactor::Dictionary {
308 dict: Dictionary::GlueIndexes,
309 alias: TableAlias {
310 name: "GLUE_INDEXES".to_owned(),
311 columns: Vec::new(),
312 },
313 },
314 joins: Vec::new(),
315 },
316 selection: Some(Expr::BinaryOp {
317 left: Box::new(Expr::Identifier("TABLE_NAME".to_owned())),
318 op: BinaryOperator::Eq,
319 right: Box::new(Expr::Literal(AstLiteral::QuotedString(
320 table_name.to_owned(),
321 ))),
322 }),
323 group_by: Vec::new(),
324 having: None,
325 })),
326 order_by: Vec::new(),
327 limit: None,
328 offset: None,
329 };
330
331 let (labels, rows) = select_with_labels(storage, &query, None).await?;
332 let labels = labels.unwrap_or_default();
333 let rows = rows
334 .map(|row| row?.try_into_vec())
335 .try_collect::<Vec<_>>()
336 .await?;
337
338 if rows.is_empty() {
339 return Err(ExecuteError::TableNotFound(table_name.to_owned()).into());
340 }
341
342 Ok(Payload::Select { labels, rows })
343 }
344 Statement::ShowVariable(variable) => match variable {
345 Variable::Tables => {
346 let query = Query {
347 body: SetExpr::Select(Box::new(crate::ast::Select {
348 distinct: false,
349 projection: vec![SelectItem::Expr {
350 expr: Expr::Identifier("TABLE_NAME".to_owned()),
351 label: "TABLE_NAME".to_owned(),
352 }],
353 from: TableWithJoins {
354 relation: TableFactor::Dictionary {
355 dict: Dictionary::GlueTables,
356 alias: TableAlias {
357 name: "GLUE_TABLES".to_owned(),
358 columns: Vec::new(),
359 },
360 },
361 joins: Vec::new(),
362 },
363 selection: None,
364 group_by: Vec::new(),
365 having: None,
366 })),
367 order_by: Vec::new(),
368 limit: None,
369 offset: None,
370 };
371
372 let table_names = select(storage, &query, None)
373 .await?
374 .map(|row| row?.try_into_vec())
375 .try_collect::<Vec<Vec<Value>>>()
376 .await?
377 .iter()
378 .flat_map(|values| values.iter().map(|value| value.into()))
379 .collect::<Vec<_>>();
380
381 Ok(Payload::ShowVariable(PayloadVariable::Tables(table_names)))
382 }
383 Variable::Functions => {
384 let mut function_desc: Vec<_> = storage
385 .fetch_all_functions()
386 .await?
387 .iter()
388 .map(|f| f.to_str())
389 .collect();
390 function_desc.sort();
391 Ok(Payload::ShowVariable(PayloadVariable::Functions(
392 function_desc,
393 )))
394 }
395 Variable::Version => {
396 let version = var("CARGO_PKG_VERSION")
397 .unwrap_or_else(|_| env!("CARGO_PKG_VERSION").to_owned());
398 let payload = Payload::ShowVariable(PayloadVariable::Version(version));
399
400 Ok(payload)
401 }
402 },
403 Statement::CreateFunction {
404 or_replace,
405 name,
406 args,
407 return_,
408 } => insert_function(storage, name, args, *or_replace, return_)
409 .await
410 .map(|_| Payload::Create),
411 Statement::DropFunction { if_exists, names } => delete_function(storage, names, *if_exists)
412 .await
413 .map(|_| Payload::DropFunction),
414 }
415}