dbschema 0.1.1

Define database schema's as HCL files, and generate idempotent SQL migrations
Documentation
use super::{LintCheck, LintMessage, LintSeverity};
use crate::ir::Config;

pub struct SqlSyntax;

impl SqlSyntax {
    fn push_err(&self, msgs: &mut Vec<LintMessage>, ctx: &str, err: pg_query::Error) {
        msgs.push(LintMessage {
            check: self.name(),
            message: format!("invalid SQL in {}: {}", ctx, err),
            severity: LintSeverity::Error,
        });
    }

    fn check_stmt(&self, msgs: &mut Vec<LintMessage>, sql: &str, ctx: &str) {
        if let Err(err) = pg_query::parse(sql) {
            self.push_err(msgs, ctx, err);
        }
    }

    fn check_expr(&self, msgs: &mut Vec<LintMessage>, expr: &str, ctx: &str) {
        if let Err(err) = pg_query::parse(&format!("SELECT {}", expr)) {
            self.push_err(msgs, ctx, err);
        }
    }
}

impl LintCheck for SqlSyntax {
    fn name(&self) -> &'static str {
        "sql-syntax"
    }

    fn run(&self, cfg: &Config) -> Vec<LintMessage> {
        let mut msgs = Vec::new();

        for view in &cfg.views {
            self.check_stmt(&mut msgs, &view.sql, &format!("view '{}'", view.name));
        }
        for mview in &cfg.materialized {
            self.check_stmt(
                &mut msgs,
                &mview.sql,
                &format!("materialized view '{}'", mview.name),
            );
        }
        for policy in &cfg.policies {
            if let Some(using) = &policy.using {
                self.check_expr(&mut msgs, using, &format!("policy '{}' USING", policy.name));
            }
            if let Some(check) = &policy.check {
                self.check_expr(&mut msgs, check, &format!("policy '{}' CHECK", policy.name));
            }
        }
        for table in &cfg.tables {
            for chk in &table.checks {
                self.check_expr(
                    &mut msgs,
                    &chk.expression,
                    &format!("table '{}' CHECK", table.name),
                );
            }
        }
        for domain in &cfg.domains {
            if let Some(expr) = &domain.constraint {
                self.check_expr(
                    &mut msgs,
                    expr,
                    &format!("domain '{}' CONSTRAINT", domain.name),
                );
            }
            if let Some(expr) = &domain.check {
                self.check_expr(&mut msgs, expr, &format!("domain '{}' CHECK", domain.name));
            }
        }
        for trig in &cfg.triggers {
            if let Some(when) = &trig.when {
                self.check_expr(&mut msgs, when, &format!("trigger '{}' WHEN", trig.name));
            }
        }
        for func in &cfg.functions {
            if func.language.to_lowercase() == "sql" {
                self.check_stmt(&mut msgs, &func.body, &format!("function '{}'", func.name));
            }
        }
        for proc in &cfg.procedures {
            if proc.language.to_lowercase() == "sql" {
                self.check_stmt(&mut msgs, &proc.body, &format!("procedure '{}'", proc.name));
            }
        }
        for test in &cfg.tests {
            for stmt in test
                .setup
                .iter()
                .chain(&test.asserts)
                .chain(&test.assert_fail)
                .chain(&test.teardown)
            {
                self.check_stmt(&mut msgs, stmt, &format!("test '{}'", test.name));
            }
        }

        msgs
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ir::{Config, ViewSpec};
    use crate::lint::{LintSettings, run_with_checks};

    #[test]
    fn detects_invalid_sql() {
        let view = ViewSpec {
            name: "v".into(),
            alt_name: None,
            schema: None,
            replace: false,
            sql: "SELEC 1".into(),
            comment: None,
        };
        let cfg = Config {
            views: vec![view],
            ..Default::default()
        };
        let msgs = run_with_checks(&cfg, vec![Box::new(SqlSyntax)], &LintSettings::default());
        assert!(msgs.iter().any(|m| m.check == "sql-syntax"));
    }
}