gluesql_core/executor/
insert.rs

1use {
2    super::{
3        select::select,
4        validate::{ColumnValidation, validate_unique},
5    },
6    crate::{
7        ast::{ColumnDef, ColumnUniqueOption, Expr, ForeignKey, Query, SetExpr, Values},
8        data::{Key, Row, Schema, Value},
9        executor::{evaluate::evaluate_stateless, limit::Limit},
10        result::Result,
11        store::{DataRow, GStore, GStoreMut},
12    },
13    futures::stream::{self, StreamExt, TryStreamExt},
14    serde::Serialize,
15    std::{fmt::Debug, sync::Arc},
16    thiserror::Error as ThisError,
17};
18
19#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
20pub enum InsertError {
21    #[error("table not found: {0}")]
22    TableNotFound(String),
23
24    #[error("lack of required column: {0}")]
25    LackOfRequiredColumn(String),
26
27    #[error("wrong column name: {0}")]
28    WrongColumnName(String),
29
30    #[error("column and values not matched")]
31    ColumnAndValuesNotMatched,
32
33    #[error("literals have more values than target columns")]
34    TooManyValues,
35
36    #[error("only single value accepted for schemaless row insert")]
37    OnlySingleValueAcceptedForSchemalessRow,
38
39    #[error("map type required: {0}")]
40    MapTypeValueRequired(String),
41
42    #[error(
43        "cannot find referenced value on {table_name}.{column_name} with value {referenced_value:?}"
44    )]
45    CannotFindReferencedValue {
46        table_name: String,
47        column_name: String,
48        referenced_value: String,
49    },
50
51    #[error("unreachable referencing column name: {0}")]
52    ConflictReferencingColumnName(String),
53}
54
55enum RowsData {
56    Append(Vec<DataRow>),
57    Insert(Vec<(Key, DataRow)>),
58}
59
60pub async fn insert<T: GStore + GStoreMut>(
61    storage: &mut T,
62    table_name: &str,
63    columns: &[String],
64    source: &Query,
65) -> Result<usize> {
66    let Schema {
67        column_defs,
68        foreign_keys,
69        ..
70    } = storage
71        .fetch_schema(table_name)
72        .await?
73        .ok_or_else(|| InsertError::TableNotFound(table_name.to_owned()))?;
74
75    let rows = match column_defs {
76        Some(column_defs) => {
77            fetch_vec_rows(
78                storage,
79                table_name,
80                column_defs,
81                columns,
82                source,
83                foreign_keys,
84            )
85            .await
86        }
87        None => fetch_map_rows(storage, source).await.map(RowsData::Append),
88    }?;
89
90    match rows {
91        RowsData::Append(rows) => {
92            let num_rows = rows.len();
93
94            storage
95                .append_data(table_name, rows)
96                .await
97                .map(|_| num_rows)
98        }
99        RowsData::Insert(rows) => {
100            let num_rows = rows.len();
101
102            storage
103                .insert_data(table_name, rows)
104                .await
105                .map(|_| num_rows)
106        }
107    }
108}
109
110async fn fetch_vec_rows<T: GStore>(
111    storage: &T,
112    table_name: &str,
113    column_defs: Vec<ColumnDef>,
114    columns: &[String],
115    source: &Query,
116    foreign_keys: Vec<ForeignKey>,
117) -> Result<RowsData> {
118    let labels = Arc::from(
119        column_defs
120            .iter()
121            .map(|column_def| column_def.name.to_owned())
122            .collect::<Vec<_>>(),
123    );
124    let column_defs = Arc::from(column_defs);
125    let column_validation = ColumnValidation::All(&column_defs);
126
127    #[derive(futures_enum::Stream)]
128    enum Rows<I1, I2> {
129        Values(I1),
130        Select(I2),
131    }
132
133    let rows = match &source.body {
134        SetExpr::Values(Values(values_list)) => {
135            let limit = Limit::new(source.limit.as_ref(), source.offset.as_ref()).await?;
136            let rows = stream::iter(values_list).then(|values| {
137                let column_defs = Arc::clone(&column_defs);
138                let labels = Arc::clone(&labels);
139
140                async move {
141                    Ok(Row::Vec {
142                        columns: labels,
143                        values: fill_values(&column_defs, columns, values).await?,
144                    })
145                }
146            });
147            let rows = limit.apply(rows);
148            let rows = rows.map(|row| row?.try_into_vec());
149
150            Rows::Values(rows)
151        }
152        SetExpr::Select(_) => {
153            let rows = select(storage, source, None).await?.map(|row| {
154                let values = row?.try_into_vec()?;
155
156                column_defs
157                    .iter()
158                    .zip(values.iter())
159                    .try_for_each(|(column_def, value)| {
160                        let ColumnDef {
161                            data_type,
162                            nullable,
163                            ..
164                        } = column_def;
165
166                        value.validate_type(data_type)?;
167                        value.validate_null(*nullable)
168                    })?;
169
170                Ok(values)
171            });
172
173            Rows::Select(rows)
174        }
175    }
176    .try_collect::<Vec<Vec<Value>>>()
177    .await?;
178
179    validate_unique(
180        storage,
181        table_name,
182        column_validation,
183        rows.iter().map(|values| values.as_slice()),
184    )
185    .await?;
186
187    validate_foreign_key(storage, &column_defs, foreign_keys, &rows).await?;
188
189    let primary_key = column_defs.iter().position(|ColumnDef { unique, .. }| {
190        unique == &Some(ColumnUniqueOption { is_primary: true })
191    });
192
193    match primary_key {
194        Some(i) => rows
195            .into_iter()
196            .filter_map(|values| {
197                values
198                    .get(i)
199                    .map(Key::try_from)
200                    .map(|result| result.map(|key| (key, values.into())))
201            })
202            .collect::<Result<Vec<_>>>()
203            .map(RowsData::Insert),
204        None => Ok(RowsData::Append(rows.into_iter().map(Into::into).collect())),
205    }
206}
207
208async fn validate_foreign_key<T: GStore>(
209    storage: &T,
210    column_defs: &Arc<[ColumnDef]>,
211    foreign_keys: Vec<ForeignKey>,
212    rows: &[Vec<Value>],
213) -> Result<()> {
214    for foreign_key in foreign_keys {
215        let ForeignKey {
216            referencing_column_name,
217            referenced_table_name,
218            referenced_column_name,
219            ..
220        } = &foreign_key;
221
222        let target_index = column_defs
223            .iter()
224            .enumerate()
225            .find(|(_, c)| &c.name == referencing_column_name)
226            .ok_or_else(|| {
227                InsertError::ConflictReferencingColumnName(referencing_column_name.to_owned())
228            })?;
229
230        for row in rows.iter() {
231            let value =
232                row.get(target_index.0)
233                    .ok_or(InsertError::ConflictReferencingColumnName(
234                        referencing_column_name.to_owned(),
235                    ))?;
236
237            if value == &Value::Null {
238                continue;
239            }
240
241            let no_referenced = storage
242                .fetch_data(referenced_table_name, &Key::try_from(value)?)
243                .await?
244                .is_none();
245
246            if no_referenced {
247                return Err(InsertError::CannotFindReferencedValue {
248                    table_name: referenced_table_name.to_owned(),
249                    column_name: referenced_column_name.to_owned(),
250                    referenced_value: String::from(value),
251                }
252                .into());
253            }
254        }
255    }
256
257    Ok(())
258}
259
260async fn fetch_map_rows<T: GStore>(storage: &T, source: &Query) -> Result<Vec<DataRow>> {
261    #[derive(futures_enum::Stream)]
262    enum Rows<I1, I2> {
263        Values(I1),
264        Select(I2),
265    }
266
267    let rows = match &source.body {
268        SetExpr::Values(Values(values_list)) => {
269            let limit = Limit::new(source.limit.as_ref(), source.offset.as_ref()).await?;
270            let rows = stream::iter(values_list).then(|values| async move {
271                if values.len() > 1 {
272                    return Err(InsertError::OnlySingleValueAcceptedForSchemalessRow.into());
273                }
274
275                evaluate_stateless(None, &values[0])
276                    .await?
277                    .try_into()
278                    .map(Row::Map)
279            });
280            let rows = limit.apply(rows);
281            let rows = rows.map_ok(Into::into);
282
283            Rows::Values(rows)
284        }
285        SetExpr::Select(_) => {
286            let rows = select(storage, source, None).await?.map(|row| {
287                let row = row?;
288
289                if let Row::Vec { values, .. } = &row {
290                    if values.len() > 1 {
291                        return Err(InsertError::OnlySingleValueAcceptedForSchemalessRow.into());
292                    } else if !matches!(&values[0], Value::Map(_)) {
293                        return Err(InsertError::MapTypeValueRequired((&values[0]).into()).into());
294                    }
295                }
296
297                Ok(row.into())
298            });
299
300            Rows::Select(rows)
301        }
302    }
303    .try_collect::<Vec<DataRow>>()
304    .await?;
305
306    Ok(rows)
307}
308
309async fn fill_values(
310    column_defs: &[ColumnDef],
311    columns: &[String],
312    values: &[Expr],
313) -> Result<Vec<Value>> {
314    if !columns.is_empty() && values.len() != columns.len() {
315        return Err(InsertError::ColumnAndValuesNotMatched.into());
316    } else if values.len() > column_defs.len() {
317        return Err(InsertError::TooManyValues.into());
318    }
319
320    if let Some(wrong_column_name) = columns.iter().find(|column_name| {
321        !column_defs
322            .iter()
323            .any(|column_def| &&column_def.name == column_name)
324    }) {
325        return Err(InsertError::WrongColumnName(wrong_column_name.to_owned()).into());
326    }
327
328    #[derive(iter_enum::Iterator)]
329    enum Columns<I1, I2> {
330        All(I1),
331        Specified(I2),
332    }
333
334    let columns = if columns.is_empty() {
335        Columns::All(column_defs.iter().map(|ColumnDef { name, .. }| name))
336    } else {
337        Columns::Specified(columns.iter())
338    };
339
340    let column_name_value_list = columns.zip(values.iter()).collect::<Vec<(_, _)>>();
341
342    let values = stream::iter(column_defs)
343        .then(|column_def| {
344            let column_name_value_list = &column_name_value_list;
345
346            async move {
347                let ColumnDef {
348                    name: def_name,
349                    data_type,
350                    nullable,
351                    ..
352                } = column_def;
353
354                let value = column_name_value_list
355                    .iter()
356                    .find(|(name, _)| name == &def_name)
357                    .map(|(_, value)| value);
358
359                match (value, &column_def.default, nullable) {
360                    (Some(&expr), _, _) | (None, Some(expr), _) => evaluate_stateless(None, expr)
361                        .await?
362                        .try_into_value(data_type, *nullable),
363                    (None, None, true) => Ok(Value::Null),
364                    (None, None, false) => {
365                        Err(InsertError::LackOfRequiredColumn(def_name.to_owned()).into())
366                    }
367                }
368            }
369        })
370        .try_collect::<Vec<Value>>()
371        .await?;
372
373    Ok(values)
374}