Skip to main content

qusql_type/
type_select.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use qusql_parse::{
15    CompoundOperator, CompoundQuery, Expression, Identifier, IdentifierPart, Issues, OptSpanned,
16    Select, SelectExpr, Span, Spanned, Statement, issue_ice, issue_todo,
17};
18
19use crate::{
20    Type,
21    type_::{BaseType, FullType},
22    type_expression::{ExpressionFlags, type_expression},
23    type_reference::type_reference,
24    typer::{ReferenceType, Typer, typer_stack},
25};
26
27/// A column in select
28#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30    /// The name of the column if one is specified or can be computed
31    pub name: Option<Identifier<'a>>,
32    /// The type of the data
33    pub type_: FullType<'a>,
34    /// A span of the expression yielding the column
35    pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39    fn span(&self) -> Span {
40        self.span.span()
41    }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46    pub columns: Vec<SelectTypeColumn<'a>>,
47    pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51    fn span(&self) -> Span {
52        self.columns
53            .opt_span()
54            .unwrap_or_else(|| self.select_span.clone())
55    }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59    typer: &mut Typer<'a, 'b>,
60    parts: &[IdentifierPart<'a>],
61    as_: &Option<Identifier<'a>>,
62    mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64    match parts {
65        [qusql_parse::IdentifierPart::Name(col)] => {
66            let mut cnt = 0;
67            let mut t = None;
68            for r in &typer.reference_types {
69                for c in &r.columns {
70                    if c.0 == *col {
71                        cnt += 1;
72                        t = Some(c);
73                    }
74                }
75            }
76            let name = as_.as_ref().unwrap_or(col);
77            if cnt > 1 {
78                let mut issue = typer.issues.err("Ambigious reference", col);
79                for r in &typer.reference_types {
80                    for c in &r.columns {
81                        if c.0 == *col {
82                            issue.frag("Defined here", &r.span);
83                        }
84                    }
85                }
86                cb(
87                    typer.issues,
88                    Some(name.clone()),
89                    FullType::invalid(),
90                    name.span(),
91                    as_.is_some(),
92                );
93            } else if let Some(t) = t {
94                cb(
95                    typer.issues,
96                    Some(name.clone()),
97                    t.1.clone(),
98                    name.span(),
99                    as_.is_some(),
100                );
101            } else {
102                typer.err("Unknown identifier", col);
103                cb(
104                    typer.issues,
105                    Some(name.clone()),
106                    FullType::invalid(),
107                    name.span(),
108                    as_.is_some(),
109                );
110            }
111        }
112        [qusql_parse::IdentifierPart::Star(v)] => {
113            if let Some(as_) = as_ {
114                typer.err("As not supported for *", as_);
115            }
116            for r in &typer.reference_types {
117                for c in &r.columns {
118                    cb(
119                        typer.issues,
120                        Some(c.0.clone()),
121                        c.1.clone(),
122                        v.clone(),
123                        false,
124                    );
125                }
126            }
127        }
128        [
129            qusql_parse::IdentifierPart::Name(tbl),
130            qusql_parse::IdentifierPart::Name(col),
131        ] => {
132            let mut t = None;
133            for r in &typer.reference_types {
134                if r.name == Some(tbl.clone()) {
135                    for c in &r.columns {
136                        if c.0 == *col {
137                            t = Some(c);
138                        }
139                    }
140                }
141            }
142            let name = as_.as_ref().unwrap_or(col);
143            if let Some(t) = t {
144                cb(
145                    typer.issues,
146                    Some(name.clone()),
147                    t.1.clone(),
148                    name.span(),
149                    as_.is_some(),
150                );
151            } else {
152                typer.err("Unknown identifier", col);
153                cb(
154                    typer.issues,
155                    Some(name.clone()),
156                    FullType::invalid(),
157                    name.span(),
158                    as_.is_some(),
159                );
160            }
161        }
162        [
163            qusql_parse::IdentifierPart::Name(tbl),
164            qusql_parse::IdentifierPart::Star(v),
165        ] => {
166            if let Some(as_) = as_ {
167                typer.err("As not supported for *", as_);
168            }
169            let mut t = None;
170            for r in &typer.reference_types {
171                if r.name == Some(tbl.clone()) {
172                    t = Some(r);
173                }
174            }
175            if let Some(t) = t {
176                for c in &t.columns {
177                    cb(
178                        typer.issues,
179                        Some(c.0.clone()),
180                        c.1.clone(),
181                        v.clone(),
182                        false,
183                    );
184                }
185            } else {
186                typer.err("Unknown table", tbl);
187            }
188        }
189        [qusql_parse::IdentifierPart::Star(v), _] => {
190            typer.err("Not supported here", v);
191        }
192        _ => {
193            typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
194        }
195    }
196}
197
198pub(crate) fn type_select<'a>(
199    typer: &mut Typer<'a, '_>,
200    select: &Select<'a>,
201    warn_duplicate: bool,
202) -> SelectType<'a> {
203    let mut guard = typer_stack(
204        typer,
205        |t| t.reference_types.clone(),
206        |t, v| t.reference_types = v,
207    );
208    let typer = &mut guard.typer;
209
210    for flag in &select.flags {
211        match &flag {
212            qusql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
213            qusql_parse::SelectFlag::Distinct(_)
214            | qusql_parse::SelectFlag::DistinctOn(_)
215            | qusql_parse::SelectFlag::DistinctRow(_) => (),
216            qusql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
217            qusql_parse::SelectFlag::HighPriority(_)
218            | qusql_parse::SelectFlag::SqlSmallResult(_)
219            | qusql_parse::SelectFlag::SqlBigResult(_)
220            | qusql_parse::SelectFlag::SqlBufferResult(_)
221            | qusql_parse::SelectFlag::SqlNoCache(_)
222            | qusql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
223        }
224    }
225
226    if let Some(references) = &select.table_references {
227        for reference in references {
228            type_reference(typer, reference, false);
229        }
230    }
231
232    if let Some((where_, _)) = &select.where_ {
233        let t = type_expression(
234            typer,
235            where_,
236            ExpressionFlags::default()
237                .with_not_null(true)
238                .with_true(true),
239            BaseType::Bool,
240        );
241        typer.ensure_base(where_, &t, BaseType::Bool);
242    }
243
244    let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
245
246    if let Some((_, group_by)) = &select.group_by {
247        for e in group_by {
248            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
249        }
250    }
251
252    if let Some((_, order_by)) = &select.order_by {
253        for (e, _) in order_by {
254            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
255        }
256    }
257
258    if let Some((having, _)) = &select.having {
259        let t = type_expression(
260            typer,
261            having,
262            ExpressionFlags::default()
263                .with_not_null(true)
264                .with_true(true),
265            BaseType::Bool,
266        );
267        typer.ensure_base(having, &t, BaseType::Bool);
268    }
269
270    if let Some((_, offset, count)) = &select.limit {
271        if let Some(offset) = offset {
272            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
273            if typer
274                .matched_type(&t, &FullType::new(Type::U64, true))
275                .is_none()
276            {
277                typer.err(format!("Expected integer type got {}", t.t), offset);
278            }
279        }
280        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
281        if typer
282            .matched_type(&t, &FullType::new(Type::U64, true))
283            .is_none()
284        {
285            typer.err(format!("Expected integer type got {}", t.t), count);
286        }
287    }
288
289    SelectType {
290        columns: result
291            .into_iter()
292            .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
293            .collect(),
294        select_span: select.span(),
295    }
296}
297
298pub(crate) fn type_select_exprs<'a, 'b>(
299    typer: &mut Typer<'a, 'b>,
300    select_exprs: &[SelectExpr<'a>],
301    warn_duplicate: bool,
302) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
303    let mut result = Vec::new();
304    let mut select_reference = ReferenceType {
305        name: None,
306        span: select_exprs.opt_span().expect("select_exprs span"),
307        columns: Vec::new(),
308    };
309
310    for e in select_exprs {
311        let mut add_result = |issues: &mut Issues<'a>,
312                              name: Option<Identifier<'a>>,
313                              type_: FullType<'a>,
314                              span: Span,
315                              as_: bool| {
316            if let Some(name) = name.clone() {
317                if as_ {
318                    select_reference.columns.push((name.clone(), type_.clone()));
319                }
320                for (on, _, os) in &result {
321                    if Some(name.clone()) == *on && warn_duplicate {
322                        issues
323                            .warn("Also defined here", &span)
324                            .frag(format!("Multiple columns with the name '{name}'"), os);
325                    }
326                }
327            }
328            result.push((name, type_, span));
329        };
330        if let Expression::Identifier(parts) = &e.expr {
331            resolve_kleene_identifier(typer, &parts.parts, &e.as_, add_result);
332        } else {
333            let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
334            if let Some(as_) = &e.as_ {
335                add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
336            } else {
337                if typer.options.warn_unnamed_column_in_select {
338                    typer.issues.warn("Unnamed column in select", e);
339                }
340                add_result(typer.issues, None, type_, 0..0, false);
341            };
342        }
343    }
344
345    typer.reference_types.push(select_reference);
346
347    result
348}
349
350pub(crate) fn type_compound_query<'a>(
351    typer: &mut Typer<'a, '_>,
352    query: &CompoundQuery<'a>,
353) -> SelectType<'a> {
354    let mut t = type_union_select(typer, &query.left, true);
355    let mut left = query.left.span();
356    for w in &query.with {
357        if w.operator != CompoundOperator::Union {
358            issue_todo!(typer.issues, w);
359        }
360
361        let t2 = type_union_select(typer, &w.statement, true);
362
363        for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
364            if let Some(l) = t.columns.get_mut(i) {
365                if let Some(r) = t2.columns.get(i) {
366                    if l.name != r.name {
367                        if let Some(ln) = &l.name {
368                            if let Some(rn) = &r.name {
369                                typer
370                                    .err("Incompatible names in union", &w.operator_span)
371                                    .frag(format!("Column {i} is named {ln}"), &left)
372                                    .frag(format!("Column {i} is named {rn}"), &w.statement);
373                            } else {
374                                typer
375                                    .err("Incompatible names in union", &w.operator_span)
376                                    .frag(format!("Column {i} is named {ln}"), &left)
377                                    .frag(format!("Column {i} has no name"), &w.statement);
378                            }
379                        } else {
380                            typer
381                                .err("Incompatible names in union", &w.operator_span)
382                                .frag(format!("Column {i} has no name"), &left)
383                                .frag(
384                                    format!(
385                                        "Column {} is named {}",
386                                        i,
387                                        r.name.as_ref().expect("name")
388                                    ),
389                                    &w.statement,
390                                );
391                        }
392                    }
393                    if l.type_.t == r.type_.t {
394                        l.type_ =
395                            FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
396                    } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
397                        l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
398                    } else {
399                        typer
400                            .err("Incompatible types in union", &w.operator_span)
401                            .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
402                            .frag(
403                                format!("Column {} is of type {}", i, r.type_.t),
404                                &w.statement,
405                            );
406                    }
407                } else if let Some(n) = &l.name {
408                    typer
409                        .err("Incompatible types in union", &w.operator_span)
410                        .frag(format!("Column {i} ({n}) only on this side"), &left);
411                } else {
412                    typer
413                        .err("Incompatible types in union", &w.operator_span)
414                        .frag(format!("Column {i} only on this side"), &left);
415                }
416            } else if let Some(n) = &t2.columns[i].name {
417                typer
418                    .err("Incompatible types in union", &w.operator_span)
419                    .frag(format!("Column {i} ({n}) only on this side"), &w.statement);
420            } else {
421                typer
422                    .err("Incompatible types in union", &w.operator_span)
423                    .frag(format!("Column {i} only on this side"), &w.statement);
424            }
425        }
426        left = left.join_span(&w.statement);
427    }
428
429    typer.reference_types.push(ReferenceType {
430        name: None,
431        span: t.span(),
432        columns: t
433            .columns
434            .iter()
435            .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
436            .collect(),
437    });
438
439    if let Some((_, order_by)) = &query.order_by {
440        for (e, _) in order_by {
441            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
442        }
443    }
444
445    if let Some((_, offset, count)) = &query.limit {
446        if let Some(offset) = offset {
447            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
448            if typer
449                .matched_type(&t, &FullType::new(Type::U64, true))
450                .is_none()
451            {
452                typer.err(format!("Expected integer type got {}", t.t), offset);
453            }
454        }
455        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
456        if typer
457            .matched_type(&t, &FullType::new(Type::U64, true))
458            .is_none()
459        {
460            typer.err(format!("Expected integer type got {}", t.t), count);
461        }
462    }
463
464    typer.reference_types.pop();
465
466    t
467}
468
469pub(crate) fn type_union_select<'a>(
470    typer: &mut Typer<'a, '_>,
471    statement: &Statement<'a>,
472    warn_duplicate: bool,
473) -> SelectType<'a> {
474    match statement {
475        Statement::Select(s) => type_select(typer, s, warn_duplicate),
476        Statement::CompoundQuery(q) => type_compound_query(typer, q),
477        s => {
478            issue_ice!(typer.issues, s);
479            SelectType {
480                columns: Vec::new(),
481                select_span: s.span(),
482            }
483        }
484    }
485}