1use alloc::{format, vec::Vec};
14use qusql_parse::{
15 CompoundOperator, CompoundQuery, Expression, Identifier, IdentifierPart, Issues, OptSpanned,
16 Select, SelectExpr, Span, Spanned, Statement, issue_ice, issue_todo,
17};
18
19use crate::{
20 Type,
21 type_::{BaseType, FullType},
22 type_expression::{ExpressionFlags, type_expression},
23 type_reference::type_reference,
24 typer::{ReferenceType, Typer, did_you_mean, typer_stack},
25};
26
27#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30 pub name: Option<Identifier<'a>>,
32 pub type_: FullType<'a>,
34 pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39 fn span(&self) -> Span {
40 self.span.span()
41 }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46 pub columns: Vec<SelectTypeColumn<'a>>,
47 pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51 fn span(&self) -> Span {
52 self.columns
53 .opt_span()
54 .unwrap_or_else(|| self.select_span.clone())
55 }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59 typer: &mut Typer<'a, 'b>,
60 parts: &[IdentifierPart<'a>],
61 as_: &Option<Identifier<'a>>,
62 mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64 match parts {
65 [qusql_parse::IdentifierPart::Name(col)] => {
66 let mut cnt = 0;
67 let mut t = None;
68 for r in &typer.reference_types {
69 for c in &r.columns {
70 if c.0 == *col {
71 cnt += 1;
72 t = Some(c);
73 }
74 }
75 }
76 let name = as_.as_ref().unwrap_or(col);
77 if cnt > 1 {
78 let mut issue = typer.issues.err("Ambigious reference", col);
79 for r in &typer.reference_types {
80 for c in &r.columns {
81 if c.0 == *col {
82 issue.frag("Defined here", &r.span);
83 }
84 }
85 }
86 cb(
87 typer.issues,
88 Some(name.clone()),
89 FullType::invalid(),
90 name.span(),
91 as_.is_some(),
92 );
93 } else if let Some(t) = t {
94 cb(
95 typer.issues,
96 Some(name.clone()),
97 t.1.clone(),
98 name.span(),
99 as_.is_some(),
100 );
101 } else {
102 let suggestion = did_you_mean(
103 col.value,
104 typer
105 .reference_types
106 .iter()
107 .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
108 );
109 let mut issue = typer.err("Unknown identifier", col);
110 if let Some(s) = suggestion {
111 issue.help(alloc::format!("did you mean `{s}`?"));
112 }
113 cb(
114 typer.issues,
115 Some(name.clone()),
116 FullType::invalid(),
117 name.span(),
118 as_.is_some(),
119 );
120 }
121 }
122 [qusql_parse::IdentifierPart::Star(v)] => {
123 if let Some(as_) = as_ {
124 typer.err("As not supported for *", as_);
125 }
126 for r in &typer.reference_types {
127 for c in &r.columns {
128 cb(
129 typer.issues,
130 Some(c.0.clone()),
131 c.1.clone(),
132 v.clone(),
133 false,
134 );
135 }
136 }
137 }
138 [
139 qusql_parse::IdentifierPart::Name(tbl),
140 qusql_parse::IdentifierPart::Name(col),
141 ] => {
142 let mut t = None;
143 for r in &typer.reference_types {
144 if r.name == Some(tbl.clone()) {
145 for c in &r.columns {
146 if c.0 == *col {
147 t = Some(c);
148 }
149 }
150 }
151 }
152 let name = as_.as_ref().unwrap_or(col);
153 if let Some(t) = t {
154 cb(
155 typer.issues,
156 Some(name.clone()),
157 t.1.clone(),
158 name.span(),
159 as_.is_some(),
160 );
161 } else {
162 let suggestion = did_you_mean(
163 col.value,
164 typer
165 .reference_types
166 .iter()
167 .filter(|r| r.name.as_deref() == Some(tbl.as_ref()))
168 .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
169 );
170 let mut issue = typer.err("Unknown identifier", col);
171 if let Some(s) = suggestion {
172 issue.help(alloc::format!("did you mean `{s}`?"));
173 }
174 cb(
175 typer.issues,
176 Some(name.clone()),
177 FullType::invalid(),
178 name.span(),
179 as_.is_some(),
180 );
181 }
182 }
183 [
184 qusql_parse::IdentifierPart::Name(tbl),
185 qusql_parse::IdentifierPart::Star(v),
186 ] => {
187 if let Some(as_) = as_ {
188 typer.err("As not supported for *", as_);
189 }
190 let mut t = None;
191 for r in &typer.reference_types {
192 if r.name == Some(tbl.clone()) {
193 t = Some(r);
194 }
195 }
196 if let Some(t) = t {
197 for c in &t.columns {
198 cb(
199 typer.issues,
200 Some(c.0.clone()),
201 c.1.clone(),
202 v.clone(),
203 false,
204 );
205 }
206 } else {
207 typer.err("Unknown table", tbl);
208 }
209 }
210 [qusql_parse::IdentifierPart::Star(v), _] => {
211 typer.err("Not supported here", v);
212 }
213 _ => {
214 typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
215 }
216 }
217}
218
219pub(crate) fn type_select<'a>(
220 typer: &mut Typer<'a, '_>,
221 select: &Select<'a>,
222 warn_duplicate: bool,
223) -> SelectType<'a> {
224 let mut guard = typer_stack(
225 typer,
226 |t| {
227 let refs = core::mem::take(&mut t.reference_types);
228 let old_outer = core::mem::take(&mut t.outer_reference_types);
229 let mut new_outer = refs.clone();
232 new_outer.extend(old_outer.iter().cloned());
233 t.outer_reference_types = new_outer;
234 (refs, old_outer)
235 },
236 |t, (refs, old_outer)| {
237 t.reference_types = refs;
238 t.outer_reference_types = old_outer;
239 },
240 );
241 let typer = &mut guard.typer;
242
243 for flag in &select.flags {
244 match &flag {
245 qusql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
246 qusql_parse::SelectFlag::Distinct(_)
247 | qusql_parse::SelectFlag::DistinctOn(_)
248 | qusql_parse::SelectFlag::DistinctRow(_) => (),
249 qusql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
250 qusql_parse::SelectFlag::HighPriority(_)
251 | qusql_parse::SelectFlag::SqlSmallResult(_)
252 | qusql_parse::SelectFlag::SqlBigResult(_)
253 | qusql_parse::SelectFlag::SqlBufferResult(_)
254 | qusql_parse::SelectFlag::SqlNoCache(_)
255 | qusql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
256 }
257 }
258
259 if let Some(references) = &select.table_references {
260 for reference in references {
261 type_reference(typer, reference, false);
262 }
263 }
264
265 if let Some((where_, _)) = &select.where_ {
266 let t = type_expression(
267 typer,
268 where_,
269 ExpressionFlags::default()
270 .with_not_null(true)
271 .with_true(true),
272 BaseType::Bool,
273 );
274 typer.ensure_base(where_, &t, BaseType::Bool);
275 }
276
277 let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
278
279 if let Some((_, group_by)) = &select.group_by {
280 for e in group_by {
281 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
282 }
283 }
284
285 if let Some((_, order_by)) = &select.order_by {
286 for (e, _) in order_by {
287 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
288 }
289 }
290
291 if let Some((having, _)) = &select.having {
292 let t = type_expression(
293 typer,
294 having,
295 ExpressionFlags::default()
296 .with_not_null(true)
297 .with_true(true),
298 BaseType::Bool,
299 );
300 typer.ensure_base(having, &t, BaseType::Bool);
301 }
302
303 if let Some((_, offset, count)) = &select.limit {
304 if let Some(offset) = offset {
305 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
306 if typer
307 .matched_type(&t, &FullType::new(Type::U64, true))
308 .is_none()
309 {
310 typer.err(format!("Expected integer type got {}", t.t), offset);
311 }
312 }
313 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
314 if typer
315 .matched_type(&t, &FullType::new(Type::U64, true))
316 .is_none()
317 {
318 typer.err(format!("Expected integer type got {}", t.t), count);
319 }
320 }
321
322 SelectType {
323 columns: result
324 .into_iter()
325 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
326 .collect(),
327 select_span: select.span(),
328 }
329}
330
331pub(crate) fn type_select_exprs<'a, 'b>(
332 typer: &mut Typer<'a, 'b>,
333 select_exprs: &[SelectExpr<'a>],
334 warn_duplicate: bool,
335) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
336 let mut result = Vec::new();
337 let mut select_reference = ReferenceType {
338 name: None,
339 span: select_exprs.opt_span().expect("select_exprs span"),
340 columns: Vec::new(),
341 };
342
343 for e in select_exprs {
344 let mut add_result = |issues: &mut Issues<'a>,
345 name: Option<Identifier<'a>>,
346 type_: FullType<'a>,
347 span: Span,
348 as_: bool| {
349 if let Some(name) = name.clone() {
350 if as_ {
351 select_reference.columns.push((name.clone(), type_.clone()));
352 }
353 for (on, _, os) in &result {
354 if Some(name.clone()) == *on && warn_duplicate {
355 issues
356 .warn("Also defined here", &span)
357 .frag(format!("Multiple columns with the name '{name}'"), os);
358 }
359 }
360 }
361 result.push((name, type_, span));
362 };
363 if let Expression::Identifier(parts) = &e.expr {
364 resolve_kleene_identifier(typer, &parts.parts, &e.as_, add_result);
365 } else {
366 let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
367 if let Some(as_) = &e.as_ {
368 add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
369 } else {
370 if typer.options.warn_unnamed_column_in_select {
371 typer.issues.warn("Unnamed column in select", e);
372 }
373 add_result(typer.issues, None, type_, 0..0, false);
374 };
375 }
376 }
377
378 typer.reference_types.push(select_reference);
379
380 result
381}
382
383pub(crate) fn type_compound_query<'a>(
384 typer: &mut Typer<'a, '_>,
385 query: &CompoundQuery<'a>,
386) -> SelectType<'a> {
387 let mut t = type_union_select(typer, &query.left, true);
388 let mut left = query.left.span();
389 for w in &query.with {
390 if w.operator != CompoundOperator::Union {
391 issue_todo!(typer.issues, w);
392 }
393
394 let t2 = type_union_select(typer, &w.statement, true);
395
396 for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
397 if let Some(l) = t.columns.get_mut(i) {
398 if let Some(r) = t2.columns.get(i) {
399 if l.name != r.name {
400 if let Some(ln) = &l.name {
401 if let Some(rn) = &r.name {
402 typer
403 .err("Incompatible names in union", &w.operator_span)
404 .frag(format!("Column {i} is named {ln}"), &left)
405 .frag(format!("Column {i} is named {rn}"), &w.statement);
406 } else {
407 typer
408 .err("Incompatible names in union", &w.operator_span)
409 .frag(format!("Column {i} is named {ln}"), &left)
410 .frag(format!("Column {i} has no name"), &w.statement);
411 }
412 } else {
413 typer
414 .err("Incompatible names in union", &w.operator_span)
415 .frag(format!("Column {i} has no name"), &left)
416 .frag(
417 format!(
418 "Column {} is named {}",
419 i,
420 r.name.as_ref().expect("name")
421 ),
422 &w.statement,
423 );
424 }
425 }
426 if l.type_.t == r.type_.t {
427 l.type_ =
428 FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
429 } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
430 l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
431 } else {
432 typer
433 .err("Incompatible types in union", &w.operator_span)
434 .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
435 .frag(
436 format!("Column {} is of type {}", i, r.type_.t),
437 &w.statement,
438 );
439 }
440 } else if let Some(n) = &l.name {
441 typer
442 .err("Incompatible types in union", &w.operator_span)
443 .frag(format!("Column {i} ({n}) only on this side"), &left);
444 } else {
445 typer
446 .err("Incompatible types in union", &w.operator_span)
447 .frag(format!("Column {i} only on this side"), &left);
448 }
449 } else if let Some(n) = &t2.columns[i].name {
450 typer
451 .err("Incompatible types in union", &w.operator_span)
452 .frag(format!("Column {i} ({n}) only on this side"), &w.statement);
453 } else {
454 typer
455 .err("Incompatible types in union", &w.operator_span)
456 .frag(format!("Column {i} only on this side"), &w.statement);
457 }
458 }
459 left = left.join_span(&w.statement);
460 }
461
462 typer.reference_types.push(ReferenceType {
463 name: None,
464 span: t.span(),
465 columns: t
466 .columns
467 .iter()
468 .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
469 .collect(),
470 });
471
472 if let Some((_, order_by)) = &query.order_by {
473 for (e, _) in order_by {
474 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
475 }
476 }
477
478 if let Some((_, offset, count)) = &query.limit {
479 if let Some(offset) = offset {
480 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
481 if typer
482 .matched_type(&t, &FullType::new(Type::U64, true))
483 .is_none()
484 {
485 typer.err(format!("Expected integer type got {}", t.t), offset);
486 }
487 }
488 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
489 if typer
490 .matched_type(&t, &FullType::new(Type::U64, true))
491 .is_none()
492 {
493 typer.err(format!("Expected integer type got {}", t.t), count);
494 }
495 }
496
497 typer.reference_types.pop();
498
499 t
500}
501
502pub(crate) fn type_union_select<'a>(
503 typer: &mut Typer<'a, '_>,
504 statement: &Statement<'a>,
505 warn_duplicate: bool,
506) -> SelectType<'a> {
507 match statement {
508 Statement::Select(s) => type_select(typer, s, warn_duplicate),
509 Statement::CompoundQuery(q) => type_compound_query(typer, q),
510 s => {
511 issue_ice!(typer.issues, s);
512 SelectType {
513 columns: Vec::new(),
514 select_span: s.span(),
515 }
516 }
517 }
518}