pretty_expressive/
print.rs

1use std::{collections::HashMap, fmt, rc::Rc};
2
3use crate::{
4    DefaultCost, Doc, DocId, DocKind,
5    cost::{Cost, CostFactory, DefaultCostFactory},
6    measure::{Measure, MeasureSet},
7    non_empty::NonEmptyVecBuilder,
8};
9
10/// Result type for pretty-printing operations.
11pub type Result<T, E = Error> = std::result::Result<T, E>;
12
13/// Error to signal that no printable layout could be found for a document.
14///
15/// This means that every choice path the printer explored resulted in a
16/// [`fail`](crate::fail). Some constraints in the document will need to be relaxed for it to
17/// be printable.
18#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
19#[error("document was not printable")]
20pub struct Error;
21
22pub(crate) struct Printer<C: CostFactory> {
23    cost: C,
24    memo: HashMap<(DocId, usize, usize, bool, bool), MeasureSet<C>>,
25}
26
27impl<C: CostFactory + 'static> Printer<C> {
28    fn new(cost: C) -> Self {
29        Self {
30            cost,
31            memo: HashMap::new(),
32        }
33    }
34
35    fn validate(&mut self, d: Doc<C::CostType>, c: usize) -> Result<PrintResult<C::CostType>> {
36        let result = self
37            .resolve(d.clone(), c, 0, false, false)
38            .merge(self.resolve(d, c, 0, false, true));
39        let is_tainted = matches!(&result, MeasureSet::Tainted(_, _));
40        let measure = self.extract_at_most_one(result).ok_or(Error)?;
41
42        Ok(PrintResult {
43            is_tainted,
44            measure,
45        })
46    }
47
48    fn resolve(
49        &mut self,
50        d: Doc<C::CostType>,
51        c: usize,
52        i: usize,
53        begin_full: bool,
54        end_full: bool,
55    ) -> MeasureSet<C> {
56        if d.0.kind.fails(begin_full, end_full) {
57            MeasureSet::Failed
58        } else if c <= self.cost.limit() && i <= self.cost.limit() && d.0.memo_weight == 0 {
59            let id = d.0.id;
60            let key = (id, c, i, begin_full, end_full);
61            if let Some(ms) = self.memo.get(&key) {
62                ms.clone()
63            } else {
64                let result = self.resolve_inner(d, c, i, begin_full, end_full, false);
65                self.memo.insert(key, result.clone());
66                result
67            }
68        } else {
69            self.resolve_inner(d, c, i, begin_full, end_full, false)
70        }
71    }
72
73    fn resolve_inner(
74        &mut self,
75        d: Doc<C::CostType>,
76        c: usize,
77        i: usize,
78        begin_full: bool,
79        end_full: bool,
80        allow_exceeds: bool,
81    ) -> MeasureSet<C> {
82        use DocKind::*;
83
84        let exceeds = if let Text(_, len) = &d.0.kind {
85            c + len > self.cost.limit() || i > self.cost.limit()
86        } else {
87            c > self.cost.limit() || i > self.cost.limit()
88        };
89
90        if exceeds && !allow_exceeds {
91            let d = d.clone();
92            return MeasureSet::Tainted(
93                d.0.newline_count,
94                Rc::new(move |this| {
95                    let resolved = this.resolve_inner(d.clone(), c, i, begin_full, end_full, true);
96                    // TODO maybe figure out if we need to store the failure
97                    this.extract_at_most_one(resolved)
98                }),
99            );
100        }
101
102        match &d.0.kind {
103            Text(s, len) => {
104                let s = s.clone();
105                MeasureSet::new(len + c, self.cost.text(c, *len), move |w| {
106                    write!(w, "{}", s)
107                })
108            }
109            Newline(_) => MeasureSet::new(i, self.cost.newline(i), move |w| {
110                writeln!(w)?;
111                write!(w, "{}", " ".repeat(i))
112            }),
113            Concat(d1, d2) => {
114                let mut analyze_left =
115                    |mid_full| match self.resolve(d1.clone(), c, i, begin_full, mid_full) {
116                        MeasureSet::Failed => MeasureSet::Failed,
117                        MeasureSet::Tainted(_, thunk) => {
118                            let d2 = d2.clone();
119                            MeasureSet::tainted(&d, move |this| {
120                                let m1 = thunk(this)?;
121                                let resolved =
122                                    this.resolve(d2.clone(), m1.last, i, mid_full, end_full);
123                                this.extract_at_most_one(resolved).map(|m2| m1.concat(m2))
124                            })
125                        }
126                        MeasureSet::Valid(m1, ms1) => {
127                            let first = self.analyze_right(m1, &d, d2, i, mid_full, end_full);
128                            ms1.into_iter().rfold(first, |ms, m| {
129                                self.analyze_right(m.clone(), &d, d2, i, mid_full, end_full)
130                                    .merge(ms)
131                            })
132                        }
133                    };
134
135                analyze_left(false).merge(analyze_left(true))
136            }
137            Alt(d1, d2) => {
138                let r1 = self.resolve(d1.clone(), c, i, begin_full, end_full);
139                let r2 = self.resolve(d2.clone(), c, i, begin_full, end_full);
140                if d1.0.newline_count < d2.0.newline_count {
141                    r2.merge(r1)
142                } else {
143                    r1.merge(r2)
144                }
145            }
146            Nest(n, d) => self.resolve(d.clone(), c, i + n, begin_full, end_full),
147            Align(d) => self.resolve(d.clone(), c, c, begin_full, end_full),
148            Reset(d) => self.resolve(d.clone(), c, 0, begin_full, end_full),
149            Cost(co, d) => {
150                let co = co.clone();
151                let add_cost = move |mut m: Measure<C::CostType>| {
152                    m.cost = co.clone() + m.cost;
153                    m
154                };
155
156                match self.resolve(d.clone(), c, i, begin_full, end_full) {
157                    MeasureSet::Failed => MeasureSet::Failed,
158                    MeasureSet::Valid(m, ms) => {
159                        MeasureSet::Valid(add_cost(m), ms.into_iter().map(add_cost).collect())
160                    }
161                    MeasureSet::Tainted(_, thunk) => {
162                        MeasureSet::tainted(d, move |this| thunk(this).map(&add_cost))
163                    }
164                }
165            }
166            Full(d) => self
167                .resolve(d.clone(), c, i, begin_full, false)
168                .merge(self.resolve(d.clone(), c, i, begin_full, true)),
169            Fail => MeasureSet::Failed,
170        }
171    }
172
173    fn analyze_right(
174        &mut self,
175        m: Measure<C::CostType>,
176        d: &Doc<C::CostType>,
177        d2: &Doc<C::CostType>,
178        i: usize,
179        begin_full: bool,
180        end_full: bool,
181    ) -> MeasureSet<C> {
182        match self.resolve(d2.clone(), m.last, i, begin_full, end_full) {
183            MeasureSet::Failed => MeasureSet::Failed,
184            MeasureSet::Tainted(_, thunk) => MeasureSet::tainted(d, move |this| {
185                let m2 = thunk(this)?;
186                Some(m.clone().concat(m2))
187            }),
188            MeasureSet::Valid(m2, ms2) => {
189                let mut result = NonEmptyVecBuilder::new();
190                let mut current_best = m.clone().concat(m2);
191
192                for m2 in ms2.into_iter() {
193                    let current = m.clone().concat(m2);
194                    if current.cost > current_best.cost {
195                        result.push(current_best);
196                    }
197                    current_best = current;
198                }
199
200                result.push(current_best);
201                let (first, rest) = result.finish();
202                MeasureSet::Valid(first, rest)
203            }
204        }
205    }
206
207    fn extract_at_most_one(&mut self, ms: MeasureSet<C>) -> Option<Measure<C::CostType>> {
208        match ms {
209            MeasureSet::Failed => None,
210            MeasureSet::Tainted(_, thunk) => thunk(self),
211            MeasureSet::Valid(m, _) => Some(m),
212        }
213    }
214}
215
216/// The resolved optimal layout for a successful print attempt.
217///
218/// A `PrintResult` is returned from [`Doc::validate`] and
219/// [`Doc::validate_with_cost`] when the printer is able to successfully produce
220/// a layout for a document.
221///
222/// `PrintResult` implements [`Display`](std::fmt::Display), so the chosen layout
223/// can be rendered through any means that allows.
224#[derive(Debug)]
225pub struct PrintResult<C: Cost> {
226    is_tainted: bool,
227    measure: Measure<C>,
228}
229
230impl<C: Cost> PrintResult<C> {
231    /// Indicates if the layout chosen was tainted.
232    ///
233    /// A tainted layout is one that exceeds the [`computation width
234    /// limit`](CostFactory::limit) imposed by the cost factory. Such a layout
235    /// won't be chosen unless there are no valid untainted layouts available.
236    /// If a tainted layout is chosen, it is not guaranteed to be optimal
237    /// according to the cost factory.
238    pub fn is_tainted(&self) -> bool {
239        self.is_tainted
240    }
241
242    /// The cost of the chosen layout.
243    pub fn cost(&self) -> C {
244        self.measure.cost.clone()
245    }
246}
247
248impl<C: Cost> fmt::Display for PrintResult<C> {
249    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250        (self.measure.layout)(f)
251    }
252}
253
254impl fmt::Display for Doc<DefaultCost> {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        let page_width = f.width().unwrap_or(80);
257        let result = self
258            .validate(page_width)
259            .expect("couldn't validate print result");
260        write!(f, "{result}")
261    }
262}
263
264impl Doc<DefaultCost> {
265    /// Run the printer to produce a layout with the default cost factory.
266    ///
267    /// This can be used instead of printing the document directly to be able
268    /// to check if the printer failed to find a layout for the document.
269    ///
270    /// If this returns `Ok`, then the resulting [`PrintResult`] can be printed
271    /// to produce the chosen layout.
272    ///
273    /// # Example
274    ///
275    /// ```
276    /// # use pretty_expressive::*;
277    /// let doc = text("hello") & space() & text("world");
278    ///
279    /// let result = doc.validate(80)?;
280    /// assert_eq!(result.cost(), DefaultCost(0, 0));
281    /// assert!(!result.is_tainted());
282    ///
283    /// assert_eq!(result.to_string(), "hello world");
284    ///
285    /// // the only valid layout exceeds the computation width limit
286    /// // (6 in this case) so the layout is still produced, but tainted
287    /// let result = doc.validate(5)?;
288    /// assert_eq!(result.cost(), DefaultCost(36, 0));
289    /// assert!(result.is_tainted());
290    ///
291    /// assert_eq!(result.to_string(), "hello world");
292    ///
293    /// let doc = fail();
294    /// assert!(matches!(doc.validate(80), Err(pretty_expressive::Error)));
295    ///
296    /// # Ok::<(), Error>(())
297    /// ```
298    pub fn validate(&self, page_width: usize) -> Result<PrintResult<DefaultCost>> {
299        self.validate_with_cost(DefaultCostFactory::new(page_width, None))
300    }
301}
302
303impl<C: Cost> Doc<C> {
304    /// Run the printer to produce a layout with a custom cost factory.
305    ///
306    /// This function behaves just like [`validate`](Self::validate), but you
307    /// provide it the entire cost factory rather than the page width.
308    ///
309    /// This is the only way to print documents that use a [`Cost`] that is not
310    /// [`DefaultCost`].
311    ///
312    /// # Example
313    ///
314    /// ```
315    /// # use pretty_expressive::*;
316    /// let cf = DefaultCostFactory::new(20, Some(40));
317    ///
318    /// let doc = text("hello") & space() & text("world");
319    ///
320    /// let result = doc.validate_with_cost(cf)?;
321    /// assert_eq!(result.cost(), DefaultCost(0, 0));
322    /// assert!(!result.is_tainted());
323    ///
324    /// assert_eq!(result.to_string(), "hello world");
325    /// # Ok::<(), Error>(())
326    /// ```
327    pub fn validate_with_cost<CF: CostFactory<CostType = C> + 'static>(
328        &self,
329        cost: CF,
330    ) -> Result<PrintResult<C>> {
331        Printer::new(cost).validate(self.clone(), 0)
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use crate::*;
338
339    #[test]
340    fn s_exp() {
341        let fn_name = text("(defn") & space() & text("my-fn");
342        let args = [text("arg1"), text("arg2"), text("arg3")];
343        let arg_list = lparen() & align(us_concat(args.clone()) | v_concat(args)) & rparen();
344        let body_forms = [text("(println 'hello)"), text("(+ 1 2 3)")];
345        let body = (space() & us_concat(body_forms.clone())) | (hard_nl() & v_concat(body_forms));
346
347        let doc = fn_name.clone() & nest(2, hard_nl() & v_append(arg_list.clone(), body.clone()))
348            | us_append(fn_name, nest(2, arg_list & body));
349
350        assert_eq!(
351            r#"(defn my-fn (arg1 arg2 arg3) (println 'hello) (+ 1 2 3)"#,
352            format!("{doc:120}")
353        );
354
355        assert_eq!(
356            r#"(defn my-fn (arg1 arg2 arg3)
357  (println 'hello)
358  (+ 1 2 3)"#,
359            format!("{doc:30}")
360        );
361
362        assert_eq!(
363            r#"(defn my-fn
364  (arg1 arg2 arg3)
365  
366  (println 'hello)
367  (+ 1 2 3)"#,
368            format!("{doc:20}")
369        );
370
371        assert_eq!(
372            r#"(defn my-fn
373  (arg1
374   arg2
375   arg3)
376  
377  (println 'hello)
378  (+ 1 2 3)"#,
379            format!("{doc:10}")
380        );
381    }
382
383    #[test]
384    fn full_comments() {
385        let doc = lparen()
386            & text("println")
387            & group(nl())
388            & full(text("; this is a comment"))
389            & nl()
390            & text("\"my text\"")
391            & rparen();
392
393        assert_eq!(
394            r#"(println ; this is a comment
395"my text")"#,
396            doc.to_string(),
397        );
398
399        let args = [
400            text("a"),
401            full(text("; the first one")),
402            text("b"),
403            full(text("; the second one")),
404        ];
405        let doc = align(v_concat(args.to_vec()));
406        assert_eq!(
407            r#"a
408; the first one
409b
410; the second one"#,
411            doc.to_string(),
412        );
413
414        let doc = group(nl()) & doc & group(brk());
415        assert_eq!(
416            r#" a
417 ; the first one
418 b
419 ; the second one"#,
420            doc.to_string(),
421        );
422
423        let doc = lparen() & ((space() & align(us_concat(args.to_vec()))) | doc) & rparen();
424        assert_eq!(
425            r#"( a
426  ; the first one
427  b
428  ; the second one
429)"#,
430            doc.to_string(),
431        );
432    }
433
434    // tests adapted from ones in the racket version
435    enum Node {
436        Str(String),
437        List(Vec<Node>),
438    }
439
440    fn pretty(node: &Node) -> Doc {
441        match node {
442            Node::List(children) => {
443                if let Some((first, rest)) = children.split_first() {
444                    let fp = pretty(first);
445                    let args: Vec<_> = rest.iter().map(pretty).collect();
446                    (lparen() & align(v_append(fp.clone(), v_concat(args.to_vec()))) & rparen())
447                        | (lparen()
448                            & align(fp.clone())
449                            & space()
450                            & align(v_concat(args.to_vec()))
451                            & rparen())
452                        | flatten(
453                            lparen()
454                                & align(us_append(fp.clone(), us_concat(args.to_vec())))
455                                & rparen(),
456                        )
457                } else {
458                    text("()")
459                }
460            }
461            Node::Str(s) => text(s),
462        }
463    }
464
465    fn pretty2(node: &Node) -> Doc {
466        match node {
467            Node::List(children) => {
468                if let Some((first, rest)) = children.split_first() {
469                    let fp = pretty2(first);
470                    let args: Vec<_> = rest.iter().map(pretty2).collect();
471                    (lparen() & align(v_append(fp.clone(), v_concat(args.to_vec()))) & rparen())
472                        | (lparen()
473                            & align(fp.clone())
474                            & space()
475                            & align(v_concat(args.to_vec()))
476                            & rparen())
477                        | (lparen()
478                            & align(us_append(fp.clone(), us_concat(args.to_vec())))
479                            & rparen())
480                } else {
481                    text("()")
482                }
483            }
484            Node::Str(s) => text(s),
485        }
486    }
487
488    fn test_doc() -> Node {
489        use Node::*;
490
491        List(vec![
492            Str("+".to_string()),
493            List(vec![
494                Str("foo".to_string()),
495                Str("1".to_string()),
496                Str("2".to_string()),
497            ]),
498            List(vec![
499                Str("bar".to_string()),
500                Str("2".to_string()),
501                Str("3".to_string()),
502            ]),
503            List(vec![
504                Str("baz".to_string()),
505                Str("3".to_string()),
506                Str("4".to_string()),
507            ]),
508        ])
509    }
510
511    #[test]
512    fn check_pretty() {
513        let doc = pretty(&test_doc());
514
515        assert_eq!(
516            format!("{doc:31}"),
517            r#"(+ (foo 1 2)
518   (bar 2 3)
519   (baz 3 4))"#
520        );
521    }
522
523    #[test]
524    fn check_pretty2() {
525        let doc = pretty2(&test_doc());
526
527        assert_eq!(
528            format!("{doc:31}"),
529            r#"(+ (foo 1
530        2) (bar 2 3) (baz 3 4))"#
531        );
532    }
533
534    #[test]
535    fn smush_it() {
536        let doc = pretty(&Node::List(vec![
537            Node::Str("+".to_string()),
538            Node::Str("123".to_string()),
539            Node::Str("456".to_string()),
540            Node::Str("789".to_string()),
541        ]));
542
543        assert_eq!(format!("{doc:15}"), "(+ 123 456 789)");
544        assert_eq!(format!("{doc:14}"), "(+ 123\n   456\n   789)");
545        assert_eq!(format!("{doc:5}"), "(+\n 123\n 456\n 789)");
546        assert_eq!(format!("{doc:1}"), "(+\n 123\n 456\n 789)");
547    }
548
549    #[test]
550    fn more_checks() {
551        let doc = nest(4, reset(text("abc") & hard_nl() & text("def")));
552        assert_eq!(doc.to_string(), "abc\ndef");
553
554        let doc = nest(4, text("abc") & hard_nl() & text("def"));
555        assert_eq!(doc.to_string(), "abc\n    def");
556
557        let doc = flatten(text("abc") & nl() & text("def")) | text("something");
558        assert_eq!(doc.to_string(), "abc def");
559
560        let doc = flatten(text("abc") & hard_nl() & text("def")) | text("something");
561        assert_eq!(doc.to_string(), "something");
562    }
563}