#![allow(clippy::needless_borrow)]
pub mod expressions;
pub mod statements;
use banshee_syntax::{Parse, SyntaxKind, SyntaxNode, SyntaxToken};
use crate::config::FormatConfig;
use crate::printer::Printer;
use crate::rules;
#[must_use]
pub fn format(source: &str, config: &FormatConfig) -> String {
let parse = banshee_parser::parse(source);
format_parse(&parse, config)
}
pub fn format_parse(parse: &Parse, config: &FormatConfig) -> String {
let mut printer = Printer::new(config.clone());
printer.set_river_width(rules::standard_river_width());
let root = parse.syntax();
format_node(&root, &mut printer);
printer.finish()
}
fn format_node(node: &SyntaxNode, printer: &mut Printer) {
match node.kind() {
SyntaxKind::SOURCE_FILE => {
format_source_file(node, printer);
}
SyntaxKind::SELECT_STMT => {
statements::format_select(node, printer);
}
SyntaxKind::INSERT_STMT => {
statements::format_insert(node, printer);
}
SyntaxKind::UPDATE_STMT => {
statements::format_update(node, printer);
}
SyntaxKind::DELETE_STMT => {
statements::format_delete(node, printer);
}
SyntaxKind::CREATE_TABLE_STMT => {
statements::format_create_table(node, printer);
}
SyntaxKind::CREATE_INDEX_STMT => {
statements::format_create_index(node, printer);
}
SyntaxKind::ALTER_STMT => {
statements::format_alter(node, printer);
}
SyntaxKind::ERROR => {
format_error_node(node, printer);
}
_ => {
format_children(node, printer);
}
}
}
fn format_source_file(node: &SyntaxNode, printer: &mut Printer) {
let mut first_stmt = true;
for child in node.children() {
if child.kind() == SyntaxKind::WHITESPACE || child.kind() == SyntaxKind::NEWLINE {
continue;
}
if !first_stmt {
printer.ensure_blank_line();
}
format_node(&child, printer);
first_stmt = false;
}
printer.ensure_newline();
}
pub fn format_children(node: &SyntaxNode, printer: &mut Printer) {
for element in node.children_with_tokens() {
match element {
cstree::util::NodeOrToken::Node(child) => {
format_node(&child, printer);
}
cstree::util::NodeOrToken::Token(token) => {
format_token(&token, printer);
}
}
}
}
pub fn format_token(token: &SyntaxToken, printer: &mut Printer) {
let kind = token.kind();
let text = token.text();
match kind {
SyntaxKind::WHITESPACE => {
}
SyntaxKind::NEWLINE => {
}
SyntaxKind::LINE_COMMENT | SyntaxKind::BLOCK_COMMENT => {
printer.write(text);
}
k if k.is_keyword() => {
if rules::keywords::needs_space_before(k) {
printer.space();
}
printer.write_keyword(text);
if rules::keywords::needs_space_after(k) {
printer.space();
}
}
SyntaxKind::IDENT => {
printer.write_identifier(text);
}
SyntaxKind::QUOTED_IDENT => {
printer.write(text);
}
SyntaxKind::INTEGER
| SyntaxKind::FLOAT
| SyntaxKind::STRING
| SyntaxKind::DOLLAR_STRING
| SyntaxKind::BIT_STRING
| SyntaxKind::HEX_STRING
| SyntaxKind::PARAM => {
printer.write(text);
}
SyntaxKind::SEMICOLON => {
printer.write(";");
}
SyntaxKind::COMMA => {
printer.write_comma();
}
SyntaxKind::DOT => {
printer.write(".");
}
SyntaxKind::COLON => {
printer.write(":");
}
SyntaxKind::DOUBLE_COLON => {
printer.write("::");
}
SyntaxKind::L_PAREN => {
printer.write("(");
if printer.config().parentheses_spacing {
printer.space();
}
}
SyntaxKind::R_PAREN => {
if printer.config().parentheses_spacing {
printer.space();
}
printer.write(")");
}
SyntaxKind::L_BRACKET => {
printer.write("[");
}
SyntaxKind::R_BRACKET => {
printer.write("]");
}
SyntaxKind::L_BRACE => {
printer.write("{");
}
SyntaxKind::R_BRACE => {
printer.write("}");
}
SyntaxKind::EQ
| SyntaxKind::NE
| SyntaxKind::LT
| SyntaxKind::LE
| SyntaxKind::GT
| SyntaxKind::GE
| SyntaxKind::PLUS
| SyntaxKind::MINUS
| SyntaxKind::STAR
| SyntaxKind::SLASH
| SyntaxKind::PERCENT
| SyntaxKind::CARET
| SyntaxKind::PIPE_PIPE => {
printer.write_operator(text);
}
SyntaxKind::ARROW | SyntaxKind::ARROW_TEXT => {
printer.write_operator_compact(text);
}
SyntaxKind::HASH_ARROW | SyntaxKind::HASH_ARROW_TEXT => {
printer.write_operator_compact(text);
}
SyntaxKind::AT_GT
| SyntaxKind::LT_AT
| SyntaxKind::QUESTION
| SyntaxKind::QUESTION_PIPE
| SyntaxKind::QUESTION_AMP
| SyntaxKind::HASH_MINUS
| SyntaxKind::AT_QUESTION
| SyntaxKind::AT_AT => {
printer.write_operator(text);
}
_ => {
printer.write(text);
}
}
}
fn format_error_node(node: &SyntaxNode, printer: &mut Printer) {
for element in node.children_with_tokens() {
match element {
cstree::util::NodeOrToken::Node(child) => {
format_error_node(&child, printer);
}
cstree::util::NodeOrToken::Token(token) => {
printer.write_raw(token.text());
}
}
}
}
pub fn find_child(node: &SyntaxNode, kind: SyntaxKind) -> Option<SyntaxNode> {
node.children().find(|c| c.kind() == kind).cloned()
}
pub fn find_token(node: &SyntaxNode, kind: SyntaxKind) -> Option<SyntaxToken> {
for element in node.children_with_tokens() {
if let cstree::util::NodeOrToken::Token(token) = element
&& token.kind() == kind
{
return Some(token.clone());
}
}
None
}
pub fn children_of_kind(node: &SyntaxNode, kind: SyntaxKind) -> Vec<SyntaxNode> {
node.children()
.filter(|c| c.kind() == kind)
.cloned()
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_simple_select() {
let sql = "select id from users";
let formatted = format(sql, &FormatConfig::compact());
assert!(formatted.contains("SELECT"));
assert!(formatted.contains("FROM"));
}
#[test]
fn test_format_preserves_error_nodes() {
let sql = "SELECT ??? FROM";
let formatted = format(sql, &FormatConfig::compact());
assert!(formatted.contains("???") || formatted.contains("SELECT"));
}
}