use std::iter;
use itertools::Itertools;
use typst_syntax::{SyntaxKind, SyntaxNode};
use crate::{
ext::StrExt,
pretty::{Context, PrettyPrinter, prelude::*, util::is_comment_node},
};
enum ChainItem<'a> {
Body(ArenaDoc<'a>),
Op(ArenaDoc<'a>),
Comment(ArenaDoc<'a>),
Linebreak,
}
pub struct ChainStylist<'a> {
printer: &'a PrettyPrinter<'a>,
items: Vec<ChainItem<'a>>,
chain_op_num: usize,
has_comment: bool,
}
#[derive(Default)]
pub struct ChainStyle {
pub no_break_single: bool,
pub space_around_op: bool,
}
impl<'a> ChainStylist<'a> {
pub fn new(printer: &'a PrettyPrinter<'a>) -> Self {
Self {
printer,
items: Default::default(),
chain_op_num: 0,
has_comment: false,
}
}
pub fn process_resolved(
self,
ctx: Context,
nodes: impl Iterator<Item = &'a SyntaxNode>,
operand_pred: impl Fn(&'a SyntaxNode) -> bool,
op_converter: impl Fn(&'a SyntaxNode) -> Option<ArenaDoc<'a>>,
rhs_converter: impl Fn(Context, &'a SyntaxNode) -> Option<ArenaDoc<'a>>,
fallback_converter: impl Fn(Context, &'a SyntaxNode) -> Option<ArenaDoc<'a>>,
) -> Self {
let mut nodes = nodes.collect_vec();
nodes.reverse();
self.process(
ctx,
nodes,
operand_pred,
op_converter,
rhs_converter,
fallback_converter,
)
}
pub fn process(
mut self,
ctx: Context,
nodes: Vec<&'a SyntaxNode>,
operand_pred: impl Fn(&'a SyntaxNode) -> bool,
op_converter: impl Fn(&'a SyntaxNode) -> Option<ArenaDoc<'a>>,
rhs_converter: impl Fn(Context, &'a SyntaxNode) -> Option<ArenaDoc<'a>>,
fallback_converter: impl Fn(Context, &'a SyntaxNode) -> Option<ArenaDoc<'a>>,
) -> Self {
for node in nodes {
if operand_pred(node) {
self.chain_op_num += 1;
let children = node.children().as_slice();
let mut seen_op = false;
for (i, child) in node.children().enumerate() {
if let Some(op) = op_converter(child) {
seen_op = true;
self.items.push(ChainItem::Op(op));
} else if child.kind() == SyntaxKind::Space {
if child.leaf_text().has_linebreak()
&& (matches!(self.items.last(), Some(ChainItem::Comment(_)))
|| children.get(i + 1).is_some_and(is_comment_node))
{
self.items.push(ChainItem::Linebreak);
}
} else if is_comment_node(child) {
let doc = self.printer.convert_comment(ctx, child);
self.items.push(ChainItem::Comment(doc));
self.has_comment = true;
} else if seen_op && let Some(rhs) = rhs_converter(ctx, child) {
self.items.push(ChainItem::Body(rhs));
}
}
} else if let Some(fallback) = fallback_converter(ctx, node) {
if let Some(ChainItem::Body(body)) = self.items.last_mut() {
*body += fallback;
} else {
self.items.push(ChainItem::Body(fallback));
}
}
}
self
}
pub fn print_doc(self, sty: ChainStyle) -> ArenaDoc<'a> {
let arena = &self.printer.arena;
let op_sep = if sty.space_around_op {
arena.line()
} else {
arena.line_()
};
let use_simple_layout = self.chain_op_num == 1 && sty.no_break_single && !self.has_comment;
let mut iter = self.items.into_iter();
let Some(ChainItem::Body(first_doc)) = iter.next() else {
panic!("Chain must starts with a body");
};
let mut follow_docs = arena.nil();
let mut leading = false;
let mut space_after = true;
for item in iter {
match item {
ChainItem::Body(body) => {
follow_docs += body;
leading = false;
space_after = true;
}
ChainItem::Op(op) => {
if !(leading || use_simple_layout) {
follow_docs += op_sep.clone();
}
follow_docs += op;
if sty.space_around_op {
follow_docs += arena.space();
}
leading = false;
space_after = false;
}
ChainItem::Comment(cmt) => {
if space_after {
follow_docs += arena.space();
}
follow_docs += cmt;
leading = false;
space_after = true;
}
ChainItem::Linebreak => {
if !leading {
leading = true;
space_after = false;
follow_docs += arena.hardline();
}
}
}
}
if use_simple_layout {
first_doc + follow_docs
} else {
first_doc + self.printer.indent(follow_docs)
}
.group()
}
}
pub fn iterate_deep_nodes<'a>(
node: &'a SyntaxNode,
accepter: impl Fn(&SyntaxNode) -> Option<&SyntaxNode> + 'a,
) -> impl Iterator<Item = &'a SyntaxNode> {
let mut current = Some(node);
iter::from_fn(move || {
let ret = current;
if let Some(ret) = ret {
current = accepter(ret);
Some(ret)
} else {
None
}
})
}