use std::{collections::HashMap, fmt, rc::Rc};
use crate::{
DefaultCost, Doc, DocId, DocKind,
cost::{Cost, CostFactory, DefaultCostFactory},
measure::{Measure, MeasureSet},
non_empty::NonEmptyVecBuilder,
};
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
#[error("document was not printable")]
pub struct Error;
pub(crate) struct Printer<C: CostFactory> {
cost: C,
memo: HashMap<(DocId, usize, usize, bool, bool), MeasureSet<C>>,
}
impl<C: CostFactory + 'static> Printer<C> {
fn new(cost: C) -> Self {
Self {
cost,
memo: HashMap::new(),
}
}
fn validate(&mut self, d: Doc<C::CostType>, c: usize) -> Result<PrintResult<C::CostType>> {
let result = self
.resolve(d.clone(), c, 0, false, false)
.merge(self.resolve(d, c, 0, false, true));
let is_tainted = matches!(&result, MeasureSet::Tainted(_, _));
let measure = self.extract_at_most_one(result).ok_or(Error)?;
Ok(PrintResult {
is_tainted,
measure,
})
}
fn resolve(
&mut self,
d: Doc<C::CostType>,
c: usize,
i: usize,
begin_full: bool,
end_full: bool,
) -> MeasureSet<C> {
if d.0.kind.fails(begin_full, end_full) {
MeasureSet::Failed
} else if c <= self.cost.limit() && i <= self.cost.limit() && d.0.memo_weight == 0 {
let id = d.0.id;
let key = (id, c, i, begin_full, end_full);
if let Some(ms) = self.memo.get(&key) {
ms.clone()
} else {
let result = self.resolve_inner(d, c, i, begin_full, end_full, false);
self.memo.insert(key, result.clone());
result
}
} else {
self.resolve_inner(d, c, i, begin_full, end_full, false)
}
}
fn resolve_inner(
&mut self,
d: Doc<C::CostType>,
c: usize,
i: usize,
begin_full: bool,
end_full: bool,
allow_exceeds: bool,
) -> MeasureSet<C> {
use DocKind::*;
let exceeds = if let Text(_, len) = &d.0.kind {
c + len > self.cost.limit() || i > self.cost.limit()
} else {
c > self.cost.limit() || i > self.cost.limit()
};
if exceeds && !allow_exceeds {
let d = d.clone();
return MeasureSet::Tainted(
d.0.newline_count,
Rc::new(move |this| {
let resolved = this.resolve_inner(d.clone(), c, i, begin_full, end_full, true);
this.extract_at_most_one(resolved)
}),
);
}
match &d.0.kind {
Text(s, len) => {
let s = s.clone();
MeasureSet::new(len + c, self.cost.text(c, *len), move |w| {
write!(w, "{}", s)
})
}
Newline(_) => MeasureSet::new(i, self.cost.newline(i), move |w| {
writeln!(w)?;
write!(w, "{}", " ".repeat(i))
}),
Concat(d1, d2) => {
let mut analyze_left =
|mid_full| match self.resolve(d1.clone(), c, i, begin_full, mid_full) {
MeasureSet::Failed => MeasureSet::Failed,
MeasureSet::Tainted(_, thunk) => {
let d2 = d2.clone();
MeasureSet::tainted(&d, move |this| {
let m1 = thunk(this)?;
let resolved =
this.resolve(d2.clone(), m1.last, i, mid_full, end_full);
this.extract_at_most_one(resolved).map(|m2| m1.concat(m2))
})
}
MeasureSet::Valid(m1, ms1) => {
let first = self.analyze_right(m1, &d, d2, i, mid_full, end_full);
ms1.into_iter().rfold(first, |ms, m| {
self.analyze_right(m.clone(), &d, d2, i, mid_full, end_full)
.merge(ms)
})
}
};
analyze_left(false).merge(analyze_left(true))
}
Alt(d1, d2) => {
let r1 = self.resolve(d1.clone(), c, i, begin_full, end_full);
let r2 = self.resolve(d2.clone(), c, i, begin_full, end_full);
if d1.0.newline_count < d2.0.newline_count {
r2.merge(r1)
} else {
r1.merge(r2)
}
}
Nest(n, d) => self.resolve(d.clone(), c, i + n, begin_full, end_full),
Align(d) => self.resolve(d.clone(), c, c, begin_full, end_full),
Reset(d) => self.resolve(d.clone(), c, 0, begin_full, end_full),
Cost(co, d) => {
let co = co.clone();
let add_cost = move |mut m: Measure<C::CostType>| {
m.cost = co.clone() + m.cost;
m
};
match self.resolve(d.clone(), c, i, begin_full, end_full) {
MeasureSet::Failed => MeasureSet::Failed,
MeasureSet::Valid(m, ms) => {
MeasureSet::Valid(add_cost(m), ms.into_iter().map(add_cost).collect())
}
MeasureSet::Tainted(_, thunk) => {
MeasureSet::tainted(d, move |this| thunk(this).map(&add_cost))
}
}
}
Full(d) => self
.resolve(d.clone(), c, i, begin_full, false)
.merge(self.resolve(d.clone(), c, i, begin_full, true)),
Fail => MeasureSet::Failed,
}
}
fn analyze_right(
&mut self,
m: Measure<C::CostType>,
d: &Doc<C::CostType>,
d2: &Doc<C::CostType>,
i: usize,
begin_full: bool,
end_full: bool,
) -> MeasureSet<C> {
match self.resolve(d2.clone(), m.last, i, begin_full, end_full) {
MeasureSet::Failed => MeasureSet::Failed,
MeasureSet::Tainted(_, thunk) => MeasureSet::tainted(d, move |this| {
let m2 = thunk(this)?;
Some(m.clone().concat(m2))
}),
MeasureSet::Valid(m2, ms2) => {
let mut result = NonEmptyVecBuilder::new();
let mut current_best = m.clone().concat(m2);
for m2 in ms2.into_iter() {
let current = m.clone().concat(m2);
if current.cost > current_best.cost {
result.push(current_best);
}
current_best = current;
}
result.push(current_best);
let (first, rest) = result.finish();
MeasureSet::Valid(first, rest)
}
}
}
fn extract_at_most_one(&mut self, ms: MeasureSet<C>) -> Option<Measure<C::CostType>> {
match ms {
MeasureSet::Failed => None,
MeasureSet::Tainted(_, thunk) => thunk(self),
MeasureSet::Valid(m, _) => Some(m),
}
}
}
#[derive(Debug)]
pub struct PrintResult<C: Cost> {
is_tainted: bool,
measure: Measure<C>,
}
impl<C: Cost> PrintResult<C> {
pub fn is_tainted(&self) -> bool {
self.is_tainted
}
pub fn cost(&self) -> C {
self.measure.cost.clone()
}
}
impl<C: Cost> fmt::Display for PrintResult<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.measure.layout)(f)
}
}
impl fmt::Display for Doc<DefaultCost> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let page_width = f.width().unwrap_or(80);
let result = self
.validate(page_width)
.expect("couldn't validate print result");
write!(f, "{result}")
}
}
impl Doc<DefaultCost> {
pub fn validate(&self, page_width: usize) -> Result<PrintResult<DefaultCost>> {
self.validate_with_cost(DefaultCostFactory::new(page_width, None))
}
}
impl<C: Cost> Doc<C> {
pub fn validate_with_cost<CF: CostFactory<CostType = C> + 'static>(
&self,
cost: CF,
) -> Result<PrintResult<C>> {
Printer::new(cost).validate(self.clone(), 0)
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn s_exp() {
let fn_name = text("(defn") & space() & text("my-fn");
let args = [text("arg1"), text("arg2"), text("arg3")];
let arg_list = lparen() & align(us_concat(args.clone()) | v_concat(args)) & rparen();
let body_forms = [text("(println 'hello)"), text("(+ 1 2 3)")];
let body = (space() & us_concat(body_forms.clone())) | (hard_nl() & v_concat(body_forms));
let doc = fn_name.clone() & nest(2, hard_nl() & v_append(arg_list.clone(), body.clone()))
| us_append(fn_name, nest(2, arg_list & body));
assert_eq!(
r#"(defn my-fn (arg1 arg2 arg3) (println 'hello) (+ 1 2 3)"#,
format!("{doc:120}")
);
assert_eq!(
r#"(defn my-fn (arg1 arg2 arg3)
(println 'hello)
(+ 1 2 3)"#,
format!("{doc:30}")
);
assert_eq!(
r#"(defn my-fn
(arg1 arg2 arg3)
(println 'hello)
(+ 1 2 3)"#,
format!("{doc:20}")
);
assert_eq!(
r#"(defn my-fn
(arg1
arg2
arg3)
(println 'hello)
(+ 1 2 3)"#,
format!("{doc:10}")
);
}
#[test]
fn full_comments() {
let doc = lparen()
& text("println")
& group(nl())
& full(text("; this is a comment"))
& nl()
& text("\"my text\"")
& rparen();
assert_eq!(
r#"(println ; this is a comment
"my text")"#,
doc.to_string(),
);
let args = [
text("a"),
full(text("; the first one")),
text("b"),
full(text("; the second one")),
];
let doc = align(v_concat(args.to_vec()));
assert_eq!(
r#"a
; the first one
b
; the second one"#,
doc.to_string(),
);
let doc = group(nl()) & doc & group(brk());
assert_eq!(
r#" a
; the first one
b
; the second one"#,
doc.to_string(),
);
let doc = lparen() & ((space() & align(us_concat(args.to_vec()))) | doc) & rparen();
assert_eq!(
r#"( a
; the first one
b
; the second one
)"#,
doc.to_string(),
);
}
enum Node {
Str(String),
List(Vec<Node>),
}
fn pretty(node: &Node) -> Doc {
match node {
Node::List(children) => {
if let Some((first, rest)) = children.split_first() {
let fp = pretty(first);
let args: Vec<_> = rest.iter().map(pretty).collect();
(lparen() & align(v_append(fp.clone(), v_concat(args.to_vec()))) & rparen())
| (lparen()
& align(fp.clone())
& space()
& align(v_concat(args.to_vec()))
& rparen())
| flatten(
lparen()
& align(us_append(fp.clone(), us_concat(args.to_vec())))
& rparen(),
)
} else {
text("()")
}
}
Node::Str(s) => text(s),
}
}
fn pretty2(node: &Node) -> Doc {
match node {
Node::List(children) => {
if let Some((first, rest)) = children.split_first() {
let fp = pretty2(first);
let args: Vec<_> = rest.iter().map(pretty2).collect();
(lparen() & align(v_append(fp.clone(), v_concat(args.to_vec()))) & rparen())
| (lparen()
& align(fp.clone())
& space()
& align(v_concat(args.to_vec()))
& rparen())
| (lparen()
& align(us_append(fp.clone(), us_concat(args.to_vec())))
& rparen())
} else {
text("()")
}
}
Node::Str(s) => text(s),
}
}
fn test_doc() -> Node {
use Node::*;
List(vec![
Str("+".to_string()),
List(vec![
Str("foo".to_string()),
Str("1".to_string()),
Str("2".to_string()),
]),
List(vec![
Str("bar".to_string()),
Str("2".to_string()),
Str("3".to_string()),
]),
List(vec![
Str("baz".to_string()),
Str("3".to_string()),
Str("4".to_string()),
]),
])
}
#[test]
fn check_pretty() {
let doc = pretty(&test_doc());
assert_eq!(
format!("{doc:31}"),
r#"(+ (foo 1 2)
(bar 2 3)
(baz 3 4))"#
);
}
#[test]
fn check_pretty2() {
let doc = pretty2(&test_doc());
assert_eq!(
format!("{doc:31}"),
r#"(+ (foo 1
2) (bar 2 3) (baz 3 4))"#
);
}
#[test]
fn smush_it() {
let doc = pretty(&Node::List(vec![
Node::Str("+".to_string()),
Node::Str("123".to_string()),
Node::Str("456".to_string()),
Node::Str("789".to_string()),
]));
assert_eq!(format!("{doc:15}"), "(+ 123 456 789)");
assert_eq!(format!("{doc:14}"), "(+ 123\n 456\n 789)");
assert_eq!(format!("{doc:5}"), "(+\n 123\n 456\n 789)");
assert_eq!(format!("{doc:1}"), "(+\n 123\n 456\n 789)");
}
#[test]
fn more_checks() {
let doc = nest(4, reset(text("abc") & hard_nl() & text("def")));
assert_eq!(doc.to_string(), "abc\ndef");
let doc = nest(4, text("abc") & hard_nl() & text("def"));
assert_eq!(doc.to_string(), "abc\n def");
let doc = flatten(text("abc") & nl() & text("def")) | text("something");
assert_eq!(doc.to_string(), "abc def");
let doc = flatten(text("abc") & hard_nl() & text("def")) | text("something");
assert_eq!(doc.to_string(), "something");
}
}