Skip to main content

qusql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use qusql_parse::{
15    Identifier, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16    OptSpanned, Spanned, issue_todo,
17};
18
19use crate::{
20    BaseType, SelectTypeColumn, Type,
21    type_expression::{ExpressionFlags, type_expression},
22    type_select::{SelectType, type_select, type_select_exprs},
23    typer::{ReferenceType, Typer, did_you_mean, typer_stack, unqualified_name},
24};
25
26/// Does the insert yield an auto increment id
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29    Yes,
30    No,
31    Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35    typer: &mut Typer<'a, '_>,
36    ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38    let table = unqualified_name(typer.issues, &ior.table);
39    let columns = &ior.columns;
40
41    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42        if schema.view {
43            typer.err("Inserts into views not yet implemented", table);
44        }
45        let mut col_types = Vec::new();
46
47        for col in columns {
48            if let Some(schema_col) = schema.get_column(col.value) {
49                col_types.push((schema_col.type_.clone(), col.span()));
50            } else {
51                typer.err("No such column in schema", col);
52            }
53        }
54
55        if let Some(set) = &ior.set {
56            for col in &schema.columns {
57                if col.auto_increment
58                    || col.default
59                    || !col.type_.not_null
60                    || col.as_.is_some()
61                    || col.generated
62                    || set.pairs.iter().any(|v| v.column == col.identifier)
63                {
64                    continue;
65                }
66                typer.err(
67                    format!(
68                        "No value for column {} provided, but it has no default value",
69                        &col.identifier
70                    ),
71                    set,
72                );
73            }
74        } else {
75            for col in &schema.columns {
76                if col.auto_increment
77                    || col.default
78                    || !col.type_.not_null
79                    || col.as_.is_some()
80                    || col.generated
81                    || columns.contains(&col.identifier)
82                {
83                    continue;
84                }
85                typer.err(
86                    format!(
87                        "No value for column {} provided, but it has no default value",
88                        &col.identifier
89                    ),
90                    &columns.opt_span().unwrap_or(table.span()),
91                );
92            }
93        }
94
95        (
96            Some(col_types),
97            schema.columns.iter().any(|c| c.auto_increment),
98        )
99    } else {
100        typer.err("Unknown table", table);
101        (None, false)
102    };
103
104    if let Some(values) = &ior.values {
105        for row in &values.1 {
106            for (j, e) in row.iter().enumerate() {
107                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109                    if typer.matched_type(&t, et).is_none() {
110                        typer
111                            .err(format!("Got type {}", t.t), e)
112                            .frag(format!("Expected {}", et.t), ets);
113                    } else if let Type::Args(_, args) = &t.t {
114                        for (idx, arg_type, _) in args.iter() {
115                            typer.constrain_arg(*idx, arg_type, et);
116                        }
117                    }
118                } else {
119                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120                }
121            }
122            if let Some(s) = &s
123                && s.len() != row.len()
124            {
125                typer
126                    .err(
127                        format!("Got {} columns", row.len()),
128                        &row.opt_span().unwrap(),
129                    )
130                    .frag(
131                        format!("Expected {}", columns.len()),
132                        &columns.opt_span().unwrap_or(table.span()),
133                    );
134            }
135        }
136    }
137
138    if let Some(select_stmt) = &ior.select
139        && let qusql_parse::Statement::Select(select_inner) = select_stmt
140    {
141        let select = type_select(typer, select_inner, true);
142        if let Some(s) = s {
143            for i in 0..usize::max(s.len(), select.columns.len()) {
144                match (s.get(i), select.columns.get(i)) {
145                    (Some((et, ets)), Some(t)) => {
146                        if typer.matched_type(&t.type_, et).is_none() {
147                            typer
148                                .err(format!("Got type {}", t.type_.t), &t.span)
149                                .frag(format!("Expected {}", et), ets);
150                        }
151                    }
152                    (None, Some(t)) => {
153                        typer.err("Column in select not in insert", &t.span);
154                    }
155                    (Some((_, ets)), None) => {
156                        typer.err("Missing column in select", ets);
157                    }
158                    (None, None) => {
159                        panic!("ICE")
160                    }
161                }
162            }
163        }
164    }
165    // Compound queries (UNION/INTERSECT/EXCEPT) skip the column type check
166
167    let mut guard = typer_stack(
168        typer,
169        |t| core::mem::take(&mut t.reference_types),
170        |t, v| t.reference_types = v,
171    );
172    let typer = &mut guard.typer;
173
174    if let Some(s) = typer.schemas.schemas.get(table.value) {
175        let mut columns = Vec::new();
176        for c in &s.columns {
177            columns.push((c.identifier.clone(), c.type_.clone()));
178        }
179        for v in &typer.reference_types {
180            if v.name == Some(table.clone()) {
181                typer
182                    .issues
183                    .err("Duplicate definitions", table)
184                    .frag("Already defined here", &v.span);
185            }
186        }
187        typer.reference_types.push(ReferenceType {
188            name: Some(table.clone()),
189            span: table.span(),
190            columns,
191        });
192    }
193
194    if let Some(set) = &ior.set {
195        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
196            let mut cnt = 0;
197            let mut t = None;
198            for r in &typer.reference_types {
199                for c in &r.columns {
200                    if c.0 == *column {
201                        cnt += 1;
202                        t = Some(c.clone());
203                    }
204                }
205            }
206            if cnt > 1 {
207                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
208                let mut issue = typer.issues.err("Ambiguous reference", column);
209                for r in &typer.reference_types {
210                    for c in &r.columns {
211                        if c.0 == *column {
212                            issue.frag("Defined here", &r.span);
213                        }
214                    }
215                }
216            } else if let Some(t) = t {
217                let value_type =
218                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
219                if typer.matched_type(&value_type, &t.1).is_none() {
220                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
221                } else if let Type::Args(_, args) = &value_type.t {
222                    for (idx, arg_type, _) in args.iter() {
223                        typer.constrain_arg(*idx, arg_type, &t.1);
224                    }
225                }
226            } else {
227                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
228                let suggestion = did_you_mean(
229                    column.value,
230                    typer
231                        .reference_types
232                        .iter()
233                        .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
234                );
235                let mut issue = typer.err("Unknown identifier", column);
236                if let Some(s) = suggestion {
237                    issue.help(alloc::format!("did you mean `{s}`?"));
238                }
239            }
240        }
241    }
242
243    if let Some(up) = &ior.on_duplicate_key_update {
244        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
245            let mut cnt = 0;
246            let mut t = None;
247            for r in &typer.reference_types {
248                for c in &r.columns {
249                    if c.0 == *column {
250                        cnt += 1;
251                        t = Some(c.clone());
252                    }
253                }
254            }
255            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
256            if cnt > 1 {
257                type_expression(typer, value, flags, BaseType::Any);
258                let mut issue = typer.issues.err("Ambiguous reference", column);
259                for r in &typer.reference_types {
260                    for c in &r.columns {
261                        if c.0 == *column {
262                            issue.frag("Defined here", &r.span);
263                        }
264                    }
265                }
266            } else if let Some(t) = t {
267                let value_type = type_expression(typer, value, flags, t.1.base());
268                if typer.matched_type(&value_type, &t.1).is_none() {
269                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
270                } else if let Type::Args(_, args) = &value_type.t {
271                    for (idx, arg_type, _) in args.iter() {
272                        typer.constrain_arg(*idx, arg_type, &t.1);
273                    }
274                }
275            } else {
276                type_expression(typer, value, flags, BaseType::Any);
277                let suggestion = did_you_mean(
278                    column.value,
279                    typer
280                        .reference_types
281                        .iter()
282                        .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
283                );
284                let mut issue = typer.err("Unknown identifier", column);
285                if let Some(s) = suggestion {
286                    issue.help(alloc::format!("did you mean `{s}`?"));
287                }
288            }
289        }
290    }
291
292    if let Some(on_conflict) = &ior.on_conflict {
293        match &on_conflict.target {
294            qusql_parse::OnConflictTarget::Columns { names } => {
295                for name in names {
296                    let mut t = None;
297                    for r in &typer.reference_types {
298                        for c in &r.columns {
299                            if c.0 == *name {
300                                t = Some(c.clone());
301                            }
302                        }
303                    }
304                    if t.is_none() {
305                        let suggestion = did_you_mean(
306                            name.value,
307                            typer
308                                .reference_types
309                                .iter()
310                                .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
311                        );
312                        let mut issue = typer.err("Unknown identifier", name);
313                        if let Some(s) = suggestion {
314                            issue.help(alloc::format!("did you mean `{s}`?"));
315                        }
316                    }
317                }
318                //TODO check if there is a unique constraint on column
319            }
320            qusql_parse::OnConflictTarget::OnConstraint {
321                on_constraint_span, ..
322            } => {
323                issue_todo!(typer.issues, on_constraint_span);
324            }
325            qusql_parse::OnConflictTarget::None => (),
326        }
327
328        match &on_conflict.action {
329            qusql_parse::OnConflictAction::DoNothing(_) => (),
330            qusql_parse::OnConflictAction::DoUpdateSet {
331                sets,
332                where_,
333                do_update_set_span,
334            } => {
335                let mut excluded_columns = Vec::new();
336                if let Some(schema) = typer.schemas.schemas.get(table.value)
337                    && !schema.view
338                {
339                    for col in &ior.columns {
340                        if let Some(schema_col) = schema.get_column(col.value) {
341                            excluded_columns
342                                .push((schema_col.identifier.clone(), schema_col.type_.clone()));
343                        }
344                    }
345                }
346
347                typer.reference_types.push(ReferenceType {
348                    name: Some(Identifier::new("EXCLUDED", do_update_set_span.clone())),
349                    span: do_update_set_span.span(),
350                    columns: excluded_columns,
351                });
352
353                for (key, value) in sets {
354                    let mut cnt = 0;
355                    let mut t = None;
356                    for r in &typer.reference_types {
357                        if let Some(name) = &r.name
358                            && name.value == "EXCLUDED"
359                        {
360                            continue;
361                        }
362                        for c in &r.columns {
363                            if c.0 == *key {
364                                cnt += 1;
365                                t = Some(c.clone());
366                            }
367                        }
368                    }
369                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
370                    if cnt > 1 {
371                        type_expression(typer, value, flags, BaseType::Any);
372                        let mut issue = typer.issues.err("Ambiguous reference", key);
373                        for r in &typer.reference_types {
374                            for c in &r.columns {
375                                if c.0 == *key {
376                                    issue.frag("Defined here", &r.span);
377                                }
378                            }
379                        }
380                    } else if let Some(t) = t {
381                        let value_type = type_expression(typer, value, flags, t.1.base());
382                        if typer.matched_type(&value_type, &t.1).is_none() {
383                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
384                        } else if let Type::Args(_, args) = &value_type.t {
385                            for (idx, arg_type, _) in args.iter() {
386                                typer.constrain_arg(*idx, arg_type, &t.1);
387                            }
388                        }
389                    } else {
390                        type_expression(typer, value, flags, BaseType::Any);
391                        let suggestion = did_you_mean(
392                            key.value,
393                            typer
394                                .reference_types
395                                .iter()
396                                .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
397                        );
398                        let mut issue = typer.err("Unknown identifier", key);
399                        if let Some(s) = suggestion {
400                            issue.help(alloc::format!("did you mean `{s}`?"));
401                        }
402                    }
403                }
404                if let Some((_, where_)) = where_ {
405                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
406                }
407            }
408        }
409    }
410
411    let returning_select = match &ior.returning {
412        Some((returning_span, returning_exprs)) => {
413            let columns = type_select_exprs(typer, returning_exprs, true)
414                .into_iter()
415                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
416                .collect();
417            Some(SelectType {
418                columns,
419                select_span: returning_span.join_span(returning_exprs),
420            })
421        }
422        None => None,
423    };
424
425    core::mem::drop(guard);
426
427    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
428        if ior
429            .flags
430            .iter()
431            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
432            || ior.on_duplicate_key_update.is_some()
433        {
434            AutoIncrementId::Optional
435        } else {
436            AutoIncrementId::Yes
437        }
438    } else {
439        AutoIncrementId::No
440    };
441
442    (auto_increment_id, returning_select)
443}