1use std::{collections::HashMap, fmt, rc::Rc};
2
3use crate::{
4 DefaultCost, Doc, DocId, DocKind,
5 cost::{Cost, CostFactory, DefaultCostFactory},
6 measure::{Measure, MeasureSet},
7 non_empty::NonEmptyVecBuilder,
8};
9
10pub type Result<T, E = Error> = std::result::Result<T, E>;
12
13#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
19#[error("document was not printable")]
20pub struct Error;
21
22pub(crate) struct Printer<C: CostFactory> {
23 cost: C,
24 memo: HashMap<(DocId, usize, usize, bool, bool), MeasureSet<C>>,
25}
26
27impl<C: CostFactory + 'static> Printer<C> {
28 fn new(cost: C) -> Self {
29 Self {
30 cost,
31 memo: HashMap::new(),
32 }
33 }
34
35 fn validate(&mut self, d: Doc<C::CostType>, c: usize) -> Result<PrintResult<C::CostType>> {
36 let result = self
37 .resolve(d.clone(), c, 0, false, false)
38 .merge(self.resolve(d, c, 0, false, true));
39 let is_tainted = matches!(&result, MeasureSet::Tainted(_, _));
40 let measure = self.extract_at_most_one(result).ok_or(Error)?;
41
42 Ok(PrintResult {
43 is_tainted,
44 measure,
45 })
46 }
47
48 fn resolve(
49 &mut self,
50 d: Doc<C::CostType>,
51 c: usize,
52 i: usize,
53 begin_full: bool,
54 end_full: bool,
55 ) -> MeasureSet<C> {
56 if d.0.kind.fails(begin_full, end_full) {
57 MeasureSet::Failed
58 } else if c <= self.cost.limit() && i <= self.cost.limit() && d.0.memo_weight == 0 {
59 let id = d.0.id;
60 let key = (id, c, i, begin_full, end_full);
61 if let Some(ms) = self.memo.get(&key) {
62 ms.clone()
63 } else {
64 let result = self.resolve_inner(d, c, i, begin_full, end_full, false);
65 self.memo.insert(key, result.clone());
66 result
67 }
68 } else {
69 self.resolve_inner(d, c, i, begin_full, end_full, false)
70 }
71 }
72
73 fn resolve_inner(
74 &mut self,
75 d: Doc<C::CostType>,
76 c: usize,
77 i: usize,
78 begin_full: bool,
79 end_full: bool,
80 allow_exceeds: bool,
81 ) -> MeasureSet<C> {
82 use DocKind::*;
83
84 let exceeds = if let Text(_, len) = &d.0.kind {
85 c + len > self.cost.limit() || i > self.cost.limit()
86 } else {
87 c > self.cost.limit() || i > self.cost.limit()
88 };
89
90 if exceeds && !allow_exceeds {
91 let d = d.clone();
92 return MeasureSet::Tainted(
93 d.0.newline_count,
94 Rc::new(move |this| {
95 let resolved = this.resolve_inner(d.clone(), c, i, begin_full, end_full, true);
96 this.extract_at_most_one(resolved)
98 }),
99 );
100 }
101
102 match &d.0.kind {
103 Text(s, len) => {
104 let s = s.clone();
105 MeasureSet::new(len + c, self.cost.text(c, *len), move |w| {
106 write!(w, "{}", s)
107 })
108 }
109 Newline(_) => MeasureSet::new(i, self.cost.newline(i), move |w| {
110 writeln!(w)?;
111 write!(w, "{}", " ".repeat(i))
112 }),
113 Concat(d1, d2) => {
114 let mut analyze_left =
115 |mid_full| match self.resolve(d1.clone(), c, i, begin_full, mid_full) {
116 MeasureSet::Failed => MeasureSet::Failed,
117 MeasureSet::Tainted(_, thunk) => {
118 let d2 = d2.clone();
119 MeasureSet::tainted(&d, move |this| {
120 let m1 = thunk(this)?;
121 let resolved =
122 this.resolve(d2.clone(), m1.last, i, mid_full, end_full);
123 this.extract_at_most_one(resolved).map(|m2| m1.concat(m2))
124 })
125 }
126 MeasureSet::Valid(m1, ms1) => {
127 let first = self.analyze_right(m1, &d, d2, i, mid_full, end_full);
128 ms1.into_iter().rfold(first, |ms, m| {
129 self.analyze_right(m.clone(), &d, d2, i, mid_full, end_full)
130 .merge(ms)
131 })
132 }
133 };
134
135 analyze_left(false).merge(analyze_left(true))
136 }
137 Alt(d1, d2) => {
138 let r1 = self.resolve(d1.clone(), c, i, begin_full, end_full);
139 let r2 = self.resolve(d2.clone(), c, i, begin_full, end_full);
140 if d1.0.newline_count < d2.0.newline_count {
141 r2.merge(r1)
142 } else {
143 r1.merge(r2)
144 }
145 }
146 Nest(n, d) => self.resolve(d.clone(), c, i + n, begin_full, end_full),
147 Align(d) => self.resolve(d.clone(), c, c, begin_full, end_full),
148 Reset(d) => self.resolve(d.clone(), c, 0, begin_full, end_full),
149 Cost(co, d) => {
150 let co = co.clone();
151 let add_cost = move |mut m: Measure<C::CostType>| {
152 m.cost = co.clone() + m.cost;
153 m
154 };
155
156 match self.resolve(d.clone(), c, i, begin_full, end_full) {
157 MeasureSet::Failed => MeasureSet::Failed,
158 MeasureSet::Valid(m, ms) => {
159 MeasureSet::Valid(add_cost(m), ms.into_iter().map(add_cost).collect())
160 }
161 MeasureSet::Tainted(_, thunk) => {
162 MeasureSet::tainted(d, move |this| thunk(this).map(&add_cost))
163 }
164 }
165 }
166 Full(d) => self
167 .resolve(d.clone(), c, i, begin_full, false)
168 .merge(self.resolve(d.clone(), c, i, begin_full, true)),
169 Fail => MeasureSet::Failed,
170 }
171 }
172
173 fn analyze_right(
174 &mut self,
175 m: Measure<C::CostType>,
176 d: &Doc<C::CostType>,
177 d2: &Doc<C::CostType>,
178 i: usize,
179 begin_full: bool,
180 end_full: bool,
181 ) -> MeasureSet<C> {
182 match self.resolve(d2.clone(), m.last, i, begin_full, end_full) {
183 MeasureSet::Failed => MeasureSet::Failed,
184 MeasureSet::Tainted(_, thunk) => MeasureSet::tainted(d, move |this| {
185 let m2 = thunk(this)?;
186 Some(m.clone().concat(m2))
187 }),
188 MeasureSet::Valid(m2, ms2) => {
189 let mut result = NonEmptyVecBuilder::new();
190 let mut current_best = m.clone().concat(m2);
191
192 for m2 in ms2.into_iter() {
193 let current = m.clone().concat(m2);
194 if current.cost > current_best.cost {
195 result.push(current_best);
196 }
197 current_best = current;
198 }
199
200 result.push(current_best);
201 let (first, rest) = result.finish();
202 MeasureSet::Valid(first, rest)
203 }
204 }
205 }
206
207 fn extract_at_most_one(&mut self, ms: MeasureSet<C>) -> Option<Measure<C::CostType>> {
208 match ms {
209 MeasureSet::Failed => None,
210 MeasureSet::Tainted(_, thunk) => thunk(self),
211 MeasureSet::Valid(m, _) => Some(m),
212 }
213 }
214}
215
216#[derive(Debug)]
225pub struct PrintResult<C: Cost> {
226 is_tainted: bool,
227 measure: Measure<C>,
228}
229
230impl<C: Cost> PrintResult<C> {
231 pub fn is_tainted(&self) -> bool {
239 self.is_tainted
240 }
241
242 pub fn cost(&self) -> C {
244 self.measure.cost.clone()
245 }
246}
247
248impl<C: Cost> fmt::Display for PrintResult<C> {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 (self.measure.layout)(f)
251 }
252}
253
254impl fmt::Display for Doc<DefaultCost> {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 let page_width = f.width().unwrap_or(80);
257 let result = self
258 .validate(page_width)
259 .expect("couldn't validate print result");
260 write!(f, "{result}")
261 }
262}
263
264impl Doc<DefaultCost> {
265 pub fn validate(&self, page_width: usize) -> Result<PrintResult<DefaultCost>> {
299 self.validate_with_cost(DefaultCostFactory::new(page_width, None))
300 }
301}
302
303impl<C: Cost> Doc<C> {
304 pub fn validate_with_cost<CF: CostFactory<CostType = C> + 'static>(
328 &self,
329 cost: CF,
330 ) -> Result<PrintResult<C>> {
331 Printer::new(cost).validate(self.clone(), 0)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use crate::*;
338
339 #[test]
340 fn s_exp() {
341 let fn_name = text("(defn") & space() & text("my-fn");
342 let args = [text("arg1"), text("arg2"), text("arg3")];
343 let arg_list = lparen() & align(us_concat(args.clone()) | v_concat(args)) & rparen();
344 let body_forms = [text("(println 'hello)"), text("(+ 1 2 3)")];
345 let body = (space() & us_concat(body_forms.clone())) | (hard_nl() & v_concat(body_forms));
346
347 let doc = fn_name.clone() & nest(2, hard_nl() & v_append(arg_list.clone(), body.clone()))
348 | us_append(fn_name, nest(2, arg_list & body));
349
350 assert_eq!(
351 r#"(defn my-fn (arg1 arg2 arg3) (println 'hello) (+ 1 2 3)"#,
352 format!("{doc:120}")
353 );
354
355 assert_eq!(
356 r#"(defn my-fn (arg1 arg2 arg3)
357 (println 'hello)
358 (+ 1 2 3)"#,
359 format!("{doc:30}")
360 );
361
362 assert_eq!(
363 r#"(defn my-fn
364 (arg1 arg2 arg3)
365
366 (println 'hello)
367 (+ 1 2 3)"#,
368 format!("{doc:20}")
369 );
370
371 assert_eq!(
372 r#"(defn my-fn
373 (arg1
374 arg2
375 arg3)
376
377 (println 'hello)
378 (+ 1 2 3)"#,
379 format!("{doc:10}")
380 );
381 }
382
383 #[test]
384 fn full_comments() {
385 let doc = lparen()
386 & text("println")
387 & group(nl())
388 & full(text("; this is a comment"))
389 & nl()
390 & text("\"my text\"")
391 & rparen();
392
393 assert_eq!(
394 r#"(println ; this is a comment
395"my text")"#,
396 doc.to_string(),
397 );
398
399 let args = [
400 text("a"),
401 full(text("; the first one")),
402 text("b"),
403 full(text("; the second one")),
404 ];
405 let doc = align(v_concat(args.to_vec()));
406 assert_eq!(
407 r#"a
408; the first one
409b
410; the second one"#,
411 doc.to_string(),
412 );
413
414 let doc = group(nl()) & doc & group(brk());
415 assert_eq!(
416 r#" a
417 ; the first one
418 b
419 ; the second one"#,
420 doc.to_string(),
421 );
422
423 let doc = lparen() & ((space() & align(us_concat(args.to_vec()))) | doc) & rparen();
424 assert_eq!(
425 r#"( a
426 ; the first one
427 b
428 ; the second one
429)"#,
430 doc.to_string(),
431 );
432 }
433
434 enum Node {
436 Str(String),
437 List(Vec<Node>),
438 }
439
440 fn pretty(node: &Node) -> Doc {
441 match node {
442 Node::List(children) => {
443 if let Some((first, rest)) = children.split_first() {
444 let fp = pretty(first);
445 let args: Vec<_> = rest.iter().map(pretty).collect();
446 (lparen() & align(v_append(fp.clone(), v_concat(args.to_vec()))) & rparen())
447 | (lparen()
448 & align(fp.clone())
449 & space()
450 & align(v_concat(args.to_vec()))
451 & rparen())
452 | flatten(
453 lparen()
454 & align(us_append(fp.clone(), us_concat(args.to_vec())))
455 & rparen(),
456 )
457 } else {
458 text("()")
459 }
460 }
461 Node::Str(s) => text(s),
462 }
463 }
464
465 fn pretty2(node: &Node) -> Doc {
466 match node {
467 Node::List(children) => {
468 if let Some((first, rest)) = children.split_first() {
469 let fp = pretty2(first);
470 let args: Vec<_> = rest.iter().map(pretty2).collect();
471 (lparen() & align(v_append(fp.clone(), v_concat(args.to_vec()))) & rparen())
472 | (lparen()
473 & align(fp.clone())
474 & space()
475 & align(v_concat(args.to_vec()))
476 & rparen())
477 | (lparen()
478 & align(us_append(fp.clone(), us_concat(args.to_vec())))
479 & rparen())
480 } else {
481 text("()")
482 }
483 }
484 Node::Str(s) => text(s),
485 }
486 }
487
488 fn test_doc() -> Node {
489 use Node::*;
490
491 List(vec![
492 Str("+".to_string()),
493 List(vec![
494 Str("foo".to_string()),
495 Str("1".to_string()),
496 Str("2".to_string()),
497 ]),
498 List(vec![
499 Str("bar".to_string()),
500 Str("2".to_string()),
501 Str("3".to_string()),
502 ]),
503 List(vec![
504 Str("baz".to_string()),
505 Str("3".to_string()),
506 Str("4".to_string()),
507 ]),
508 ])
509 }
510
511 #[test]
512 fn check_pretty() {
513 let doc = pretty(&test_doc());
514
515 assert_eq!(
516 format!("{doc:31}"),
517 r#"(+ (foo 1 2)
518 (bar 2 3)
519 (baz 3 4))"#
520 );
521 }
522
523 #[test]
524 fn check_pretty2() {
525 let doc = pretty2(&test_doc());
526
527 assert_eq!(
528 format!("{doc:31}"),
529 r#"(+ (foo 1
530 2) (bar 2 3) (baz 3 4))"#
531 );
532 }
533
534 #[test]
535 fn smush_it() {
536 let doc = pretty(&Node::List(vec![
537 Node::Str("+".to_string()),
538 Node::Str("123".to_string()),
539 Node::Str("456".to_string()),
540 Node::Str("789".to_string()),
541 ]));
542
543 assert_eq!(format!("{doc:15}"), "(+ 123 456 789)");
544 assert_eq!(format!("{doc:14}"), "(+ 123\n 456\n 789)");
545 assert_eq!(format!("{doc:5}"), "(+\n 123\n 456\n 789)");
546 assert_eq!(format!("{doc:1}"), "(+\n 123\n 456\n 789)");
547 }
548
549 #[test]
550 fn more_checks() {
551 let doc = nest(4, reset(text("abc") & hard_nl() & text("def")));
552 assert_eq!(doc.to_string(), "abc\ndef");
553
554 let doc = nest(4, text("abc") & hard_nl() & text("def"));
555 assert_eq!(doc.to_string(), "abc\n def");
556
557 let doc = flatten(text("abc") & nl() & text("def")) | text("something");
558 assert_eq!(doc.to_string(), "abc def");
559
560 let doc = flatten(text("abc") & hard_nl() & text("def")) | text("something");
561 assert_eq!(doc.to_string(), "something");
562 }
563}