use rigsql_core::{Segment, SegmentType, Span};
use crate::violation::{LintViolation, SourceEdit};
pub fn has_as_keyword(children: &[Segment]) -> bool {
children.iter().any(|child| {
if let Segment::Token(t) = child {
t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case("AS")
} else {
false
}
})
}
pub fn first_non_trivia(children: &[Segment]) -> Option<&Segment> {
children.iter().find(|c| !c.segment_type().is_trivia())
}
pub fn last_non_trivia(children: &[Segment]) -> Option<&Segment> {
children
.iter()
.rev()
.find(|c| !c.segment_type().is_trivia())
}
const NOT_ALIAS_KEYWORDS: &[&str] = &[
"ALTER",
"AND",
"BEGIN",
"BREAK",
"CATCH",
"CLOSE",
"COMMIT",
"CONTINUE",
"CREATE",
"CROSS",
"CURSOR",
"DEALLOCATE",
"DECLARE",
"DELETE",
"DROP",
"ELSE",
"END",
"EXCEPT",
"EXEC",
"EXECUTE",
"FETCH",
"FOR",
"FROM",
"FULL",
"GO",
"GOTO",
"GROUP",
"HAVING",
"IF",
"INNER",
"INSERT",
"INTERSECT",
"INTO",
"JOIN",
"LEFT",
"LIMIT",
"MERGE",
"NATURAL",
"NEXT",
"OFFSET",
"ON",
"OPEN",
"OR",
"ORDER",
"OUTPUT",
"OVER",
"PRINT",
"RAISERROR",
"RETURN",
"RETURNING",
"RIGHT",
"ROLLBACK",
"SELECT",
"SET",
"TABLE",
"THEN",
"THROW",
"TRUNCATE",
"TRY",
"UNION",
"UPDATE",
"VALUES",
"WHEN",
"WHERE",
"WHILE",
"WITH",
];
pub fn is_false_alias(children: &[Segment]) -> bool {
if let Some(Segment::Token(t)) = last_non_trivia(children) {
let upper = t.token.text.to_ascii_uppercase();
return NOT_ALIAS_KEYWORDS.binary_search(&upper.as_str()).is_ok();
}
false
}
pub fn insert_as_keyword_fix(children: &[Segment]) -> Vec<SourceEdit> {
last_non_trivia(children)
.map(|alias| vec![SourceEdit::insert(alias.span().start, "AS ")])
.unwrap_or_default()
}
pub fn capitalise(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
None => String::new(),
}
}
pub fn check_capitalisation(
rule_code: &'static str,
category: &str,
text: &str,
expected: &str,
policy_name: &str,
span: Span,
) -> Option<LintViolation> {
if text != expected {
let message = format!(
"{} must be {} case. Found '{}' instead of '{}'.",
category, policy_name, text, expected
);
let msg_key = format!("rules.{rule_code}.msg");
let params = vec![
("category".to_string(), category.to_string()),
("policy".to_string(), policy_name.to_string()),
("found".to_string(), text.to_string()),
("expected".to_string(), expected.to_string()),
];
Some(LintViolation::with_fix_and_msg_key(
rule_code,
message,
span,
vec![SourceEdit::replace(span, expected.to_string())],
msg_key,
params,
))
} else {
None
}
}
pub fn extract_alias_name(children: &[Segment]) -> Option<String> {
for child in children.iter().rev() {
let st = child.segment_type();
if st == SegmentType::Identifier || st == SegmentType::QuotedIdentifier {
if let Segment::Token(t) = child {
return Some(t.token.text.to_string());
}
}
if st.is_trivia() {
continue;
}
if st != SegmentType::Keyword {
break;
}
}
None
}
pub fn has_trailing_newline(segment: &Segment) -> bool {
for child in segment.children().iter().rev() {
let st = child.segment_type();
if st == SegmentType::Newline {
return true;
}
if st == SegmentType::Whitespace {
continue;
}
return false;
}
false
}
pub fn is_in_table_context(ctx: &crate::rule::RuleContext) -> bool {
ctx.parent.is_some_and(|p| {
let pt = p.segment_type();
pt == SegmentType::FromClause || pt == SegmentType::JoinClause
})
}
pub fn find_keyword_in_children<'a>(
children: &'a [Segment],
name: &str,
) -> Option<(usize, &'a Segment)> {
children.iter().enumerate().find(|(_, c)| {
if let Segment::Token(t) = c {
t.segment_type == SegmentType::Keyword && t.token.text.eq_ignore_ascii_case(name)
} else {
false
}
})
}
pub fn collect_matching_tokens<F>(segment: &Segment, filter: &F, out: &mut Vec<(String, Span)>)
where
F: Fn(&Segment) -> Option<(String, Span)>,
{
if let Some(pair) = filter(segment) {
out.push(pair);
}
for child in segment.children() {
collect_matching_tokens(child, filter, out);
}
}
pub fn determine_majority_case(tokens: &[(String, Span)]) -> &'static str {
let mut upper_count = 0u32;
let mut lower_count = 0u32;
for (text, _) in tokens {
let is_all_upper = text
.chars()
.all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase());
let is_all_lower = text
.chars()
.all(|c| !c.is_ascii_alphabetic() || c.is_ascii_lowercase());
if is_all_upper {
upper_count += 1;
} else if is_all_lower {
lower_count += 1;
}
}
if lower_count > upper_count {
"lower"
} else {
"upper"
}
}