1use {
2 super::{
3 select::select,
4 validate::{ColumnValidation, validate_unique},
5 },
6 crate::{
7 ast::{ColumnDef, ColumnUniqueOption, Expr, ForeignKey, Query, SetExpr, Values},
8 data::{Key, Row, Schema, Value},
9 executor::{evaluate::evaluate_stateless, limit::Limit},
10 result::Result,
11 store::{DataRow, GStore, GStoreMut},
12 },
13 futures::stream::{self, StreamExt, TryStreamExt},
14 serde::Serialize,
15 std::{fmt::Debug, sync::Arc},
16 thiserror::Error as ThisError,
17};
18
19#[derive(ThisError, Serialize, Debug, PartialEq, Eq)]
20pub enum InsertError {
21 #[error("table not found: {0}")]
22 TableNotFound(String),
23
24 #[error("lack of required column: {0}")]
25 LackOfRequiredColumn(String),
26
27 #[error("wrong column name: {0}")]
28 WrongColumnName(String),
29
30 #[error("column and values not matched")]
31 ColumnAndValuesNotMatched,
32
33 #[error("literals have more values than target columns")]
34 TooManyValues,
35
36 #[error("only single value accepted for schemaless row insert")]
37 OnlySingleValueAcceptedForSchemalessRow,
38
39 #[error("map type required: {0}")]
40 MapTypeValueRequired(String),
41
42 #[error(
43 "cannot find referenced value on {table_name}.{column_name} with value {referenced_value:?}"
44 )]
45 CannotFindReferencedValue {
46 table_name: String,
47 column_name: String,
48 referenced_value: String,
49 },
50
51 #[error("unreachable referencing column name: {0}")]
52 ConflictReferencingColumnName(String),
53}
54
55enum RowsData {
56 Append(Vec<DataRow>),
57 Insert(Vec<(Key, DataRow)>),
58}
59
60pub async fn insert<T: GStore + GStoreMut>(
61 storage: &mut T,
62 table_name: &str,
63 columns: &[String],
64 source: &Query,
65) -> Result<usize> {
66 let Schema {
67 column_defs,
68 foreign_keys,
69 ..
70 } = storage
71 .fetch_schema(table_name)
72 .await?
73 .ok_or_else(|| InsertError::TableNotFound(table_name.to_owned()))?;
74
75 let rows = match column_defs {
76 Some(column_defs) => {
77 fetch_vec_rows(
78 storage,
79 table_name,
80 column_defs,
81 columns,
82 source,
83 foreign_keys,
84 )
85 .await
86 }
87 None => fetch_map_rows(storage, source).await.map(RowsData::Append),
88 }?;
89
90 match rows {
91 RowsData::Append(rows) => {
92 let num_rows = rows.len();
93
94 storage
95 .append_data(table_name, rows)
96 .await
97 .map(|_| num_rows)
98 }
99 RowsData::Insert(rows) => {
100 let num_rows = rows.len();
101
102 storage
103 .insert_data(table_name, rows)
104 .await
105 .map(|_| num_rows)
106 }
107 }
108}
109
110async fn fetch_vec_rows<T: GStore>(
111 storage: &T,
112 table_name: &str,
113 column_defs: Vec<ColumnDef>,
114 columns: &[String],
115 source: &Query,
116 foreign_keys: Vec<ForeignKey>,
117) -> Result<RowsData> {
118 let labels = Arc::from(
119 column_defs
120 .iter()
121 .map(|column_def| column_def.name.to_owned())
122 .collect::<Vec<_>>(),
123 );
124 let column_defs = Arc::from(column_defs);
125 let column_validation = ColumnValidation::All(&column_defs);
126
127 #[derive(futures_enum::Stream)]
128 enum Rows<I1, I2> {
129 Values(I1),
130 Select(I2),
131 }
132
133 let rows = match &source.body {
134 SetExpr::Values(Values(values_list)) => {
135 let limit = Limit::new(source.limit.as_ref(), source.offset.as_ref()).await?;
136 let rows = stream::iter(values_list).then(|values| {
137 let column_defs = Arc::clone(&column_defs);
138 let labels = Arc::clone(&labels);
139
140 async move {
141 Ok(Row::Vec {
142 columns: labels,
143 values: fill_values(&column_defs, columns, values).await?,
144 })
145 }
146 });
147 let rows = limit.apply(rows);
148 let rows = rows.map(|row| row?.try_into_vec());
149
150 Rows::Values(rows)
151 }
152 SetExpr::Select(_) => {
153 let rows = select(storage, source, None).await?.map(|row| {
154 let values = row?.try_into_vec()?;
155
156 column_defs
157 .iter()
158 .zip(values.iter())
159 .try_for_each(|(column_def, value)| {
160 let ColumnDef {
161 data_type,
162 nullable,
163 ..
164 } = column_def;
165
166 value.validate_type(data_type)?;
167 value.validate_null(*nullable)
168 })?;
169
170 Ok(values)
171 });
172
173 Rows::Select(rows)
174 }
175 }
176 .try_collect::<Vec<Vec<Value>>>()
177 .await?;
178
179 validate_unique(
180 storage,
181 table_name,
182 column_validation,
183 rows.iter().map(|values| values.as_slice()),
184 )
185 .await?;
186
187 validate_foreign_key(storage, &column_defs, foreign_keys, &rows).await?;
188
189 let primary_key = column_defs.iter().position(|ColumnDef { unique, .. }| {
190 unique == &Some(ColumnUniqueOption { is_primary: true })
191 });
192
193 match primary_key {
194 Some(i) => rows
195 .into_iter()
196 .filter_map(|values| {
197 values
198 .get(i)
199 .map(Key::try_from)
200 .map(|result| result.map(|key| (key, values.into())))
201 })
202 .collect::<Result<Vec<_>>>()
203 .map(RowsData::Insert),
204 None => Ok(RowsData::Append(rows.into_iter().map(Into::into).collect())),
205 }
206}
207
208async fn validate_foreign_key<T: GStore>(
209 storage: &T,
210 column_defs: &Arc<[ColumnDef]>,
211 foreign_keys: Vec<ForeignKey>,
212 rows: &[Vec<Value>],
213) -> Result<()> {
214 for foreign_key in foreign_keys {
215 let ForeignKey {
216 referencing_column_name,
217 referenced_table_name,
218 referenced_column_name,
219 ..
220 } = &foreign_key;
221
222 let target_index = column_defs
223 .iter()
224 .enumerate()
225 .find(|(_, c)| &c.name == referencing_column_name)
226 .ok_or_else(|| {
227 InsertError::ConflictReferencingColumnName(referencing_column_name.to_owned())
228 })?;
229
230 for row in rows.iter() {
231 let value =
232 row.get(target_index.0)
233 .ok_or(InsertError::ConflictReferencingColumnName(
234 referencing_column_name.to_owned(),
235 ))?;
236
237 if value == &Value::Null {
238 continue;
239 }
240
241 let no_referenced = storage
242 .fetch_data(referenced_table_name, &Key::try_from(value)?)
243 .await?
244 .is_none();
245
246 if no_referenced {
247 return Err(InsertError::CannotFindReferencedValue {
248 table_name: referenced_table_name.to_owned(),
249 column_name: referenced_column_name.to_owned(),
250 referenced_value: String::from(value),
251 }
252 .into());
253 }
254 }
255 }
256
257 Ok(())
258}
259
260async fn fetch_map_rows<T: GStore>(storage: &T, source: &Query) -> Result<Vec<DataRow>> {
261 #[derive(futures_enum::Stream)]
262 enum Rows<I1, I2> {
263 Values(I1),
264 Select(I2),
265 }
266
267 let rows = match &source.body {
268 SetExpr::Values(Values(values_list)) => {
269 let limit = Limit::new(source.limit.as_ref(), source.offset.as_ref()).await?;
270 let rows = stream::iter(values_list).then(|values| async move {
271 if values.len() > 1 {
272 return Err(InsertError::OnlySingleValueAcceptedForSchemalessRow.into());
273 }
274
275 evaluate_stateless(None, &values[0])
276 .await?
277 .try_into()
278 .map(Row::Map)
279 });
280 let rows = limit.apply(rows);
281 let rows = rows.map_ok(Into::into);
282
283 Rows::Values(rows)
284 }
285 SetExpr::Select(_) => {
286 let rows = select(storage, source, None).await?.map(|row| {
287 let row = row?;
288
289 if let Row::Vec { values, .. } = &row {
290 if values.len() > 1 {
291 return Err(InsertError::OnlySingleValueAcceptedForSchemalessRow.into());
292 } else if !matches!(&values[0], Value::Map(_)) {
293 return Err(InsertError::MapTypeValueRequired((&values[0]).into()).into());
294 }
295 }
296
297 Ok(row.into())
298 });
299
300 Rows::Select(rows)
301 }
302 }
303 .try_collect::<Vec<DataRow>>()
304 .await?;
305
306 Ok(rows)
307}
308
309async fn fill_values(
310 column_defs: &[ColumnDef],
311 columns: &[String],
312 values: &[Expr],
313) -> Result<Vec<Value>> {
314 if !columns.is_empty() && values.len() != columns.len() {
315 return Err(InsertError::ColumnAndValuesNotMatched.into());
316 } else if values.len() > column_defs.len() {
317 return Err(InsertError::TooManyValues.into());
318 }
319
320 if let Some(wrong_column_name) = columns.iter().find(|column_name| {
321 !column_defs
322 .iter()
323 .any(|column_def| &&column_def.name == column_name)
324 }) {
325 return Err(InsertError::WrongColumnName(wrong_column_name.to_owned()).into());
326 }
327
328 #[derive(iter_enum::Iterator)]
329 enum Columns<I1, I2> {
330 All(I1),
331 Specified(I2),
332 }
333
334 let columns = if columns.is_empty() {
335 Columns::All(column_defs.iter().map(|ColumnDef { name, .. }| name))
336 } else {
337 Columns::Specified(columns.iter())
338 };
339
340 let column_name_value_list = columns.zip(values.iter()).collect::<Vec<(_, _)>>();
341
342 let values = stream::iter(column_defs)
343 .then(|column_def| {
344 let column_name_value_list = &column_name_value_list;
345
346 async move {
347 let ColumnDef {
348 name: def_name,
349 data_type,
350 nullable,
351 ..
352 } = column_def;
353
354 let value = column_name_value_list
355 .iter()
356 .find(|(name, _)| name == &def_name)
357 .map(|(_, value)| value);
358
359 match (value, &column_def.default, nullable) {
360 (Some(&expr), _, _) | (None, Some(expr), _) => evaluate_stateless(None, expr)
361 .await?
362 .try_into_value(data_type, *nullable),
363 (None, None, true) => Ok(Value::Null),
364 (None, None, false) => {
365 Err(InsertError::LackOfRequiredColumn(def_name.to_owned()).into())
366 }
367 }
368 }
369 })
370 .try_collect::<Vec<Value>>()
371 .await?;
372
373 Ok(values)
374}