pub mod value_handle;
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use hugr_core::{
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
hugr::hugrmut::HugrMut,
ops::{
Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, constant::OpaqueValue,
},
types::EdgeKind,
};
use value_handle::ValueHandle;
use crate::passes::composable::{ComposablePass, PassScope, WithScope};
use crate::passes::dataflow::{
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
partial_from_const,
};
use crate::passes::dead_code::{DeadCodeElimError, DeadCodeElimPass, PreserveNode};
#[derive(Debug, Clone, Default)]
pub struct ConstantFoldPass {
allow_increase_termination: bool,
scope: PassScope,
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
}
#[derive(Clone, Debug, Error, PartialEq)]
#[non_exhaustive]
pub enum ConstFoldError {
#[error("{node} has OpType {op} which cannot be an entry-point")]
InvalidEntryPoint {
node: Node,
op: OpType,
},
#[error("Entry-point {node} is not part of the Hugr")]
MissingEntryPoint {
node: Node,
},
}
impl ConstantFoldPass {
#[must_use]
pub fn allow_increase_termination(mut self) -> Self {
self.allow_increase_termination = true;
self
}
pub fn with_inputs(
mut self,
node: Node,
inputs: impl IntoIterator<Item = (impl Into<IncomingPort>, Value)>,
) -> Self {
self.inputs
.entry(node)
.or_default()
.extend(inputs.into_iter().map(|(p, v)| (p.into(), v)));
self
}
}
impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
type Error = ConstFoldError;
type Result = ();
fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
let Some(root) = self.scope.root(hugr) else {
return Ok(()); };
let fresh_node = Node::from(portgraph::NodeIndex::new(
hugr.nodes().max().map_or(0, |n| n.index() + 1),
));
let mut m = Machine::new(&hugr);
for (&n, in_vals) in &self.inputs {
if !hugr.contains_node(n) {
return Err(ConstFoldError::MissingEntryPoint { node: n });
}
m.prepopulate_inputs(
n,
in_vals.iter().map(|(p, v)| {
let const_with_dummy_loc = partial_from_const(
&ConstFoldContext,
ConstLocation::Field(p.index(), &fresh_node.into()),
v,
);
(*p, const_with_dummy_loc)
}),
)
.map_err(|op| ConstFoldError::InvalidEntryPoint { node: n, op })?;
}
for node in self.scope.preserve_interface(hugr) {
if node == hugr.module_root() || self.inputs.contains_key(&node) {
continue;
}
if hugr.children(node).next().is_none() {
continue;
}
const NO_INPUTS: [(IncomingPort, PartialValue<ValueHandle>); 0] = [];
m.prepopulate_inputs(node, NO_INPUTS)
.map_err(|op| ConstFoldError::InvalidEntryPoint { node, op })?;
}
let results = m.run_subtree(ConstFoldContext, root);
let mb_root_inp = hugr.get_io(hugr.entrypoint()).map(|[i, _]| i);
let wires_to_break = hugr
.descendants(root)
.flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip)))
.filter(|(n, ip)| {
*n != root && matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_)))
})
.filter_map(|(n, ip)| {
let (src, outp) = hugr.single_linked_output(n, ip).unwrap();
(!hugr.get_optype(src).is_load_constant() && Some(src) != mb_root_inp).then_some((
n,
ip,
results
.try_read_wire_concrete::<Value>(Wire::new(src, outp))
.ok()?,
))
})
.collect::<Vec<_>>();
let terminating_tail_loops = hugr
.descendants(root)
.filter(|n| {
results.tail_loop_terminates(*n) == Some(TailLoopTermination::NeverContinues)
})
.collect::<Vec<_>>();
for (n, inport, v) in wires_to_break {
let parent = hugr.get_parent(n).unwrap();
let datatype = v.get_type();
let cst = hugr.add_node_with_parent(parent, Const::new(v));
let lcst = hugr.add_node_with_parent(parent, LoadConstant { datatype });
hugr.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0));
hugr.disconnect(n, inport);
hugr.connect(lcst, OutgoingPort::from(0), n, inport);
}
let dce = DeadCodeElimPass::<H>::default_with_scope(self.scope.clone());
dce.with_entry_points(self.inputs.keys().copied())
.set_preserve_callback(if self.allow_increase_termination {
Arc::new(|_, _| PreserveNode::CanRemoveIgnoringChildren)
} else {
Arc::new(move |h, n| {
if terminating_tail_loops.contains(&n) {
PreserveNode::DeferToChildren
} else {
PreserveNode::default_for(h, n)
}
})
})
.run(hugr)
.map_err(|e| match e {
DeadCodeElimError::NodeNotFound(_) => {
panic!("ConstFoldError::MissingEntrypoint not raised above")
}
})?;
Ok(())
}
}
impl WithScope for ConstantFoldPass {
fn with_scope(mut self, scope: impl Into<PassScope>) -> Self {
self.scope = scope.into();
self
}
}
struct ConstFoldContext;
impl ConstLoader<ValueHandle<Node>> for ConstFoldContext {
type Node = Node;
fn value_from_opaque(
&self,
loc: ConstLocation<Node>,
val: &OpaqueValue,
) -> Option<ValueHandle<Node>> {
Some(ValueHandle::new_opaque(loc, val.clone()))
}
}
impl DFContext<ValueHandle<Node>> for ConstFoldContext {
fn interpret_leaf_op(
&mut self,
node: Node,
op: &ExtensionOp,
ins: &[PartialValue<ValueHandle<Node>>],
outs: &mut [PartialValue<ValueHandle<Node>>],
) {
let sig = op.signature();
let known_ins = sig
.input_types()
.iter()
.enumerate()
.zip(ins.iter())
.filter_map(|((i, ty), pv)| {
pv.clone()
.try_into_concrete(ty)
.ok()
.map(|v| (IncomingPort::from(i), v))
})
.collect::<Vec<_>>();
for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() {
outs[p.index()] =
partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v);
}
}
}
#[cfg(test)]
mod test;