Skip to main content

rigsql_rules/
utils.rs

1use rigsql_core::{Segment, SegmentType, Span};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5/// Check if an AliasExpression's children contain an explicit AS keyword.
6pub fn has_as_keyword(children: &[Segment]) -> bool {
7    children.iter().any(|child| {
8        if let Segment::Token(t) = child {
9            t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("AS")
10        } else {
11            false
12        }
13    })
14}
15
16/// Return the first non-trivia child segment.
17pub fn first_non_trivia(children: &[Segment]) -> Option<&Segment> {
18    children.iter().find(|c| !c.segment_type().is_trivia())
19}
20
21/// Return the last non-trivia child segment.
22pub fn last_non_trivia(children: &[Segment]) -> Option<&Segment> {
23    children
24        .iter()
25        .rev()
26        .find(|c| !c.segment_type().is_trivia())
27}
28
29/// Keywords that should NOT be treated as alias names.
30/// Sorted alphabetically for binary_search.
31const NOT_ALIAS_KEYWORDS: &[&str] = &[
32    "ALTER",
33    "AND",
34    "BEGIN",
35    "BREAK",
36    "CATCH",
37    "CLOSE",
38    "COMMIT",
39    "CONTINUE",
40    "CREATE",
41    "CROSS",
42    "CURSOR",
43    "DEALLOCATE",
44    "DECLARE",
45    "DELETE",
46    "DROP",
47    "ELSE",
48    "END",
49    "EXCEPT",
50    "EXEC",
51    "EXECUTE",
52    "FETCH",
53    "FOR",
54    "FROM",
55    "FULL",
56    "GO",
57    "GOTO",
58    "GROUP",
59    "HAVING",
60    "IF",
61    "INNER",
62    "INSERT",
63    "INTERSECT",
64    "INTO",
65    "JOIN",
66    "LEFT",
67    "LIMIT",
68    "MERGE",
69    "NATURAL",
70    "NEXT",
71    "OFFSET",
72    "ON",
73    "OPEN",
74    "OR",
75    "ORDER",
76    "OUTPUT",
77    "OVER",
78    "PRINT",
79    "RAISERROR",
80    "RETURN",
81    "RETURNING",
82    "RIGHT",
83    "ROLLBACK",
84    "SELECT",
85    "SET",
86    "TABLE",
87    "THEN",
88    "THROW",
89    "TRUNCATE",
90    "TRY",
91    "UNION",
92    "UPDATE",
93    "VALUES",
94    "WHEN",
95    "WHERE",
96    "WHILE",
97    "WITH",
98];
99
100/// Check if the "alias name" in an AliasExpression is actually a misidentified
101/// SQL keyword (e.g. OVER in window functions). Returns true if the alias
102/// looks like a false positive.
103pub fn is_false_alias(children: &[Segment]) -> bool {
104    // The alias name is the last non-trivia child
105    if let Some(Segment::Token(t)) = last_non_trivia(children) {
106        let upper = t.token.text.to_ascii_uppercase();
107        return NOT_ALIAS_KEYWORDS.binary_search(&upper.as_str()).is_ok();
108    }
109    false
110}
111
112/// Generate a fix that inserts "AS " before the last non-trivia child (the alias name).
113/// Used by AL01 and AL02.
114pub fn insert_as_keyword_fix(children: &[Segment]) -> Vec<SourceEdit> {
115    last_non_trivia(children)
116        .map(|alias| vec![SourceEdit::insert(alias.span().start, "AS ")])
117        .unwrap_or_default()
118}
119
120/// Check capitalisation of a token and return a violation if it doesn't match.
121/// Shared by CP01, CP04, CP05 to avoid duplicating violation creation.
122pub fn check_capitalisation(
123    rule_code: &'static str,
124    category: &str,
125    text: &str,
126    expected: &str,
127    policy_name: &str,
128    span: Span,
129) -> Option<LintViolation> {
130    if text != expected {
131        let message = format!(
132            "{} must be {} case. Found '{}' instead of '{}'.",
133            category, policy_name, text, expected
134        );
135        let msg_key = format!("rules.{rule_code}.msg");
136        let params = vec![
137            ("category".to_string(), category.to_string()),
138            ("policy".to_string(), policy_name.to_string()),
139            ("found".to_string(), text.to_string()),
140            ("expected".to_string(), expected.to_string()),
141        ];
142        Some(LintViolation::with_fix_and_msg_key(
143            rule_code,
144            message,
145            span,
146            vec![SourceEdit::replace(span, expected.to_string())],
147            msg_key,
148            params,
149        ))
150    } else {
151        None
152    }
153}
154
155/// Extract the alias name from an AliasExpression.
156/// The alias name is the last Identifier or QuotedIdentifier before any
157/// non-trivia, non-keyword segment (scanning from the end).
158pub fn extract_alias_name(children: &[Segment]) -> Option<String> {
159    for child in children.iter().rev() {
160        let st = child.segment_type();
161        if st == SegmentType::Identifier || st == SegmentType::QuotedIdentifier {
162            if let Segment::Token(t) = child {
163                return Some(t.token.text.to_string());
164            }
165        }
166        if st.is_trivia() {
167            continue;
168        }
169        if st != SegmentType::Keyword {
170            break;
171        }
172    }
173    None
174}
175
176/// Check if a segment ends with a Newline (possibly preceded by Whitespace).
177/// Used by layout rules (LT07, LT14) to detect newlines absorbed into clause bodies.
178pub fn has_trailing_newline(segment: &Segment) -> bool {
179    for child in segment.children().iter().rev() {
180        let st = child.segment_type();
181        if st == SegmentType::Newline {
182            return true;
183        }
184        if st == SegmentType::Whitespace {
185            continue;
186        }
187        return false;
188    }
189    false
190}
191
192/// Check if the current rule context is a table alias (parent is FROM or JOIN clause).
193pub fn is_in_table_context(ctx: &crate::rule::RuleContext) -> bool {
194    ctx.parent.is_some_and(|p| {
195        let pt = p.segment_type();
196        pt == SegmentType::FromClause || pt == SegmentType::JoinClause
197    })
198}
199
200/// Find a keyword by case-insensitive name in children. Returns (index, segment).
201pub fn find_keyword_in_children<'a>(
202    children: &'a [Segment],
203    name: &str,
204) -> Option<(usize, &'a Segment)> {
205    children.iter().enumerate().find(|(_, c)| {
206        if let Segment::Token(t) = c {
207            t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case(name)
208        } else {
209            false
210        }
211    })
212}