use annotate_snippets::{AnnotationKind, Level, Patch, Renderer, Snippet, renderer::DecorStyle};
use anyhow::Result;
use console::style;
use line_index::LineIndex;
use line_index::TextRange;
use log::info;
use serde::Serialize;
use squawk_linter::{Fix, Linter, Rule, Version};
use squawk_syntax::SourceFile;
use std::hash::DefaultHasher;
use std::hash::Hash;
use std::hash::Hasher;
use std::io;
use std::process::ExitCode;
use crate::cmd::Input;
use crate::{
Reporter,
file::{sql_from_path, sql_from_stdin},
};
fn check_sql(
sql: &str,
path: &str,
excluded_rules: &[Rule],
pg_version: Option<Version>,
assume_in_transaction: bool,
) -> CheckReport {
let mut linter = Linter::without_rules(excluded_rules);
if let Some(pg_version) = pg_version {
linter.settings.pg_version = pg_version;
}
linter.settings.assume_in_transaction = assume_in_transaction;
let parse = SourceFile::parse(sql);
let parse_errors = parse.errors();
let line_index = LineIndex::new(sql);
if !parse_errors.is_empty() {
let violations = parse_errors
.into_iter()
.map(|e| {
let range_start = e.range().start();
let line_col = line_index.line_col(range_start);
let range_end = e.range().end();
let line_end = line_index.line_col(range_end);
ReportViolation {
file: path.to_string(),
line: line_col.line as usize,
line_end: line_end.line as usize,
column: line_col.col as usize,
column_end: line_end.col as usize,
level: ViolationLevel::Error,
help: None,
range: e.range(),
message: e.message().to_string(),
rule_name: "syntax-error".to_string(),
fix: None,
}
})
.collect();
return CheckReport {
filename: path.into(),
sql: sql.into(),
violations,
};
}
let errors = linter.lint(&parse, sql);
let violations = errors
.into_iter()
.map(|e| {
let range_start = e.text_range.start();
let line_col = line_index.line_col(range_start);
let range_end = e.text_range.end();
let line_end = line_index.line_col(range_end);
ReportViolation {
file: path.to_string(),
line: line_col.line as usize,
line_end: line_end.line as usize,
column: line_col.col as usize,
column_end: line_end.col as usize,
range: e.text_range,
help: e.help,
level: ViolationLevel::Warning,
message: e.message,
rule_name: e.code.to_string(),
fix: e.fix,
}
})
.collect();
CheckReport {
filename: path.into(),
sql: sql.into(),
violations,
}
}
fn render_lint_error<W: std::io::Write>(
f: &mut W,
err: &ReportViolation,
filename: &str,
sql: &str,
) -> Result<()> {
let renderer = Renderer::styled().decor_style(DecorStyle::Unicode);
let level = match err.level {
ViolationLevel::Warning => Level::WARNING,
ViolationLevel::Error => Level::ERROR,
};
let snippet = Snippet::source(sql)
.path(filename)
.fold(true)
.annotation(AnnotationKind::Primary.span(err.range.into()));
let mut group = level
.primary_title(&err.message)
.id(&err.rule_name)
.element(snippet);
if let Some(help) = &err.help {
group = group.element(Level::HELP.message(help));
}
if let Some(fix) = &err.fix {
let mut patch_snippet = Snippet::source(sql).path(filename).fold(true);
for edit in &fix.edits {
let start: usize = edit.text_range.start().into();
let end: usize = edit.text_range.end().into();
let replacement = edit.text.as_deref().unwrap_or("");
patch_snippet = patch_snippet.patch(Patch::new(start..end, replacement));
}
group = group.element(patch_snippet);
}
writeln!(f, "{}", renderer.render(&[group]))?;
Ok(())
}
pub(crate) struct LintArgs {
pub(crate) input: Input,
pub(crate) excluded_rules: Vec<Rule>,
pub(crate) pg_version: Option<Version>,
pub(crate) assume_in_transaction: bool,
pub(crate) reporter: Reporter,
pub(crate) github_annotations: bool,
}
pub fn lint_files(args: &LintArgs) -> Result<Vec<CheckReport>> {
let mut violations = vec![];
match &args.input {
Input::Stdin(stdin) => {
info!("reading content from stdin");
let sql = sql_from_stdin()?;
// ignore stdin if it's empty.
if sql.trim().is_empty() {
info!("ignoring empty stdin");
} else {
let path = stdin.path.clone().unwrap_or_else(|| "stdin".into());
let content = check_sql(
&sql,
&path,
&args.excluded_rules,
args.pg_version,
args.assume_in_transaction,
);
violations.push(content);
}
}
Input::Paths(path_bufs) => {
for path in path_bufs {
info!("checking file path: {}", path.display());
let sql = sql_from_path(path)?;
let content = check_sql(
&sql,
path.to_str().unwrap(),
&args.excluded_rules,
args.pg_version,
args.assume_in_transaction,
);
violations.push(content);
}
}
}
Ok(violations)
}
pub fn lint_and_report<W: io::Write>(f: &mut W, args: LintArgs) -> Result<ExitCode> {
let violations = lint_files(&args)?;
let ok = violations.iter().map(|x| x.violations.len()).sum::<usize>() == 0;
print_violations(f, violations, &args.reporter, args.github_annotations)?;
Ok(if ok {
ExitCode::SUCCESS
} else {
ExitCode::FAILURE
})
}
struct Summary {
total_violations: usize,
files_checked: usize,
files_with_violations: usize,
}
impl Summary {
fn from(violations: &[CheckReport]) -> Summary {
let total_violations: usize = violations.iter().map(|x| x.violations.len()).sum();
let files_checked = violations.len();
let files_with_violations = violations
.iter()
.filter(|x| !x.violations.is_empty())
.count();
Summary {
total_violations,
files_checked,
files_with_violations,
}
}
}
fn print_summary<W: io::Write>(f: &mut W, summary: &Summary) -> Result<()> {
if summary.total_violations == 0 {
writeln!(
f,
"\nFound 0 issues in {files_checked} {files} 🎉",
files_checked = summary.files_checked,
files = if summary.files_checked == 1 {
"file"
} else {
"files"
}
)?;
} else {
writeln!(
f,
"\nFind detailed examples and solutions for each rule at {}",
style("https://squawkhq.com/docs/rules").underlined()
)?;
writeln!(
f,
"Found {total_violations} issue{plural} in {files_with_violations} file{files_plural} (checked {files_checked} {files_checked_plural})",
total_violations = summary.total_violations,
plural = if summary.total_violations == 1 {
""
} else {
"s"
},
files_with_violations = summary.files_with_violations,
files_plural = if summary.files_with_violations == 1 {
""
} else {
"s"
},
files_checked = summary.files_checked,
files_checked_plural = if summary.files_checked == 1 {
"source file"
} else {
"source files"
}
)?;
}
Ok(())
}
#[derive(Debug, Serialize)]
pub enum ViolationLevel {
Warning,
Error,
}
impl std::fmt::Display for ViolationLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let val = match self {
Self::Warning => "warning",
Self::Error => "error",
};
write!(f, "{val}")
}
}
// TODO: don't use this for json dumps
#[derive(Debug, Serialize)]
pub struct ReportViolation {
pub file: String,
pub line: usize,
pub column: usize,
#[serde(skip_serializing)]
pub range: TextRange,
pub level: ViolationLevel,
pub message: String,
pub help: Option<String>,
pub rule_name: String,
pub column_end: usize,
pub line_end: usize,
#[serde(skip_serializing)]
pub fix: Option<Fix>,
}
fn fmt_gcc<W: io::Write>(f: &mut W, reports: &[CheckReport]) -> Result<()> {
for report in reports {
for violation in &report.violations {
writeln!(
f,
"{}:{}:{}: {}: {} {}",
violation.file,
violation.line,
violation.column,
violation.level,
violation.rule_name,
violation.message,
)?;
}
}
Ok(())
}
pub fn fmt_tty_violation<W: io::Write>(
f: &mut W,
violation: &ReportViolation,
filename: &str,
sql: &str,
) -> Result<()> {
render_lint_error(f, violation, filename, sql)?;
Ok(())
}
pub fn fmt_tty<W: io::Write>(f: &mut W, reports: &[CheckReport]) -> Result<()> {
let summary = Summary::from(reports);
for report in reports {
for violation in &report.violations {
fmt_tty_violation(f, violation, &report.filename, &report.sql)?;
}
}
print_summary(f, &summary)?;
Ok(())
}
fn fmt_json<W: io::Write>(f: &mut W, reports: Vec<CheckReport>) -> Result<()> {
let violations = reports
.into_iter()
.flat_map(|x| x.violations)
.collect::<Vec<_>>();
let json_str = serde_json::to_string(&violations)?;
writeln!(f, "{json_str}")?;
Ok(())
}
pub fn fmt_github_annotations<W: io::Write>(f: &mut W, reports: &[CheckReport]) -> Result<()> {
for report in reports {
for violation in &report.violations {
let level = match violation.level {
ViolationLevel::Warning => "warning",
ViolationLevel::Error => "error",
};
writeln!(
f,
"::{level} file={file},line={line},col={col},endLine={line_end},endColumn={col_end},title={title}::{message}",
file = violation.file,
line = violation.line + 1,
line_end = violation.line_end + 1,
col = violation.column,
col_end = violation.column_end,
title = violation.rule_name,
message = violation.message,
)?;
}
}
Ok(())
}
#[derive(Serialize)]
struct GitLabLines {
begin: usize,
end: usize,
}
#[derive(Serialize)]
struct GitLabLocation {
path: String,
lines: GitLabLines,
}
#[derive(Serialize)]
struct GitLabIssue {
description: String,
severity: String,
fingerprint: String,
location: GitLabLocation,
check_name: String,
}
impl From<&ViolationLevel> for String {
fn from(level: &ViolationLevel) -> Self {
match level {
ViolationLevel::Warning => "minor".to_string(),
ViolationLevel::Error => "major".to_string(),
}
}
}
fn make_fingerprint(v: &ReportViolation) -> String {
let key = format!(
"{}:{}-{}:{}:{}:{}",
v.file, v.line, v.line_end, v.rule_name, v.message, v.level
);
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn to_gitlab_issue(v: &ReportViolation) -> GitLabIssue {
let mut desc = v.message.clone();
if let Some(help) = &v.help {
if !help.trim().is_empty() {
desc.push_str(" Suggestion: ");
desc.push_str(help.trim());
}
}
GitLabIssue {
description: desc,
severity: String::from(&v.level),
fingerprint: make_fingerprint(v),
location: GitLabLocation {
path: v.file.clone(),
lines: GitLabLines {
begin: v.line,
end: if v.line_end >= v.line {
v.line_end
} else {
v.line
},
},
},
check_name: v.rule_name.clone(),
}
}
fn fmt_gitlab<W: io::Write>(f: &mut W, reports: Vec<CheckReport>) -> Result<()> {
let issues: Vec<GitLabIssue> = reports
.into_iter()
.flat_map(|r| r.violations.into_iter())
.map(|v| to_gitlab_issue(&v))
.collect();
let json = serde_json::to_string(&issues)?;
writeln!(f, "{json}")?;
Ok(())
}
#[derive(Debug)]
pub struct CheckReport {
pub filename: String,
pub sql: String,
pub violations: Vec<ReportViolation>,
}
pub fn print_violations<W: io::Write>(
writer: &mut W,
reports: Vec<CheckReport>,
reporter: &Reporter,
github_annotations: bool,
) -> Result<()> {
if github_annotations {
fmt_github_annotations(writer, &reports)?;
}
match reporter {
Reporter::Gcc => fmt_gcc(writer, &reports),
Reporter::Json => fmt_json(writer, reports),
Reporter::Tty => fmt_tty(writer, &reports),
Reporter::Gitlab => fmt_gitlab(writer, reports),
}
}
#[cfg(test)]
mod test_check_files {
use super::check_sql;
use crate::reporter::fmt_json;
use insta::assert_snapshot;
use serde_json::Value;
#[test]
fn check_files_invalid_syntax() {
let sql = r"
select \;
";
let mut buff = Vec::new();
let res = check_sql(sql, "test.sql", &[], None, false);
fmt_json(&mut buff, vec![res]).unwrap();
let val: Value = serde_json::from_slice(&buff).unwrap();
assert_snapshot!(val);
}
#[test]
fn skip_lint_on_syntax_error() {
let error_sql = "ALTER TABLE foo ALTER CONSTRAINT bar RENAME TO quux;";
let mut buff = vec![];
let res = check_sql(error_sql, "test.sql", &[], None, false);
fmt_json(&mut buff, vec![res]).unwrap();
assert_snapshot!(String::from_utf8_lossy(&buff), @r#"[{"file":"test.sql","line":0,"column":36,"level":"Error","message":"missing comma","help":null,"rule_name":"syntax-error","column_end":36,"line_end":0}]"#);
}
}
#[cfg(test)]
mod test_reporter {
use super::check_sql;
use crate::reporter::{Reporter, print_violations};
use console::strip_ansi_codes;
use insta::{assert_debug_snapshot, assert_snapshot};
#[test]
fn display_violations_gcc() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer NOT NULL;
ALTER TABLE "core_foo" ADD COLUMN "bar" integer NOT NULL;
SELECT 1;
"#;
let filename = "main.sql";
let mut buff = Vec::new();
let res = print_violations(
&mut buff,
vec![check_sql(sql, filename, &[], None, false)],
&Reporter::Gcc,
false,
);
assert!(res.is_ok());
assert_snapshot!(String::from_utf8_lossy(&buff), @r"
main.sql:1:3: warning: require-timeout-settings Missing `set lock_timeout` before potentially slow operations
main.sql:1:3: warning: require-timeout-settings Missing `set statement_timeout` before potentially slow operations
main.sql:1:29: warning: adding-required-field Adding a new column that is `NOT NULL` and has no default value to an existing table effectively makes it required.
main.sql:1:29: warning: prefer-robust-stmts Missing `IF NOT EXISTS`, the migration can't be rerun if it fails part way through.
main.sql:1:46: warning: prefer-bigint-over-int Using 32-bit integer fields can result in hitting the max `int` limit.
main.sql:2:23: warning: adding-required-field Adding a new column that is `NOT NULL` and has no default value to an existing table effectively makes it required.
main.sql:2:23: warning: prefer-robust-stmts Missing `IF NOT EXISTS`, the migration can't be rerun if it fails part way through.
main.sql:2:40: warning: prefer-bigint-over-int Using 32-bit integer fields can result in hitting the max `int` limit.
");
}
#[test]
fn display_violations_tty_and_github_annotations() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer NOT NULL;
-- multi line example
ALTER TABLE "core_foo"
ADD COLUMN "bar"
integer NOT NULL;
SELECT 1;
"#;
let filename = "main.sql";
let mut buff = Vec::new();
let res = print_violations(
&mut buff,
vec![check_sql(sql, filename, &[], None, false)],
&Reporter::Tty,
true,
);
assert!(res.is_ok());
// remove the color codes so tests behave in CI as they do locally
assert_snapshot!(strip_ansi_codes(&String::from_utf8_lossy(&buff)));
}
#[test]
fn display_violations_tty() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer NOT NULL;
ALTER TABLE "core_foo" ADD COLUMN "bar" integer NOT NULL;
SELECT 1;
"#;
let filename = "main.sql";
let mut buff = Vec::new();
let res = print_violations(
&mut buff,
vec![check_sql(sql, filename, &[], None, false)],
&Reporter::Tty,
false,
);
assert!(res.is_ok());
// remove the color codes so tests behave in CI as they do locally
assert_snapshot!(strip_ansi_codes(&String::from_utf8_lossy(&buff)));
}
#[test]
fn display_no_violations_tty() {
let mut buff = Vec::new();
let sql = "select 1;";
let res = print_violations(
&mut buff,
vec![check_sql(sql, "main.sql", &[], None, false)],
&Reporter::Tty,
false,
);
assert!(res.is_ok());
// remove the color codes so tests behave in CI as they do locally
assert_snapshot!(strip_ansi_codes(&String::from_utf8_lossy(&buff)));
}
#[test]
fn display_violations_json() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer NOT NULL;
ALTER TABLE "core_foo" ADD COLUMN "bar" integer NOT NULL;
SELECT 1;
"#;
let filename = "main.sql";
let mut buff = Vec::new();
let res = print_violations(
&mut buff,
vec![check_sql(sql, filename, &[], None, false)],
&Reporter::Json,
false,
);
assert!(res.is_ok());
assert_snapshot!(String::from_utf8_lossy(&buff), @r#"[{"file":"main.sql","line":1,"column":3,"level":"Warning","message":"Missing `set lock_timeout` before potentially slow operations","help":"Configure a `lock_timeout` before this statement.","rule_name":"require-timeout-settings","column_end":62,"line_end":1},{"file":"main.sql","line":1,"column":3,"level":"Warning","message":"Missing `set statement_timeout` before potentially slow operations","help":"Configure a `statement_timeout` before this statement","rule_name":"require-timeout-settings","column_end":62,"line_end":1},{"file":"main.sql","line":1,"column":29,"level":"Warning","message":"Adding a new column that is `NOT NULL` and has no default value to an existing table effectively makes it required.","help":"Make the field nullable or add a non-VOLATILE DEFAULT","rule_name":"adding-required-field","column_end":62,"line_end":1},{"file":"main.sql","line":1,"column":29,"level":"Warning","message":"Missing `IF NOT EXISTS`, the migration can't be rerun if it fails part way through.","help":null,"rule_name":"prefer-robust-stmts","column_end":62,"line_end":1},{"file":"main.sql","line":1,"column":46,"level":"Warning","message":"Using 32-bit integer fields can result in hitting the max `int` limit.","help":"Use 64-bit integer values instead to prevent hitting this limit.","rule_name":"prefer-bigint-over-int","column_end":53,"line_end":1},{"file":"main.sql","line":2,"column":23,"level":"Warning","message":"Adding a new column that is `NOT NULL` and has no default value to an existing table effectively makes it required.","help":"Make the field nullable or add a non-VOLATILE DEFAULT","rule_name":"adding-required-field","column_end":56,"line_end":2},{"file":"main.sql","line":2,"column":23,"level":"Warning","message":"Missing `IF NOT EXISTS`, the migration can't be rerun if it fails part way through.","help":null,"rule_name":"prefer-robust-stmts","column_end":56,"line_end":2},{"file":"main.sql","line":2,"column":40,"level":"Warning","message":"Using 32-bit integer fields can result in hitting the max `int` limit.","help":"Use 64-bit integer values instead to prevent hitting this limit.","rule_name":"prefer-bigint-over-int","column_end":47,"line_end":2}]"#);
}
#[test]
fn display_violations_gitlab() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer NOT NULL;
ALTER TABLE "core_foo" ADD COLUMN "bar" integer NOT NULL;
SELECT 1;
"#;
let filename = "main.sql";
let mut buff = Vec::new();
let res = print_violations(
&mut buff,
vec![check_sql(sql, filename, &[], None, false)],
&Reporter::Gitlab,
false,
);
assert!(res.is_ok());
assert_snapshot!(String::from_utf8_lossy(&buff), @r#"[{"description":"Missing `set lock_timeout` before potentially slow operations Suggestion: Configure a `lock_timeout` before this statement.","severity":"minor","fingerprint":"106496a54f0e1a20","location":{"path":"main.sql","lines":{"begin":1,"end":1}},"check_name":"require-timeout-settings"},{"description":"Missing `set statement_timeout` before potentially slow operations Suggestion: Configure a `statement_timeout` before this statement","severity":"minor","fingerprint":"2904632bf27251d2","location":{"path":"main.sql","lines":{"begin":1,"end":1}},"check_name":"require-timeout-settings"},{"description":"Adding a new column that is `NOT NULL` and has no default value to an existing table effectively makes it required. Suggestion: Make the field nullable or add a non-VOLATILE DEFAULT","severity":"minor","fingerprint":"87fbb54d93cdb8c9","location":{"path":"main.sql","lines":{"begin":1,"end":1}},"check_name":"adding-required-field"},{"description":"Missing `IF NOT EXISTS`, the migration can't be rerun if it fails part way through.","severity":"minor","fingerprint":"21df0ee3817ad84","location":{"path":"main.sql","lines":{"begin":1,"end":1}},"check_name":"prefer-robust-stmts"},{"description":"Using 32-bit integer fields can result in hitting the max `int` limit. Suggestion: Use 64-bit integer values instead to prevent hitting this limit.","severity":"minor","fingerprint":"3d0e81dc13bc8757","location":{"path":"main.sql","lines":{"begin":1,"end":1}},"check_name":"prefer-bigint-over-int"},{"description":"Adding a new column that is `NOT NULL` and has no default value to an existing table effectively makes it required. Suggestion: Make the field nullable or add a non-VOLATILE DEFAULT","severity":"minor","fingerprint":"4bdd655ad8e102ad","location":{"path":"main.sql","lines":{"begin":2,"end":2}},"check_name":"adding-required-field"},{"description":"Missing `IF NOT EXISTS`, the migration can't be rerun if it fails part way through.","severity":"minor","fingerprint":"1b2e8c81e717c442","location":{"path":"main.sql","lines":{"begin":2,"end":2}},"check_name":"prefer-robust-stmts"},{"description":"Using 32-bit integer fields can result in hitting the max `int` limit. Suggestion: Use 64-bit integer values instead to prevent hitting this limit.","severity":"minor","fingerprint":"2bed2a431803b811","location":{"path":"main.sql","lines":{"begin":2,"end":2}},"check_name":"prefer-bigint-over-int"}]"#);
}
#[test]
fn span_offsets() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer NOT NULL;
ALTER TABLE "core_foo" ADD COLUMN "bar" integer NOT NULL;
SELECT 1;
"#;
let filename = "main.sql";
assert_debug_snapshot!(check_sql(sql, filename, &[], None, false));
}
}