use banshee_syntax::{Parse, SyntaxKind, SyntaxNode, SyntaxToken};
use crate::pg_formatter::PgFormatterConfig;
pub(crate) mod expressions;
pub(crate) mod statements;
pub(crate) use expressions::*;
pub(crate) use statements::*;
#[must_use]
pub fn format(source: &str, config: &PgFormatterConfig) -> String {
let parse = banshee_parser::parse(source);
format_parse(&parse, config)
}
pub fn format_parse(parse: &Parse, config: &PgFormatterConfig) -> String {
let mut printer = PgPrinter::new(config.clone());
let root = parse.syntax();
format_node(&root, &mut printer);
printer.finish()
}
#[derive(Debug)]
pub struct PgPrinter {
config: PgFormatterConfig,
buffer: String,
line: usize,
column: usize,
indent_level: usize,
at_line_start: bool,
}
impl PgPrinter {
pub fn new(config: PgFormatterConfig) -> Self {
Self {
config,
buffer: String::new(),
line: 0,
column: 0,
indent_level: 0,
at_line_start: true,
}
}
pub fn config(&self) -> &PgFormatterConfig {
&self.config
}
pub fn finish(mut self) -> String {
if self.config.no_extra_line {
while self.buffer.ends_with('\n') {
self.buffer.pop();
}
} else {
if !self.buffer.ends_with('\n') {
self.buffer.push('\n');
}
}
self.buffer
}
fn indent_str(&self) -> String {
if self.config.use_tabs {
"\t".to_string()
} else {
" ".repeat(self.config.spaces)
}
}
fn write_indent(&mut self) {
let indent = self.indent_str().repeat(self.indent_level);
self.buffer.push_str(&indent);
self.column += indent.len();
}
pub fn write(&mut self, text: &str) {
if text.is_empty() {
return;
}
if self.at_line_start && !text.starts_with('\n') {
self.write_indent();
self.at_line_start = false;
}
for ch in text.chars() {
if ch == '\n' {
self.buffer.push(ch);
self.line += 1;
self.column = 0;
self.at_line_start = true;
} else {
if self.at_line_start {
self.write_indent();
self.at_line_start = false;
}
self.buffer.push(ch);
self.column += 1;
}
}
}
pub fn write_raw(&mut self, text: &str) {
for ch in text.chars() {
if ch == '\n' {
self.buffer.push(ch);
self.line += 1;
self.column = 0;
self.at_line_start = true;
} else {
self.buffer.push(ch);
self.column += 1;
self.at_line_start = false;
}
}
}
pub fn newline(&mut self) {
self.buffer.push('\n');
self.line += 1;
self.column = 0;
self.at_line_start = true;
}
pub fn space(&mut self) {
if self.at_line_start {
self.write_indent();
self.at_line_start = false;
}
self.buffer.push(' ');
self.column += 1;
}
pub fn indent(&mut self) {
self.indent_level += 1;
}
pub fn dedent(&mut self) {
self.indent_level = self.indent_level.saturating_sub(1);
}
pub fn ensure_newline(&mut self) {
if !self.at_line_start {
self.newline();
}
}
pub fn ensure_blank_line(&mut self) {
if !self.at_line_start {
self.newline();
}
self.newline();
}
pub fn write_keyword(&mut self, keyword: &str) {
let transformed = self.config.keyword_case.apply(keyword);
self.write(&transformed);
}
pub fn write_function(&mut self, name: &str) {
let transformed = self.config.function_case.apply(name);
self.write(&transformed);
}
pub fn write_type(&mut self, name: &str) {
let transformed = self.config.type_case.apply(name);
self.write(&transformed);
}
pub fn write_identifier(&mut self, ident: &str) {
self.write(ident);
}
pub fn write_comma(&mut self) {
self.write(",");
if self.config.comma_break {
self.newline();
}
}
pub fn write_operator(&mut self, op: &str) {
self.space();
self.write(op);
self.space();
}
pub fn write_operator_compact(&mut self, op: &str) {
self.write(op);
}
pub fn should_remove_comments(&self) -> bool {
self.config.no_comment
}
pub fn no_space_function(&self) -> bool {
self.config.no_space_function
}
pub fn comma_start(&self) -> bool {
self.config.comma_start
}
}
pub(crate) fn format_node(node: &SyntaxNode, printer: &mut PgPrinter) {
match node.kind() {
SyntaxKind::SOURCE_FILE => format_source_file(node, printer),
SyntaxKind::SELECT_STMT => format_select(node, printer),
SyntaxKind::INSERT_STMT => format_insert(node, printer),
SyntaxKind::UPDATE_STMT => format_update(node, printer),
SyntaxKind::DELETE_STMT => format_delete(node, printer),
SyntaxKind::CREATE_TABLE_STMT => format_create_table(node, printer),
SyntaxKind::CREATE_INDEX_STMT => format_create_index(node, printer),
SyntaxKind::ALTER_STMT => format_alter(node, printer),
SyntaxKind::ERROR => format_error_node(node, printer),
_ => format_children(node, printer),
}
}
pub(crate) fn format_source_file(node: &SyntaxNode, printer: &mut PgPrinter) {
let mut first_stmt = true;
for child in node.children() {
if child.kind() == SyntaxKind::WHITESPACE || child.kind() == SyntaxKind::NEWLINE {
continue;
}
if !first_stmt && !printer.config.no_grouping {
printer.ensure_blank_line();
} else if !first_stmt {
printer.ensure_newline();
}
format_node(child, printer);
first_stmt = false;
}
printer.ensure_newline();
}
pub(crate) fn format_children(node: &SyntaxNode, printer: &mut PgPrinter) {
for element in node.children_with_tokens() {
match element {
cstree::util::NodeOrToken::Node(child) => {
format_node(child, printer);
}
cstree::util::NodeOrToken::Token(token) => {
format_token(token, printer);
}
}
}
}
pub(crate) fn format_token(token: &SyntaxToken, printer: &mut PgPrinter) {
let kind = token.kind();
let text = token.text();
match kind {
SyntaxKind::WHITESPACE | SyntaxKind::NEWLINE => {
}
SyntaxKind::LINE_COMMENT | SyntaxKind::BLOCK_COMMENT => {
if !printer.should_remove_comments() {
printer.write(text);
}
}
k if k.is_keyword() => {
if needs_space_before(k) {
printer.space();
}
printer.write_keyword(text);
if needs_space_after(k) {
printer.space();
}
}
SyntaxKind::IDENT => {
printer.write_identifier(text);
}
SyntaxKind::QUOTED_IDENT => {
printer.write(text);
}
SyntaxKind::INTEGER
| SyntaxKind::FLOAT
| SyntaxKind::STRING
| SyntaxKind::DOLLAR_STRING
| SyntaxKind::BIT_STRING
| SyntaxKind::HEX_STRING
| SyntaxKind::PARAM => {
printer.write(text);
}
SyntaxKind::SEMICOLON => {
printer.write(";");
}
SyntaxKind::COMMA => {
printer.write_comma();
}
SyntaxKind::DOT => {
printer.write(".");
}
SyntaxKind::COLON => {
printer.write(":");
}
SyntaxKind::DOUBLE_COLON => {
printer.write("::");
}
SyntaxKind::L_PAREN => {
printer.write("(");
}
SyntaxKind::R_PAREN => {
printer.write(")");
}
SyntaxKind::L_BRACKET => {
printer.write("[");
}
SyntaxKind::R_BRACKET => {
printer.write("]");
}
SyntaxKind::EQ
| SyntaxKind::NE
| SyntaxKind::LT
| SyntaxKind::LE
| SyntaxKind::GT
| SyntaxKind::GE
| SyntaxKind::PLUS
| SyntaxKind::MINUS
| SyntaxKind::STAR
| SyntaxKind::SLASH
| SyntaxKind::PERCENT
| SyntaxKind::CARET
| SyntaxKind::PIPE_PIPE => {
printer.write_operator(text);
}
SyntaxKind::ARROW | SyntaxKind::ARROW_TEXT => {
printer.write_operator_compact(text);
}
SyntaxKind::HASH_ARROW | SyntaxKind::HASH_ARROW_TEXT => {
printer.write_operator_compact(text);
}
SyntaxKind::AT_GT
| SyntaxKind::LT_AT
| SyntaxKind::QUESTION
| SyntaxKind::QUESTION_PIPE
| SyntaxKind::QUESTION_AMP
| SyntaxKind::HASH_MINUS
| SyntaxKind::AT_QUESTION
| SyntaxKind::AT_AT => {
printer.write_operator(text);
}
_ => {
printer.write(text);
}
}
}
pub(crate) fn format_error_node(node: &SyntaxNode, printer: &mut PgPrinter) {
for element in node.children_with_tokens() {
match element {
cstree::util::NodeOrToken::Node(child) => {
format_error_node(child, printer);
}
cstree::util::NodeOrToken::Token(token) => {
printer.write_raw(token.text());
}
}
}
}
pub(crate) fn find_child(node: &SyntaxNode, kind: SyntaxKind) -> Option<SyntaxNode> {
node.children().find(|c| c.kind() == kind).cloned()
}
pub(crate) fn has_token(node: &SyntaxNode, kind: SyntaxKind) -> bool {
for element in node.children_with_tokens() {
if let cstree::util::NodeOrToken::Token(token) = element
&& token.kind() == kind
{
return true;
}
}
false
}
pub(crate) fn needs_space_before(kind: SyntaxKind) -> bool {
matches!(
kind,
SyntaxKind::AND_KW
| SyntaxKind::OR_KW
| SyntaxKind::NOT_KW
| SyntaxKind::IS_KW
| SyntaxKind::IN_KW
| SyntaxKind::LIKE_KW
| SyntaxKind::ILIKE_KW
| SyntaxKind::BETWEEN_KW
| SyntaxKind::AS_KW
| SyntaxKind::ASC_KW
| SyntaxKind::DESC_KW
)
}
pub(crate) fn needs_space_after(kind: SyntaxKind) -> bool {
matches!(
kind,
SyntaxKind::AND_KW
| SyntaxKind::OR_KW
| SyntaxKind::NOT_KW
| SyntaxKind::IS_KW
| SyntaxKind::IN_KW
| SyntaxKind::LIKE_KW
| SyntaxKind::ILIKE_KW
| SyntaxKind::AS_KW
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pg_formatter::CaseOption;
#[test]
fn test_format_simple_select() {
let sql = "select id, name from users";
let config = PgFormatterConfig::default();
let formatted = format(sql, &config);
assert!(formatted.contains("SELECT"));
assert!(formatted.contains("FROM"));
}
#[test]
fn test_format_with_function_case() {
let sql = "SELECT COUNT(*) FROM users";
let config = PgFormatterConfig::default().with_function_case(CaseOption::Lower);
let formatted = format(sql, &config);
eprintln!("Formatted output: {:?}", formatted);
assert!(formatted.contains("COUNT") || formatted.contains("count"));
}
#[test]
fn test_format_with_keyword_case_lower() {
let sql = "SELECT id FROM users";
let config = PgFormatterConfig::default().with_keyword_case(CaseOption::Lower);
let formatted = format(sql, &config);
assert!(formatted.contains("select"));
assert!(formatted.contains("from"));
}
#[test]
fn test_format_no_comment() {
let sql = "SELECT id -- comment\nFROM users";
let config = PgFormatterConfig::default().with_no_comment(true);
let formatted = format(sql, &config);
assert!(!formatted.contains("comment"));
}
#[test]
fn test_format_no_extra_line() {
let sql = "SELECT id FROM users";
let config = PgFormatterConfig {
no_extra_line: true,
..Default::default()
};
let formatted = format(sql, &config);
assert!(!formatted.ends_with('\n'));
}
#[test]
fn test_format_comma_start() {
let sql = "SELECT id, name, email FROM users";
let config = PgFormatterConfig::default().with_comma_start(true);
let formatted = format(sql, &config);
assert!(formatted.contains(", "));
}
#[test]
fn test_format_type_case() {
let sql = "SELECT id::INTEGER FROM users";
let config = PgFormatterConfig::default().with_type_case(CaseOption::Upper);
let formatted = format(sql, &config);
assert!(formatted.contains("INTEGER"));
}
}