sql_type/
type_reference.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::{
14    schema::IndexKey,
15    type_::BaseType,
16    type_expression::{ExpressionFlags, type_expression},
17    type_select::type_union_select,
18    typer::{ReferenceType, Typer, unqualified_name},
19};
20use alloc::vec::Vec;
21use sql_parse::{OptSpanned, Spanned, TableReference, issue_todo};
22
23pub(crate) fn type_reference<'a>(
24    typer: &mut Typer<'a, '_>,
25    reference: &TableReference<'a>,
26    force_null: bool,
27) {
28    let mut given_refs = core::mem::take(&mut typer.reference_types);
29    match reference {
30        sql_parse::TableReference::Table {
31            identifier,
32            as_,
33            index_hints,
34            ..
35        } => {
36            let identifier = unqualified_name(typer.issues, identifier);
37            if let Some(s) = typer.get_schema(identifier.value) {
38                let mut columns = Vec::new();
39                for c in &s.columns {
40                    let mut type_ = c.type_.clone();
41                    type_.not_null = type_.not_null && !force_null;
42                    columns.push((c.identifier.clone(), type_));
43                }
44                let name = as_.as_ref().unwrap_or(identifier).clone();
45                for v in &typer.reference_types {
46                    if v.name == Some(name.clone()) {
47                        typer
48                            .issues
49                            .err("Duplicate definitions", &name)
50                            .frag("Already defined here", &v.span);
51                    }
52                }
53                for index_hint in index_hints {
54                    if matches!(index_hint.type_, sql_parse::IndexHintType::Index(_)) {
55                        for index in &index_hint.index_list {
56                            if !typer.schemas.indices.contains_key(&IndexKey {
57                                table: Some(identifier.clone()),
58                                index: index.clone(),
59                            }) {
60                                typer.err("Unknown index", index);
61                            }
62                        }
63                    }
64                }
65
66                typer.reference_types.push(ReferenceType {
67                    name: Some(name.clone()),
68                    span: name.span(),
69                    columns,
70                });
71            } else {
72                typer.issues.err("Unknown table or view", identifier);
73            }
74        }
75        sql_parse::TableReference::Query { query, as_, .. } => {
76            let select = type_union_select(typer, query, true);
77
78            let span = if let Some(as_) = as_ {
79                as_.span.clone()
80            } else {
81                select.columns.opt_span().unwrap_or_else(|| query.span())
82            };
83
84            typer.reference_types.push(ReferenceType {
85                name: as_.clone(),
86                span,
87                columns: select
88                    .columns
89                    .iter()
90                    .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
91                    .collect(),
92            });
93        }
94        sql_parse::TableReference::Join {
95            join,
96            left,
97            right,
98            specification,
99        } => {
100            let (left_force_null, right_force_null) = match join {
101                sql_parse::JoinType::Left(_) => (force_null, true),
102                sql_parse::JoinType::Right(_) => (true, force_null),
103                sql_parse::JoinType::Inner(_)
104                | sql_parse::JoinType::Cross(_)
105                | sql_parse::JoinType::Normal(_) => (force_null, force_null),
106                _ => {
107                    issue_todo!(typer.issues, join);
108                    (force_null, force_null)
109                }
110            };
111            type_reference(typer, left, left_force_null);
112            type_reference(typer, right, right_force_null);
113            match &specification {
114                Some(sql_parse::JoinSpecification::On(e, _)) => {
115                    let t = type_expression(typer, e, ExpressionFlags::default(), BaseType::Bool);
116                    typer.ensure_base(e, &t, BaseType::Bool);
117                }
118                Some(s @ sql_parse::JoinSpecification::Using(_, _)) => {
119                    issue_todo!(typer.issues, s);
120                }
121                None => (),
122            }
123        }
124    }
125
126    core::mem::swap(&mut typer.reference_types, &mut given_refs);
127    typer.reference_types.extend(given_refs);
128}