1use alloc::{format, vec::Vec};
14use qusql_parse::{
15 Identifier, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16 OptSpanned, Spanned, issue_todo,
17};
18
19use crate::{
20 BaseType, SelectTypeColumn, Type,
21 type_expression::{ExpressionFlags, type_expression},
22 type_select::{SelectType, type_select, type_select_exprs},
23 typer::{ReferenceType, Typer, did_you_mean, typer_stack, unqualified_name},
24};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29 Yes,
30 No,
31 Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35 typer: &mut Typer<'a, '_>,
36 ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38 let table = unqualified_name(typer.issues, &ior.table);
39 let columns = &ior.columns;
40
41 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42 if schema.view {
43 typer.err("Inserts into views not yet implemented", table);
44 }
45 let mut col_types = Vec::new();
46
47 for col in columns {
48 if let Some(schema_col) = schema.get_column(col.value) {
49 col_types.push((schema_col.type_.clone(), col.span()));
50 } else {
51 typer.err("No such column in schema", col);
52 }
53 }
54
55 if let Some(set) = &ior.set {
56 for col in &schema.columns {
57 if col.auto_increment
58 || col.default
59 || !col.type_.not_null
60 || col.as_.is_some()
61 || col.generated
62 || set.pairs.iter().any(|v| v.column == col.identifier)
63 {
64 continue;
65 }
66 typer.err(
67 format!(
68 "No value for column {} provided, but it has no default value",
69 &col.identifier
70 ),
71 set,
72 );
73 }
74 } else {
75 for col in &schema.columns {
76 if col.auto_increment
77 || col.default
78 || !col.type_.not_null
79 || col.as_.is_some()
80 || col.generated
81 || columns.contains(&col.identifier)
82 {
83 continue;
84 }
85 typer.err(
86 format!(
87 "No value for column {} provided, but it has no default value",
88 &col.identifier
89 ),
90 &columns.opt_span().unwrap_or(table.span()),
91 );
92 }
93 }
94
95 (
96 Some(col_types),
97 schema.columns.iter().any(|c| c.auto_increment),
98 )
99 } else {
100 typer.err("Unknown table", table);
101 (None, false)
102 };
103
104 if let Some(values) = &ior.values {
105 for row in &values.1 {
106 for (j, e) in row.iter().enumerate() {
107 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109 if typer.matched_type(&t, et).is_none() {
110 typer
111 .err(format!("Got type {}", t.t), e)
112 .frag(format!("Expected {}", et.t), ets);
113 } else if let Type::Args(_, args) = &t.t {
114 for (idx, arg_type, _) in args.iter() {
115 typer.constrain_arg(*idx, arg_type, et);
116 }
117 }
118 } else {
119 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120 }
121 }
122 if let Some(s) = &s
123 && s.len() != row.len()
124 {
125 typer
126 .err(
127 format!("Got {} columns", row.len()),
128 &row.opt_span().unwrap(),
129 )
130 .frag(
131 format!("Expected {}", columns.len()),
132 &columns.opt_span().unwrap_or(table.span()),
133 );
134 }
135 }
136 }
137
138 if let Some(select_stmt) = &ior.select
139 && let qusql_parse::Statement::Select(select_inner) = select_stmt
140 {
141 let select = type_select(typer, select_inner, true);
142 if let Some(s) = s {
143 for i in 0..usize::max(s.len(), select.columns.len()) {
144 match (s.get(i), select.columns.get(i)) {
145 (Some((et, ets)), Some(t)) => {
146 if typer.matched_type(&t.type_, et).is_none() {
147 typer
148 .err(format!("Got type {}", t.type_.t), &t.span)
149 .frag(format!("Expected {}", et), ets);
150 }
151 }
152 (None, Some(t)) => {
153 typer.err("Column in select not in insert", &t.span);
154 }
155 (Some((_, ets)), None) => {
156 typer.err("Missing column in select", ets);
157 }
158 (None, None) => {
159 panic!("ICE")
160 }
161 }
162 }
163 }
164 }
165 let mut guard = typer_stack(
168 typer,
169 |t| core::mem::take(&mut t.reference_types),
170 |t, v| t.reference_types = v,
171 );
172 let typer = &mut guard.typer;
173
174 if let Some(s) = typer.schemas.schemas.get(table.value) {
175 let mut columns = Vec::new();
176 for c in &s.columns {
177 columns.push((c.identifier.clone(), c.type_.clone()));
178 }
179 for v in &typer.reference_types {
180 if v.name == Some(table.clone()) {
181 typer
182 .issues
183 .err("Duplicate definitions", table)
184 .frag("Already defined here", &v.span);
185 }
186 }
187 typer.reference_types.push(ReferenceType {
188 name: Some(table.clone()),
189 span: table.span(),
190 columns,
191 });
192 }
193
194 if let Some(set) = &ior.set {
195 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
196 let mut cnt = 0;
197 let mut t = None;
198 for r in &typer.reference_types {
199 for c in &r.columns {
200 if c.0 == *column {
201 cnt += 1;
202 t = Some(c.clone());
203 }
204 }
205 }
206 if cnt > 1 {
207 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
208 let mut issue = typer.issues.err("Ambiguous reference", column);
209 for r in &typer.reference_types {
210 for c in &r.columns {
211 if c.0 == *column {
212 issue.frag("Defined here", &r.span);
213 }
214 }
215 }
216 } else if let Some(t) = t {
217 let value_type =
218 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
219 if typer.matched_type(&value_type, &t.1).is_none() {
220 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
221 } else if let Type::Args(_, args) = &value_type.t {
222 for (idx, arg_type, _) in args.iter() {
223 typer.constrain_arg(*idx, arg_type, &t.1);
224 }
225 }
226 } else {
227 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
228 let suggestion = did_you_mean(
229 column.value,
230 typer
231 .reference_types
232 .iter()
233 .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
234 );
235 let mut issue = typer.err("Unknown identifier", column);
236 if let Some(s) = suggestion {
237 issue.help(alloc::format!("did you mean `{s}`?"));
238 }
239 }
240 }
241 }
242
243 if let Some(up) = &ior.on_duplicate_key_update {
244 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
245 let mut cnt = 0;
246 let mut t = None;
247 for r in &typer.reference_types {
248 for c in &r.columns {
249 if c.0 == *column {
250 cnt += 1;
251 t = Some(c.clone());
252 }
253 }
254 }
255 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
256 if cnt > 1 {
257 type_expression(typer, value, flags, BaseType::Any);
258 let mut issue = typer.issues.err("Ambiguous reference", column);
259 for r in &typer.reference_types {
260 for c in &r.columns {
261 if c.0 == *column {
262 issue.frag("Defined here", &r.span);
263 }
264 }
265 }
266 } else if let Some(t) = t {
267 let value_type = type_expression(typer, value, flags, t.1.base());
268 if typer.matched_type(&value_type, &t.1).is_none() {
269 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
270 } else if let Type::Args(_, args) = &value_type.t {
271 for (idx, arg_type, _) in args.iter() {
272 typer.constrain_arg(*idx, arg_type, &t.1);
273 }
274 }
275 } else {
276 type_expression(typer, value, flags, BaseType::Any);
277 let suggestion = did_you_mean(
278 column.value,
279 typer
280 .reference_types
281 .iter()
282 .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
283 );
284 let mut issue = typer.err("Unknown identifier", column);
285 if let Some(s) = suggestion {
286 issue.help(alloc::format!("did you mean `{s}`?"));
287 }
288 }
289 }
290 }
291
292 if let Some(on_conflict) = &ior.on_conflict {
293 match &on_conflict.target {
294 qusql_parse::OnConflictTarget::Columns { names } => {
295 for name in names {
296 let mut t = None;
297 for r in &typer.reference_types {
298 for c in &r.columns {
299 if c.0 == *name {
300 t = Some(c.clone());
301 }
302 }
303 }
304 if t.is_none() {
305 let suggestion = did_you_mean(
306 name.value,
307 typer
308 .reference_types
309 .iter()
310 .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
311 );
312 let mut issue = typer.err("Unknown identifier", name);
313 if let Some(s) = suggestion {
314 issue.help(alloc::format!("did you mean `{s}`?"));
315 }
316 }
317 }
318 }
320 qusql_parse::OnConflictTarget::OnConstraint {
321 on_constraint_span, ..
322 } => {
323 issue_todo!(typer.issues, on_constraint_span);
324 }
325 qusql_parse::OnConflictTarget::None => (),
326 }
327
328 match &on_conflict.action {
329 qusql_parse::OnConflictAction::DoNothing(_) => (),
330 qusql_parse::OnConflictAction::DoUpdateSet {
331 sets,
332 where_,
333 do_update_set_span,
334 } => {
335 let mut excluded_columns = Vec::new();
336 if let Some(schema) = typer.schemas.schemas.get(table.value)
337 && !schema.view
338 {
339 for col in &ior.columns {
340 if let Some(schema_col) = schema.get_column(col.value) {
341 excluded_columns
342 .push((schema_col.identifier.clone(), schema_col.type_.clone()));
343 }
344 }
345 }
346
347 typer.reference_types.push(ReferenceType {
348 name: Some(Identifier::new("EXCLUDED", do_update_set_span.clone())),
349 span: do_update_set_span.span(),
350 columns: excluded_columns,
351 });
352
353 for (key, value) in sets {
354 let mut cnt = 0;
355 let mut t = None;
356 for r in &typer.reference_types {
357 if let Some(name) = &r.name
358 && name.value == "EXCLUDED"
359 {
360 continue;
361 }
362 for c in &r.columns {
363 if c.0 == *key {
364 cnt += 1;
365 t = Some(c.clone());
366 }
367 }
368 }
369 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
370 if cnt > 1 {
371 type_expression(typer, value, flags, BaseType::Any);
372 let mut issue = typer.issues.err("Ambiguous reference", key);
373 for r in &typer.reference_types {
374 for c in &r.columns {
375 if c.0 == *key {
376 issue.frag("Defined here", &r.span);
377 }
378 }
379 }
380 } else if let Some(t) = t {
381 let value_type = type_expression(typer, value, flags, t.1.base());
382 if typer.matched_type(&value_type, &t.1).is_none() {
383 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
384 } else if let Type::Args(_, args) = &value_type.t {
385 for (idx, arg_type, _) in args.iter() {
386 typer.constrain_arg(*idx, arg_type, &t.1);
387 }
388 }
389 } else {
390 type_expression(typer, value, flags, BaseType::Any);
391 let suggestion = did_you_mean(
392 key.value,
393 typer
394 .reference_types
395 .iter()
396 .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
397 );
398 let mut issue = typer.err("Unknown identifier", key);
399 if let Some(s) = suggestion {
400 issue.help(alloc::format!("did you mean `{s}`?"));
401 }
402 }
403 }
404 if let Some((_, where_)) = where_ {
405 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
406 }
407 }
408 }
409 }
410
411 let returning_select = match &ior.returning {
412 Some((returning_span, returning_exprs)) => {
413 let columns = type_select_exprs(typer, returning_exprs, true)
414 .into_iter()
415 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
416 .collect();
417 Some(SelectType {
418 columns,
419 select_span: returning_span.join_span(returning_exprs),
420 })
421 }
422 None => None,
423 };
424
425 core::mem::drop(guard);
426
427 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
428 if ior
429 .flags
430 .iter()
431 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
432 || ior.on_duplicate_key_update.is_some()
433 {
434 AutoIncrementId::Optional
435 } else {
436 AutoIncrementId::Yes
437 }
438 } else {
439 AutoIncrementId::No
440 };
441
442 (auto_increment_id, returning_select)
443}