squawk-ide 2.47.0

Linter for Postgres migrations & SQL
Documentation
use std::fmt;

use squawk_syntax::{
    SyntaxKind,
    ast::{self, AstNode},
};

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Type {
    Integer,
    Numeric,
    Text,
    Bit,
    Boolean,
    Unknown,
    Record,
    Array(Box<Type>),
    Other(String),
}

impl fmt::Display for Type {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Type::Integer => write!(f, "integer"),
            Type::Numeric => write!(f, "numeric"),
            Type::Text => write!(f, "text"),
            Type::Bit => write!(f, "bit"),
            Type::Boolean => write!(f, "boolean"),
            Type::Unknown => write!(f, "unknown"),
            Type::Record => write!(f, "record"),
            Type::Array(inner) => write!(f, "{inner}[]"),
            Type::Other(s) => write!(f, "{s}"),
        }
    }
}

pub(crate) fn infer_type_from_expr(expr: &ast::Expr) -> Option<Type> {
    match expr {
        ast::Expr::CastExpr(cast_expr) => infer_type_from_ty(&cast_expr.ty()?),
        ast::Expr::ArrayExpr(array_expr) => {
            let first_elem = array_expr.exprs().next()?;
            let elem_ty = infer_type_from_expr(&first_elem)?;
            Some(Type::Array(Box::new(elem_ty)))
        }
        ast::Expr::BinExpr(_bin_expr) => todo!(),
        ast::Expr::Literal(literal) => infer_type_from_literal(literal),
        ast::Expr::ParenExpr(paren) => paren.expr().and_then(|e| infer_type_from_expr(&e)),
        ast::Expr::TupleExpr(_) => Some(Type::Record),
        _ => None,
    }
}

pub(crate) fn infer_type_from_ty(ty: &ast::Type) -> Option<Type> {
    match ty {
        ast::Type::CharType(_) => Some(Type::Text),
        ast::Type::BitType(_) => Some(Type::Bit),
        ast::Type::PathType(path_type) => {
            let name = path_type.path()?.segment()?.name_ref()?;
            Some(Type::Other(name.syntax().text().to_string()))
        }
        _ => None,
    }
}

fn infer_type_from_literal(literal: &ast::Literal) -> Option<Type> {
    let token = literal.syntax().first_token()?;
    match token.kind() {
        SyntaxKind::INT_NUMBER => Some(Type::Integer),
        SyntaxKind::FLOAT_NUMBER => Some(Type::Numeric),
        SyntaxKind::STRING
        | SyntaxKind::DOLLAR_QUOTED_STRING
        | SyntaxKind::ESC_STRING
        | SyntaxKind::UNICODE_ESC_STRING => Some(Type::Text),
        SyntaxKind::BIT_STRING | SyntaxKind::BYTE_STRING => Some(Type::Bit),
        SyntaxKind::TRUE_KW | SyntaxKind::FALSE_KW => Some(Type::Boolean),
        SyntaxKind::NULL_KW => Some(Type::Unknown),
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use insta::assert_snapshot;

    fn infer(sql: &str) -> String {
        let parse = ast::SourceFile::parse(sql);
        for stmt in parse.tree().stmts() {
            match stmt {
                ast::Stmt::Select(select) => {
                    let select_clause = select.select_clause().expect("expected select clause");
                    let target_list = select_clause.target_list().expect("expected target list");

                    if let Some(target) = target_list.targets().next() {
                        let expr = target.expr().expect("expected expr");
                        let ty = infer_type_from_expr(&expr).expect("expected type");
                        return format!("{ty}");
                    }
                }
                _ => unreachable!("unexpected stmt type"),
            }
        }
        unreachable!("should always have at least one target")
    }

    #[test]
    fn integer_literal() {
        assert_snapshot!(infer("select 1"), @"integer");
    }

    #[test]
    fn float_literal() {
        assert_snapshot!(infer("select 1.5"), @"numeric");
    }

    #[test]
    fn string_literal() {
        assert_snapshot!(infer("select 'hello'"), @"text");
    }

    #[test]
    fn dollar_quoted_string() {
        assert_snapshot!(infer("select $$hello$$"), @"text");
    }

    #[test]
    fn escape_string() {
        assert_snapshot!(infer("select E'hello'"), @"text");
    }

    #[test]
    fn boolean_true() {
        assert_snapshot!(infer("select true"), @"boolean");
    }

    #[test]
    fn boolean_false() {
        assert_snapshot!(infer("select false"), @"boolean");
    }

    #[test]
    fn null_literal() {
        assert_snapshot!(infer("select null"), @"unknown");
    }

    #[test]
    fn cast_expr() {
        assert_snapshot!(infer("select 1::text"), @"text");
    }

    #[test]
    fn cast_expr_varchar() {
        assert_snapshot!(infer("select 1::varchar(255)"), @"text");
    }

    #[test]
    fn bit_string() {
        assert_snapshot!(infer("select b'100'"), @"bit");
    }

    #[test]
    fn bit_varying() {
        assert_snapshot!(infer("select b'100'::bit varying"), @"bit");
    }

    #[test]
    fn array() {
        assert_snapshot!(infer("select array['foo', 'bar']"), @"text[]");
    }

    #[test]
    fn record() {
        assert_snapshot!(infer("select (1, 2)"), @"record");
    }

    #[test]
    fn paren_expr() {
        assert_snapshot!(infer("select (1)"), @"integer");
    }

    #[test]
    fn nested_paren_expr() {
        assert_snapshot!(infer("select ((1.5))"), @"numeric");
    }
}