Skip to main content

qusql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use qusql_parse::{
15    Identifier, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16    OptSpanned, Spanned, issue_todo,
17};
18
19use crate::{
20    BaseType, SelectTypeColumn, Type,
21    type_expression::{ExpressionFlags, type_expression},
22    type_select::{SelectType, type_select, type_select_exprs},
23    typer::{ReferenceType, Typer, typer_stack, unqualified_name},
24};
25
26/// Does the insert yield an auto increment id
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29    Yes,
30    No,
31    Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35    typer: &mut Typer<'a, '_>,
36    ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38    let table = unqualified_name(typer.issues, &ior.table);
39    let columns = &ior.columns;
40
41    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42        if schema.view {
43            typer.err("Inserts into views not yet implemented", table);
44        }
45        let mut col_types = Vec::new();
46
47        for col in columns {
48            if let Some(schema_col) = schema.get_column(col.value) {
49                col_types.push((schema_col.type_.clone(), col.span()));
50            } else {
51                typer.err("No such column in schema", col);
52            }
53        }
54
55        if let Some(set) = &ior.set {
56            for col in &schema.columns {
57                if col.auto_increment
58                    || col.default
59                    || !col.type_.not_null
60                    || col.as_.is_some()
61                    || col.generated
62                    || set.pairs.iter().any(|v| v.column == col.identifier)
63                {
64                    continue;
65                }
66                typer.err(
67                    format!(
68                        "No value for column {} provided, but it has no default value",
69                        &col.identifier
70                    ),
71                    set,
72                );
73            }
74        } else {
75            for col in &schema.columns {
76                if col.auto_increment
77                    || col.default
78                    || !col.type_.not_null
79                    || col.as_.is_some()
80                    || col.generated
81                    || columns.contains(&col.identifier)
82                {
83                    continue;
84                }
85                typer.err(
86                    format!(
87                        "No value for column {} provided, but it has no default value",
88                        &col.identifier
89                    ),
90                    &columns.opt_span().unwrap_or(table.span()),
91                );
92            }
93        }
94
95        (
96            Some(col_types),
97            schema.columns.iter().any(|c| c.auto_increment),
98        )
99    } else {
100        typer.err("Unknown table", table);
101        (None, false)
102    };
103
104    if let Some(values) = &ior.values {
105        for row in &values.1 {
106            for (j, e) in row.iter().enumerate() {
107                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109                    if typer.matched_type(&t, et).is_none() {
110                        typer
111                            .err(format!("Got type {}", t.t), e)
112                            .frag(format!("Expected {}", et.t), ets);
113                    } else if let Type::Args(_, args) = &t.t {
114                        for (idx, arg_type, _) in args.iter() {
115                            typer.constrain_arg(*idx, arg_type, et);
116                        }
117                    }
118                } else {
119                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120                }
121            }
122            if let Some(s) = &s
123                && s.len() != row.len()
124            {
125                typer
126                    .err(
127                        format!("Got {} columns", row.len()),
128                        &row.opt_span().unwrap(),
129                    )
130                    .frag(
131                        format!("Expected {}", columns.len()),
132                        &columns.opt_span().unwrap_or(table.span()),
133                    );
134            }
135        }
136    }
137
138    if let Some(select_stmt) = &ior.select
139        && let qusql_parse::Statement::Select(select_inner) = select_stmt
140    {
141        let select = type_select(typer, select_inner, true);
142        if let Some(s) = s {
143            for i in 0..usize::max(s.len(), select.columns.len()) {
144                match (s.get(i), select.columns.get(i)) {
145                    (Some((et, ets)), Some(t)) => {
146                        if typer.matched_type(&t.type_, et).is_none() {
147                            typer
148                                .err(format!("Got type {}", t.type_.t), &t.span)
149                                .frag(format!("Expected {}", et), ets);
150                        }
151                    }
152                    (None, Some(t)) => {
153                        typer.err("Column in select not in insert", &t.span);
154                    }
155                    (Some((_, ets)), None) => {
156                        typer.err("Missing column in select", ets);
157                    }
158                    (None, None) => {
159                        panic!("ICE")
160                    }
161                }
162            }
163        }
164    }
165    // Compound queries (UNION/INTERSECT/EXCEPT) skip the column type check
166
167    let mut guard = typer_stack(
168        typer,
169        |t| core::mem::take(&mut t.reference_types),
170        |t, v| t.reference_types = v,
171    );
172    let typer = &mut guard.typer;
173
174    if let Some(s) = typer.schemas.schemas.get(table.value) {
175        let mut columns = Vec::new();
176        for c in &s.columns {
177            columns.push((c.identifier.clone(), c.type_.clone()));
178        }
179        for v in &typer.reference_types {
180            if v.name == Some(table.clone()) {
181                typer
182                    .issues
183                    .err("Duplicate definitions", table)
184                    .frag("Already defined here", &v.span);
185            }
186        }
187        typer.reference_types.push(ReferenceType {
188            name: Some(table.clone()),
189            span: table.span(),
190            columns,
191        });
192    }
193
194    if let Some(set) = &ior.set {
195        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
196            let mut cnt = 0;
197            let mut t = None;
198            for r in &typer.reference_types {
199                for c in &r.columns {
200                    if c.0 == *column {
201                        cnt += 1;
202                        t = Some(c.clone());
203                    }
204                }
205            }
206            if cnt > 1 {
207                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
208                let mut issue = typer.issues.err("Ambiguous reference", column);
209                for r in &typer.reference_types {
210                    for c in &r.columns {
211                        if c.0 == *column {
212                            issue.frag("Defined here", &r.span);
213                        }
214                    }
215                }
216            } else if let Some(t) = t {
217                let value_type =
218                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
219                if typer.matched_type(&value_type, &t.1).is_none() {
220                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
221                } else if let Type::Args(_, args) = &value_type.t {
222                    for (idx, arg_type, _) in args.iter() {
223                        typer.constrain_arg(*idx, arg_type, &t.1);
224                    }
225                }
226            } else {
227                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
228                typer.err("Unknown identifier", column);
229            }
230        }
231    }
232
233    if let Some(up) = &ior.on_duplicate_key_update {
234        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
235            let mut cnt = 0;
236            let mut t = None;
237            for r in &typer.reference_types {
238                for c in &r.columns {
239                    if c.0 == *column {
240                        cnt += 1;
241                        t = Some(c.clone());
242                    }
243                }
244            }
245            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
246            if cnt > 1 {
247                type_expression(typer, value, flags, BaseType::Any);
248                let mut issue = typer.issues.err("Ambiguous reference", column);
249                for r in &typer.reference_types {
250                    for c in &r.columns {
251                        if c.0 == *column {
252                            issue.frag("Defined here", &r.span);
253                        }
254                    }
255                }
256            } else if let Some(t) = t {
257                let value_type = type_expression(typer, value, flags, t.1.base());
258                if typer.matched_type(&value_type, &t.1).is_none() {
259                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
260                } else if let Type::Args(_, args) = &value_type.t {
261                    for (idx, arg_type, _) in args.iter() {
262                        typer.constrain_arg(*idx, arg_type, &t.1);
263                    }
264                }
265            } else {
266                type_expression(typer, value, flags, BaseType::Any);
267                typer.err("Unknown identifier", column);
268            }
269        }
270    }
271
272    if let Some(on_conflict) = &ior.on_conflict {
273        match &on_conflict.target {
274            qusql_parse::OnConflictTarget::Columns { names } => {
275                for name in names {
276                    let mut t = None;
277                    for r in &typer.reference_types {
278                        for c in &r.columns {
279                            if c.0 == *name {
280                                t = Some(c.clone());
281                            }
282                        }
283                    }
284                    if t.is_none() {
285                        typer.err("Unknown identifier", name);
286                    }
287                }
288                //TODO check if there is a unique constraint on column
289            }
290            qusql_parse::OnConflictTarget::OnConstraint {
291                on_constraint_span, ..
292            } => {
293                issue_todo!(typer.issues, on_constraint_span);
294            }
295            qusql_parse::OnConflictTarget::None => (),
296        }
297
298        match &on_conflict.action {
299            qusql_parse::OnConflictAction::DoNothing(_) => (),
300            qusql_parse::OnConflictAction::DoUpdateSet {
301                sets,
302                where_,
303                do_update_set_span,
304            } => {
305                let mut excluded_columns = Vec::new();
306                if let Some(schema) = typer.schemas.schemas.get(table.value)
307                    && !schema.view
308                {
309                    for col in &ior.columns {
310                        if let Some(schema_col) = schema.get_column(col.value) {
311                            excluded_columns
312                                .push((schema_col.identifier.clone(), schema_col.type_.clone()));
313                        }
314                    }
315                }
316
317                typer.reference_types.push(ReferenceType {
318                    name: Some(Identifier::new("EXCLUDED", do_update_set_span.clone())),
319                    span: do_update_set_span.span(),
320                    columns: excluded_columns,
321                });
322
323                for (key, value) in sets {
324                    let mut cnt = 0;
325                    let mut t = None;
326                    for r in &typer.reference_types {
327                        if let Some(name) = &r.name
328                            && name.value == "EXCLUDED"
329                        {
330                            continue;
331                        }
332                        for c in &r.columns {
333                            if c.0 == *key {
334                                cnt += 1;
335                                t = Some(c.clone());
336                            }
337                        }
338                    }
339                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
340                    if cnt > 1 {
341                        type_expression(typer, value, flags, BaseType::Any);
342                        let mut issue = typer.issues.err("Ambiguous reference", key);
343                        for r in &typer.reference_types {
344                            for c in &r.columns {
345                                if c.0 == *key {
346                                    issue.frag("Defined here", &r.span);
347                                }
348                            }
349                        }
350                    } else if let Some(t) = t {
351                        let value_type = type_expression(typer, value, flags, t.1.base());
352                        if typer.matched_type(&value_type, &t.1).is_none() {
353                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
354                        } else if let Type::Args(_, args) = &value_type.t {
355                            for (idx, arg_type, _) in args.iter() {
356                                typer.constrain_arg(*idx, arg_type, &t.1);
357                            }
358                        }
359                    } else {
360                        type_expression(typer, value, flags, BaseType::Any);
361                        typer.err("Unknown identifier", key);
362                    }
363                }
364                if let Some((_, where_)) = where_ {
365                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
366                }
367            }
368        }
369    }
370
371    let returning_select = match &ior.returning {
372        Some((returning_span, returning_exprs)) => {
373            let columns = type_select_exprs(typer, returning_exprs, true)
374                .into_iter()
375                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
376                .collect();
377            Some(SelectType {
378                columns,
379                select_span: returning_span.join_span(returning_exprs),
380            })
381        }
382        None => None,
383    };
384
385    core::mem::drop(guard);
386
387    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
388        if ior
389            .flags
390            .iter()
391            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
392            || ior.on_duplicate_key_update.is_some()
393        {
394            AutoIncrementId::Optional
395        } else {
396            AutoIncrementId::Yes
397        }
398    } else {
399        AutoIncrementId::No
400    };
401
402    (auto_increment_id, returning_select)
403}