1use rigsql_core::{Segment, SegmentType, Span};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5pub fn has_as_keyword(children: &[Segment]) -> bool {
7 children.iter().any(|child| {
8 if let Segment::Token(t) = child {
9 t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("AS")
10 } else {
11 false
12 }
13 })
14}
15
16pub fn first_non_trivia(children: &[Segment]) -> Option<&Segment> {
18 children.iter().find(|c| !c.segment_type().is_trivia())
19}
20
21pub fn last_non_trivia(children: &[Segment]) -> Option<&Segment> {
23 children
24 .iter()
25 .rev()
26 .find(|c| !c.segment_type().is_trivia())
27}
28
29const NOT_ALIAS_KEYWORDS: &[&str] = &[
32 "ALTER",
33 "AND",
34 "BEGIN",
35 "BREAK",
36 "CATCH",
37 "CLOSE",
38 "COMMIT",
39 "CONTINUE",
40 "CREATE",
41 "CROSS",
42 "CURSOR",
43 "DEALLOCATE",
44 "DECLARE",
45 "DELETE",
46 "DROP",
47 "ELSE",
48 "END",
49 "EXCEPT",
50 "EXEC",
51 "EXECUTE",
52 "FETCH",
53 "FOR",
54 "FROM",
55 "FULL",
56 "GO",
57 "GOTO",
58 "GROUP",
59 "HAVING",
60 "IF",
61 "INNER",
62 "INSERT",
63 "INTERSECT",
64 "INTO",
65 "JOIN",
66 "LEFT",
67 "LIMIT",
68 "MERGE",
69 "NATURAL",
70 "NEXT",
71 "OFFSET",
72 "ON",
73 "OPEN",
74 "OR",
75 "ORDER",
76 "OUTPUT",
77 "OVER",
78 "PRINT",
79 "RAISERROR",
80 "RETURN",
81 "RETURNING",
82 "RIGHT",
83 "ROLLBACK",
84 "SELECT",
85 "SET",
86 "TABLE",
87 "THEN",
88 "THROW",
89 "TRUNCATE",
90 "TRY",
91 "UNION",
92 "UPDATE",
93 "VALUES",
94 "WHEN",
95 "WHERE",
96 "WHILE",
97 "WITH",
98];
99
100pub fn is_false_alias(children: &[Segment]) -> bool {
104 if let Some(Segment::Token(t)) = last_non_trivia(children) {
106 let upper = t.token.text.to_ascii_uppercase();
107 return NOT_ALIAS_KEYWORDS.binary_search(&upper.as_str()).is_ok();
108 }
109 false
110}
111
112pub fn insert_as_keyword_fix(children: &[Segment]) -> Vec<SourceEdit> {
115 last_non_trivia(children)
116 .map(|alias| vec![SourceEdit::insert(alias.span().start, "AS ")])
117 .unwrap_or_default()
118}
119
120pub fn check_capitalisation(
123 rule_code: &'static str,
124 category: &str,
125 text: &str,
126 expected: &str,
127 policy_name: &str,
128 span: Span,
129) -> Option<LintViolation> {
130 if text != expected {
131 Some(LintViolation::with_fix(
132 rule_code,
133 format!(
134 "{} must be {} case. Found '{}' instead of '{}'.",
135 category, policy_name, text, expected
136 ),
137 span,
138 vec![SourceEdit::replace(span, expected.to_string())],
139 ))
140 } else {
141 None
142 }
143}
144
145pub fn extract_alias_name(children: &[Segment]) -> Option<String> {
149 for child in children.iter().rev() {
150 let st = child.segment_type();
151 if st == SegmentType::Identifier || st == SegmentType::QuotedIdentifier {
152 if let Segment::Token(t) = child {
153 return Some(t.token.text.to_string());
154 }
155 }
156 if st.is_trivia() {
157 continue;
158 }
159 if st != SegmentType::Keyword {
160 break;
161 }
162 }
163 None
164}
165
166pub fn has_trailing_newline(segment: &Segment) -> bool {
169 for child in segment.children().iter().rev() {
170 let st = child.segment_type();
171 if st == SegmentType::Newline {
172 return true;
173 }
174 if st == SegmentType::Whitespace {
175 continue;
176 }
177 return false;
178 }
179 false
180}
181
182pub fn is_in_table_context(ctx: &crate::rule::RuleContext) -> bool {
184 ctx.parent.is_some_and(|p| {
185 let pt = p.segment_type();
186 pt == SegmentType::FromClause || pt == SegmentType::JoinClause
187 })
188}
189
190pub fn find_keyword_in_children<'a>(
192 children: &'a [Segment],
193 name: &str,
194) -> Option<(usize, &'a Segment)> {
195 children.iter().enumerate().find(|(_, c)| {
196 if let Segment::Token(t) = c {
197 t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case(name)
198 } else {
199 false
200 }
201 })
202}