pretty-expressive 1.0.0

A pretty expressive printer
Documentation
use std::{fmt, rc::Rc};

use crate::{
    Doc,
    cost::{Cost, CostFactory},
    non_empty::NonEmptyVecBuilder,
    print::Printer,
};

// the warning says these bounds aren't enforced, but i sure do get a compile error
// if the CostFactory bound isn't there
#[expect(type_alias_bounds)]
pub(crate) type MeasurePromise<C: CostFactory> =
    Rc<dyn Fn(&mut Printer<C>) -> Option<Measure<C::CostType>>>;
pub type Layout = dyn Fn(&mut fmt::Formatter<'_>) -> fmt::Result;

pub(crate) enum MeasureSet<C: CostFactory> {
    Failed,
    Valid(Measure<C::CostType>, Vec<Measure<C::CostType>>),
    Tainted(usize, MeasurePromise<C>),
}

pub struct Measure<C: Cost> {
    pub last: usize,
    pub cost: C,
    pub layout: Rc<Layout>,
}

impl<C: CostFactory> Clone for MeasureSet<C> {
    fn clone(&self) -> Self {
        match self {
            MeasureSet::Failed => MeasureSet::Failed,
            MeasureSet::Valid(m, ms) => MeasureSet::Valid(m.clone(), ms.to_vec()),
            MeasureSet::Tainted(nl, thunk) => MeasureSet::Tainted(*nl, thunk.clone()),
        }
    }
}

impl<C: Cost> Clone for Measure<C> {
    fn clone(&self) -> Self {
        Self {
            last: self.last,
            cost: self.cost.clone(),
            layout: Rc::clone(&self.layout),
        }
    }
}

impl<C: CostFactory> fmt::Debug for MeasureSet<C> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            MeasureSet::Failed => f.debug_struct("Failed").finish(),
            MeasureSet::Valid(m, ms) => f.debug_list().entry(m).entries(ms).finish(),
            MeasureSet::Tainted(nl, _) => f.debug_tuple("Tainted").field(nl).finish(),
        }
    }
}

impl<C: Cost> fmt::Debug for Measure<C> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Measure")
            .field("last", &self.last)
            .field("cost", &self.cost)
            .finish()
    }
}

impl<C: CostFactory + 'static> MeasureSet<C> {
    pub fn new(
        last: usize,
        cost: C::CostType,
        layout: impl Fn(&mut fmt::Formatter) -> fmt::Result + 'static,
    ) -> Self {
        Self::Valid(
            Measure {
                last,
                cost,
                layout: Rc::new(layout),
            },
            vec![],
        )
    }

    pub fn tainted(
        d: &Doc<C::CostType>,
        f: impl Fn(&mut Printer<C>) -> Option<Measure<C::CostType>> + 'static,
    ) -> Self {
        Self::Tainted(d.0.newline_count, Rc::new(f))
    }

    pub fn merge(self, other: Self) -> Self {
        match (self, other) {
            (MeasureSet::Failed, other) => other,
            (this, MeasureSet::Failed) => this,
            (MeasureSet::Tainted(nl1, mt1), MeasureSet::Tainted(nl2, mt2)) => {
                let (nl, ms1, ms2) = if nl1 >= nl2 {
                    (nl1, mt1, mt2)
                } else {
                    (nl2, mt2, mt1)
                };
                // TODO handle pruning
                MeasureSet::Tainted(
                    nl,
                    Rc::new(move |r| match ms1(r) {
                        Some(m) => Some(m),
                        None => ms2(r),
                    }),
                )
            }
            (ms1, MeasureSet::Tainted(_, _)) => ms1,
            (MeasureSet::Tainted(_, _), ms2) => ms2,
            (MeasureSet::Valid(m1, ms1), MeasureSet::Valid(m2, ms2)) => {
                let mut iter1 = ms1.into_iter();
                let mut iter2 = ms2.into_iter();
                let mut result = NonEmptyVecBuilder::new();

                let mut m1_next = Some(m1);
                let mut m2_next = Some(m2);

                loop {
                    let Some(m1) = m1_next.take() else {
                        break;
                    };
                    let Some(m2) = m2_next.take() else {
                        // m1 got taken out of m1_next, so put it back to be able to handle it after the break
                        m1_next = Some(m1);
                        break;
                    };

                    if m1.dominates(&m2) {
                        m1_next = Some(m1);
                        m2_next = iter2.next();
                    } else if m2.dominates(&m1) {
                        m1_next = iter1.next();
                        m2_next = Some(m2);
                    } else if m1.last > m2.last {
                        result.push(m1);
                        m1_next = iter1.next();
                        m2_next = Some(m2);
                    } else {
                        result.push(m2);
                        m1_next = Some(m1);
                        m2_next = iter2.next();
                    }
                }

                if let Some(m1) = m1_next.take() {
                    result.push(m1);
                }
                if let Some(m2) = m2_next.take() {
                    result.push(m2);
                }

                result.extend(iter1);
                result.extend(iter2);

                let (first, rest) = result.finish();
                MeasureSet::Valid(first, rest)
            }
        }
    }
}

impl<C: Cost<Output = C>> Measure<C> {
    fn dominates(&self, other: &Self) -> bool {
        self.last <= other.last && self.cost <= other.cost
    }

    pub(crate) fn concat(self, m2: Self) -> Self {
        Measure {
            last: m2.last,
            cost: self.cost + m2.cost,
            layout: Rc::new(move |r| {
                (self.layout)(r)?;
                (m2.layout)(r)
            }),
        }
    }
}

impl<C: Cost> fmt::Display for Measure<C> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (self.layout)(f)
    }
}