1use {
2 super::{
3 alter::{
4 CreateTableOptions, alter_table, create_index, create_table, delete_function,
5 drop_table, insert_function,
6 },
7 delete::delete,
8 fetch::fetch,
9 insert::insert,
10 select::{select, select_with_labels},
11 update::Update,
12 validate::{ColumnValidation, validate_unique},
13 },
14 crate::{
15 ast::{
16 BinaryOperator, DataType, Dictionary, Expr, Literal, Query, SelectItem, SetExpr,
17 Statement, TableAlias, TableFactor, TableWithJoins, Variable,
18 },
19 data::{Key, Row, Schema, Value},
20 result::Result,
21 store::{GStore, GStoreMut},
22 },
23 futures::stream::{StreamExt, TryStreamExt},
24 serde::{Deserialize, Serialize},
25 std::{
26 collections::{BTreeMap, HashMap},
27 env::var,
28 fmt::Debug,
29 sync::Arc,
30 },
31 thiserror::Error as ThisError,
32};
33
34#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
35pub enum ExecuteError {
36 #[error("table not found: {0}")]
37 TableNotFound(String),
38}
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
41pub enum Payload {
42 ShowColumns(Vec<(String, DataType)>),
43 Create,
44 Insert(usize),
45 Select {
46 labels: Vec<String>,
47 rows: Vec<Vec<Value>>,
48 },
49 SelectMap(Vec<BTreeMap<String, Value>>),
50 Delete(usize),
51 Update(usize),
52 DropTable(usize),
53 DropFunction,
54 AlterTable,
55 CreateIndex,
56 DropIndex,
57 StartTransaction,
58 Commit,
59 Rollback,
60 ShowVariable(PayloadVariable),
61}
62
63impl Payload {
64 pub fn select(&self) -> Option<impl Iterator<Item = HashMap<&str, &Value>>> {
72 #[derive(iter_enum::Iterator)]
73 enum Iter<I1, I2> {
74 Schema(I1),
75 Schemaless(I2),
76 }
77
78 Some(match self {
79 Payload::Select { labels, rows } => Iter::Schema(rows.iter().map(move |row| {
80 labels
81 .iter()
82 .zip(row.iter())
83 .map(|(label, value)| (label.as_str(), value))
84 .collect::<HashMap<_, _>>()
85 })),
86 Payload::SelectMap(rows) => Iter::Schemaless(rows.iter().map(|row| {
87 row.iter()
88 .map(|(k, v)| (k.as_str(), v))
89 .collect::<HashMap<_, _>>()
90 })),
91 _ => return None,
92 })
93 }
94}
95
96#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
97pub enum PayloadVariable {
98 Tables(Vec<String>),
99 Functions(Vec<String>),
100 Version(String),
101}
102
103pub async fn execute<T: GStore + GStoreMut>(
104 storage: &mut T,
105 statement: &Statement,
106) -> Result<Payload> {
107 if matches!(
108 statement,
109 Statement::StartTransaction | Statement::Rollback | Statement::Commit
110 ) {
111 return execute_inner(storage, statement).await;
112 }
113
114 let autocommit = storage.begin(true).await?;
115 let result = execute_inner(storage, statement).await;
116
117 if !autocommit {
118 return result;
119 }
120
121 match result {
122 Ok(payload) => storage.commit().await.map(|()| payload),
123 Err(error) => {
124 storage.rollback().await?;
125
126 Err(error)
127 }
128 }
129}
130
131async fn execute_inner<T: GStore + GStoreMut>(
132 storage: &mut T,
133 statement: &Statement,
134) -> Result<Payload> {
135 match statement {
136 Statement::CreateTable {
139 name,
140 columns,
141 if_not_exists,
142 source,
143 engine,
144 foreign_keys,
145 comment,
146 } => {
147 let options = CreateTableOptions {
148 target_table_name: name,
149 column_defs: columns.as_ref().map(Vec::as_slice),
150 if_not_exists: *if_not_exists,
151 source,
152 engine,
153 foreign_keys,
154 comment,
155 };
156
157 create_table(storage, options)
158 .await
159 .map(|()| Payload::Create)
160 }
161 Statement::DropTable {
162 names,
163 if_exists,
164 cascade,
165 ..
166 } => drop_table(storage, names, *if_exists, *cascade)
167 .await
168 .map(Payload::DropTable),
169 Statement::AlterTable { name, operation } => alter_table(storage, name, operation)
170 .await
171 .map(|()| Payload::AlterTable),
172 Statement::CreateIndex {
173 name,
174 table_name,
175 column,
176 } => create_index(storage, table_name, name, column)
177 .await
178 .map(|()| Payload::CreateIndex),
179 Statement::DropIndex { name, table_name } => storage
180 .drop_index(table_name, name)
181 .await
182 .map(|()| Payload::DropIndex),
183 Statement::StartTransaction => storage
185 .begin(false)
186 .await
187 .map(|_| Payload::StartTransaction),
188 Statement::Commit => storage.commit().await.map(|()| Payload::Commit),
189 Statement::Rollback => storage.rollback().await.map(|()| Payload::Rollback),
190 Statement::Insert {
192 table_name,
193 columns,
194 source,
195 } => insert(storage, table_name, columns, source)
196 .await
197 .map(Payload::Insert),
198 Statement::Update {
199 table_name,
200 selection,
201 assignments,
202 } => {
203 let Schema {
204 column_defs,
205 foreign_keys,
206 ..
207 } = storage
208 .fetch_schema(table_name)
209 .await?
210 .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
211
212 let all_columns = column_defs
213 .as_deref()
214 .map(|columns| columns.iter().map(|col_def| col_def.name.clone()).collect());
215 let columns_to_update: Vec<String> = assignments
216 .iter()
217 .map(|assignment| assignment.id.clone())
218 .collect();
219
220 let update = Update::new(storage, table_name, assignments, column_defs.as_deref())?;
221
222 let foreign_keys = Arc::new(foreign_keys);
223
224 let rows = fetch(storage, table_name, all_columns, selection.as_ref())
225 .await?
226 .and_then(|item| {
227 let update = &update;
228 let (key, row) = item;
229
230 let foreign_keys = Arc::clone(&foreign_keys);
231 async move {
232 let row = update.apply(row, foreign_keys.as_ref()).await?;
233
234 Ok((key, row))
235 }
236 })
237 .try_collect::<Vec<(Key, Row)>>()
238 .await?;
239
240 if let Some(column_defs) = column_defs {
241 let column_validation =
242 ColumnValidation::SpecifiedColumns(&column_defs, columns_to_update);
243 let rows = rows.iter().filter_map(|(_, row)| match row {
244 Row::Vec { values, .. } => Some(values.as_slice()),
245 Row::Map(_) => None,
246 });
247
248 validate_unique(storage, table_name, column_validation, rows).await?;
249 }
250
251 let num_rows = rows.len();
252 let rows = rows
253 .into_iter()
254 .map(|(key, row)| (key, row.into()))
255 .collect();
256
257 storage
258 .insert_data(table_name, rows)
259 .await
260 .map(|()| Payload::Update(num_rows))
261 }
262 Statement::Delete {
263 table_name,
264 selection,
265 } => delete(storage, table_name, selection.as_ref()).await,
266
267 Statement::Query(query) => {
269 let (labels, rows) = select_with_labels(storage, query, None).await?;
270
271 match labels {
272 Some(labels) => rows
273 .map(|row| row?.try_into_vec())
274 .try_collect::<Vec<_>>()
275 .await
276 .map(|rows| Payload::Select { labels, rows }),
277 None => rows
278 .map(|row| row?.try_into_map())
279 .try_collect::<Vec<_>>()
280 .await
281 .map(Payload::SelectMap),
282 }
283 }
284 Statement::ShowColumns { table_name } => {
285 let Schema { column_defs, .. } = storage
286 .fetch_schema(table_name)
287 .await?
288 .ok_or_else(|| ExecuteError::TableNotFound(table_name.to_owned()))?;
289
290 let output: Vec<(String, DataType)> = column_defs
291 .unwrap_or_default()
292 .into_iter()
293 .map(|key| (key.name, key.data_type))
294 .collect();
295
296 Ok(Payload::ShowColumns(output))
297 }
298 Statement::ShowIndexes(table_name) => {
299 let query = Query {
300 body: SetExpr::Select(Box::new(crate::ast::Select {
301 distinct: false,
302 projection: vec![SelectItem::Wildcard],
303 from: TableWithJoins {
304 relation: TableFactor::Dictionary {
305 dict: Dictionary::GlueIndexes,
306 alias: TableAlias {
307 name: "GLUE_INDEXES".to_owned(),
308 columns: Vec::new(),
309 },
310 },
311 joins: Vec::new(),
312 },
313 selection: Some(Expr::BinaryOp {
314 left: Box::new(Expr::Identifier("TABLE_NAME".to_owned())),
315 op: BinaryOperator::Eq,
316 right: Box::new(Expr::Literal(Literal::QuotedString(
317 table_name.to_owned(),
318 ))),
319 }),
320 group_by: Vec::new(),
321 having: None,
322 })),
323 order_by: Vec::new(),
324 limit: None,
325 offset: None,
326 };
327
328 let (labels, rows) = select_with_labels(storage, &query, None).await?;
329 let labels = labels.unwrap_or_default();
330 let rows = rows
331 .map(|row| row?.try_into_vec())
332 .try_collect::<Vec<_>>()
333 .await?;
334
335 if rows.is_empty() {
336 return Err(ExecuteError::TableNotFound(table_name.to_owned()).into());
337 }
338
339 Ok(Payload::Select { labels, rows })
340 }
341 Statement::ShowVariable(variable) => match variable {
342 Variable::Tables => {
343 let query = Query {
344 body: SetExpr::Select(Box::new(crate::ast::Select {
345 distinct: false,
346 projection: vec![SelectItem::Expr {
347 expr: Expr::Identifier("TABLE_NAME".to_owned()),
348 label: "TABLE_NAME".to_owned(),
349 }],
350 from: TableWithJoins {
351 relation: TableFactor::Dictionary {
352 dict: Dictionary::GlueTables,
353 alias: TableAlias {
354 name: "GLUE_TABLES".to_owned(),
355 columns: Vec::new(),
356 },
357 },
358 joins: Vec::new(),
359 },
360 selection: None,
361 group_by: Vec::new(),
362 having: None,
363 })),
364 order_by: Vec::new(),
365 limit: None,
366 offset: None,
367 };
368
369 let table_names = select(storage, &query, None)
370 .await?
371 .map(|row| row?.try_into_vec())
372 .try_collect::<Vec<Vec<Value>>>()
373 .await?
374 .iter()
375 .flat_map(|values| values.iter().map(Into::into))
376 .collect::<Vec<_>>();
377
378 Ok(Payload::ShowVariable(PayloadVariable::Tables(table_names)))
379 }
380 Variable::Functions => {
381 let mut function_desc: Vec<_> = storage
382 .fetch_all_functions()
383 .await?
384 .iter()
385 .map(|f| f.to_str())
386 .collect();
387 function_desc.sort();
388 Ok(Payload::ShowVariable(PayloadVariable::Functions(
389 function_desc,
390 )))
391 }
392 Variable::Version => {
393 let version = var("CARGO_PKG_VERSION")
394 .unwrap_or_else(|_| env!("CARGO_PKG_VERSION").to_owned());
395 let payload = Payload::ShowVariable(PayloadVariable::Version(version));
396
397 Ok(payload)
398 }
399 },
400 Statement::CreateFunction {
401 or_replace,
402 name,
403 args,
404 return_,
405 } => insert_function(storage, name, args, *or_replace, return_)
406 .await
407 .map(|()| Payload::Create),
408 Statement::DropFunction { if_exists, names } => delete_function(storage, names, *if_exists)
409 .await
410 .map(|()| Payload::DropFunction),
411 }
412}