Skip to main content

qusql_type/
type_select.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use qusql_parse::{
15    CompoundOperator, CompoundQuery, Expression, Identifier, IdentifierPart, Issues, OptSpanned,
16    Select, SelectExpr, Span, Spanned, Statement, issue_ice, issue_todo,
17};
18
19use crate::{
20    Type,
21    type_::{BaseType, FullType},
22    type_expression::{ExpressionFlags, type_expression},
23    type_reference::type_reference,
24    typer::{ReferenceType, Typer, did_you_mean, typer_stack},
25};
26
27/// A column in select
28#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30    /// The name of the column if one is specified or can be computed
31    pub name: Option<Identifier<'a>>,
32    /// The type of the data
33    pub type_: FullType<'a>,
34    /// A span of the expression yielding the column
35    pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39    fn span(&self) -> Span {
40        self.span.span()
41    }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46    pub columns: Vec<SelectTypeColumn<'a>>,
47    pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51    fn span(&self) -> Span {
52        self.columns
53            .opt_span()
54            .unwrap_or_else(|| self.select_span.clone())
55    }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59    typer: &mut Typer<'a, 'b>,
60    parts: &[IdentifierPart<'a>],
61    as_: &Option<Identifier<'a>>,
62    mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64    match parts {
65        [qusql_parse::IdentifierPart::Name(col)] => {
66            let mut cnt = 0;
67            let mut t = None;
68            for r in &typer.reference_types {
69                for c in &r.columns {
70                    if c.0 == *col {
71                        cnt += 1;
72                        t = Some(c);
73                    }
74                }
75            }
76            let name = as_.as_ref().unwrap_or(col);
77            if cnt > 1 {
78                let mut issue = typer.issues.err("Ambigious reference", col);
79                for r in &typer.reference_types {
80                    for c in &r.columns {
81                        if c.0 == *col {
82                            issue.frag("Defined here", &r.span);
83                        }
84                    }
85                }
86                cb(
87                    typer.issues,
88                    Some(name.clone()),
89                    FullType::invalid(),
90                    name.span(),
91                    as_.is_some(),
92                );
93            } else if let Some(t) = t {
94                cb(
95                    typer.issues,
96                    Some(name.clone()),
97                    t.1.clone(),
98                    name.span(),
99                    as_.is_some(),
100                );
101            } else {
102                let suggestion = did_you_mean(
103                    col.value,
104                    typer
105                        .reference_types
106                        .iter()
107                        .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
108                );
109                let mut issue = typer.err("Unknown identifier", col);
110                if let Some(s) = suggestion {
111                    issue.help(alloc::format!("did you mean `{s}`?"));
112                }
113                cb(
114                    typer.issues,
115                    Some(name.clone()),
116                    FullType::invalid(),
117                    name.span(),
118                    as_.is_some(),
119                );
120            }
121        }
122        [qusql_parse::IdentifierPart::Star(v)] => {
123            if let Some(as_) = as_ {
124                typer.err("As not supported for *", as_);
125            }
126            for r in &typer.reference_types {
127                for c in &r.columns {
128                    cb(
129                        typer.issues,
130                        Some(c.0.clone()),
131                        c.1.clone(),
132                        v.clone(),
133                        false,
134                    );
135                }
136            }
137        }
138        [
139            qusql_parse::IdentifierPart::Name(tbl),
140            qusql_parse::IdentifierPart::Name(col),
141        ] => {
142            let mut t = None;
143            for r in &typer.reference_types {
144                if r.name == Some(tbl.clone()) {
145                    for c in &r.columns {
146                        if c.0 == *col {
147                            t = Some(c);
148                        }
149                    }
150                }
151            }
152            let name = as_.as_ref().unwrap_or(col);
153            if let Some(t) = t {
154                cb(
155                    typer.issues,
156                    Some(name.clone()),
157                    t.1.clone(),
158                    name.span(),
159                    as_.is_some(),
160                );
161            } else {
162                let suggestion = did_you_mean(
163                    col.value,
164                    typer
165                        .reference_types
166                        .iter()
167                        .filter(|r| r.name.as_deref() == Some(tbl.as_ref()))
168                        .flat_map(|r| r.columns.iter().map(|(id, _)| id.value)),
169                );
170                let mut issue = typer.err("Unknown identifier", col);
171                if let Some(s) = suggestion {
172                    issue.help(alloc::format!("did you mean `{s}`?"));
173                }
174                cb(
175                    typer.issues,
176                    Some(name.clone()),
177                    FullType::invalid(),
178                    name.span(),
179                    as_.is_some(),
180                );
181            }
182        }
183        [
184            qusql_parse::IdentifierPart::Name(tbl),
185            qusql_parse::IdentifierPart::Star(v),
186        ] => {
187            if let Some(as_) = as_ {
188                typer.err("As not supported for *", as_);
189            }
190            let mut t = None;
191            for r in &typer.reference_types {
192                if r.name == Some(tbl.clone()) {
193                    t = Some(r);
194                }
195            }
196            if let Some(t) = t {
197                for c in &t.columns {
198                    cb(
199                        typer.issues,
200                        Some(c.0.clone()),
201                        c.1.clone(),
202                        v.clone(),
203                        false,
204                    );
205                }
206            } else {
207                typer.err("Unknown table", tbl);
208            }
209        }
210        [qusql_parse::IdentifierPart::Star(v), _] => {
211            typer.err("Not supported here", v);
212        }
213        _ => {
214            typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
215        }
216    }
217}
218
219pub(crate) fn type_select<'a>(
220    typer: &mut Typer<'a, '_>,
221    select: &Select<'a>,
222    warn_duplicate: bool,
223) -> SelectType<'a> {
224    let mut guard = typer_stack(
225        typer,
226        |t| {
227            let refs = core::mem::take(&mut t.reference_types);
228            let old_outer = core::mem::take(&mut t.outer_reference_types);
229            // Make the current scope's refs visible as the outer (correlated) scope
230            // for any subqueries encountered within this SELECT.
231            let mut new_outer = refs.clone();
232            new_outer.extend(old_outer.iter().cloned());
233            t.outer_reference_types = new_outer;
234            (refs, old_outer)
235        },
236        |t, (refs, old_outer)| {
237            t.reference_types = refs;
238            t.outer_reference_types = old_outer;
239        },
240    );
241    let typer = &mut guard.typer;
242
243    for flag in &select.flags {
244        match &flag {
245            qusql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
246            qusql_parse::SelectFlag::Distinct(_)
247            | qusql_parse::SelectFlag::DistinctOn(_)
248            | qusql_parse::SelectFlag::DistinctRow(_) => (),
249            qusql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
250            qusql_parse::SelectFlag::HighPriority(_)
251            | qusql_parse::SelectFlag::SqlSmallResult(_)
252            | qusql_parse::SelectFlag::SqlBigResult(_)
253            | qusql_parse::SelectFlag::SqlBufferResult(_)
254            | qusql_parse::SelectFlag::SqlNoCache(_)
255            | qusql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
256        }
257    }
258
259    if let Some(references) = &select.table_references {
260        for reference in references {
261            type_reference(typer, reference, false);
262        }
263    }
264
265    if let Some((where_, _)) = &select.where_ {
266        let t = type_expression(
267            typer,
268            where_,
269            ExpressionFlags::default()
270                .with_not_null(true)
271                .with_true(true),
272            BaseType::Bool,
273        );
274        typer.ensure_base(where_, &t, BaseType::Bool);
275    }
276
277    let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
278
279    if let Some((_, group_by)) = &select.group_by {
280        for e in group_by {
281            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
282        }
283    }
284
285    if let Some((_, order_by)) = &select.order_by {
286        for (e, _) in order_by {
287            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
288        }
289    }
290
291    if let Some((having, _)) = &select.having {
292        let t = type_expression(
293            typer,
294            having,
295            ExpressionFlags::default()
296                .with_not_null(true)
297                .with_true(true),
298            BaseType::Bool,
299        );
300        typer.ensure_base(having, &t, BaseType::Bool);
301    }
302
303    if let Some((_, offset, count)) = &select.limit {
304        if let Some(offset) = offset {
305            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
306            if typer
307                .matched_type(&t, &FullType::new(Type::U64, true))
308                .is_none()
309            {
310                typer.err(format!("Expected integer type got {}", t.t), offset);
311            }
312        }
313        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
314        if typer
315            .matched_type(&t, &FullType::new(Type::U64, true))
316            .is_none()
317        {
318            typer.err(format!("Expected integer type got {}", t.t), count);
319        }
320    }
321
322    SelectType {
323        columns: result
324            .into_iter()
325            .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
326            .collect(),
327        select_span: select.span(),
328    }
329}
330
331pub(crate) fn type_select_exprs<'a, 'b>(
332    typer: &mut Typer<'a, 'b>,
333    select_exprs: &[SelectExpr<'a>],
334    warn_duplicate: bool,
335) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
336    let mut result = Vec::new();
337    let mut select_reference = ReferenceType {
338        name: None,
339        span: select_exprs.opt_span().expect("select_exprs span"),
340        columns: Vec::new(),
341    };
342
343    for e in select_exprs {
344        let mut add_result = |issues: &mut Issues<'a>,
345                              name: Option<Identifier<'a>>,
346                              type_: FullType<'a>,
347                              span: Span,
348                              as_: bool| {
349            if let Some(name) = name.clone() {
350                if as_ {
351                    select_reference.columns.push((name.clone(), type_.clone()));
352                }
353                for (on, _, os) in &result {
354                    if Some(name.clone()) == *on && warn_duplicate {
355                        issues
356                            .warn("Also defined here", &span)
357                            .frag(format!("Multiple columns with the name '{name}'"), os);
358                    }
359                }
360            }
361            result.push((name, type_, span));
362        };
363        if let Expression::Identifier(parts) = &e.expr {
364            resolve_kleene_identifier(typer, &parts.parts, &e.as_, add_result);
365        } else {
366            let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
367            if let Some(as_) = &e.as_ {
368                add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
369            } else {
370                if typer.options.warn_unnamed_column_in_select {
371                    typer.issues.warn("Unnamed column in select", e);
372                }
373                add_result(typer.issues, None, type_, 0..0, false);
374            };
375        }
376    }
377
378    typer.reference_types.push(select_reference);
379
380    result
381}
382
383pub(crate) fn type_compound_query<'a>(
384    typer: &mut Typer<'a, '_>,
385    query: &CompoundQuery<'a>,
386) -> SelectType<'a> {
387    let mut t = type_union_select(typer, &query.left, true);
388    let mut left = query.left.span();
389    for w in &query.with {
390        if w.operator != CompoundOperator::Union {
391            issue_todo!(typer.issues, w);
392        }
393
394        let t2 = type_union_select(typer, &w.statement, true);
395
396        for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
397            if let Some(l) = t.columns.get_mut(i) {
398                if let Some(r) = t2.columns.get(i) {
399                    if l.name != r.name {
400                        if let Some(ln) = &l.name {
401                            if let Some(rn) = &r.name {
402                                typer
403                                    .err("Incompatible names in union", &w.operator_span)
404                                    .frag(format!("Column {i} is named {ln}"), &left)
405                                    .frag(format!("Column {i} is named {rn}"), &w.statement);
406                            } else {
407                                typer
408                                    .err("Incompatible names in union", &w.operator_span)
409                                    .frag(format!("Column {i} is named {ln}"), &left)
410                                    .frag(format!("Column {i} has no name"), &w.statement);
411                            }
412                        } else {
413                            typer
414                                .err("Incompatible names in union", &w.operator_span)
415                                .frag(format!("Column {i} has no name"), &left)
416                                .frag(
417                                    format!(
418                                        "Column {} is named {}",
419                                        i,
420                                        r.name.as_ref().expect("name")
421                                    ),
422                                    &w.statement,
423                                );
424                        }
425                    }
426                    if l.type_.t == r.type_.t {
427                        l.type_ =
428                            FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
429                    } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
430                        l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
431                    } else {
432                        typer
433                            .err("Incompatible types in union", &w.operator_span)
434                            .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
435                            .frag(
436                                format!("Column {} is of type {}", i, r.type_.t),
437                                &w.statement,
438                            );
439                    }
440                } else if let Some(n) = &l.name {
441                    typer
442                        .err("Incompatible types in union", &w.operator_span)
443                        .frag(format!("Column {i} ({n}) only on this side"), &left);
444                } else {
445                    typer
446                        .err("Incompatible types in union", &w.operator_span)
447                        .frag(format!("Column {i} only on this side"), &left);
448                }
449            } else if let Some(n) = &t2.columns[i].name {
450                typer
451                    .err("Incompatible types in union", &w.operator_span)
452                    .frag(format!("Column {i} ({n}) only on this side"), &w.statement);
453            } else {
454                typer
455                    .err("Incompatible types in union", &w.operator_span)
456                    .frag(format!("Column {i} only on this side"), &w.statement);
457            }
458        }
459        left = left.join_span(&w.statement);
460    }
461
462    typer.reference_types.push(ReferenceType {
463        name: None,
464        span: t.span(),
465        columns: t
466            .columns
467            .iter()
468            .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
469            .collect(),
470    });
471
472    if let Some((_, order_by)) = &query.order_by {
473        for (e, _) in order_by {
474            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
475        }
476    }
477
478    if let Some((_, offset, count)) = &query.limit {
479        if let Some(offset) = offset {
480            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
481            if typer
482                .matched_type(&t, &FullType::new(Type::U64, true))
483                .is_none()
484            {
485                typer.err(format!("Expected integer type got {}", t.t), offset);
486            }
487        }
488        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
489        if typer
490            .matched_type(&t, &FullType::new(Type::U64, true))
491            .is_none()
492        {
493            typer.err(format!("Expected integer type got {}", t.t), count);
494        }
495    }
496
497    typer.reference_types.pop();
498
499    t
500}
501
502pub(crate) fn type_union_select<'a>(
503    typer: &mut Typer<'a, '_>,
504    statement: &Statement<'a>,
505    warn_duplicate: bool,
506) -> SelectType<'a> {
507    match statement {
508        Statement::Select(s) => type_select(typer, s, warn_duplicate),
509        Statement::CompoundQuery(q) => type_compound_query(typer, q),
510        s => {
511            issue_ice!(typer.issues, s);
512            SelectType {
513                columns: Vec::new(),
514                select_span: s.span(),
515            }
516        }
517    }
518}