1use alloc::{format, vec::Vec};
14use qusql_parse::{
15 CompoundOperator, CompoundQuery, Expression, Identifier, IdentifierPart, Issues, OptSpanned,
16 Select, SelectExpr, Span, Spanned, Statement, issue_ice, issue_todo,
17};
18
19use crate::{
20 Type,
21 type_::{BaseType, FullType},
22 type_expression::{ExpressionFlags, type_expression},
23 type_reference::type_reference,
24 typer::{ReferenceType, Typer, typer_stack},
25};
26
27#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30 pub name: Option<Identifier<'a>>,
32 pub type_: FullType<'a>,
34 pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39 fn span(&self) -> Span {
40 self.span.span()
41 }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46 pub columns: Vec<SelectTypeColumn<'a>>,
47 pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51 fn span(&self) -> Span {
52 self.columns
53 .opt_span()
54 .unwrap_or_else(|| self.select_span.clone())
55 }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59 typer: &mut Typer<'a, 'b>,
60 parts: &[IdentifierPart<'a>],
61 as_: &Option<Identifier<'a>>,
62 mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64 match parts {
65 [qusql_parse::IdentifierPart::Name(col)] => {
66 let mut cnt = 0;
67 let mut t = None;
68 for r in &typer.reference_types {
69 for c in &r.columns {
70 if c.0 == *col {
71 cnt += 1;
72 t = Some(c);
73 }
74 }
75 }
76 let name = as_.as_ref().unwrap_or(col);
77 if cnt > 1 {
78 let mut issue = typer.issues.err("Ambigious reference", col);
79 for r in &typer.reference_types {
80 for c in &r.columns {
81 if c.0 == *col {
82 issue.frag("Defined here", &r.span);
83 }
84 }
85 }
86 cb(
87 typer.issues,
88 Some(name.clone()),
89 FullType::invalid(),
90 name.span(),
91 as_.is_some(),
92 );
93 } else if let Some(t) = t {
94 cb(
95 typer.issues,
96 Some(name.clone()),
97 t.1.clone(),
98 name.span(),
99 as_.is_some(),
100 );
101 } else {
102 typer.err("Unknown identifier", col);
103 cb(
104 typer.issues,
105 Some(name.clone()),
106 FullType::invalid(),
107 name.span(),
108 as_.is_some(),
109 );
110 }
111 }
112 [qusql_parse::IdentifierPart::Star(v)] => {
113 if let Some(as_) = as_ {
114 typer.err("As not supported for *", as_);
115 }
116 for r in &typer.reference_types {
117 for c in &r.columns {
118 cb(
119 typer.issues,
120 Some(c.0.clone()),
121 c.1.clone(),
122 v.clone(),
123 false,
124 );
125 }
126 }
127 }
128 [
129 qusql_parse::IdentifierPart::Name(tbl),
130 qusql_parse::IdentifierPart::Name(col),
131 ] => {
132 let mut t = None;
133 for r in &typer.reference_types {
134 if r.name == Some(tbl.clone()) {
135 for c in &r.columns {
136 if c.0 == *col {
137 t = Some(c);
138 }
139 }
140 }
141 }
142 let name = as_.as_ref().unwrap_or(col);
143 if let Some(t) = t {
144 cb(
145 typer.issues,
146 Some(name.clone()),
147 t.1.clone(),
148 name.span(),
149 as_.is_some(),
150 );
151 } else {
152 typer.err("Unknown identifier", col);
153 cb(
154 typer.issues,
155 Some(name.clone()),
156 FullType::invalid(),
157 name.span(),
158 as_.is_some(),
159 );
160 }
161 }
162 [
163 qusql_parse::IdentifierPart::Name(tbl),
164 qusql_parse::IdentifierPart::Star(v),
165 ] => {
166 if let Some(as_) = as_ {
167 typer.err("As not supported for *", as_);
168 }
169 let mut t = None;
170 for r in &typer.reference_types {
171 if r.name == Some(tbl.clone()) {
172 t = Some(r);
173 }
174 }
175 if let Some(t) = t {
176 for c in &t.columns {
177 cb(
178 typer.issues,
179 Some(c.0.clone()),
180 c.1.clone(),
181 v.clone(),
182 false,
183 );
184 }
185 } else {
186 typer.err("Unknown table", tbl);
187 }
188 }
189 [qusql_parse::IdentifierPart::Star(v), _] => {
190 typer.err("Not supported here", v);
191 }
192 _ => {
193 typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
194 }
195 }
196}
197
198pub(crate) fn type_select<'a>(
199 typer: &mut Typer<'a, '_>,
200 select: &Select<'a>,
201 warn_duplicate: bool,
202) -> SelectType<'a> {
203 let mut guard = typer_stack(
204 typer,
205 |t| t.reference_types.clone(),
206 |t, v| t.reference_types = v,
207 );
208 let typer = &mut guard.typer;
209
210 for flag in &select.flags {
211 match &flag {
212 qusql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
213 qusql_parse::SelectFlag::Distinct(_)
214 | qusql_parse::SelectFlag::DistinctOn(_)
215 | qusql_parse::SelectFlag::DistinctRow(_) => (),
216 qusql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
217 qusql_parse::SelectFlag::HighPriority(_)
218 | qusql_parse::SelectFlag::SqlSmallResult(_)
219 | qusql_parse::SelectFlag::SqlBigResult(_)
220 | qusql_parse::SelectFlag::SqlBufferResult(_)
221 | qusql_parse::SelectFlag::SqlNoCache(_)
222 | qusql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
223 }
224 }
225
226 if let Some(references) = &select.table_references {
227 for reference in references {
228 type_reference(typer, reference, false);
229 }
230 }
231
232 if let Some((where_, _)) = &select.where_ {
233 let t = type_expression(
234 typer,
235 where_,
236 ExpressionFlags::default()
237 .with_not_null(true)
238 .with_true(true),
239 BaseType::Bool,
240 );
241 typer.ensure_base(where_, &t, BaseType::Bool);
242 }
243
244 let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
245
246 if let Some((_, group_by)) = &select.group_by {
247 for e in group_by {
248 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
249 }
250 }
251
252 if let Some((_, order_by)) = &select.order_by {
253 for (e, _) in order_by {
254 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
255 }
256 }
257
258 if let Some((having, _)) = &select.having {
259 let t = type_expression(
260 typer,
261 having,
262 ExpressionFlags::default()
263 .with_not_null(true)
264 .with_true(true),
265 BaseType::Bool,
266 );
267 typer.ensure_base(having, &t, BaseType::Bool);
268 }
269
270 if let Some((_, offset, count)) = &select.limit {
271 if let Some(offset) = offset {
272 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
273 if typer
274 .matched_type(&t, &FullType::new(Type::U64, true))
275 .is_none()
276 {
277 typer.err(format!("Expected integer type got {}", t.t), offset);
278 }
279 }
280 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
281 if typer
282 .matched_type(&t, &FullType::new(Type::U64, true))
283 .is_none()
284 {
285 typer.err(format!("Expected integer type got {}", t.t), count);
286 }
287 }
288
289 SelectType {
290 columns: result
291 .into_iter()
292 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
293 .collect(),
294 select_span: select.span(),
295 }
296}
297
298pub(crate) fn type_select_exprs<'a, 'b>(
299 typer: &mut Typer<'a, 'b>,
300 select_exprs: &[SelectExpr<'a>],
301 warn_duplicate: bool,
302) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
303 let mut result = Vec::new();
304 let mut select_reference = ReferenceType {
305 name: None,
306 span: select_exprs.opt_span().expect("select_exprs span"),
307 columns: Vec::new(),
308 };
309
310 for e in select_exprs {
311 let mut add_result = |issues: &mut Issues<'a>,
312 name: Option<Identifier<'a>>,
313 type_: FullType<'a>,
314 span: Span,
315 as_: bool| {
316 if let Some(name) = name.clone() {
317 if as_ {
318 select_reference.columns.push((name.clone(), type_.clone()));
319 }
320 for (on, _, os) in &result {
321 if Some(name.clone()) == *on && warn_duplicate {
322 issues
323 .warn("Also defined here", &span)
324 .frag(format!("Multiple columns with the name '{name}'"), os);
325 }
326 }
327 }
328 result.push((name, type_, span));
329 };
330 if let Expression::Identifier(parts) = &e.expr {
331 resolve_kleene_identifier(typer, &parts.parts, &e.as_, add_result);
332 } else {
333 let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
334 if let Some(as_) = &e.as_ {
335 add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
336 } else {
337 if typer.options.warn_unnamed_column_in_select {
338 typer.issues.warn("Unnamed column in select", e);
339 }
340 add_result(typer.issues, None, type_, 0..0, false);
341 };
342 }
343 }
344
345 typer.reference_types.push(select_reference);
346
347 result
348}
349
350pub(crate) fn type_compound_query<'a>(
351 typer: &mut Typer<'a, '_>,
352 query: &CompoundQuery<'a>,
353) -> SelectType<'a> {
354 let mut t = type_union_select(typer, &query.left, true);
355 let mut left = query.left.span();
356 for w in &query.with {
357 if w.operator != CompoundOperator::Union {
358 issue_todo!(typer.issues, w);
359 }
360
361 let t2 = type_union_select(typer, &w.statement, true);
362
363 for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
364 if let Some(l) = t.columns.get_mut(i) {
365 if let Some(r) = t2.columns.get(i) {
366 if l.name != r.name {
367 if let Some(ln) = &l.name {
368 if let Some(rn) = &r.name {
369 typer
370 .err("Incompatible names in union", &w.operator_span)
371 .frag(format!("Column {i} is named {ln}"), &left)
372 .frag(format!("Column {i} is named {rn}"), &w.statement);
373 } else {
374 typer
375 .err("Incompatible names in union", &w.operator_span)
376 .frag(format!("Column {i} is named {ln}"), &left)
377 .frag(format!("Column {i} has no name"), &w.statement);
378 }
379 } else {
380 typer
381 .err("Incompatible names in union", &w.operator_span)
382 .frag(format!("Column {i} has no name"), &left)
383 .frag(
384 format!(
385 "Column {} is named {}",
386 i,
387 r.name.as_ref().expect("name")
388 ),
389 &w.statement,
390 );
391 }
392 }
393 if l.type_.t == r.type_.t {
394 l.type_ =
395 FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
396 } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
397 l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
398 } else {
399 typer
400 .err("Incompatible types in union", &w.operator_span)
401 .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
402 .frag(
403 format!("Column {} is of type {}", i, r.type_.t),
404 &w.statement,
405 );
406 }
407 } else if let Some(n) = &l.name {
408 typer
409 .err("Incompatible types in union", &w.operator_span)
410 .frag(format!("Column {i} ({n}) only on this side"), &left);
411 } else {
412 typer
413 .err("Incompatible types in union", &w.operator_span)
414 .frag(format!("Column {i} only on this side"), &left);
415 }
416 } else if let Some(n) = &t2.columns[i].name {
417 typer
418 .err("Incompatible types in union", &w.operator_span)
419 .frag(format!("Column {i} ({n}) only on this side"), &w.statement);
420 } else {
421 typer
422 .err("Incompatible types in union", &w.operator_span)
423 .frag(format!("Column {i} only on this side"), &w.statement);
424 }
425 }
426 left = left.join_span(&w.statement);
427 }
428
429 typer.reference_types.push(ReferenceType {
430 name: None,
431 span: t.span(),
432 columns: t
433 .columns
434 .iter()
435 .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
436 .collect(),
437 });
438
439 if let Some((_, order_by)) = &query.order_by {
440 for (e, _) in order_by {
441 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
442 }
443 }
444
445 if let Some((_, offset, count)) = &query.limit {
446 if let Some(offset) = offset {
447 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
448 if typer
449 .matched_type(&t, &FullType::new(Type::U64, true))
450 .is_none()
451 {
452 typer.err(format!("Expected integer type got {}", t.t), offset);
453 }
454 }
455 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
456 if typer
457 .matched_type(&t, &FullType::new(Type::U64, true))
458 .is_none()
459 {
460 typer.err(format!("Expected integer type got {}", t.t), count);
461 }
462 }
463
464 typer.reference_types.pop();
465
466 t
467}
468
469pub(crate) fn type_union_select<'a>(
470 typer: &mut Typer<'a, '_>,
471 statement: &Statement<'a>,
472 warn_duplicate: bool,
473) -> SelectType<'a> {
474 match statement {
475 Statement::Select(s) => type_select(typer, s, warn_duplicate),
476 Statement::CompoundQuery(q) => type_compound_query(typer, q),
477 s => {
478 issue_ice!(typer.issues, s);
479 SelectType {
480 columns: Vec::new(),
481 select_span: s.span(),
482 }
483 }
484 }
485}