qusql-type 0.4.0

Typer for sql
Documentation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::{
    schema::IndexKey,
    type_::{BaseType, FullType},
    type_expression::{ExpressionFlags, type_expression},
    type_select::type_union_select,
    typer::{ReferenceType, Typer, unqualified_name},
};
use alloc::vec::Vec;
use qusql_parse::{Identifier, OptSpanned, Spanned, TableReference, issue_todo};

pub(crate) fn type_reference<'a>(
    typer: &mut Typer<'a, '_>,
    reference: &TableReference<'a>,
    force_null: bool,
) {
    let mut given_refs = core::mem::take(&mut typer.reference_types);
    match reference {
        qusql_parse::TableReference::Table {
            identifier,
            as_,
            index_hints,
            ..
        } => {
            let identifier = unqualified_name(typer.issues, identifier);
            if let Some(s) = typer.get_schema(identifier.value) {
                let mut columns = Vec::new();
                for c in &s.columns {
                    let mut type_ = c.type_.clone();
                    type_.not_null = type_.not_null && !force_null;
                    columns.push((c.identifier.clone(), type_));
                }
                let name = as_.as_ref().unwrap_or(identifier).clone();
                for v in &typer.reference_types {
                    if v.name == Some(name.clone()) {
                        typer
                            .issues
                            .err("Duplicate definitions", &name)
                            .frag("Already defined here", &v.span);
                    }
                }
                for index_hint in index_hints {
                    if matches!(index_hint.type_, qusql_parse::IndexHintType::Index(_)) {
                        for index in &index_hint.index_list {
                            if !typer.schemas.indices.contains_key(&IndexKey {
                                table: Some(identifier.clone()),
                                index: index.clone(),
                            }) {
                                typer.err("Unknown index", index);
                            }
                        }
                    }
                }

                typer.reference_types.push(ReferenceType {
                    name: Some(name.clone()),
                    span: name.span(),
                    columns,
                });
            } else {
                typer.issues.err("Unknown table or view", identifier);
            }
        }
        qusql_parse::TableReference::Query { query, as_, .. } => {
            let select = type_union_select(typer, query, true);

            let span = if let Some(as_) = as_ {
                as_.span.clone()
            } else {
                select.columns.opt_span().unwrap_or_else(|| query.span())
            };

            typer.reference_types.push(ReferenceType {
                name: as_.clone(),
                span,
                columns: select
                    .columns
                    .iter()
                    .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
                    .collect(),
            });
        }
        qusql_parse::TableReference::Join {
            join,
            left,
            right,
            specification,
        } => {
            let (left_force_null, right_force_null) = match join {
                qusql_parse::JoinType::Left(_) => (force_null, true),
                qusql_parse::JoinType::Right(_) => (true, force_null),
                qusql_parse::JoinType::Inner(_)
                | qusql_parse::JoinType::Cross(_)
                | qusql_parse::JoinType::Normal(_) => (force_null, force_null),
                _ => {
                    issue_todo!(typer.issues, join);
                    (force_null, force_null)
                }
            };
            type_reference(typer, left, left_force_null);
            type_reference(typer, right, right_force_null);
            match &specification {
                Some(qusql_parse::JoinSpecification::On(e, _)) => {
                    let t = type_expression(typer, e, ExpressionFlags::default(), BaseType::Bool);
                    typer.ensure_base(e, &t, BaseType::Bool);
                }
                Some(s @ qusql_parse::JoinSpecification::Using(_, _)) => {
                    issue_todo!(typer.issues, s);
                }
                None => (),
            }
        }
        qusql_parse::TableReference::JsonTable { .. } => {
            issue_todo!(typer.issues, reference);
        }
        TableReference::Function {
            name,
            args,
            with_ordinality,
            as_,
            col_list,
            ..
        } => {
            match name {
                qusql_parse::TableFunctionName::Unnest(unnest_span) => {
                    // Each argument to UNNEST expands to one column.
                    // The column type is the element type of the array argument.
                    let mut columns: Vec<(Identifier<'a>, FullType<'a>)> = Vec::new();
                    for (idx, arg) in args.iter().enumerate() {
                        let arr_type =
                            type_expression(typer, arg, ExpressionFlags::default(), BaseType::Any);
                        let elem_type = if let crate::type_::Type::Array(inner) = arr_type.t {
                            FullType::new(*inner, false)
                        } else {
                            // If we can't determine it's an array, use Any/nullable
                            FullType::new(BaseType::Any, false)
                        };
                        // Use col_list alias if provided, otherwise generate "unnest1", "unnest2", ...
                        let col_name = if let Some(alias) = col_list.get(idx) {
                            alias.clone()
                        } else {
                            static UNNEST_NAMES: [&str; 8] = [
                                "unnest1", "unnest2", "unnest3", "unnest4", "unnest5", "unnest6",
                                "unnest7", "unnest8",
                            ];
                            let name = UNNEST_NAMES.get(idx).copied().unwrap_or("unnest");
                            Identifier::new(name, unnest_span.clone())
                        };
                        columns.push((col_name, elem_type));
                    }
                    // WITH ORDINALITY appends a bigint ordinality column
                    if with_ordinality.is_some() {
                        let ord_name = if let Some(alias) = col_list.get(args.len()) {
                            alias.clone()
                        } else {
                            Identifier::new("ordinality", unnest_span.clone())
                        };
                        columns.push((ord_name, FullType::new(BaseType::Integer, true)));
                    }
                    let span = if let Some(as_) = as_ {
                        as_.span.clone()
                    } else {
                        unnest_span.clone()
                    };
                    typer.reference_types.push(ReferenceType {
                        name: as_.clone(),
                        span,
                        columns,
                    });
                }
                qusql_parse::TableFunctionName::GenerateSeries(s)
                | qusql_parse::TableFunctionName::StringToTable(s) => {
                    issue_todo!(typer.issues, s);
                }
                qusql_parse::TableFunctionName::Other(n) => {
                    issue_todo!(typer.issues, n);
                }
            }
        }
    }

    let new_refs = core::mem::take(&mut typer.reference_types);
    // Inner scope refs shadow outer scope refs with the same alias (e.g. a CTE
    // referencing a table with the same name as an outer-scope CTE).
    for new_ref in &new_refs {
        if let Some(name) = &new_ref.name {
            given_refs.retain(|r| r.name.as_ref() != Some(name));
        }
    }
    typer.reference_types = given_refs;
    typer.reference_types.extend(new_refs);
}