Skip to main content

rigsql_rules/
utils.rs

1use rigsql_core::{Segment, SegmentType, Span};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5/// Check if an AliasExpression's children contain an explicit AS keyword.
6pub fn has_as_keyword(children: &[Segment]) -> bool {
7    children.iter().any(|child| {
8        if let Segment::Token(t) = child {
9            t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("AS")
10        } else {
11            false
12        }
13    })
14}
15
16/// Return the first non-trivia child segment.
17pub fn first_non_trivia(children: &[Segment]) -> Option<&Segment> {
18    children.iter().find(|c| !c.segment_type().is_trivia())
19}
20
21/// Return the last non-trivia child segment.
22pub fn last_non_trivia(children: &[Segment]) -> Option<&Segment> {
23    children
24        .iter()
25        .rev()
26        .find(|c| !c.segment_type().is_trivia())
27}
28
29/// Keywords that should NOT be treated as alias names.
30/// Sorted alphabetically for binary_search.
31const NOT_ALIAS_KEYWORDS: &[&str] = &[
32    "ALTER",
33    "AND",
34    "BEGIN",
35    "BREAK",
36    "CATCH",
37    "CLOSE",
38    "COMMIT",
39    "CONTINUE",
40    "CREATE",
41    "CROSS",
42    "CURSOR",
43    "DEALLOCATE",
44    "DECLARE",
45    "DELETE",
46    "DROP",
47    "ELSE",
48    "END",
49    "EXCEPT",
50    "EXEC",
51    "EXECUTE",
52    "FETCH",
53    "FOR",
54    "FROM",
55    "FULL",
56    "GO",
57    "GOTO",
58    "GROUP",
59    "HAVING",
60    "IF",
61    "INNER",
62    "INSERT",
63    "INTERSECT",
64    "INTO",
65    "JOIN",
66    "LEFT",
67    "LIMIT",
68    "MERGE",
69    "NATURAL",
70    "NEXT",
71    "OFFSET",
72    "ON",
73    "OPEN",
74    "OR",
75    "ORDER",
76    "OUTPUT",
77    "OVER",
78    "PRINT",
79    "RAISERROR",
80    "RETURN",
81    "RETURNING",
82    "RIGHT",
83    "ROLLBACK",
84    "SELECT",
85    "SET",
86    "TABLE",
87    "THEN",
88    "THROW",
89    "TRUNCATE",
90    "TRY",
91    "UNION",
92    "UPDATE",
93    "VALUES",
94    "WHEN",
95    "WHERE",
96    "WHILE",
97    "WITH",
98];
99
100/// Check if the "alias name" in an AliasExpression is actually a misidentified
101/// SQL keyword (e.g. OVER in window functions). Returns true if the alias
102/// looks like a false positive.
103pub fn is_false_alias(children: &[Segment]) -> bool {
104    // The alias name is the last non-trivia child
105    if let Some(Segment::Token(t)) = last_non_trivia(children) {
106        let upper = t.token.text.to_ascii_uppercase();
107        return NOT_ALIAS_KEYWORDS.binary_search(&upper.as_str()).is_ok();
108    }
109    false
110}
111
112/// Generate a fix that inserts "AS " before the last non-trivia child (the alias name).
113/// Used by AL01 and AL02.
114pub fn insert_as_keyword_fix(children: &[Segment]) -> Vec<SourceEdit> {
115    last_non_trivia(children)
116        .map(|alias| vec![SourceEdit::insert(alias.span().start, "AS ")])
117        .unwrap_or_default()
118}
119
120/// Check capitalisation of a token and return a violation if it doesn't match.
121/// Shared by CP01, CP04, CP05 to avoid duplicating violation creation.
122pub fn check_capitalisation(
123    rule_code: &'static str,
124    category: &str,
125    text: &str,
126    expected: &str,
127    policy_name: &str,
128    span: Span,
129) -> Option<LintViolation> {
130    if text != expected {
131        Some(LintViolation::with_fix(
132            rule_code,
133            format!(
134                "{} must be {} case. Found '{}' instead of '{}'.",
135                category, policy_name, text, expected
136            ),
137            span,
138            vec![SourceEdit::replace(span, expected.to_string())],
139        ))
140    } else {
141        None
142    }
143}
144
145/// Extract the alias name from an AliasExpression.
146/// The alias name is the last Identifier or QuotedIdentifier before any
147/// non-trivia, non-keyword segment (scanning from the end).
148pub fn extract_alias_name(children: &[Segment]) -> Option<String> {
149    for child in children.iter().rev() {
150        let st = child.segment_type();
151        if st == SegmentType::Identifier || st == SegmentType::QuotedIdentifier {
152            if let Segment::Token(t) = child {
153                return Some(t.token.text.to_string());
154            }
155        }
156        if st.is_trivia() {
157            continue;
158        }
159        if st != SegmentType::Keyword {
160            break;
161        }
162    }
163    None
164}
165
166/// Check if a segment ends with a Newline (possibly preceded by Whitespace).
167/// Used by layout rules (LT07, LT14) to detect newlines absorbed into clause bodies.
168pub fn has_trailing_newline(segment: &Segment) -> bool {
169    for child in segment.children().iter().rev() {
170        let st = child.segment_type();
171        if st == SegmentType::Newline {
172            return true;
173        }
174        if st == SegmentType::Whitespace {
175            continue;
176        }
177        return false;
178    }
179    false
180}
181
182/// Check if the current rule context is a table alias (parent is FROM or JOIN clause).
183pub fn is_in_table_context(ctx: &crate::rule::RuleContext) -> bool {
184    ctx.parent.is_some_and(|p| {
185        let pt = p.segment_type();
186        pt == SegmentType::FromClause || pt == SegmentType::JoinClause
187    })
188}
189
190/// Find a keyword by case-insensitive name in children. Returns (index, segment).
191pub fn find_keyword_in_children<'a>(
192    children: &'a [Segment],
193    name: &str,
194) -> Option<(usize, &'a Segment)> {
195    children.iter().enumerate().find(|(_, c)| {
196        if let Segment::Token(t) = c {
197            t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case(name)
198        } else {
199            false
200        }
201    })
202}