Skip to main content

rigsql_rules/
utils.rs

1use rigsql_core::{Segment, SegmentType, Span};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5/// Check if an AliasExpression's children contain an explicit AS keyword.
6pub fn has_as_keyword(children: &[Segment]) -> bool {
7    children.iter().any(|child| {
8        if let Segment::Token(t) = child {
9            t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("AS")
10        } else {
11            false
12        }
13    })
14}
15
16/// Return the first non-trivia child segment.
17pub fn first_non_trivia(children: &[Segment]) -> Option<&Segment> {
18    children.iter().find(|c| !c.segment_type().is_trivia())
19}
20
21/// Return the last non-trivia child segment.
22pub fn last_non_trivia(children: &[Segment]) -> Option<&Segment> {
23    children
24        .iter()
25        .rev()
26        .find(|c| !c.segment_type().is_trivia())
27}
28
29/// Keywords that should NOT be treated as alias names.
30/// Sorted alphabetically for binary_search.
31const NOT_ALIAS_KEYWORDS: &[&str] = &[
32    "ALTER",
33    "AND",
34    "BEGIN",
35    "BREAK",
36    "CATCH",
37    "CLOSE",
38    "COMMIT",
39    "CONTINUE",
40    "CREATE",
41    "CROSS",
42    "CURSOR",
43    "DEALLOCATE",
44    "DECLARE",
45    "DELETE",
46    "DROP",
47    "ELSE",
48    "END",
49    "EXCEPT",
50    "EXEC",
51    "EXECUTE",
52    "FETCH",
53    "FOR",
54    "FROM",
55    "FULL",
56    "GO",
57    "GOTO",
58    "GROUP",
59    "HAVING",
60    "IF",
61    "INNER",
62    "INSERT",
63    "INTERSECT",
64    "INTO",
65    "JOIN",
66    "LEFT",
67    "LIMIT",
68    "MERGE",
69    "NATURAL",
70    "NEXT",
71    "OFFSET",
72    "ON",
73    "OPEN",
74    "OR",
75    "ORDER",
76    "OUTPUT",
77    "OVER",
78    "PRINT",
79    "RAISERROR",
80    "RETURN",
81    "RETURNING",
82    "RIGHT",
83    "ROLLBACK",
84    "SELECT",
85    "SET",
86    "TABLE",
87    "THEN",
88    "THROW",
89    "TRUNCATE",
90    "TRY",
91    "UNION",
92    "UPDATE",
93    "VALUES",
94    "WHEN",
95    "WHERE",
96    "WHILE",
97    "WITH",
98];
99
100/// Check if the "alias name" in an AliasExpression is actually a misidentified
101/// SQL keyword (e.g. OVER in window functions). Returns true if the alias
102/// looks like a false positive.
103pub fn is_false_alias(children: &[Segment]) -> bool {
104    // The alias name is the last non-trivia child
105    if let Some(Segment::Token(t)) = last_non_trivia(children) {
106        let upper = t.token.text.to_ascii_uppercase();
107        return NOT_ALIAS_KEYWORDS.binary_search(&upper.as_str()).is_ok();
108    }
109    false
110}
111
112/// Generate a fix that inserts "AS " before the last non-trivia child (the alias name).
113/// Used by AL01 and AL02.
114pub fn insert_as_keyword_fix(children: &[Segment]) -> Vec<SourceEdit> {
115    last_non_trivia(children)
116        .map(|alias| vec![SourceEdit::insert(alias.span().start, "AS ")])
117        .unwrap_or_default()
118}
119
120/// Capitalise the first letter and lowercase the rest.
121/// Used by CP01 and CP03 for the `Capitalise` policy.
122pub fn capitalise(s: &str) -> String {
123    let mut chars = s.chars();
124    match chars.next() {
125        Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
126        None => String::new(),
127    }
128}
129
130/// Check capitalisation of a token and return a violation if it doesn't match.
131/// Shared by CP01, CP04, CP05 to avoid duplicating violation creation.
132pub fn check_capitalisation(
133    rule_code: &'static str,
134    category: &str,
135    text: &str,
136    expected: &str,
137    policy_name: &str,
138    span: Span,
139) -> Option<LintViolation> {
140    if text != expected {
141        let message = format!(
142            "{} must be {} case. Found '{}' instead of '{}'.",
143            category, policy_name, text, expected
144        );
145        let msg_key = format!("rules.{rule_code}.msg");
146        let params = vec![
147            ("category".to_string(), category.to_string()),
148            ("policy".to_string(), policy_name.to_string()),
149            ("found".to_string(), text.to_string()),
150            ("expected".to_string(), expected.to_string()),
151        ];
152        Some(LintViolation::with_fix_and_msg_key(
153            rule_code,
154            message,
155            span,
156            vec![SourceEdit::replace(span, expected.to_string())],
157            msg_key,
158            params,
159        ))
160    } else {
161        None
162    }
163}
164
165/// Extract the alias name from an AliasExpression.
166/// The alias name is the last Identifier or QuotedIdentifier before any
167/// non-trivia, non-keyword segment (scanning from the end).
168pub fn extract_alias_name(children: &[Segment]) -> Option<String> {
169    for child in children.iter().rev() {
170        let st = child.segment_type();
171        if st == SegmentType::Identifier || st == SegmentType::QuotedIdentifier {
172            if let Segment::Token(t) = child {
173                return Some(t.token.text.to_string());
174            }
175        }
176        if st.is_trivia() {
177            continue;
178        }
179        if st != SegmentType::Keyword {
180            break;
181        }
182    }
183    None
184}
185
186/// Check if a segment ends with a Newline (possibly preceded by Whitespace).
187/// Used by layout rules (LT07, LT14) to detect newlines absorbed into clause bodies.
188pub fn has_trailing_newline(segment: &Segment) -> bool {
189    for child in segment.children().iter().rev() {
190        let st = child.segment_type();
191        if st == SegmentType::Newline {
192            return true;
193        }
194        if st == SegmentType::Whitespace {
195            continue;
196        }
197        return false;
198    }
199    false
200}
201
202/// Check if the current rule context is a table alias (parent is FROM or JOIN clause).
203pub fn is_in_table_context(ctx: &crate::rule::RuleContext) -> bool {
204    ctx.parent.is_some_and(|p| {
205        let pt = p.segment_type();
206        pt == SegmentType::FromClause || pt == SegmentType::JoinClause
207    })
208}
209
210/// Find a keyword by case-insensitive name in children. Returns (index, segment).
211pub fn find_keyword_in_children<'a>(
212    children: &'a [Segment],
213    name: &str,
214) -> Option<(usize, &'a Segment)> {
215    children.iter().enumerate().find(|(_, c)| {
216        if let Segment::Token(t) = c {
217            t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case(name)
218        } else {
219            false
220        }
221    })
222}
223
224/// Collect all leaf tokens from a CST that match a filter predicate.
225/// Used by CP rules in `consistent` mode to gather all tokens of a category.
226pub fn collect_matching_tokens<F>(segment: &Segment, filter: &F, out: &mut Vec<(String, Span)>)
227where
228    F: Fn(&Segment) -> Option<(String, Span)>,
229{
230    if let Some(pair) = filter(segment) {
231        out.push(pair);
232    }
233    for child in segment.children() {
234        collect_matching_tokens(child, filter, out);
235    }
236}
237
238/// Determine the majority case from a list of token texts.
239/// Returns `"upper"` or `"lower"`. Mixed-case tokens are skipped (always violations).
240/// On tie, defaults to `"upper"`.
241pub fn determine_majority_case(tokens: &[(String, Span)]) -> &'static str {
242    let mut upper_count = 0u32;
243    let mut lower_count = 0u32;
244    for (text, _) in tokens {
245        let is_all_upper = text
246            .chars()
247            .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase());
248        let is_all_lower = text
249            .chars()
250            .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_lowercase());
251        if is_all_upper {
252            upper_count += 1;
253        } else if is_all_lower {
254            lower_count += 1;
255        }
256        // mixed-case: skip (they'll be flagged regardless)
257    }
258    if lower_count > upper_count {
259        "lower"
260    } else {
261        "upper"
262    }
263}