1use alloc::{format, vec::Vec};
14use qusql_parse::{
15 Identifier, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16 OptSpanned, Spanned, issue_todo,
17};
18
19use crate::{
20 BaseType, SelectTypeColumn, Type,
21 type_expression::{ExpressionFlags, type_expression},
22 type_select::{SelectType, type_select, type_select_exprs},
23 typer::{ReferenceType, Typer, typer_stack, unqualified_name},
24};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29 Yes,
30 No,
31 Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35 typer: &mut Typer<'a, '_>,
36 ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38 let table = unqualified_name(typer.issues, &ior.table);
39 let columns = &ior.columns;
40
41 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42 if schema.view {
43 typer.err("Inserts into views not yet implemented", table);
44 }
45 let mut col_types = Vec::new();
46
47 for col in columns {
48 if let Some(schema_col) = schema.get_column(col.value) {
49 col_types.push((schema_col.type_.clone(), col.span()));
50 } else {
51 typer.err("No such column in schema", col);
52 }
53 }
54
55 if let Some(set) = &ior.set {
56 for col in &schema.columns {
57 if col.auto_increment
58 || col.default
59 || !col.type_.not_null
60 || col.as_.is_some()
61 || col.generated
62 || set.pairs.iter().any(|v| v.column == col.identifier)
63 {
64 continue;
65 }
66 typer.err(
67 format!(
68 "No value for column {} provided, but it has no default value",
69 &col.identifier
70 ),
71 set,
72 );
73 }
74 } else {
75 for col in &schema.columns {
76 if col.auto_increment
77 || col.default
78 || !col.type_.not_null
79 || col.as_.is_some()
80 || col.generated
81 || columns.contains(&col.identifier)
82 {
83 continue;
84 }
85 typer.err(
86 format!(
87 "No value for column {} provided, but it has no default value",
88 &col.identifier
89 ),
90 &columns.opt_span().unwrap_or(table.span()),
91 );
92 }
93 }
94
95 (
96 Some(col_types),
97 schema.columns.iter().any(|c| c.auto_increment),
98 )
99 } else {
100 typer.err("Unknown table", table);
101 (None, false)
102 };
103
104 if let Some(values) = &ior.values {
105 for row in &values.1 {
106 for (j, e) in row.iter().enumerate() {
107 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109 if typer.matched_type(&t, et).is_none() {
110 typer
111 .err(format!("Got type {}", t.t), e)
112 .frag(format!("Expected {}", et.t), ets);
113 } else if let Type::Args(_, args) = &t.t {
114 for (idx, arg_type, _) in args.iter() {
115 typer.constrain_arg(*idx, arg_type, et);
116 }
117 }
118 } else {
119 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120 }
121 }
122 if let Some(s) = &s
123 && s.len() != row.len()
124 {
125 typer
126 .err(
127 format!("Got {} columns", row.len()),
128 &row.opt_span().unwrap(),
129 )
130 .frag(
131 format!("Expected {}", columns.len()),
132 &columns.opt_span().unwrap_or(table.span()),
133 );
134 }
135 }
136 }
137
138 if let Some(select_stmt) = &ior.select
139 && let qusql_parse::Statement::Select(select_inner) = select_stmt
140 {
141 let select = type_select(typer, select_inner, true);
142 if let Some(s) = s {
143 for i in 0..usize::max(s.len(), select.columns.len()) {
144 match (s.get(i), select.columns.get(i)) {
145 (Some((et, ets)), Some(t)) => {
146 if typer.matched_type(&t.type_, et).is_none() {
147 typer
148 .err(format!("Got type {}", t.type_.t), &t.span)
149 .frag(format!("Expected {}", et), ets);
150 }
151 }
152 (None, Some(t)) => {
153 typer.err("Column in select not in insert", &t.span);
154 }
155 (Some((_, ets)), None) => {
156 typer.err("Missing column in select", ets);
157 }
158 (None, None) => {
159 panic!("ICE")
160 }
161 }
162 }
163 }
164 }
165 let mut guard = typer_stack(
168 typer,
169 |t| core::mem::take(&mut t.reference_types),
170 |t, v| t.reference_types = v,
171 );
172 let typer = &mut guard.typer;
173
174 if let Some(s) = typer.schemas.schemas.get(table.value) {
175 let mut columns = Vec::new();
176 for c in &s.columns {
177 columns.push((c.identifier.clone(), c.type_.clone()));
178 }
179 for v in &typer.reference_types {
180 if v.name == Some(table.clone()) {
181 typer
182 .issues
183 .err("Duplicate definitions", table)
184 .frag("Already defined here", &v.span);
185 }
186 }
187 typer.reference_types.push(ReferenceType {
188 name: Some(table.clone()),
189 span: table.span(),
190 columns,
191 });
192 }
193
194 if let Some(set) = &ior.set {
195 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
196 let mut cnt = 0;
197 let mut t = None;
198 for r in &typer.reference_types {
199 for c in &r.columns {
200 if c.0 == *column {
201 cnt += 1;
202 t = Some(c.clone());
203 }
204 }
205 }
206 if cnt > 1 {
207 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
208 let mut issue = typer.issues.err("Ambiguous reference", column);
209 for r in &typer.reference_types {
210 for c in &r.columns {
211 if c.0 == *column {
212 issue.frag("Defined here", &r.span);
213 }
214 }
215 }
216 } else if let Some(t) = t {
217 let value_type =
218 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
219 if typer.matched_type(&value_type, &t.1).is_none() {
220 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
221 } else if let Type::Args(_, args) = &value_type.t {
222 for (idx, arg_type, _) in args.iter() {
223 typer.constrain_arg(*idx, arg_type, &t.1);
224 }
225 }
226 } else {
227 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
228 typer.err("Unknown identifier", column);
229 }
230 }
231 }
232
233 if let Some(up) = &ior.on_duplicate_key_update {
234 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
235 let mut cnt = 0;
236 let mut t = None;
237 for r in &typer.reference_types {
238 for c in &r.columns {
239 if c.0 == *column {
240 cnt += 1;
241 t = Some(c.clone());
242 }
243 }
244 }
245 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
246 if cnt > 1 {
247 type_expression(typer, value, flags, BaseType::Any);
248 let mut issue = typer.issues.err("Ambiguous reference", column);
249 for r in &typer.reference_types {
250 for c in &r.columns {
251 if c.0 == *column {
252 issue.frag("Defined here", &r.span);
253 }
254 }
255 }
256 } else if let Some(t) = t {
257 let value_type = type_expression(typer, value, flags, t.1.base());
258 if typer.matched_type(&value_type, &t.1).is_none() {
259 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
260 } else if let Type::Args(_, args) = &value_type.t {
261 for (idx, arg_type, _) in args.iter() {
262 typer.constrain_arg(*idx, arg_type, &t.1);
263 }
264 }
265 } else {
266 type_expression(typer, value, flags, BaseType::Any);
267 typer.err("Unknown identifier", column);
268 }
269 }
270 }
271
272 if let Some(on_conflict) = &ior.on_conflict {
273 match &on_conflict.target {
274 qusql_parse::OnConflictTarget::Columns { names } => {
275 for name in names {
276 let mut t = None;
277 for r in &typer.reference_types {
278 for c in &r.columns {
279 if c.0 == *name {
280 t = Some(c.clone());
281 }
282 }
283 }
284 if t.is_none() {
285 typer.err("Unknown identifier", name);
286 }
287 }
288 }
290 qusql_parse::OnConflictTarget::OnConstraint {
291 on_constraint_span, ..
292 } => {
293 issue_todo!(typer.issues, on_constraint_span);
294 }
295 qusql_parse::OnConflictTarget::None => (),
296 }
297
298 match &on_conflict.action {
299 qusql_parse::OnConflictAction::DoNothing(_) => (),
300 qusql_parse::OnConflictAction::DoUpdateSet {
301 sets,
302 where_,
303 do_update_set_span,
304 } => {
305 let mut excluded_columns = Vec::new();
306 if let Some(schema) = typer.schemas.schemas.get(table.value)
307 && !schema.view
308 {
309 for col in &ior.columns {
310 if let Some(schema_col) = schema.get_column(col.value) {
311 excluded_columns
312 .push((schema_col.identifier.clone(), schema_col.type_.clone()));
313 }
314 }
315 }
316
317 typer.reference_types.push(ReferenceType {
318 name: Some(Identifier::new("EXCLUDED", do_update_set_span.clone())),
319 span: do_update_set_span.span(),
320 columns: excluded_columns,
321 });
322
323 for (key, value) in sets {
324 let mut cnt = 0;
325 let mut t = None;
326 for r in &typer.reference_types {
327 if let Some(name) = &r.name
328 && name.value == "EXCLUDED"
329 {
330 continue;
331 }
332 for c in &r.columns {
333 if c.0 == *key {
334 cnt += 1;
335 t = Some(c.clone());
336 }
337 }
338 }
339 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
340 if cnt > 1 {
341 type_expression(typer, value, flags, BaseType::Any);
342 let mut issue = typer.issues.err("Ambiguous reference", key);
343 for r in &typer.reference_types {
344 for c in &r.columns {
345 if c.0 == *key {
346 issue.frag("Defined here", &r.span);
347 }
348 }
349 }
350 } else if let Some(t) = t {
351 let value_type = type_expression(typer, value, flags, t.1.base());
352 if typer.matched_type(&value_type, &t.1).is_none() {
353 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
354 } else if let Type::Args(_, args) = &value_type.t {
355 for (idx, arg_type, _) in args.iter() {
356 typer.constrain_arg(*idx, arg_type, &t.1);
357 }
358 }
359 } else {
360 type_expression(typer, value, flags, BaseType::Any);
361 typer.err("Unknown identifier", key);
362 }
363 }
364 if let Some((_, where_)) = where_ {
365 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
366 }
367 }
368 }
369 }
370
371 let returning_select = match &ior.returning {
372 Some((returning_span, returning_exprs)) => {
373 let columns = type_select_exprs(typer, returning_exprs, true)
374 .into_iter()
375 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
376 .collect();
377 Some(SelectType {
378 columns,
379 select_span: returning_span.join_span(returning_exprs),
380 })
381 }
382 None => None,
383 };
384
385 core::mem::drop(guard);
386
387 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
388 if ior
389 .flags
390 .iter()
391 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
392 || ior.on_duplicate_key_update.is_some()
393 {
394 AutoIncrementId::Optional
395 } else {
396 AutoIncrementId::Yes
397 }
398 } else {
399 AutoIncrementId::No
400 };
401
402 (auto_increment_id, returning_select)
403}