Skip to main content

qusql_type/
type_select.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use qusql_parse::{
15    CompoundOperator, CompoundQuery, Expression, Identifier, IdentifierPart, Issues, OptSpanned,
16    Select, SelectExpr, Span, Spanned, Statement, issue_ice, issue_todo,
17};
18
19use crate::{
20    Type,
21    type_::{BaseType, FullType},
22    type_expression::{ExpressionFlags, type_expression},
23    type_reference::type_reference,
24    typer::{ReferenceType, Typer, typer_stack},
25};
26
27/// A column in select
28#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30    /// The name of the column if one is specified or can be computed
31    pub name: Option<Identifier<'a>>,
32    /// The type of the data
33    pub type_: FullType<'a>,
34    /// A span of the expression yielding the column
35    pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39    fn span(&self) -> Span {
40        self.span.span()
41    }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46    pub columns: Vec<SelectTypeColumn<'a>>,
47    pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51    fn span(&self) -> Span {
52        self.columns
53            .opt_span()
54            .unwrap_or_else(|| self.select_span.clone())
55    }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59    typer: &mut Typer<'a, 'b>,
60    parts: &[IdentifierPart<'a>],
61    as_: &Option<Identifier<'a>>,
62    mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64    match parts {
65        [qusql_parse::IdentifierPart::Name(col)] => {
66            let mut cnt = 0;
67            let mut t = None;
68            for r in &typer.reference_types {
69                for c in &r.columns {
70                    if c.0 == *col {
71                        cnt += 1;
72                        t = Some(c);
73                    }
74                }
75            }
76            let name = as_.as_ref().unwrap_or(col);
77            if cnt > 1 {
78                let mut issue = typer.issues.err("Ambigious reference", col);
79                for r in &typer.reference_types {
80                    for c in &r.columns {
81                        if c.0 == *col {
82                            issue.frag("Defined here", &r.span);
83                        }
84                    }
85                }
86                cb(
87                    typer.issues,
88                    Some(name.clone()),
89                    FullType::invalid(),
90                    name.span(),
91                    as_.is_some(),
92                );
93            } else if let Some(t) = t {
94                cb(
95                    typer.issues,
96                    Some(name.clone()),
97                    t.1.clone(),
98                    name.span(),
99                    as_.is_some(),
100                );
101            } else {
102                typer.err("Unknown identifier", col);
103                cb(
104                    typer.issues,
105                    Some(name.clone()),
106                    FullType::invalid(),
107                    name.span(),
108                    as_.is_some(),
109                );
110            }
111        }
112        [qusql_parse::IdentifierPart::Star(v)] => {
113            if let Some(as_) = as_ {
114                typer.err("As not supported for *", as_);
115            }
116            for r in &typer.reference_types {
117                for c in &r.columns {
118                    cb(
119                        typer.issues,
120                        Some(c.0.clone()),
121                        c.1.clone(),
122                        v.clone(),
123                        false,
124                    );
125                }
126            }
127        }
128        [
129            qusql_parse::IdentifierPart::Name(tbl),
130            qusql_parse::IdentifierPart::Name(col),
131        ] => {
132            let mut t = None;
133            for r in &typer.reference_types {
134                if r.name == Some(tbl.clone()) {
135                    for c in &r.columns {
136                        if c.0 == *col {
137                            t = Some(c);
138                        }
139                    }
140                }
141            }
142            let name = as_.as_ref().unwrap_or(col);
143            if let Some(t) = t {
144                cb(
145                    typer.issues,
146                    Some(name.clone()),
147                    t.1.clone(),
148                    name.span(),
149                    as_.is_some(),
150                );
151            } else {
152                typer.err("Unknown identifier", col);
153                cb(
154                    typer.issues,
155                    Some(name.clone()),
156                    FullType::invalid(),
157                    name.span(),
158                    as_.is_some(),
159                );
160            }
161        }
162        [
163            qusql_parse::IdentifierPart::Name(tbl),
164            qusql_parse::IdentifierPart::Star(v),
165        ] => {
166            if let Some(as_) = as_ {
167                typer.err("As not supported for *", as_);
168            }
169            let mut t = None;
170            for r in &typer.reference_types {
171                if r.name == Some(tbl.clone()) {
172                    t = Some(r);
173                }
174            }
175            if let Some(t) = t {
176                for c in &t.columns {
177                    cb(
178                        typer.issues,
179                        Some(c.0.clone()),
180                        c.1.clone(),
181                        v.clone(),
182                        false,
183                    );
184                }
185            } else {
186                typer.err("Unknown table", tbl);
187            }
188        }
189        [qusql_parse::IdentifierPart::Star(v), _] => {
190            typer.err("Not supported here", v);
191        }
192        _ => {
193            typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
194        }
195    }
196}
197
198pub(crate) fn type_select<'a>(
199    typer: &mut Typer<'a, '_>,
200    select: &Select<'a>,
201    warn_duplicate: bool,
202) -> SelectType<'a> {
203    let mut guard = typer_stack(
204        typer,
205        |t| {
206            let refs = core::mem::take(&mut t.reference_types);
207            let old_outer = core::mem::take(&mut t.outer_reference_types);
208            // Make the current scope's refs visible as the outer (correlated) scope
209            // for any subqueries encountered within this SELECT.
210            let mut new_outer = refs.clone();
211            new_outer.extend(old_outer.iter().cloned());
212            t.outer_reference_types = new_outer;
213            (refs, old_outer)
214        },
215        |t, (refs, old_outer)| {
216            t.reference_types = refs;
217            t.outer_reference_types = old_outer;
218        },
219    );
220    let typer = &mut guard.typer;
221
222    for flag in &select.flags {
223        match &flag {
224            qusql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
225            qusql_parse::SelectFlag::Distinct(_)
226            | qusql_parse::SelectFlag::DistinctOn(_)
227            | qusql_parse::SelectFlag::DistinctRow(_) => (),
228            qusql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
229            qusql_parse::SelectFlag::HighPriority(_)
230            | qusql_parse::SelectFlag::SqlSmallResult(_)
231            | qusql_parse::SelectFlag::SqlBigResult(_)
232            | qusql_parse::SelectFlag::SqlBufferResult(_)
233            | qusql_parse::SelectFlag::SqlNoCache(_)
234            | qusql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
235        }
236    }
237
238    if let Some(references) = &select.table_references {
239        for reference in references {
240            type_reference(typer, reference, false);
241        }
242    }
243
244    if let Some((where_, _)) = &select.where_ {
245        let t = type_expression(
246            typer,
247            where_,
248            ExpressionFlags::default()
249                .with_not_null(true)
250                .with_true(true),
251            BaseType::Bool,
252        );
253        typer.ensure_base(where_, &t, BaseType::Bool);
254    }
255
256    let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
257
258    if let Some((_, group_by)) = &select.group_by {
259        for e in group_by {
260            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
261        }
262    }
263
264    if let Some((_, order_by)) = &select.order_by {
265        for (e, _) in order_by {
266            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
267        }
268    }
269
270    if let Some((having, _)) = &select.having {
271        let t = type_expression(
272            typer,
273            having,
274            ExpressionFlags::default()
275                .with_not_null(true)
276                .with_true(true),
277            BaseType::Bool,
278        );
279        typer.ensure_base(having, &t, BaseType::Bool);
280    }
281
282    if let Some((_, offset, count)) = &select.limit {
283        if let Some(offset) = offset {
284            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
285            if typer
286                .matched_type(&t, &FullType::new(Type::U64, true))
287                .is_none()
288            {
289                typer.err(format!("Expected integer type got {}", t.t), offset);
290            }
291        }
292        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
293        if typer
294            .matched_type(&t, &FullType::new(Type::U64, true))
295            .is_none()
296        {
297            typer.err(format!("Expected integer type got {}", t.t), count);
298        }
299    }
300
301    SelectType {
302        columns: result
303            .into_iter()
304            .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
305            .collect(),
306        select_span: select.span(),
307    }
308}
309
310pub(crate) fn type_select_exprs<'a, 'b>(
311    typer: &mut Typer<'a, 'b>,
312    select_exprs: &[SelectExpr<'a>],
313    warn_duplicate: bool,
314) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
315    let mut result = Vec::new();
316    let mut select_reference = ReferenceType {
317        name: None,
318        span: select_exprs.opt_span().expect("select_exprs span"),
319        columns: Vec::new(),
320    };
321
322    for e in select_exprs {
323        let mut add_result = |issues: &mut Issues<'a>,
324                              name: Option<Identifier<'a>>,
325                              type_: FullType<'a>,
326                              span: Span,
327                              as_: bool| {
328            if let Some(name) = name.clone() {
329                if as_ {
330                    select_reference.columns.push((name.clone(), type_.clone()));
331                }
332                for (on, _, os) in &result {
333                    if Some(name.clone()) == *on && warn_duplicate {
334                        issues
335                            .warn("Also defined here", &span)
336                            .frag(format!("Multiple columns with the name '{name}'"), os);
337                    }
338                }
339            }
340            result.push((name, type_, span));
341        };
342        if let Expression::Identifier(parts) = &e.expr {
343            resolve_kleene_identifier(typer, &parts.parts, &e.as_, add_result);
344        } else {
345            let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
346            if let Some(as_) = &e.as_ {
347                add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
348            } else {
349                if typer.options.warn_unnamed_column_in_select {
350                    typer.issues.warn("Unnamed column in select", e);
351                }
352                add_result(typer.issues, None, type_, 0..0, false);
353            };
354        }
355    }
356
357    typer.reference_types.push(select_reference);
358
359    result
360}
361
362pub(crate) fn type_compound_query<'a>(
363    typer: &mut Typer<'a, '_>,
364    query: &CompoundQuery<'a>,
365) -> SelectType<'a> {
366    let mut t = type_union_select(typer, &query.left, true);
367    let mut left = query.left.span();
368    for w in &query.with {
369        if w.operator != CompoundOperator::Union {
370            issue_todo!(typer.issues, w);
371        }
372
373        let t2 = type_union_select(typer, &w.statement, true);
374
375        for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
376            if let Some(l) = t.columns.get_mut(i) {
377                if let Some(r) = t2.columns.get(i) {
378                    if l.name != r.name {
379                        if let Some(ln) = &l.name {
380                            if let Some(rn) = &r.name {
381                                typer
382                                    .err("Incompatible names in union", &w.operator_span)
383                                    .frag(format!("Column {i} is named {ln}"), &left)
384                                    .frag(format!("Column {i} is named {rn}"), &w.statement);
385                            } else {
386                                typer
387                                    .err("Incompatible names in union", &w.operator_span)
388                                    .frag(format!("Column {i} is named {ln}"), &left)
389                                    .frag(format!("Column {i} has no name"), &w.statement);
390                            }
391                        } else {
392                            typer
393                                .err("Incompatible names in union", &w.operator_span)
394                                .frag(format!("Column {i} has no name"), &left)
395                                .frag(
396                                    format!(
397                                        "Column {} is named {}",
398                                        i,
399                                        r.name.as_ref().expect("name")
400                                    ),
401                                    &w.statement,
402                                );
403                        }
404                    }
405                    if l.type_.t == r.type_.t {
406                        l.type_ =
407                            FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
408                    } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
409                        l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
410                    } else {
411                        typer
412                            .err("Incompatible types in union", &w.operator_span)
413                            .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
414                            .frag(
415                                format!("Column {} is of type {}", i, r.type_.t),
416                                &w.statement,
417                            );
418                    }
419                } else if let Some(n) = &l.name {
420                    typer
421                        .err("Incompatible types in union", &w.operator_span)
422                        .frag(format!("Column {i} ({n}) only on this side"), &left);
423                } else {
424                    typer
425                        .err("Incompatible types in union", &w.operator_span)
426                        .frag(format!("Column {i} only on this side"), &left);
427                }
428            } else if let Some(n) = &t2.columns[i].name {
429                typer
430                    .err("Incompatible types in union", &w.operator_span)
431                    .frag(format!("Column {i} ({n}) only on this side"), &w.statement);
432            } else {
433                typer
434                    .err("Incompatible types in union", &w.operator_span)
435                    .frag(format!("Column {i} only on this side"), &w.statement);
436            }
437        }
438        left = left.join_span(&w.statement);
439    }
440
441    typer.reference_types.push(ReferenceType {
442        name: None,
443        span: t.span(),
444        columns: t
445            .columns
446            .iter()
447            .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
448            .collect(),
449    });
450
451    if let Some((_, order_by)) = &query.order_by {
452        for (e, _) in order_by {
453            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
454        }
455    }
456
457    if let Some((_, offset, count)) = &query.limit {
458        if let Some(offset) = offset {
459            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
460            if typer
461                .matched_type(&t, &FullType::new(Type::U64, true))
462                .is_none()
463            {
464                typer.err(format!("Expected integer type got {}", t.t), offset);
465            }
466        }
467        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
468        if typer
469            .matched_type(&t, &FullType::new(Type::U64, true))
470            .is_none()
471        {
472            typer.err(format!("Expected integer type got {}", t.t), count);
473        }
474    }
475
476    typer.reference_types.pop();
477
478    t
479}
480
481pub(crate) fn type_union_select<'a>(
482    typer: &mut Typer<'a, '_>,
483    statement: &Statement<'a>,
484    warn_duplicate: bool,
485) -> SelectType<'a> {
486    match statement {
487        Statement::Select(s) => type_select(typer, s, warn_duplicate),
488        Statement::CompoundQuery(q) => type_compound_query(typer, q),
489        s => {
490            issue_ice!(typer.issues, s);
491            SelectType {
492                columns: Vec::new(),
493                select_span: s.span(),
494            }
495        }
496    }
497}