haloumi-picus 0.5.5

Haloumi backend to the Picus Constraint Language.
Documentation
use std::{
    cell::RefCell,
    collections::HashSet,
    fmt,
    ops::{Deref, DerefMut},
    rc::Rc,
};

use haloumi_ir::Prime;

use crate::pcl::{
    display::{ListItem, TextRepresentable, TextRepresentation},
    expr::{self, Expr, traits::ConstraintEmitter},
    ident::Ident,
    stmt::{
        self, Stmt,
        traits::{ConstraintLike as _, FreeVars as _, StmtConstantFolding as _},
    },
    vars::{VarAllocator, VarKind, VarStr, Vars},
};

pub type ModuleRef<K> = Rc<RefCell<Module<K>>>;

impl<Key: VarKind + Default + Clone> VarAllocator for ModuleRef<Key> {
    type Kind = Key;

    fn allocate<K: Into<Self::Kind> + Into<VarStr> + Clone>(&self, kind: K) -> VarStr {
        let mut r = self.borrow_mut();
        r.deref_mut().add_var(kind)
    }
}

struct ModuleSummary {
    input_count: usize,
    output_count: usize,
    temp_count: usize,
    constraint_count: usize,
}

type TR<'a> = TextRepresentation<'a>;

#[derive(Clone, Debug)]
pub struct ModuleHeader(Ident);

impl From<String> for ModuleHeader {
    fn from(value: String) -> Self {
        Self(value.into())
    }
}

impl Deref for ModuleHeader {
    type Target = String;

    fn deref(&self) -> &Self::Target {
        self.0.value()
    }
}

impl DerefMut for ModuleHeader {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.0.value_mut()
    }
}

impl TextRepresentable for ModuleHeader {
    fn to_repr(&self) -> TextRepresentation<'_> {
        owned_list!("begin-module", &self.0).break_line()
    }

    fn width_hint(&self) -> usize {
        15 + self.0.width_hint()
    }
}

#[derive(Debug)]
pub struct Module<K: VarKind> {
    pub(crate) name: ModuleHeader,
    pub(crate) stmts: Vec<Stmt>,
    pub(crate) vars: Vars<K>,
}

impl<K: VarKind + Clone> Clone for Module<K> {
    fn clone(&self) -> Self {
        Self {
            name: self.name.clone(),
            stmts: self.stmts.clone(),
            vars: self.vars.clone(),
        }
    }
}

impl<K: VarKind + Clone> From<ModuleRef<K>> for Module<K> {
    fn from(value: ModuleRef<K>) -> Self {
        value.borrow().clone()
    }
}

impl<K: VarKind + Default> From<String> for Module<K> {
    fn from(name: String) -> Self {
        Self {
            name: name.into(),
            stmts: Default::default(),
            vars: Default::default(),
        }
    }
}

impl<K: VarKind> ConstraintEmitter for Module<K> {
    fn emit(&mut self, lhs: Expr, rhs: Expr) {
        self.stmts.push(stmt::constrain(expr::eq(&lhs, &rhs)))
    }
}

pub trait ModuleLike<K> {
    fn fold_stmts(&mut self, prime: Prime);

    fn add_constraint(&mut self, constraint: Expr) {
        log::debug!("Adding constraint {constraint:?}");
        self.add_stmt(stmt::constrain(constraint))
    }

    fn constraints_len(&self) -> usize;

    fn add_stmt(&mut self, stmt: Stmt);
}

pub trait ModuleWithVars<K> {
    fn add_var<I: Into<K> + Into<VarStr> + Clone>(&mut self, k: I) -> VarStr;

    #[allow(dead_code)]
    fn add_vars<I: Into<K> + Into<VarStr> + Clone>(&mut self, it: impl Iterator<Item = I>) {
        it.for_each(|k| {
            self.add_var(k);
        });
    }
}

impl<K: VarKind> ModuleLike<K> for Module<K> {
    fn fold_stmts(&mut self, prime: Prime) {
        self.stmts = self
            .stmts()
            .iter()
            .map(|s| s.fold(prime).unwrap_or(s.clone()))
            .collect();
    }

    fn constraints_len(&self) -> usize {
        self.stmts.iter().filter(|s| s.is_constraint()).count()
    }

    fn add_stmt(&mut self, stmt: Stmt) {
        self.stmts.push(stmt)
    }
}

impl<K: VarKind> ModuleLike<K> for ModuleRef<K> {
    fn fold_stmts(&mut self, prime: Prime) {
        self.borrow_mut().fold_stmts(prime)
    }

    fn constraints_len(&self) -> usize {
        self.borrow().constraints_len()
    }

    fn add_stmt(&mut self, stmt: Stmt) {
        self.borrow_mut().add_stmt(stmt)
    }
}

impl<K: VarKind> Module<K> {
    pub fn name(&self) -> &str {
        self.name.as_str()
    }

    pub fn vars(&self) -> &Vars<K> {
        &self.vars
    }

    pub fn stmts(&self) -> &[Stmt] {
        &self.stmts
    }

    pub fn stmts_mut(&mut self) -> &mut [Stmt] {
        &mut self.stmts
    }

    pub fn remove_stmt_if<F>(&mut self, mut f: F)
    where
        F: FnMut(&mut Stmt) -> bool,
    {
        self.stmts.retain_mut(|s| !f(s));
    }

    pub fn add_stmts(&mut self, stmts: &[Stmt]) {
        self.stmts.extend_from_slice(stmts)
    }

    pub fn inputs_as_exprs(&self) -> anyhow::Result<Vec<Expr>> {
        self.vars
            .inputs()
            .map(|input| Ok(expr::known_var(&VarStr::try_from(input.to_owned())?)))
            .collect()
    }

    fn summarize(&self) -> ModuleSummary {
        let input_count = self.vars.inputs().count();
        let output_count = self.vars.outputs().count();
        let free_vars = self
            .stmts
            .iter()
            .flat_map(|s| s.free_vars())
            .map(|fv| fv.as_ref())
            .collect::<HashSet<_>>();
        let temps = self.vars.temporaries().collect::<HashSet<_>>();
        let used_temps = temps.intersection(&free_vars);
        let temp_count = used_temps.count();
        let constraint_count = self.stmts.iter().filter(|s| s.is_constraint()).count();

        ModuleSummary {
            input_count,
            output_count,
            temp_count,
            constraint_count,
        }
    }
}

impl<K: VarKind + Default + Clone + fmt::Debug> Module<K> {
    pub fn new<S: Into<K> + Into<VarStr> + Clone>(
        name: String,
        inputs: impl Iterator<Item = S>,
        outputs: impl Iterator<Item = S>,
    ) -> Self {
        let mut m = Self::from(name);
        for k in inputs.chain(outputs) {
            m.add_var(k);
        }
        m
    }
    pub fn shared<S: Into<K> + Into<VarStr> + Clone>(
        name: String,
        inputs: impl Iterator<Item = S>,
        outputs: impl Iterator<Item = S>,
    ) -> ModuleRef<K> {
        Rc::new(Self::new(name, inputs, outputs).into())
    }
}

impl<K: VarKind + Default + Clone + fmt::Debug> ModuleWithVars<K> for Module<K> {
    fn add_var<I: Into<K> + Into<VarStr> + Clone>(&mut self, k: I) -> VarStr {
        self.vars.insert(k)
    }

    #[allow(dead_code)]
    fn add_vars<I: Into<K> + Into<VarStr> + Clone>(&mut self, it: impl Iterator<Item = I>) {
        self.vars.extend(it)
    }
}

impl<K: VarKind> TextRepresentable for Module<K> {
    fn to_repr(&self) -> TextRepresentation<'_> {
        let summary = self.summarize();
        let sorted_inputs = self.vars.inputs().collect::<Vec<_>>();
        let sorted_outputs = self.vars.outputs().collect::<Vec<_>>();
        owned_list!(&self.name)
            + [
                format!("Number of inputs:      {}", summary.input_count),
                format!("Number of outputs:     {}", summary.output_count),
                format!("Number of temporaries: {}", summary.temp_count),
                format!("Number of constraints: {}", summary.constraint_count),
            ]
            .into_iter()
            .map(TR::owned_comment)
            .sum()
            + TR::owned_list(
                &sorted_inputs
                    .into_iter()
                    .map(|i: &str| owned_list!("input", i).break_line().into())
                    .collect::<Vec<ListItem>>(),
            )
            + TR::owned_list(
                &sorted_outputs
                    .into_iter()
                    .map(|o: &str| owned_list!("output", o).break_line().into())
                    .collect::<Vec<ListItem>>(),
            )
            + self.stmts.to_repr()
            + owned_list!(owned_list!("end-module"))
            + TR::comment(self.name())
    }

    fn width_hint(&self) -> usize {
        self.name.width_hint()
    }
}