use std::iter::Peekable;
use crate::ast::{command_name, environment_name};
use crate::parser::parse;
use crate::semantic::{ArgKind, ArgSpec, Signatures, scan_definitions};
use crate::syntax::{SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken};
use super::context::FormatContext;
use super::ir::Ir;
use super::printer::Printer;
use super::style::{FormatStyle, WrapMode};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FormatError {
ParseErrors { count: usize },
UnsupportedConstruct { kind: SyntaxKind, snippet: String },
}
impl std::fmt::Display for FormatError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ParseErrors { count } => write!(
f,
"input contains {count} parser diagnostic(s); formatter only supports parseable input"
),
Self::UnsupportedConstruct { kind, snippet } => {
write!(
f,
"unsupported construct for formatter: {kind:?} near {snippet:?}"
)
}
}
}
}
impl std::error::Error for FormatError {}
pub fn format(input: &str) -> Result<String, FormatError> {
format_with_style(input, FormatStyle::default())
}
pub fn format_with_style(input: &str, style: FormatStyle) -> Result<String, FormatError> {
let parsed = parse(input);
if !parsed.errors.is_empty() {
return Err(FormatError::ParseErrors {
count: parsed.errors.len(),
});
}
format_node(&parsed.syntax(), style)
}
pub fn format_node(root: &SyntaxNode, style: FormatStyle) -> Result<String, FormatError> {
validate_supported_tokens(root)?;
let ctx = FormatContext::new(style);
let mut formatted = format_root(root, ctx);
let trimmed_len = formatted.trim_end_matches([' ', '\t', '\n', '\r']).len();
formatted.truncate(trimmed_len);
if !formatted.is_empty() {
formatted.push('\n');
}
Ok(formatted)
}
fn validate_supported_tokens(root: &SyntaxNode) -> Result<(), FormatError> {
for element in root.descendants_with_tokens() {
let Some(token) = element.into_token() else {
continue;
};
if token.kind() == SyntaxKind::ERROR {
return Err(FormatError::UnsupportedConstruct {
kind: token.kind(),
snippet: token.text().to_string(),
});
}
}
Ok(())
}
fn format_root(root: &SyntaxNode, ctx: FormatContext) -> String {
let user = scan_definitions(root);
let cx = LowerCtx {
wrap: ctx.style().wrap,
signatures: Signatures::new(&user),
};
let ir = lower_node(root, cx);
Printer::new(ctx.style()).print(&ir)
}
#[derive(Clone, Copy)]
struct LowerCtx<'a> {
wrap: WrapMode,
signatures: Signatures<'a>,
}
fn lower_node(node: &SyntaxNode, cx: LowerCtx<'_>) -> Ir {
match node.kind() {
SyntaxKind::PARAGRAPH if cx.wrap == WrapMode::Reflow => {
return lower_paragraph_reflow(node, cx);
}
SyntaxKind::ENVIRONMENT if !has_verbatim_body(node) => {
return lower_environment(node, cx);
}
SyntaxKind::COMMAND if cx.wrap == WrapMode::Reflow && command_has_prose_arg(node, cx) => {
return lower_command(node, cx);
}
SyntaxKind::GROUP if spans_multiple_lines(node) => {
return lower_bracketed(node, SyntaxKind::L_BRACE, SyntaxKind::R_BRACE, cx);
}
SyntaxKind::OPTIONAL if spans_multiple_lines(node) => {
return lower_bracketed(node, SyntaxKind::L_BRACKET, SyntaxKind::R_BRACKET, cx);
}
_ => {}
}
Ir::concat(lower_element_stream(node.children_with_tokens(), cx))
}
fn lower_paragraph_reflow(node: &SyntaxNode, cx: LowerCtx<'_>) -> Ir {
reflow_elements(node.children_with_tokens(), cx)
}
fn reflow_elements(elements: impl Iterator<Item = SyntaxElement>, cx: LowerCtx<'_>) -> Ir {
let mut atom: Vec<Ir> = Vec::new();
let mut run: Vec<Ir> = Vec::new();
let mut lines: Vec<Ir> = Vec::new();
let mut seps: Vec<Ir> = Vec::new();
let mut pending_sep: Ir = Ir::hard_line();
fn flush_atom(atom: &mut Vec<Ir>, run: &mut Vec<Ir>) {
if !atom.is_empty() {
run.push(Ir::concat(atom.drain(..)));
}
}
fn push_segment(content: Ir, lines: &mut Vec<Ir>, seps: &mut Vec<Ir>, pending_sep: &mut Ir) {
seps.push(std::mem::replace(pending_sep, Ir::hard_line()));
lines.push(content);
}
fn end_line(
atom: &mut Vec<Ir>,
run: &mut Vec<Ir>,
lines: &mut Vec<Ir>,
seps: &mut Vec<Ir>,
pending_sep: &mut Ir,
) {
flush_atom(atom, run);
if !run.is_empty() {
push_segment(Ir::fill(run.drain(..)), lines, seps, pending_sep);
}
}
let mut iter = elements.peekable();
while let Some(element) = iter.next() {
match element {
SyntaxElement::Token(token) if is_collapsible_trivia(token.kind()) => {
let (newlines, _) = consume_trivia_run(&token, &mut iter);
if newlines >= 2 {
end_line(&mut atom, &mut run, &mut lines, &mut seps, &mut pending_sep);
pending_sep = Ir::empty_line();
} else {
flush_atom(&mut atom, &mut run);
}
}
SyntaxElement::Token(token) if token.kind() == SyntaxKind::COMMENT => {
atom.push(Ir::verbatim(token.text()));
end_line(&mut atom, &mut run, &mut lines, &mut seps, &mut pending_sep);
}
SyntaxElement::Token(token) if token.text().contains('\n') => {
let before = token.text().split_once('\n').map(|(b, _)| b).unwrap_or("");
if !before.is_empty() {
atom.push(Ir::verbatim(before));
}
end_line(&mut atom, &mut run, &mut lines, &mut seps, &mut pending_sep);
}
SyntaxElement::Token(token) => atom.push(Ir::verbatim(token.text())),
SyntaxElement::Node(child) if child.kind() == SyntaxKind::LINE_BREAK => {
atom.push(lower_node(&child, cx));
end_line(&mut atom, &mut run, &mut lines, &mut seps, &mut pending_sep);
}
SyntaxElement::Node(child) => {
let ir = lower_node(&child, cx);
if ir.contains_forced_break() {
end_line(&mut atom, &mut run, &mut lines, &mut seps, &mut pending_sep);
push_segment(ir, &mut lines, &mut seps, &mut pending_sep);
} else {
atom.push(ir);
}
}
}
}
end_line(&mut atom, &mut run, &mut lines, &mut seps, &mut pending_sep);
let mut result: Vec<Ir> = Vec::with_capacity(lines.len().saturating_mul(2));
for (i, line) in lines.into_iter().enumerate() {
if i > 0 {
result.push(seps[i].clone());
}
result.push(line);
}
Ir::concat(result)
}
fn lower_element_stream(
elements: impl Iterator<Item = SyntaxElement>,
cx: LowerCtx<'_>,
) -> Vec<Ir> {
let mut out = Vec::new();
let mut iter = elements.peekable();
while let Some(element) = iter.next() {
match element {
SyntaxElement::Node(child) => out.push(lower_node(&child, cx)),
SyntaxElement::Token(token) if is_collapsible_trivia(token.kind()) => {
let (newlines, trailing_ws) = consume_trivia_run(&token, &mut iter);
out.push(classify_trivia(newlines, trailing_ws));
}
SyntaxElement::Token(token) => out.push(Ir::verbatim(token.text())),
}
}
out
}
fn lower_environment(node: &SyntaxNode, cx: LowerCtx<'_>) -> Ir {
let mut begin = Ir::Nil;
let mut end = Ir::Nil;
let mut body_elements: Vec<SyntaxElement> = Vec::new();
for element in node.children_with_tokens() {
match &element {
SyntaxElement::Node(child) if child.kind() == SyntaxKind::BEGIN => {
begin = lower_begin(child, cx);
}
SyntaxElement::Node(child) if child.kind() == SyntaxKind::END => {
end = lower_node(child, cx);
}
_ => body_elements.push(element),
}
}
let body = Ir::concat(lower_element_stream(body_elements.into_iter(), cx));
let body = trim_trailing_break(trim_leading_break(body));
if matches!(body, Ir::Nil) {
Ir::concat([begin, Ir::hard_line(), end])
} else {
Ir::concat([
begin,
Ir::indent(Ir::concat([Ir::hard_line(), body])),
Ir::hard_line(),
end,
])
}
}
fn lower_begin(begin: &SyntaxNode, cx: LowerCtx<'_>) -> Ir {
let arity = environment_name(begin)
.and_then(|name| cx.signatures.environment(&name))
.map(|sig| sig.args.len())
.unwrap_or(0);
let has_comment = begin
.children_with_tokens()
.filter_map(|element| element.into_token())
.any(|token| token.kind() == SyntaxKind::COMMENT);
if arity == 0 || has_comment {
return lower_node(begin, cx);
}
let mut head: Vec<Ir> = Vec::new();
let mut tail: Vec<SyntaxElement> = Vec::new();
let mut args_seen = 0;
let mut in_tail = false;
for element in begin.children_with_tokens() {
if in_tail {
tail.push(element);
continue;
}
match &element {
SyntaxElement::Node(child)
if matches!(child.kind(), SyntaxKind::GROUP | SyntaxKind::OPTIONAL) =>
{
head.push(lower_node(child, cx));
args_seen += 1;
if args_seen == arity {
in_tail = true;
}
}
SyntaxElement::Node(child) => head.push(lower_node(child, cx)),
SyntaxElement::Token(token) if is_collapsible_trivia(token.kind()) => {}
SyntaxElement::Token(token) => head.push(Ir::verbatim(token.text())),
}
}
if !tail.is_empty() {
head.extend(lower_element_stream(tail.into_iter(), cx));
}
Ir::concat(head)
}
fn lower_bracketed(node: &SyntaxNode, open: SyntaxKind, close: SyntaxKind, cx: LowerCtx<'_>) -> Ir {
let mut open_ir = Ir::Nil;
let mut close_ir = Ir::Nil;
let mut body_elements: Vec<SyntaxElement> = Vec::new();
for element in node.children_with_tokens() {
match &element {
SyntaxElement::Token(t) if t.kind() == open && matches!(open_ir, Ir::Nil) => {
open_ir = Ir::verbatim(t.text());
}
SyntaxElement::Token(t) if t.kind() == close => {
close_ir = Ir::verbatim(t.text());
}
_ => body_elements.push(element),
}
}
let body = Ir::concat(lower_element_stream(body_elements.into_iter(), cx));
let body = trim_trailing_break(trim_leading_break(body));
if matches!(body, Ir::Nil) {
Ir::concat([open_ir, close_ir])
} else {
Ir::concat([
open_ir,
Ir::indent(Ir::concat([Ir::hard_line(), body])),
Ir::hard_line(),
close_ir,
])
}
}
fn command_has_prose_arg(command: &SyntaxNode, cx: LowerCtx<'_>) -> bool {
command_name(command)
.and_then(|name| cx.signatures.command(&name))
.is_some_and(|sig| sig.args.iter().any(|spec| spec.prose))
}
fn lower_command(node: &SyntaxNode, cx: LowerCtx<'_>) -> Ir {
let Some(sig) = command_name(node).and_then(|name| cx.signatures.command(&name)) else {
return Ir::concat(lower_element_stream(node.children_with_tokens(), cx));
};
let mut out: Vec<Ir> = Vec::new();
let mut slot = 0usize;
let mut iter = node.children_with_tokens().peekable();
while let Some(element) = iter.next() {
match element {
SyntaxElement::Node(child)
if matches!(child.kind(), SyntaxKind::GROUP | SyntaxKind::OPTIONAL) =>
{
let is_bracket = child.kind() == SyntaxKind::OPTIONAL;
let prose =
match_arg_slot(&sig.args, &mut slot, is_bracket).is_some_and(|spec| spec.prose);
if prose {
let (open, close) = if is_bracket {
(SyntaxKind::L_BRACKET, SyntaxKind::R_BRACKET)
} else {
(SyntaxKind::L_BRACE, SyntaxKind::R_BRACE)
};
out.push(lower_prose_group(&child, open, close, cx));
} else {
out.push(lower_node(&child, cx));
}
}
SyntaxElement::Node(child) => out.push(lower_node(&child, cx)),
SyntaxElement::Token(token) if is_collapsible_trivia(token.kind()) => {
let (newlines, trailing_ws) = consume_trivia_run(&token, &mut iter);
out.push(classify_trivia(newlines, trailing_ws));
}
SyntaxElement::Token(token) => out.push(Ir::verbatim(token.text())),
}
}
Ir::concat(out)
}
fn match_arg_slot(args: &[ArgSpec], slot: &mut usize, is_bracket: bool) -> Option<ArgSpec> {
while *slot < args.len() {
let spec = args[*slot];
let spec_bracket = matches!(spec.kind, ArgKind::Bracket);
if spec_bracket == is_bracket {
*slot += 1;
return Some(spec);
}
if spec_bracket {
*slot += 1;
continue;
}
return None;
}
None
}
fn lower_prose_group(
node: &SyntaxNode,
open: SyntaxKind,
close: SyntaxKind,
cx: LowerCtx<'_>,
) -> Ir {
let mut open_ir = Ir::Nil;
let mut close_ir = Ir::Nil;
let mut body_elements: Vec<SyntaxElement> = Vec::new();
for element in node.children_with_tokens() {
match &element {
SyntaxElement::Token(t) if t.kind() == open && matches!(open_ir, Ir::Nil) => {
open_ir = Ir::verbatim(t.text());
}
SyntaxElement::Token(t) if t.kind() == close => {
close_ir = Ir::verbatim(t.text());
}
_ => body_elements.push(element),
}
}
let body = reflow_elements(body_elements.into_iter(), cx);
if matches!(body, Ir::Nil) {
Ir::concat([open_ir, close_ir])
} else {
Ir::group(Ir::concat([
open_ir,
Ir::indent(Ir::concat([Ir::soft_line(), body])),
Ir::soft_line(),
close_ir,
]))
}
}
fn spans_multiple_lines(node: &SyntaxNode) -> bool {
node.children_with_tokens()
.filter_map(|e| e.into_token())
.any(|t| t.kind() == SyntaxKind::NEWLINE)
}
fn has_verbatim_body(node: &SyntaxNode) -> bool {
node.children_with_tokens()
.filter_map(|e| e.into_token())
.any(|t| t.kind() == SyntaxKind::VERBATIM_BODY)
}
fn is_collapsible_trivia(kind: SyntaxKind) -> bool {
matches!(kind, SyntaxKind::WHITESPACE | SyntaxKind::NEWLINE)
}
fn consume_trivia_run(
first: &SyntaxToken,
iter: &mut Peekable<impl Iterator<Item = SyntaxElement>>,
) -> (usize, String) {
let mut newlines = 0;
let mut trailing_ws = String::new();
absorb(first, &mut newlines, &mut trailing_ws);
loop {
match iter.peek() {
Some(SyntaxElement::Token(tok)) if is_collapsible_trivia(tok.kind()) => {}
_ => break,
}
let token = match iter.next() {
Some(SyntaxElement::Token(tok)) => tok,
_ => unreachable!("peeked a collapsible trivia token"),
};
absorb(&token, &mut newlines, &mut trailing_ws);
}
(newlines, trailing_ws)
}
fn absorb(tok: &SyntaxToken, newlines: &mut usize, trailing_ws: &mut String) {
if tok.kind() == SyntaxKind::NEWLINE {
*newlines += 1;
trailing_ws.clear();
} else {
trailing_ws.push_str(tok.text());
}
}
fn classify_trivia(newlines: usize, trailing_ws: String) -> Ir {
match newlines {
0 => Ir::verbatim(trailing_ws),
1 => Ir::hard_line(),
_ => Ir::empty_line(),
}
}
fn is_trimmable_break(ir: &Ir) -> bool {
match ir {
Ir::HardLine | Ir::EmptyLine | Ir::Nil => true,
Ir::Verbatim { text, force_break } => {
!force_break && text.chars().all(|c| c == ' ' || c == '\t')
}
_ => false,
}
}
fn trim_leading_break(ir: Ir) -> Ir {
if is_trimmable_break(&ir) {
return Ir::Nil;
}
match ir {
Ir::Concat(items) => {
let mut v: Vec<Ir> = items.iter().cloned().collect();
while !v.is_empty() {
let head = trim_leading_break(v.remove(0));
if matches!(head, Ir::Nil) {
continue;
}
v.insert(0, head);
break;
}
Ir::concat(v)
}
other => other,
}
}
fn trim_trailing_break(ir: Ir) -> Ir {
if is_trimmable_break(&ir) {
return Ir::Nil;
}
match ir {
Ir::Concat(items) => {
let mut v: Vec<Ir> = items.iter().cloned().collect();
while let Some(last) = v.pop() {
let tail = trim_trailing_break(last);
if matches!(tail, Ir::Nil) {
continue;
}
v.push(tail);
break;
}
Ir::concat(v)
}
other => other,
}
}