use rigsql_core::{Segment, SegmentType, TokenKind};
use super::CapitalisationPolicy;
use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
use crate::utils::{check_capitalisation, determine_majority_case};
use crate::violation::LintViolation;
const BUILTIN_FUNCTIONS: &[&str] = &[
"ABS",
"ACOS",
"APP_NAME",
"ASCII",
"ASIN",
"ATAN",
"ATAN2",
"AVG",
"CAST",
"CEILING",
"CHAR",
"CHARINDEX",
"CHOOSE",
"COALESCE",
"CONCAT",
"CONCAT_WS",
"CONVERT",
"COS",
"COT",
"COUNT",
"COUNT_BIG",
"CUME_DIST",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"CURSOR_STATUS",
"DATALENGTH",
"DATEADD",
"DATEDIFF",
"DATEDIFF_BIG",
"DATEFROMPARTS",
"DATENAME",
"DATEPART",
"DATETIME2FROMPARTS",
"DATETIMEFROMPARTS",
"DAY",
"DB_ID",
"DB_NAME",
"DENSE_RANK",
"DIFFERENCE",
"EOMONTH",
"ERROR_LINE",
"ERROR_MESSAGE",
"ERROR_NUMBER",
"ERROR_PROCEDURE",
"ERROR_SEVERITY",
"ERROR_STATE",
"EXP",
"FIRST_VALUE",
"FLOOR",
"FORMAT",
"GETDATE",
"GETUTCDATE",
"GREATEST",
"GROUPING",
"GROUPING_ID",
"HAS_PERMS_BY_NAME",
"HOST_NAME",
"IDENTITY",
"IDENT_CURRENT",
"IFNULL",
"IIF",
"ISJSON",
"ISNULL",
"ISNUMERIC",
"JSON_ARRAY",
"JSON_MODIFY",
"JSON_OBJECT",
"JSON_QUERY",
"JSON_VALUE",
"LAG",
"LAST_VALUE",
"LEAD",
"LEAST",
"LEFT",
"LEN",
"LENGTH",
"LOG",
"LOG10",
"LOWER",
"LTRIM",
"MAX",
"MIN",
"MONTH",
"NCHAR",
"NEWID",
"NTILE",
"NULLIF",
"NVL",
"NVL2",
"OBJECT_ID",
"OBJECT_NAME",
"PARSENAME",
"PATINDEX",
"PERCENT_RANK",
"PI",
"POWER",
"QUOTENAME",
"RAND",
"RANK",
"REPLACE",
"REPLICATE",
"REVERSE",
"RIGHT",
"ROUND",
"ROW_NUMBER",
"RTRIM",
"SCHEMA_NAME",
"SCOPE_IDENTITY",
"SIGN",
"SIN",
"SOUNDEX",
"SPACE",
"SQRT",
"SQUARE",
"STR",
"STRING_AGG",
"STRING_SPLIT",
"STUFF",
"SUBSTRING",
"SUM",
"SUSER_SNAME",
"SWITCHOFFSET",
"SYSDATETIME",
"SYSUTCDATETIME",
"TAN",
"TODATETIMEOFFSET",
"TRANSLATE",
"TRIM",
"TRY_CAST",
"TRY_CONVERT",
"TRY_PARSE",
"TYPE_NAME",
"UNICODE",
"UPPER",
"USER_NAME",
"YEAR",
];
#[derive(Debug)]
pub struct RuleCP03 {
pub policy: CapitalisationPolicy,
}
impl Default for RuleCP03 {
fn default() -> Self {
Self {
policy: CapitalisationPolicy::Upper,
}
}
}
impl Rule for RuleCP03 {
fn code(&self) -> &'static str {
"CP03"
}
fn name(&self) -> &'static str {
"capitalisation.functions"
}
fn description(&self) -> &'static str {
"Function names must be consistently capitalised."
}
fn explanation(&self) -> &'static str {
"Function names like COUNT, SUM, COALESCE should be consistently capitalised. \
Whether upper or lower depends on your team's convention."
}
fn groups(&self) -> &[RuleGroup] {
&[RuleGroup::Capitalisation]
}
fn is_fixable(&self) -> bool {
true
}
fn crawl_type(&self) -> CrawlType {
if self.policy == CapitalisationPolicy::Consistent {
CrawlType::RootOnly
} else {
CrawlType::Segment(vec![SegmentType::FunctionCall])
}
}
fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
if let Some(policy) = settings.get("capitalisation_policy") {
self.policy = CapitalisationPolicy::from_config(policy);
}
}
fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
if self.policy == CapitalisationPolicy::Consistent {
return self.eval_consistent(ctx);
}
let children = ctx.segment.children();
if children.is_empty() {
return vec![];
}
let name_seg = Self::find_function_name(children);
let Some(Segment::Token(t)) = name_seg else {
return vec![];
};
if t.token.kind != TokenKind::Word {
return vec![];
}
let text = t.token.text.as_str();
let upper = text.to_ascii_uppercase();
if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_err() {
return vec![];
}
let (expected, policy_name) = match self.policy {
CapitalisationPolicy::Upper => (upper, "upper"),
CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
CapitalisationPolicy::Consistent => unreachable!(),
};
check_capitalisation(
self.code(),
"Function names",
text,
&expected,
policy_name,
t.token.span,
)
.into_iter()
.collect()
}
}
impl RuleCP03 {
fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
let mut tokens = Vec::new();
Self::collect_builtin_function_names(ctx.root, &mut tokens);
if tokens.is_empty() {
return vec![];
}
let majority = determine_majority_case(&tokens);
let mut violations = Vec::new();
for (text, span) in &tokens {
let expected = match majority {
"upper" => text.to_ascii_uppercase(),
_ => text.to_ascii_lowercase(),
};
if let Some(v) = check_capitalisation(
self.code(),
"Function names",
text,
&expected,
majority,
*span,
) {
violations.push(v);
}
}
violations
}
fn collect_builtin_function_names(
segment: &Segment,
out: &mut Vec<(String, rigsql_core::Span)>,
) {
if segment.segment_type() == SegmentType::FunctionCall {
if let Some(Segment::Token(t)) = Self::find_function_name(segment.children()) {
if t.token.kind == TokenKind::Word {
let upper = t.token.text.to_ascii_uppercase();
if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_ok() {
out.push((t.token.text.to_string(), t.token.span));
}
}
}
}
for child in segment.children() {
Self::collect_builtin_function_names(child, out);
}
}
fn find_function_name(children: &[Segment]) -> Option<&Segment> {
for child in children {
match child.segment_type() {
SegmentType::Identifier => return Some(child),
SegmentType::ColumnRef => {
let inner = child.children();
return inner
.iter()
.rev()
.find(|s| s.segment_type() == SegmentType::Identifier);
}
_ if child.segment_type().is_trivia() => continue,
_ => break,
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::lint_sql;
#[test]
fn test_cp03_flags_lowercase_function() {
let violations = lint_sql("SELECT count(*) FROM t", RuleCP03::default());
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "COUNT");
}
#[test]
fn test_cp03_flags_mixed_case() {
let violations = lint_sql("SELECT Count(*) FROM t", RuleCP03::default());
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "COUNT");
}
#[test]
fn test_cp03_accepts_all_upper() {
let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCP03::default());
assert_eq!(violations.len(), 0);
}
#[test]
fn test_cp03_lower_policy_flags_upper() {
let rule = RuleCP03 {
policy: CapitalisationPolicy::Lower,
};
let violations = lint_sql("SELECT COUNT(*) FROM t", rule);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "count");
}
#[test]
fn test_cp03_lower_policy_accepts_lower() {
let rule = RuleCP03 {
policy: CapitalisationPolicy::Lower,
};
let violations = lint_sql("SELECT count(*) FROM t", rule);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_cp03_capitalise_policy() {
let rule = RuleCP03 {
policy: CapitalisationPolicy::Capitalise,
};
let violations = lint_sql("SELECT count(*) FROM t", rule);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "Count");
}
#[test]
fn test_cp03_skips_user_defined_function() {
let violations = lint_sql(
"SELECT GetDropdownOptions('a', 'b') FROM t",
RuleCP03::default(),
);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_cp03_consistent_flags_minority() {
let rule = RuleCP03 {
policy: CapitalisationPolicy::Consistent,
};
let violations = lint_sql("SELECT COUNT(*), SUM(x), avg(y) FROM t", rule);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "AVG");
}
#[test]
fn test_cp03_consistent_all_same_no_violation() {
let rule = RuleCP03 {
policy: CapitalisationPolicy::Consistent,
};
let violations = lint_sql("SELECT COUNT(*), SUM(x) FROM t", rule);
assert_eq!(violations.len(), 0);
}
#[test]
fn test_cp03_consistent_majority_lower() {
let rule = RuleCP03 {
policy: CapitalisationPolicy::Consistent,
};
let violations = lint_sql("SELECT count(*), sum(x), AVG(y) FROM t", rule);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "avg");
}
#[test]
fn test_cp03_flags_replace_function() {
let violations = lint_sql("SELECT replace(col, 'a', 'b') FROM t", RuleCP03::default());
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].fixes[0].new_text, "REPLACE");
}
}