use std::{
cell::RefCell,
collections::HashSet,
fmt,
ops::{Deref, DerefMut},
rc::Rc,
};
use haloumi_ir::Prime;
use crate::pcl::{
display::{ListItem, TextRepresentable, TextRepresentation},
expr::{self, Expr, traits::ConstraintEmitter},
ident::Ident,
stmt::{
self, Stmt,
traits::{ConstraintLike as _, FreeVars as _, StmtConstantFolding as _},
},
vars::{VarAllocator, VarKind, VarStr, Vars},
};
pub type ModuleRef<K> = Rc<RefCell<Module<K>>>;
impl<Key: VarKind + Default + Clone> VarAllocator for ModuleRef<Key> {
type Kind = Key;
fn allocate<K: Into<Self::Kind> + Into<VarStr> + Clone>(&self, kind: K) -> VarStr {
let mut r = self.borrow_mut();
r.deref_mut().add_var(kind)
}
}
struct ModuleSummary {
input_count: usize,
output_count: usize,
temp_count: usize,
constraint_count: usize,
}
type TR<'a> = TextRepresentation<'a>;
#[derive(Clone, Debug)]
pub struct ModuleHeader(Ident);
impl From<String> for ModuleHeader {
fn from(value: String) -> Self {
Self(value.into())
}
}
impl Deref for ModuleHeader {
type Target = String;
fn deref(&self) -> &Self::Target {
self.0.value()
}
}
impl DerefMut for ModuleHeader {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.value_mut()
}
}
impl TextRepresentable for ModuleHeader {
fn to_repr(&self) -> TextRepresentation<'_> {
owned_list!("begin-module", &self.0).break_line()
}
fn width_hint(&self) -> usize {
15 + self.0.width_hint()
}
}
#[derive(Debug)]
pub struct Module<K: VarKind> {
pub(crate) name: ModuleHeader,
pub(crate) stmts: Vec<Stmt>,
pub(crate) vars: Vars<K>,
}
impl<K: VarKind + Clone> Clone for Module<K> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
stmts: self.stmts.clone(),
vars: self.vars.clone(),
}
}
}
impl<K: VarKind + Clone> From<ModuleRef<K>> for Module<K> {
fn from(value: ModuleRef<K>) -> Self {
value.borrow().clone()
}
}
impl<K: VarKind + Default> From<String> for Module<K> {
fn from(name: String) -> Self {
Self {
name: name.into(),
stmts: Default::default(),
vars: Default::default(),
}
}
}
impl<K: VarKind> ConstraintEmitter for Module<K> {
fn emit(&mut self, lhs: Expr, rhs: Expr) {
self.stmts.push(stmt::constrain(expr::eq(&lhs, &rhs)))
}
}
pub trait ModuleLike<K> {
fn fold_stmts(&mut self, prime: Prime);
fn add_constraint(&mut self, constraint: Expr) {
log::debug!("Adding constraint {constraint:?}");
self.add_stmt(stmt::constrain(constraint))
}
fn constraints_len(&self) -> usize;
fn add_stmt(&mut self, stmt: Stmt);
}
pub trait ModuleWithVars<K> {
fn add_var<I: Into<K> + Into<VarStr> + Clone>(&mut self, k: I) -> VarStr;
#[allow(dead_code)]
fn add_vars<I: Into<K> + Into<VarStr> + Clone>(&mut self, it: impl Iterator<Item = I>) {
it.for_each(|k| {
self.add_var(k);
});
}
}
impl<K: VarKind> ModuleLike<K> for Module<K> {
fn fold_stmts(&mut self, prime: Prime) {
self.stmts = self
.stmts()
.iter()
.map(|s| s.fold(prime).unwrap_or(s.clone()))
.collect();
}
fn constraints_len(&self) -> usize {
self.stmts.iter().filter(|s| s.is_constraint()).count()
}
fn add_stmt(&mut self, stmt: Stmt) {
self.stmts.push(stmt)
}
}
impl<K: VarKind> ModuleLike<K> for ModuleRef<K> {
fn fold_stmts(&mut self, prime: Prime) {
self.borrow_mut().fold_stmts(prime)
}
fn constraints_len(&self) -> usize {
self.borrow().constraints_len()
}
fn add_stmt(&mut self, stmt: Stmt) {
self.borrow_mut().add_stmt(stmt)
}
}
impl<K: VarKind> Module<K> {
pub fn name(&self) -> &str {
self.name.as_str()
}
pub fn vars(&self) -> &Vars<K> {
&self.vars
}
pub fn stmts(&self) -> &[Stmt] {
&self.stmts
}
pub fn stmts_mut(&mut self) -> &mut [Stmt] {
&mut self.stmts
}
pub fn remove_stmt_if<F>(&mut self, mut f: F)
where
F: FnMut(&mut Stmt) -> bool,
{
self.stmts.retain_mut(|s| !f(s));
}
pub fn add_stmts(&mut self, stmts: &[Stmt]) {
self.stmts.extend_from_slice(stmts)
}
pub fn inputs_as_exprs(&self) -> anyhow::Result<Vec<Expr>> {
self.vars
.inputs()
.map(|input| Ok(expr::known_var(&VarStr::try_from(input.to_owned())?)))
.collect()
}
fn summarize(&self) -> ModuleSummary {
let input_count = self.vars.inputs().count();
let output_count = self.vars.outputs().count();
let free_vars = self
.stmts
.iter()
.flat_map(|s| s.free_vars())
.map(|fv| fv.as_ref())
.collect::<HashSet<_>>();
let temps = self.vars.temporaries().collect::<HashSet<_>>();
let used_temps = temps.intersection(&free_vars);
let temp_count = used_temps.count();
let constraint_count = self.stmts.iter().filter(|s| s.is_constraint()).count();
ModuleSummary {
input_count,
output_count,
temp_count,
constraint_count,
}
}
}
impl<K: VarKind + Default + Clone + fmt::Debug> Module<K> {
pub fn new<S: Into<K> + Into<VarStr> + Clone>(
name: String,
inputs: impl Iterator<Item = S>,
outputs: impl Iterator<Item = S>,
) -> Self {
let mut m = Self::from(name);
for k in inputs.chain(outputs) {
m.add_var(k);
}
m
}
pub fn shared<S: Into<K> + Into<VarStr> + Clone>(
name: String,
inputs: impl Iterator<Item = S>,
outputs: impl Iterator<Item = S>,
) -> ModuleRef<K> {
Rc::new(Self::new(name, inputs, outputs).into())
}
}
impl<K: VarKind + Default + Clone + fmt::Debug> ModuleWithVars<K> for Module<K> {
fn add_var<I: Into<K> + Into<VarStr> + Clone>(&mut self, k: I) -> VarStr {
self.vars.insert(k)
}
#[allow(dead_code)]
fn add_vars<I: Into<K> + Into<VarStr> + Clone>(&mut self, it: impl Iterator<Item = I>) {
self.vars.extend(it)
}
}
impl<K: VarKind> TextRepresentable for Module<K> {
fn to_repr(&self) -> TextRepresentation<'_> {
let summary = self.summarize();
let sorted_inputs = self.vars.inputs().collect::<Vec<_>>();
let sorted_outputs = self.vars.outputs().collect::<Vec<_>>();
owned_list!(&self.name)
+ [
format!("Number of inputs: {}", summary.input_count),
format!("Number of outputs: {}", summary.output_count),
format!("Number of temporaries: {}", summary.temp_count),
format!("Number of constraints: {}", summary.constraint_count),
]
.into_iter()
.map(TR::owned_comment)
.sum()
+ TR::owned_list(
&sorted_inputs
.into_iter()
.map(|i: &str| owned_list!("input", i).break_line().into())
.collect::<Vec<ListItem>>(),
)
+ TR::owned_list(
&sorted_outputs
.into_iter()
.map(|o: &str| owned_list!("output", o).break_line().into())
.collect::<Vec<ListItem>>(),
)
+ self.stmts.to_repr()
+ owned_list!(owned_list!("end-module"))
+ TR::comment(self.name())
}
fn width_hint(&self) -> usize {
self.name.width_hint()
}
}