use crate::passes::composable::WithScope;
use crate::passes::const_fold::{ConstFoldError, ConstantFoldPass};
use crate::passes::dead_funcs::RemoveDeadFuncsError;
use crate::passes::inline_dfgs::InlineDFGsPass;
use crate::passes::normalize_cfgs::{NormalizeCFGError, NormalizeCFGPass};
use crate::passes::redundant_order_edges::RedundantOrderEdgesPass;
use crate::passes::untuple::UntupleError;
use crate::passes::{ComposablePass, PassScope, RemoveDeadFuncsPass, UntuplePass};
use hugr::Node;
use hugr::hugr::HugrError;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::patch::inline_dfg::InlineDFGError;
use crate::passes::BorrowSquashPass;
#[derive(Clone, Debug)]
pub struct NormalizeGuppy {
simplify_cfgs: bool,
untuple: bool,
constant_fold: bool,
dead_funcs: bool,
inline_dfgs: bool,
squash_borrows: bool,
remove_redundant_order_edges: bool,
scope: PassScope,
}
impl NormalizeGuppy {
pub fn simplify_cfgs(&mut self, simplify_cfgs: bool) -> &mut Self {
self.simplify_cfgs = simplify_cfgs;
self
}
pub fn remove_tuple_untuple(&mut self, untuple: bool) -> &mut Self {
self.untuple = untuple;
self
}
pub fn constant_folding(&mut self, constant_fold: bool) -> &mut Self {
self.constant_fold = constant_fold;
self
}
pub fn remove_dead_funcs(&mut self, dead_funcs: bool) -> &mut Self {
self.dead_funcs = dead_funcs;
self
}
pub fn inline_dfgs(&mut self, inline: bool) -> &mut Self {
self.inline_dfgs = inline;
self
}
pub fn squash_borrows(&mut self, squash: bool) -> &mut Self {
self.squash_borrows = squash;
self
}
pub fn remove_redundant_order_edges(&mut self, remove: bool) -> &mut Self {
self.remove_redundant_order_edges = remove;
self
}
}
impl Default for NormalizeGuppy {
fn default() -> Self {
Self {
simplify_cfgs: true,
constant_fold: true,
untuple: true,
dead_funcs: true,
inline_dfgs: true,
squash_borrows: true,
remove_redundant_order_edges: true,
scope: PassScope::default(),
}
}
}
impl WithScope for NormalizeGuppy {
fn with_scope(mut self, scope: impl Into<crate::passes::PassScope>) -> Self {
self.scope = scope.into();
self
}
}
impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for NormalizeGuppy {
type Error = NormalizeGuppyErrors;
type Result = ();
fn run(&self, hugr: &mut H) -> Result<Self::Result, Self::Error> {
if self.simplify_cfgs {
NormalizeCFGPass::default()
.with_scope(self.scope.clone())
.run(hugr)?;
}
if self.untuple {
UntuplePass::default_with_scope(self.scope.clone()).run(hugr)?;
}
if self.constant_fold {
ConstantFoldPass::default()
.with_scope(self.scope.clone())
.run(hugr)?;
}
if self.dead_funcs {
RemoveDeadFuncsPass::default()
.with_scope(self.scope.clone())
.run(hugr)?;
}
if self.inline_dfgs {
InlineDFGsPass::default()
.with_scope(self.scope.clone())
.run(hugr)
.unwrap_or_else(|e| match e {})
}
if self.squash_borrows {
BorrowSquashPass::default()
.with_scope(self.scope.clone())
.run(hugr)
.unwrap_or_else(|e| match e {});
}
if self.remove_redundant_order_edges {
RedundantOrderEdgesPass::default()
.with_scope(self.scope.clone())
.run(hugr)
.map_err(NormalizeGuppyErrors::RedundantOrderEdges)?;
}
Ok(())
}
}
#[derive(derive_more::Error, Debug, derive_more::Display, derive_more::From)]
pub enum NormalizeGuppyErrors {
SimplifyCFG(NormalizeCFGError),
Untuple(UntupleError),
ConstantFold(ConstFoldError),
DeadFuncs(RemoveDeadFuncsError),
Inline(InlineDFGError),
#[from(ignore)]
RedundantOrderEdges(HugrError),
}
#[cfg(test)]
mod test {
use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder};
use hugr::extension::prelude::qb_t;
use hugr::types::Signature;
use crate::TketOp;
use super::*;
#[test]
fn test_guppy_pass_noop() {
let mut b = FunctionBuilder::new("main", Signature::new_endo(vec![qb_t()])).unwrap();
let [q] = b.input_wires_arr();
let [q] = b.add_dataflow_op(TketOp::H, [q]).unwrap().outputs_arr();
let hugr = b.finish_hugr_with_outputs([q]).unwrap();
let mut hugr2 = hugr.clone();
NormalizeGuppy::default()
.simplify_cfgs(false)
.remove_tuple_untuple(false)
.constant_folding(false)
.remove_dead_funcs(false)
.inline_dfgs(false)
.remove_redundant_order_edges(false)
.run(&mut hugr2)
.unwrap();
assert_eq!(hugr2, hugr);
}
}