use banshee_syntax::{Parse, ParseError, SyntaxKind};
use crate::config::FormatConfig;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FormatError {
InvalidInput(Vec<ParseError>),
FormatterBrokeQuery(Vec<ParseError>),
SemanticMismatch {
expected: String,
actual: String,
},
}
impl std::fmt::Display for FormatError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FormatError::InvalidInput(errors) => {
write!(f, "invalid input SQL: ")?;
for (i, e) in errors.iter().enumerate() {
if i > 0 {
write!(f, "; ")?;
}
write!(f, "{}", e)?;
}
Ok(())
}
FormatError::FormatterBrokeQuery(errors) => {
write!(f, "formatter produced invalid SQL: ")?;
for (i, e) in errors.iter().enumerate() {
if i > 0 {
write!(f, "; ")?;
}
write!(f, "{}", e)?;
}
Ok(())
}
FormatError::SemanticMismatch { expected, actual } => {
write!(
f,
"semantic mismatch: expected '{}', got '{}'",
expected, actual
)
}
}
}
}
impl std::error::Error for FormatError {}
#[derive(Debug, Clone)]
pub struct FormatResult {
pub output: String,
pub is_valid: bool,
pub errors: Vec<FormatError>,
}
impl FormatResult {
#[must_use]
pub fn ok(output: String) -> Self {
Self {
output,
is_valid: true,
errors: vec![],
}
}
#[must_use]
pub fn with_error(output: String, error: FormatError) -> Self {
Self {
output,
is_valid: false,
errors: vec![error],
}
}
pub fn into_result(self) -> Result<String, Vec<FormatError>> {
if self.is_valid {
Ok(self.output)
} else {
Err(self.errors)
}
}
}
pub fn validate_format(original: &str, formatted: &str) -> Result<(), FormatError> {
let orig_parse = banshee_parser::parse(original);
if !orig_parse.errors().is_empty() {
return Err(FormatError::InvalidInput(orig_parse.errors().to_vec()));
}
let fmt_parse = banshee_parser::parse(formatted);
if !fmt_parse.errors().is_empty() {
return Err(FormatError::FormatterBrokeQuery(
fmt_parse.errors().to_vec(),
));
}
if !semantically_equal(&orig_parse, &fmt_parse) {
return Err(FormatError::SemanticMismatch {
expected: normalize(&orig_parse),
actual: normalize(&fmt_parse),
});
}
Ok(())
}
#[must_use]
pub fn semantically_equal(a: &Parse, b: &Parse) -> bool {
normalize(a) == normalize(b)
}
#[must_use]
pub fn normalize(parse: &Parse) -> String {
let mut tokens = Vec::new();
for element in parse.syntax().descendants_with_tokens() {
if let cstree::util::NodeOrToken::Token(token) = element
&& !token.kind().is_trivia()
{
let text = if is_keyword_kind(token.kind()) {
token.text().to_uppercase()
} else {
token.text().to_string()
};
tokens.push(text);
}
}
tokens.join(" ")
}
fn is_keyword_kind(kind: SyntaxKind) -> bool {
let kind_str = format!("{:?}", kind);
kind_str.ends_with("_KW")
}
#[must_use]
pub fn format_validated(source: &str, config: &FormatConfig) -> FormatResult {
let formatted = crate::format::format(source, config);
match validate_format(source, &formatted) {
Ok(()) => FormatResult::ok(formatted),
Err(e) => FormatResult::with_error(formatted, e),
}
}
#[must_use]
pub fn format_pgformatter_validated(
source: &str,
config: &crate::pg_formatter::PgFormatterConfig,
) -> FormatResult {
let formatted = crate::pg_format::format(source, config);
match validate_format(source, &formatted) {
Ok(()) => FormatResult::ok(formatted),
Err(e) => FormatResult::with_error(formatted, e),
}
}
#[must_use]
pub fn count_tokens(source: &str) -> usize {
let parse = banshee_parser::parse(source);
parse
.syntax()
.descendants_with_tokens()
.filter(|element| {
if let cstree::util::NodeOrToken::Token(token) = element {
!token.kind().is_trivia()
} else {
false
}
})
.count()
}
pub fn check_idempotent(source: &str, config: &FormatConfig) -> Result<String, (String, String)> {
let once = crate::format::format(source, config);
let twice = crate::format::format(&once, config);
if once == twice {
Ok(once)
} else {
Err((once, twice))
}
}
#[derive(Debug, Clone)]
pub struct ValidationReport {
pub original: String,
pub formatted: String,
pub both_parse: bool,
pub semantically_equal: bool,
pub is_idempotent: bool,
pub original_tokens: usize,
pub formatted_tokens: usize,
pub errors: Vec<FormatError>,
}
impl ValidationReport {
#[must_use]
pub fn is_valid(&self) -> bool {
self.both_parse && self.semantically_equal && self.is_idempotent && self.errors.is_empty()
}
}
#[must_use]
pub fn validate_comprehensive(source: &str, config: &FormatConfig) -> ValidationReport {
let formatted = crate::format::format(source, config);
let mut errors = Vec::new();
let orig_parse = banshee_parser::parse(source);
let fmt_parse = banshee_parser::parse(&formatted);
let orig_ok = orig_parse.errors().is_empty();
let fmt_ok = fmt_parse.errors().is_empty();
let both_parse = orig_ok && fmt_ok;
if !orig_ok {
errors.push(FormatError::InvalidInput(orig_parse.errors().to_vec()));
}
if !fmt_ok {
errors.push(FormatError::FormatterBrokeQuery(
fmt_parse.errors().to_vec(),
));
}
let semantically_equal = if both_parse {
let orig_norm = normalize(&orig_parse);
let fmt_norm = normalize(&fmt_parse);
if orig_norm != fmt_norm {
errors.push(FormatError::SemanticMismatch {
expected: orig_norm.clone(),
actual: fmt_norm.clone(),
});
false
} else {
true
}
} else {
false
};
let is_idempotent = check_idempotent(source, config).is_ok();
let original_tokens = count_tokens(source);
let formatted_tokens = count_tokens(&formatted);
ValidationReport {
original: source.to_string(),
formatted,
both_parse,
semantically_equal,
is_idempotent,
original_tokens,
formatted_tokens,
errors,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_simple_select() {
let original = "SELECT id FROM users";
let formatted = crate::format_sqlstyle(original);
assert!(validate_format(original, &formatted).is_ok());
}
#[test]
fn test_validate_with_where() {
let original = "SELECT id, name FROM users WHERE active = true";
let formatted = crate::format_sqlstyle(original);
assert!(validate_format(original, &formatted).is_ok());
}
#[test]
fn test_validate_join() {
let original = "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id";
let formatted = crate::format_sqlstyle(original);
assert!(validate_format(original, &formatted).is_ok());
}
#[test]
fn test_validate_cte() {
let original = "WITH active AS (SELECT * FROM users WHERE active) SELECT * FROM active";
let formatted = crate::format_sqlstyle(original);
assert!(validate_format(original, &formatted).is_ok());
}
#[test]
fn test_validate_invalid_input() {
let original = "SELECT FROM";
let formatted = "SELECT FROM";
let result = validate_format(original, formatted);
assert!(matches!(result, Err(FormatError::InvalidInput(_))));
}
#[test]
fn test_normalize_ignores_whitespace() {
let sql1 = "SELECT id FROM users";
let sql2 = "SELECT\n id\n FROM\n users";
let parse1 = banshee_parser::parse(sql1);
let parse2 = banshee_parser::parse(sql2);
assert_eq!(normalize(&parse1), normalize(&parse2));
}
#[test]
fn test_normalize_ignores_case() {
let sql1 = "SELECT id FROM users";
let sql2 = "select id from users";
let parse1 = banshee_parser::parse(sql1);
let parse2 = banshee_parser::parse(sql2);
assert_eq!(normalize(&parse1), normalize(&parse2));
}
#[test]
fn test_count_tokens() {
let sql = "SELECT id, name FROM users";
let count = count_tokens(sql);
assert!((5..=7).contains(&count));
}
#[test]
fn test_idempotent() {
let sql = "SELECT id FROM users";
let config = FormatConfig::sqlstyle();
assert!(check_idempotent(sql, &config).is_ok());
}
#[test]
fn test_format_validated() {
let sql = "SELECT id, name FROM users WHERE active = true";
let result = format_validated(sql, &FormatConfig::sqlstyle());
assert!(result.is_valid);
assert!(result.errors.is_empty());
}
#[test]
fn test_comprehensive_validation() {
let sql = "SELECT id, name FROM users";
let report = validate_comprehensive(sql, &FormatConfig::sqlstyle());
assert!(report.is_valid());
assert!(report.both_parse);
assert!(report.semantically_equal);
assert!(report.is_idempotent);
}
#[test]
fn test_format_result_into_result() {
let result = FormatResult::ok("SELECT id FROM users".to_string());
assert!(result.into_result().is_ok());
let result = FormatResult::with_error(
"broken".to_string(),
FormatError::SemanticMismatch {
expected: "a".to_string(),
actual: "b".to_string(),
},
);
assert!(result.into_result().is_err());
}
}