1use alloc::{format, vec::Vec};
14use qusql_parse::{
15 CompoundOperator, CompoundQuery, Expression, Identifier, IdentifierPart, Issues, OptSpanned,
16 Select, SelectExpr, Span, Spanned, Statement, issue_ice, issue_todo,
17};
18
19use crate::{
20 Type,
21 type_::{BaseType, FullType},
22 type_expression::{ExpressionFlags, type_expression},
23 type_reference::type_reference,
24 typer::{ReferenceType, Typer, typer_stack},
25};
26
27#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30 pub name: Option<Identifier<'a>>,
32 pub type_: FullType<'a>,
34 pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39 fn span(&self) -> Span {
40 self.span.span()
41 }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46 pub columns: Vec<SelectTypeColumn<'a>>,
47 pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51 fn span(&self) -> Span {
52 self.columns
53 .opt_span()
54 .unwrap_or_else(|| self.select_span.clone())
55 }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59 typer: &mut Typer<'a, 'b>,
60 parts: &[IdentifierPart<'a>],
61 as_: &Option<Identifier<'a>>,
62 mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64 match parts {
65 [qusql_parse::IdentifierPart::Name(col)] => {
66 let mut cnt = 0;
67 let mut t = None;
68 for r in &typer.reference_types {
69 for c in &r.columns {
70 if c.0 == *col {
71 cnt += 1;
72 t = Some(c);
73 }
74 }
75 }
76 let name = as_.as_ref().unwrap_or(col);
77 if cnt > 1 {
78 let mut issue = typer.issues.err("Ambigious reference", col);
79 for r in &typer.reference_types {
80 for c in &r.columns {
81 if c.0 == *col {
82 issue.frag("Defined here", &r.span);
83 }
84 }
85 }
86 cb(
87 typer.issues,
88 Some(name.clone()),
89 FullType::invalid(),
90 name.span(),
91 as_.is_some(),
92 );
93 } else if let Some(t) = t {
94 cb(
95 typer.issues,
96 Some(name.clone()),
97 t.1.clone(),
98 name.span(),
99 as_.is_some(),
100 );
101 } else {
102 typer.err("Unknown identifier", col);
103 cb(
104 typer.issues,
105 Some(name.clone()),
106 FullType::invalid(),
107 name.span(),
108 as_.is_some(),
109 );
110 }
111 }
112 [qusql_parse::IdentifierPart::Star(v)] => {
113 if let Some(as_) = as_ {
114 typer.err("As not supported for *", as_);
115 }
116 for r in &typer.reference_types {
117 for c in &r.columns {
118 cb(
119 typer.issues,
120 Some(c.0.clone()),
121 c.1.clone(),
122 v.clone(),
123 false,
124 );
125 }
126 }
127 }
128 [
129 qusql_parse::IdentifierPart::Name(tbl),
130 qusql_parse::IdentifierPart::Name(col),
131 ] => {
132 let mut t = None;
133 for r in &typer.reference_types {
134 if r.name == Some(tbl.clone()) {
135 for c in &r.columns {
136 if c.0 == *col {
137 t = Some(c);
138 }
139 }
140 }
141 }
142 let name = as_.as_ref().unwrap_or(col);
143 if let Some(t) = t {
144 cb(
145 typer.issues,
146 Some(name.clone()),
147 t.1.clone(),
148 name.span(),
149 as_.is_some(),
150 );
151 } else {
152 typer.err("Unknown identifier", col);
153 cb(
154 typer.issues,
155 Some(name.clone()),
156 FullType::invalid(),
157 name.span(),
158 as_.is_some(),
159 );
160 }
161 }
162 [
163 qusql_parse::IdentifierPart::Name(tbl),
164 qusql_parse::IdentifierPart::Star(v),
165 ] => {
166 if let Some(as_) = as_ {
167 typer.err("As not supported for *", as_);
168 }
169 let mut t = None;
170 for r in &typer.reference_types {
171 if r.name == Some(tbl.clone()) {
172 t = Some(r);
173 }
174 }
175 if let Some(t) = t {
176 for c in &t.columns {
177 cb(
178 typer.issues,
179 Some(c.0.clone()),
180 c.1.clone(),
181 v.clone(),
182 false,
183 );
184 }
185 } else {
186 typer.err("Unknown table", tbl);
187 }
188 }
189 [qusql_parse::IdentifierPart::Star(v), _] => {
190 typer.err("Not supported here", v);
191 }
192 _ => {
193 typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
194 }
195 }
196}
197
198pub(crate) fn type_select<'a>(
199 typer: &mut Typer<'a, '_>,
200 select: &Select<'a>,
201 warn_duplicate: bool,
202) -> SelectType<'a> {
203 let mut guard = typer_stack(
204 typer,
205 |t| {
206 let refs = core::mem::take(&mut t.reference_types);
207 let old_outer = core::mem::take(&mut t.outer_reference_types);
208 let mut new_outer = refs.clone();
211 new_outer.extend(old_outer.iter().cloned());
212 t.outer_reference_types = new_outer;
213 (refs, old_outer)
214 },
215 |t, (refs, old_outer)| {
216 t.reference_types = refs;
217 t.outer_reference_types = old_outer;
218 },
219 );
220 let typer = &mut guard.typer;
221
222 for flag in &select.flags {
223 match &flag {
224 qusql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
225 qusql_parse::SelectFlag::Distinct(_)
226 | qusql_parse::SelectFlag::DistinctOn(_)
227 | qusql_parse::SelectFlag::DistinctRow(_) => (),
228 qusql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
229 qusql_parse::SelectFlag::HighPriority(_)
230 | qusql_parse::SelectFlag::SqlSmallResult(_)
231 | qusql_parse::SelectFlag::SqlBigResult(_)
232 | qusql_parse::SelectFlag::SqlBufferResult(_)
233 | qusql_parse::SelectFlag::SqlNoCache(_)
234 | qusql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
235 }
236 }
237
238 if let Some(references) = &select.table_references {
239 for reference in references {
240 type_reference(typer, reference, false);
241 }
242 }
243
244 if let Some((where_, _)) = &select.where_ {
245 let t = type_expression(
246 typer,
247 where_,
248 ExpressionFlags::default()
249 .with_not_null(true)
250 .with_true(true),
251 BaseType::Bool,
252 );
253 typer.ensure_base(where_, &t, BaseType::Bool);
254 }
255
256 let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
257
258 if let Some((_, group_by)) = &select.group_by {
259 for e in group_by {
260 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
261 }
262 }
263
264 if let Some((_, order_by)) = &select.order_by {
265 for (e, _) in order_by {
266 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
267 }
268 }
269
270 if let Some((having, _)) = &select.having {
271 let t = type_expression(
272 typer,
273 having,
274 ExpressionFlags::default()
275 .with_not_null(true)
276 .with_true(true),
277 BaseType::Bool,
278 );
279 typer.ensure_base(having, &t, BaseType::Bool);
280 }
281
282 if let Some((_, offset, count)) = &select.limit {
283 if let Some(offset) = offset {
284 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
285 if typer
286 .matched_type(&t, &FullType::new(Type::U64, true))
287 .is_none()
288 {
289 typer.err(format!("Expected integer type got {}", t.t), offset);
290 }
291 }
292 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
293 if typer
294 .matched_type(&t, &FullType::new(Type::U64, true))
295 .is_none()
296 {
297 typer.err(format!("Expected integer type got {}", t.t), count);
298 }
299 }
300
301 SelectType {
302 columns: result
303 .into_iter()
304 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
305 .collect(),
306 select_span: select.span(),
307 }
308}
309
310pub(crate) fn type_select_exprs<'a, 'b>(
311 typer: &mut Typer<'a, 'b>,
312 select_exprs: &[SelectExpr<'a>],
313 warn_duplicate: bool,
314) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
315 let mut result = Vec::new();
316 let mut select_reference = ReferenceType {
317 name: None,
318 span: select_exprs.opt_span().expect("select_exprs span"),
319 columns: Vec::new(),
320 };
321
322 for e in select_exprs {
323 let mut add_result = |issues: &mut Issues<'a>,
324 name: Option<Identifier<'a>>,
325 type_: FullType<'a>,
326 span: Span,
327 as_: bool| {
328 if let Some(name) = name.clone() {
329 if as_ {
330 select_reference.columns.push((name.clone(), type_.clone()));
331 }
332 for (on, _, os) in &result {
333 if Some(name.clone()) == *on && warn_duplicate {
334 issues
335 .warn("Also defined here", &span)
336 .frag(format!("Multiple columns with the name '{name}'"), os);
337 }
338 }
339 }
340 result.push((name, type_, span));
341 };
342 if let Expression::Identifier(parts) = &e.expr {
343 resolve_kleene_identifier(typer, &parts.parts, &e.as_, add_result);
344 } else {
345 let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
346 if let Some(as_) = &e.as_ {
347 add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
348 } else {
349 if typer.options.warn_unnamed_column_in_select {
350 typer.issues.warn("Unnamed column in select", e);
351 }
352 add_result(typer.issues, None, type_, 0..0, false);
353 };
354 }
355 }
356
357 typer.reference_types.push(select_reference);
358
359 result
360}
361
362pub(crate) fn type_compound_query<'a>(
363 typer: &mut Typer<'a, '_>,
364 query: &CompoundQuery<'a>,
365) -> SelectType<'a> {
366 let mut t = type_union_select(typer, &query.left, true);
367 let mut left = query.left.span();
368 for w in &query.with {
369 if w.operator != CompoundOperator::Union {
370 issue_todo!(typer.issues, w);
371 }
372
373 let t2 = type_union_select(typer, &w.statement, true);
374
375 for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
376 if let Some(l) = t.columns.get_mut(i) {
377 if let Some(r) = t2.columns.get(i) {
378 if l.name != r.name {
379 if let Some(ln) = &l.name {
380 if let Some(rn) = &r.name {
381 typer
382 .err("Incompatible names in union", &w.operator_span)
383 .frag(format!("Column {i} is named {ln}"), &left)
384 .frag(format!("Column {i} is named {rn}"), &w.statement);
385 } else {
386 typer
387 .err("Incompatible names in union", &w.operator_span)
388 .frag(format!("Column {i} is named {ln}"), &left)
389 .frag(format!("Column {i} has no name"), &w.statement);
390 }
391 } else {
392 typer
393 .err("Incompatible names in union", &w.operator_span)
394 .frag(format!("Column {i} has no name"), &left)
395 .frag(
396 format!(
397 "Column {} is named {}",
398 i,
399 r.name.as_ref().expect("name")
400 ),
401 &w.statement,
402 );
403 }
404 }
405 if l.type_.t == r.type_.t {
406 l.type_ =
407 FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
408 } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
409 l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
410 } else {
411 typer
412 .err("Incompatible types in union", &w.operator_span)
413 .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
414 .frag(
415 format!("Column {} is of type {}", i, r.type_.t),
416 &w.statement,
417 );
418 }
419 } else if let Some(n) = &l.name {
420 typer
421 .err("Incompatible types in union", &w.operator_span)
422 .frag(format!("Column {i} ({n}) only on this side"), &left);
423 } else {
424 typer
425 .err("Incompatible types in union", &w.operator_span)
426 .frag(format!("Column {i} only on this side"), &left);
427 }
428 } else if let Some(n) = &t2.columns[i].name {
429 typer
430 .err("Incompatible types in union", &w.operator_span)
431 .frag(format!("Column {i} ({n}) only on this side"), &w.statement);
432 } else {
433 typer
434 .err("Incompatible types in union", &w.operator_span)
435 .frag(format!("Column {i} only on this side"), &w.statement);
436 }
437 }
438 left = left.join_span(&w.statement);
439 }
440
441 typer.reference_types.push(ReferenceType {
442 name: None,
443 span: t.span(),
444 columns: t
445 .columns
446 .iter()
447 .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
448 .collect(),
449 });
450
451 if let Some((_, order_by)) = &query.order_by {
452 for (e, _) in order_by {
453 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
454 }
455 }
456
457 if let Some((_, offset, count)) = &query.limit {
458 if let Some(offset) = offset {
459 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
460 if typer
461 .matched_type(&t, &FullType::new(Type::U64, true))
462 .is_none()
463 {
464 typer.err(format!("Expected integer type got {}", t.t), offset);
465 }
466 }
467 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
468 if typer
469 .matched_type(&t, &FullType::new(Type::U64, true))
470 .is_none()
471 {
472 typer.err(format!("Expected integer type got {}", t.t), count);
473 }
474 }
475
476 typer.reference_types.pop();
477
478 t
479}
480
481pub(crate) fn type_union_select<'a>(
482 typer: &mut Typer<'a, '_>,
483 statement: &Statement<'a>,
484 warn_duplicate: bool,
485) -> SelectType<'a> {
486 match statement {
487 Statement::Select(s) => type_select(typer, s, warn_duplicate),
488 Statement::CompoundQuery(q) => type_compound_query(typer, q),
489 s => {
490 issue_ice!(typer.issues, s);
491 SelectType {
492 columns: Vec::new(),
493 select_span: s.span(),
494 }
495 }
496 }
497}