use crate::error::AamlError;
use crate::pipeline::parser::AstNode;
#[derive(Debug, Clone)]
pub struct FormattingOptions {
pub indent_size: usize,
pub use_tabs: bool,
pub line_width: usize,
pub sort_keys: bool,
pub trailing_newline: bool,
pub preserve_blank_lines: bool,
}
impl Default for FormattingOptions {
fn default() -> Self {
Self {
indent_size: 4,
use_tabs: false,
line_width: 100, sort_keys: false,
trailing_newline: true,
preserve_blank_lines: true,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct FormatRange {
pub start_line: usize,
pub end_line: usize,
}
pub trait Formatter: Send + Sync {
fn format_document(
&self,
nodes: &[AstNode],
options: &FormattingOptions,
) -> Result<String, AamlError>;
fn format_range(
&self,
nodes: &[AstNode],
range: FormatRange,
options: &FormattingOptions,
) -> Result<String, AamlError>;
fn format_node(
&self,
node: &AstNode,
indent_level: usize,
options: &FormattingOptions,
) -> Result<String, AamlError>;
fn normalize_comments(
&self,
content: &str,
options: &FormattingOptions,
) -> Result<String, AamlError>;
fn normalize_whitespace(&self, content: &str) -> Result<String, AamlError>;
}
pub struct DefaultFormatter;
impl DefaultFormatter {
pub fn new() -> Self {
Self
}
fn create_indent(level: usize, options: &FormattingOptions) -> String {
if options.use_tabs {
"\t".repeat(level)
} else {
" ".repeat(level * options.indent_size)
}
}
fn is_hoistable(node: &AstNode) -> bool {
if let AstNode::Directive { name, .. } = node {
matches!(name.as_ref(), "import" | "derive")
} else {
false
}
}
fn format_assignment(
key: &str,
value: &str,
indent_level: usize,
options: &FormattingOptions,
) -> String {
let indent = Self::create_indent(indent_level, options);
format!("{}{} = {}", indent, key.trim(), value.trim())
}
fn format_type_alias(args: &str, indent_level: usize, options: &FormattingOptions) -> String {
let indent = Self::create_indent(indent_level, options);
if let Some((name, alias)) = args.split_once('=') {
format!("{}@type {} = {}", indent, name.trim(), alias.trim())
} else {
format!("{}@type {}", indent, args.trim())
}
}
fn format_schema(args: &str, indent_level: usize, options: &FormattingOptions) -> String {
let indent = Self::create_indent(indent_level, options);
let Some((name_part, body_part)) = args.split_once('{') else {
return Self::format_directive("schema", args, indent_level, options);
};
let schema_name = name_part.trim();
let body = body_part.trim_end_matches('}').trim();
let pairs: Vec<_> = body
.split(|c| c == ',' || c == '\n')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
let mut formatted_pairs = Vec::new();
for pair in pairs {
if let Some((k, v)) = pair.split_once(':') {
formatted_pairs.push(format!("{}: {}", k.trim(), v.trim()));
} else {
formatted_pairs.push(pair.to_string());
}
}
let single_line = format!(
"{}@schema {} {{ {} }}",
indent,
schema_name,
formatted_pairs.join(", ")
);
if options.line_width == 0 || single_line.len() <= options.line_width {
return single_line;
}
let inner_indent = Self::create_indent(indent_level + 1, options);
let mut lines = vec![format!("{}@schema {} {{", indent, schema_name)];
for pair in formatted_pairs {
lines.push(format!("{}{}", inner_indent, pair));
}
lines.push(format!("{}}}", indent));
lines.join("\n")
}
fn format_directive(
name: &str,
args: &str,
indent_level: usize,
options: &FormattingOptions,
) -> String {
let indent = Self::create_indent(indent_level, options);
if args.trim().is_empty() {
format!("{}@{}", indent, name.trim())
} else {
format!("{}@{} {}", indent, name.trim(), args.trim())
}
}
}
impl Default for DefaultFormatter {
fn default() -> Self {
Self::new()
}
}
impl Formatter for DefaultFormatter {
fn format_document(
&self,
nodes: &[AstNode],
options: &FormattingOptions,
) -> Result<String, AamlError> {
let (hoistable, others): (Vec<_>, Vec<_>) =
nodes.iter().partition(|n| Self::is_hoistable(n));
let mut header: Vec<String> = hoistable
.into_iter()
.map(|n| self.format_node(n, 0, options))
.collect::<Result<_, _>>()?;
let body: Vec<String> = others
.into_iter()
.map(|n| self.format_node(n, 0, options))
.collect::<Result<_, _>>()?;
if !header.is_empty() && !body.is_empty() && options.preserve_blank_lines {
header.push(String::new());
}
let mut result = [header, body].concat().join("\n");
if options.trailing_newline && !result.is_empty() && !result.ends_with('\n') {
result.push('\n');
}
Ok(result)
}
fn format_range(
&self,
nodes: &[AstNode],
range: FormatRange,
options: &FormattingOptions,
) -> Result<String, AamlError> {
let mut output = Vec::new();
for node in nodes {
let line = node.line(); if line >= range.start_line && line <= range.end_line {
let formatted = self.format_node(node, 0, options)?;
output.push(formatted);
} else {
output.push(format!("(original line {})", line));
}
}
Ok(output.join("\n"))
}
fn format_node(
&self,
node: &AstNode,
indent_level: usize,
options: &FormattingOptions,
) -> Result<String, AamlError> {
let formatted = match node {
AstNode::Assignment { key, value, .. } => {
Self::format_assignment(key, &value.to_string(), indent_level, options)
}
AstNode::Directive { name, args, .. } => {
match name.as_ref() {
"schema" => Self::format_schema(args, indent_level, options),
"type" => Self::format_type_alias(args, indent_level, options),
_ => Self::format_directive(name.as_ref(), args, indent_level, options),
}
}
};
Ok(formatted)
}
fn normalize_comments(
&self,
content: &str,
_options: &FormattingOptions,
) -> Result<String, AamlError> {
let lines: Vec<&str> = content.lines().collect();
let normalized: Vec<String> = lines
.iter()
.map(|line| {
if let Some(pos) = line.find('#') {
let before = &line[..pos];
let after = &line[pos + 1..];
if pos > 0
&& pos < line.len() - 1
&& before.ends_with(' ')
&& !after.starts_with('#')
{
let comment = after.trim_start();
return format!("{}# {}", before.trim_end(), comment);
}
}
line.to_string()
})
.collect();
Ok(normalized.join("\n"))
}
fn normalize_whitespace(&self, content: &str) -> Result<String, AamlError> {
let lines: Vec<&str> = content.lines().collect();
let normalized: Vec<String> = lines
.iter()
.map(|line| line.trim_end().to_string())
.collect();
Ok(normalized.join("\n"))
}
}