1use rigsql_core::{Segment, SegmentType, Span};
2
3use crate::violation::{LintViolation, SourceEdit};
4
5pub fn has_as_keyword(children: &[Segment]) -> bool {
7 children.iter().any(|child| {
8 if let Segment::Token(t) = child {
9 t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("AS")
10 } else {
11 false
12 }
13 })
14}
15
16pub fn first_non_trivia(children: &[Segment]) -> Option<&Segment> {
18 children.iter().find(|c| !c.segment_type().is_trivia())
19}
20
21pub fn last_non_trivia(children: &[Segment]) -> Option<&Segment> {
23 children
24 .iter()
25 .rev()
26 .find(|c| !c.segment_type().is_trivia())
27}
28
29const NOT_ALIAS_KEYWORDS: &[&str] = &[
32 "ALTER",
33 "AND",
34 "BEGIN",
35 "BREAK",
36 "CATCH",
37 "CLOSE",
38 "COMMIT",
39 "CONTINUE",
40 "CREATE",
41 "CROSS",
42 "CURSOR",
43 "DEALLOCATE",
44 "DECLARE",
45 "DELETE",
46 "DROP",
47 "ELSE",
48 "END",
49 "EXCEPT",
50 "EXEC",
51 "EXECUTE",
52 "FETCH",
53 "FOR",
54 "FROM",
55 "FULL",
56 "GO",
57 "GOTO",
58 "GROUP",
59 "HAVING",
60 "IF",
61 "INNER",
62 "INSERT",
63 "INTERSECT",
64 "INTO",
65 "JOIN",
66 "LEFT",
67 "LIMIT",
68 "MERGE",
69 "NATURAL",
70 "NEXT",
71 "OFFSET",
72 "ON",
73 "OPEN",
74 "OR",
75 "ORDER",
76 "OUTPUT",
77 "OVER",
78 "PRINT",
79 "RAISERROR",
80 "RETURN",
81 "RETURNING",
82 "RIGHT",
83 "ROLLBACK",
84 "SELECT",
85 "SET",
86 "TABLE",
87 "THEN",
88 "THROW",
89 "TRUNCATE",
90 "TRY",
91 "UNION",
92 "UPDATE",
93 "VALUES",
94 "WHEN",
95 "WHERE",
96 "WHILE",
97 "WITH",
98];
99
100pub fn is_false_alias(children: &[Segment]) -> bool {
104 if let Some(Segment::Token(t)) = last_non_trivia(children) {
106 let upper = t.token.text.to_ascii_uppercase();
107 return NOT_ALIAS_KEYWORDS.binary_search(&upper.as_str()).is_ok();
108 }
109 false
110}
111
112pub fn insert_as_keyword_fix(children: &[Segment]) -> Vec<SourceEdit> {
115 last_non_trivia(children)
116 .map(|alias| vec![SourceEdit::insert(alias.span().start, "AS ")])
117 .unwrap_or_default()
118}
119
120pub fn capitalise(s: &str) -> String {
123 let mut chars = s.chars();
124 match chars.next() {
125 Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
126 None => String::new(),
127 }
128}
129
130pub fn check_capitalisation(
133 rule_code: &'static str,
134 category: &str,
135 text: &str,
136 expected: &str,
137 policy_name: &str,
138 span: Span,
139) -> Option<LintViolation> {
140 if text != expected {
141 let message = format!(
142 "{} must be {} case. Found '{}' instead of '{}'.",
143 category, policy_name, text, expected
144 );
145 let msg_key = format!("rules.{rule_code}.msg");
146 let params = vec![
147 ("category".to_string(), category.to_string()),
148 ("policy".to_string(), policy_name.to_string()),
149 ("found".to_string(), text.to_string()),
150 ("expected".to_string(), expected.to_string()),
151 ];
152 Some(LintViolation::with_fix_and_msg_key(
153 rule_code,
154 message,
155 span,
156 vec![SourceEdit::replace(span, expected.to_string())],
157 msg_key,
158 params,
159 ))
160 } else {
161 None
162 }
163}
164
165pub fn extract_alias_name(children: &[Segment]) -> Option<String> {
169 for child in children.iter().rev() {
170 let st = child.segment_type();
171 if st == SegmentType::Identifier || st == SegmentType::QuotedIdentifier {
172 if let Segment::Token(t) = child {
173 return Some(t.token.text.to_string());
174 }
175 }
176 if st.is_trivia() {
177 continue;
178 }
179 if st != SegmentType::Keyword {
180 break;
181 }
182 }
183 None
184}
185
186pub fn has_trailing_newline(segment: &Segment) -> bool {
189 for child in segment.children().iter().rev() {
190 let st = child.segment_type();
191 if st == SegmentType::Newline {
192 return true;
193 }
194 if st == SegmentType::Whitespace {
195 continue;
196 }
197 return false;
198 }
199 false
200}
201
202pub fn is_in_table_context(ctx: &crate::rule::RuleContext) -> bool {
204 ctx.parent.is_some_and(|p| {
205 let pt = p.segment_type();
206 pt == SegmentType::FromClause || pt == SegmentType::JoinClause
207 })
208}
209
210pub fn find_keyword_in_children<'a>(
212 children: &'a [Segment],
213 name: &str,
214) -> Option<(usize, &'a Segment)> {
215 children.iter().enumerate().find(|(_, c)| {
216 if let Segment::Token(t) = c {
217 t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case(name)
218 } else {
219 false
220 }
221 })
222}
223
224pub fn collect_matching_tokens<F>(segment: &Segment, filter: &F, out: &mut Vec<(String, Span)>)
227where
228 F: Fn(&Segment) -> Option<(String, Span)>,
229{
230 if let Some(pair) = filter(segment) {
231 out.push(pair);
232 }
233 for child in segment.children() {
234 collect_matching_tokens(child, filter, out);
235 }
236}
237
238pub fn determine_majority_case(tokens: &[(String, Span)]) -> &'static str {
242 let mut upper_count = 0u32;
243 let mut lower_count = 0u32;
244 for (text, _) in tokens {
245 let is_all_upper = text
246 .chars()
247 .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase());
248 let is_all_lower = text
249 .chars()
250 .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_lowercase());
251 if is_all_upper {
252 upper_count += 1;
253 } else if is_all_lower {
254 lower_count += 1;
255 }
256 }
258 if lower_count > upper_count {
259 "lower"
260 } else {
261 "upper"
262 }
263}