use rigsql_core::{Segment, SegmentType};
use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
use crate::utils::first_non_trivia;
use crate::violation::{LintViolation, SourceEdit};
#[derive(Debug, Default)]
pub struct RuleCV02;
impl Rule for RuleCV02 {
fn code(&self) -> &'static str {
"CV02"
}
fn name(&self) -> &'static str {
"convention.coalesce"
}
fn description(&self) -> &'static str {
"Use COALESCE instead of IFNULL or NVL."
}
fn explanation(&self) -> &'static str {
"COALESCE is the ANSI SQL standard function for handling NULL values. \
IFNULL (MySQL) and NVL (Oracle) are database-specific alternatives. \
Using COALESCE improves portability and consistency."
}
fn groups(&self) -> &[RuleGroup] {
&[RuleGroup::Convention]
}
fn is_fixable(&self) -> bool {
true
}
fn crawl_type(&self) -> CrawlType {
CrawlType::Segment(vec![SegmentType::FunctionCall])
}
fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
let children = ctx.segment.children();
let func_name = first_non_trivia(children);
if let Some(Segment::Token(t)) = func_name {
let name = t.token.text.as_str();
if name.eq_ignore_ascii_case("IFNULL") || name.eq_ignore_ascii_case("NVL") {
return vec![LintViolation::with_fix_and_msg_key(
self.code(),
format!("Use COALESCE instead of '{}'.", name),
t.token.span,
vec![SourceEdit::replace(t.token.span, "COALESCE")],
"rules.CV02.msg",
vec![("name".to_string(), name.to_string())],
)];
}
}
vec![]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::lint_sql;
#[test]
fn test_cv02_flags_ifnull() {
let violations = lint_sql("SELECT IFNULL(a, 0) FROM t", RuleCV02);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "COALESCE");
}
#[test]
fn test_cv02_flags_nvl() {
let violations = lint_sql("SELECT NVL(a, 0) FROM t", RuleCV02);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "COALESCE");
}
#[test]
fn test_cv02_accepts_coalesce() {
let violations = lint_sql("SELECT COALESCE(a, 0) FROM t", RuleCV02);
assert_eq!(violations.len(), 0);
}
}