use std::collections::{HashMap, HashSet};
use ascent::Lattice;
use ascent::lattice::BoundedLattice;
use itertools::Itertools;
use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
use hugr_core::ops::{DataflowOpTrait, OpTag, OpTrait, OpType, TailLoop};
use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire};
use super::value_row::ValueRow;
use super::{
AbstractValue, AnalysisResults, DFContext, LoadedFunction, PartialValue, partial_from_const,
row_contains_bottom,
};
type PV<V, N> = PartialValue<V, N>;
type NodeInputs<V, N> = Vec<(IncomingPort, PV<V, N>)>;
type NodeOutputs<V, N> = Vec<(OutgoingPort, PV<V, N>)>;
pub struct Machine<H: HugrView, V: AbstractValue> {
pub(super) hugr: H,
in_wire_proto: HashMap<H::Node, NodeInputs<V, H::Node>>,
out_wire_proto: HashMap<H::Node, NodeOutputs<V, H::Node>>,
}
impl<H: HugrView, V: AbstractValue> Machine<H, V> {
pub fn new(hugr: H) -> Self {
Self {
hugr,
in_wire_proto: Default::default(),
out_wire_proto: Default::default(),
}
}
}
impl<H: HugrView, V: AbstractValue> Machine<H, V> {
pub fn prepopulate_wire(&mut self, w: Wire<H::Node>, v: PartialValue<V, H::Node>) {
self.out_wire_proto
.entry(w.node())
.or_default()
.push((w.source(), v));
}
#[expect(
clippy::result_large_err,
reason = "Not called recursively and not a performance bottleneck"
)]
#[inline]
pub fn prepopulate_inputs(
&mut self,
parent: H::Node,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>,
) -> Result<(), OpType> {
if !self.hugr.contains_node(parent) {
return Ok(());
}
match self.hugr.get_optype(parent) {
OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => {
let [inp, _] = self.hugr.get_io(parent).unwrap();
let mut vals =
vec![PartialValue::Top; self.hugr.signature(inp).unwrap().output_types().len()];
for (ip, v) in in_values {
vals[ip.index()] = v;
}
for (i, v) in vals.into_iter().enumerate() {
self.prepopulate_wire(Wire::new(inp, i), v);
}
}
OpType::DFG(_) | OpType::TailLoop(_) | OpType::CFG(_) | OpType::Conditional(_) => {
let mut vals = vec![
PartialValue::Top;
self.hugr.signature(parent).unwrap().input_types().len()
];
for (ip, v) in in_values {
vals[ip.index()] = v;
}
self.in_wire_proto
.entry(parent)
.or_default()
.extend(vals.into_iter().enumerate().map(|(i, v)| (i.into(), v)));
}
op => return Err(op.clone()),
}
Ok(())
}
#[deprecated(
note = "Use `run_subtree` and `prepopulate_wire` / `prepopulate_inputs` instead",
since = "0.18.0"
)]
pub fn run(
mut self,
context: impl DFContext<V, Node = H::Node>,
in_values: impl IntoIterator<Item = (IncomingPort, PartialValue<V, H::Node>)>,
) -> AnalysisResults<V, H> {
if self.hugr.entrypoint_optype().is_module() {
assert!(
in_values.into_iter().next().is_none(),
"No inputs possible for Module"
);
} else {
let ep = self.hugr.entrypoint();
let have_value_for_entry = self.in_wire_proto.contains_key(&ep)
|| (self.hugr.entrypoint_optype().tag() <= OpTag::DataflowParent
&& self.out_wire_proto.contains_key(&ep));
let mut p = in_values.into_iter().peekable();
if p.peek().is_some() || !have_value_for_entry {
self.prepopulate_inputs(ep, p).unwrap();
}
}
let ep = self.hugr.entrypoint();
self.run_subtree(context, ep)
}
pub fn run_subtree(
self,
context: impl DFContext<V, Node = H::Node>,
root: H::Node,
) -> AnalysisResults<V, H> {
run_datalog(
context,
self.hugr,
root,
self.in_wire_proto
.into_iter()
.flat_map(|(n, vals)| vals.into_iter().map(move |(ip, v)| (n, ip, v)))
.collect(),
self.out_wire_proto
.into_iter()
.flat_map(|(n, vals)| vals.into_iter().map(move |(op, v)| (n, op, v)))
.collect(),
)
}
}
pub(super) type InWire<V, N> = (N, IncomingPort, PartialValue<V, N>);
type OutWire<V, N> = (N, OutgoingPort, PartialValue<V, N>);
fn run_datalog<V: AbstractValue, H: HugrView>(
mut ctx: impl DFContext<V, Node = H::Node>,
hugr: H,
result_root: H::Node,
in_wire_value_proto: Vec<InWire<V, H::Node>>,
out_wire_value_proto: Vec<OutWire<V, H::Node>>,
) -> AnalysisResults<V, H> {
#![allow(
clippy::clone_on_copy,
clippy::unused_enumerate_index,
clippy::collapsible_if
)]
let all_results = ascent::ascent_run! {
pub(super) struct AscentProgram<V: AbstractValue, H: HugrView>;
relation node(H::Node); relation in_wire(H::Node, IncomingPort); relation out_wire(H::Node, OutgoingPort); relation parent_of_node(H::Node, H::Node); relation input_child(H::Node, H::Node); relation output_child(H::Node, H::Node); lattice out_wire_value(H::Node, OutgoingPort, PV<V, H::Node>); lattice in_wire_value(H::Node, IncomingPort, PV<V, H::Node>); lattice node_in_value_row(H::Node, ValueRow<V, H::Node>);
node(n) <-- for n in hugr.nodes();
in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n);
parent_of_node(parent, child) <--
node(child), if let Some(parent) = hugr.get_parent(*child);
input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent);
output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent);
out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p);
in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p);
in_wire_value(n, ip, v) <-- in_wire(n, ip),
if let Some((m, op)) = hugr.single_linked_output(*n, *ip),
out_wire_value(m, op, v);
in_wire_value(n, p, v) <-- for (n, p, v) in &in_wire_value_proto,
node(n),
if let Some(sig) = hugr.signature(*n),
if sig.input_ports().contains(p);
out_wire_value(n, p, v) <-- for (n, p, v) in &out_wire_value_proto,
node(n),
if let Some(sig) = hugr.signature(*n),
if sig.output_ports().contains(p);
node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n);
node_in_value_row(n, ValueRow::new(hugr.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v);
out_wire_value(n, p, v) <--
node(n),
let op_t = hugr.get_optype(*n),
if !op_t.is_container(),
if let Some(sig) = op_t.dataflow_signature(),
node_in_value_row(n, vs),
if let Some(outs) = propagate_leaf_op(&mut ctx, &hugr, *n, &vs[..], sig.output_count()),
for (p, v) in (0..).map(OutgoingPort::from).zip(outs);
relation dfg_node(H::Node); dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg();
out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg),
input_child(dfg, i), in_wire_value(dfg, p, v);
out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg),
output_child(dfg, o), in_wire_value(o, p, v);
out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl),
if hugr.get_optype(*tl).is_tail_loop(),
input_child(tl, i),
in_wire_value(tl, p, v);
out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl),
if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(),
input_child(tl, in_n),
output_child(tl, out_n),
node_in_value_row(out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()),
for (out_p, v) in fields.enumerate();
out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl),
if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(),
output_child(tl, out_n),
node_in_value_row(out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()),
for (out_p, v) in fields.enumerate();
relation case_node(H::Node, usize, H::Node);
case_node(cond, i, case) <-- node(cond),
if hugr.get_optype(*cond).is_conditional(),
for (i, case) in hugr.children(*cond).enumerate(),
if hugr.get_optype(case).is_case();
out_wire_value(i_node, OutgoingPort::from(out_p), v) <--
case_node(cond, case_index, case),
input_child(case, i_node),
node_in_value_row(cond, in_row),
let conditional = hugr.get_optype(*cond).as_conditional().unwrap(),
if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()),
for (out_p, v) in fields.enumerate();
out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <--
case_node(cond, _i, case),
case_reachable(cond, case),
output_child(case, o),
in_wire_value(o, o_p, v);
relation case_reachable(H::Node, H::Node);
case_reachable(cond, case) <-- case_node(cond, i, case),
in_wire_value(cond, IncomingPort::from(0), v),
if v.supports_tag(*i);
relation cfg_node(H::Node); cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg();
relation bb_reachable(H::Node, H::Node);
bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next();
bb_reachable(cfg, bb) <-- cfg_node(cfg),
bb_reachable(cfg, pred),
output_child(pred, pred_out),
in_wire_value(pred_out, IncomingPort::from(0), predicate),
for (tag, bb) in hugr.output_neighbours(*pred).enumerate(),
if predicate.supports_tag(tag);
out_wire_value(i_node, OutgoingPort::from(p.index()), v) <--
cfg_node(cfg),
if let Some(entry) = hugr.children(*cfg).next(),
input_child(entry, i_node),
in_wire_value(cfg, p, v);
relation _cfg_succ_dest(H::Node, H::Node, H::Node);
_cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1);
_cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg),
for blk in hugr.children(*cfg),
if hugr.get_optype(blk).is_dataflow_block(),
input_child(blk, inp);
out_wire_value(dest, OutgoingPort::from(out_p), v) <--
bb_reachable(cfg, pred),
if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(),
for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(),
output_child(pred, out_n),
_cfg_succ_dest(cfg, succ, dest),
node_in_value_row(out_n, out_in_row),
if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()),
for (out_p, v) in fields.enumerate();
relation func_call(H::Node, H::Node); func_call(call, func_defn) <--
node(call),
if hugr.get_optype(*call).is_call(),
if let Some(func_defn) = hugr.static_source(*call);
out_wire_value(inp, OutgoingPort::from(p.index()), v) <--
func_call(call, func),
input_child(func, inp),
in_wire_value(call, p, v);
out_wire_value(call, OutgoingPort::from(p.index()), v) <--
func_call(call, func),
output_child(func, outp),
in_wire_value(outp, p, v);
lattice indirect_call(H::Node, LatticeWrapper<H::Node>); indirect_call(call, tgt) <--
node(call),
if let OpType::CallIndirect(_) = hugr.get_optype(*call),
in_wire_value(call, IncomingPort::from(0), v),
let tgt = load_func(v);
out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <--
indirect_call(call, lv),
if let LatticeWrapper::Value(func) = lv,
input_child(func, inp),
in_wire_value(call, p, v)
if p.index() > 0;
out_wire_value(call, OutgoingPort::from(p.index()), v) <--
indirect_call(call, lv),
if let LatticeWrapper::Value(func) = lv,
output_child(func, outp),
in_wire_value(outp, p, v);
out_wire_value(call, p, PV::Top) <--
node(call),
if let OpType::CallIndirect(ci) = hugr.get_optype(*call),
in_wire_value(call, IncomingPort::from(0), v),
if matches!(v, PartialValue::Top | PartialValue::Value(_)),
for p in ci.signature().output_ports();
};
let filter_nodes = (result_root != hugr.module_root())
.then_some(hugr.descendants(result_root).collect::<HashSet<_>>());
let out_wire_values = all_results
.out_wire_value
.iter()
.filter(|(n, _, _)| filter_nodes.as_ref().is_none_or(|f| f.contains(n)))
.map(|(n, p, v)| (Wire::new(*n, *p), v.clone()))
.collect();
AnalysisResults {
hugr,
out_wire_values,
in_wire_value: all_results
.in_wire_value
.into_iter()
.filter(|(n, _, _)| filter_nodes.as_ref().is_none_or(|f| f.contains(n)))
.collect(),
case_reachable: all_results
.case_reachable
.into_iter()
.filter(|(_, n)| filter_nodes.as_ref().is_none_or(|f| f.contains(n)))
.collect(),
bb_reachable: all_results
.bb_reachable
.into_iter()
.filter(|(_, n)| filter_nodes.as_ref().is_none_or(|f| f.contains(n)))
.collect(),
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)]
enum LatticeWrapper<T> {
Bottom,
Value(T),
Top,
}
impl<N: PartialEq + PartialOrd> Lattice for LatticeWrapper<N> {
fn meet_mut(&mut self, other: Self) -> bool {
if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top {
return false;
}
if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom {
*self = other;
return true;
}
*self = LatticeWrapper::Bottom;
true
}
fn join_mut(&mut self, other: Self) -> bool {
if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom {
return false;
}
if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top {
*self = other;
return true;
}
*self = LatticeWrapper::Top;
true
}
}
fn load_func<V, N: Copy>(v: &PV<V, N>) -> LatticeWrapper<N> {
match v {
PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom,
PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => {
LatticeWrapper::Value(*func_node)
}
PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top,
}
}
fn propagate_leaf_op<V: AbstractValue, H: HugrView>(
ctx: &mut impl DFContext<V, Node = H::Node>,
hugr: &H,
n: H::Node,
ins: &[PV<V, H::Node>],
num_outs: usize,
) -> Option<ValueRow<V, H::Node>> {
match hugr.get_optype(n) {
op if op.cast::<MakeTuple>().is_some() => Some(ValueRow::from_iter([PV::new_variant(
0,
ins.iter().cloned(),
)])),
op if op.cast::<UnpackTuple>().is_some() => {
let elem_tys = op.cast::<UnpackTuple>().unwrap().0;
let tup = ins.iter().exactly_one().unwrap();
tup.variant_values(0, elem_tys.len())
.map(ValueRow::from_iter)
}
OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant(
t.tag,
ins.iter().cloned(),
)])),
OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, OpType::Call(_) | OpType::CallIndirect(_) => None, OpType::LoadConstant(load_op) => {
assert!(ins.is_empty()); let const_node = hugr
.single_linked_output(n, load_op.constant_port())
.unwrap()
.0;
let const_val = hugr.get_optype(const_node).as_const().unwrap().value();
Some(ValueRow::singleton(partial_from_const(ctx, n, const_val)))
}
OpType::LoadFunction(load_op) => {
assert!(ins.is_empty()); let func_node = hugr
.single_linked_output(n, load_op.function_port())
.unwrap()
.0;
Some(ValueRow::singleton(PartialValue::new_load(
func_node,
load_op.type_args.clone(),
)))
}
OpType::ExtensionOp(e) => {
Some(ValueRow::from_iter(if row_contains_bottom(ins) {
vec![PartialValue::Bottom; num_outs]
} else {
let mut outs = vec![PartialValue::Top; num_outs];
ctx.interpret_leaf_op(n, e, ins, &mut outs[..]);
outs
}))
}
o => todo!("Unhandled: {:?}", o), }
}
#[cfg(test)]
mod test {
use ascent::Lattice;
use super::LatticeWrapper;
#[test]
fn latwrap_join() {
for lv in [
LatticeWrapper::Value(3),
LatticeWrapper::Value(5),
LatticeWrapper::Top,
] {
let mut subject = LatticeWrapper::Bottom;
assert!(subject.join_mut(lv.clone()));
assert_eq!(subject, lv);
assert!(!subject.join_mut(lv.clone()));
assert_eq!(subject, lv);
assert_eq!(
subject.join_mut(LatticeWrapper::Value(11)),
lv != LatticeWrapper::Top
);
assert_eq!(subject, LatticeWrapper::Top);
}
}
#[test]
fn latwrap_meet() {
for lv in [
LatticeWrapper::Bottom,
LatticeWrapper::Value(3),
LatticeWrapper::Value(5),
] {
let mut subject = LatticeWrapper::Top;
assert!(subject.meet_mut(lv.clone()));
assert_eq!(subject, lv);
assert!(!subject.meet_mut(lv.clone()));
assert_eq!(subject, lv);
assert_eq!(
subject.meet_mut(LatticeWrapper::Value(11)),
lv != LatticeWrapper::Bottom
);
assert_eq!(subject, LatticeWrapper::Bottom);
}
}
}