use std::{
collections::{HashMap, HashSet, hash_map::RandomState},
sync::LazyLock,
};
use hugr_core::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
HugrBuilder, ModuleBuilder, SubContainer, endo_sig, inout_sig,
};
use hugr_core::extension::prelude::{
ConstError, ConstString, MakeTuple, UnpackTuple, bool_t, const_ok, error_type, string_type,
sum_with_error,
};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::ops::constant::{CustomConst, CustomSerialized};
use hugr_core::ops::{Const, OpTag, OpTrait, OpType, Value, handle::NodeHandle};
use hugr_core::std_extensions::arithmetic::{
conversions::ConvertOpDef,
float_ops::FloatOps,
float_types::{ConstF64, float64_type},
int_ops::IntOpDef,
int_types::{ConstInt, INT_TYPES},
};
use hugr_core::std_extensions::collections::list::ListOp;
use hugr_core::std_extensions::logic::LogicOp;
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
use hugr_core::{Hugr, HugrView, IncomingPort, Node, Visibility, type_row};
use itertools::Itertools;
use rstest::rstest;
use crate::passes::ComposablePass as _;
use crate::passes::composable::{PassScope, Preserve, ValidatingPass, WithScope};
use crate::passes::dataflow::{DFContext, PartialValue, partial_from_const};
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle};
fn constant_fold_pass(h: &mut (impl HugrMut<Node = Node> + 'static)) {
let c = ConstantFoldPass::default();
ValidatingPass::new(c).run(h).unwrap();
}
#[rstest]
#[case(ConstInt::new_u(4, 2).unwrap(), true)]
#[case(ConstF64::new(std::f64::consts::PI), false)]
fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) {
let n = Node::from(portgraph::NodeIndex::new(7));
let st = SumType::new([vec![k.get_type()], vec![]]);
let subject_val = Value::sum(0, [k.clone().into()], st).unwrap();
let ctx = ConstFoldContext;
let v1 = partial_from_const(&ctx, n, &subject_val);
let v1_subfield = {
let PartialValue::PartialSum(ps1) = v1 else {
panic!()
};
ps1.0
.into_iter()
.exactly_one()
.unwrap()
.1
.into_iter()
.exactly_one()
.unwrap()
};
let v2 = partial_from_const(&ctx, n, &k.into());
if eq {
assert_eq!(v1_subfield, v2);
} else {
assert_ne!(v1_subfield, v2);
}
}
pub fn assert_fully_folded(h: &impl HugrView, expected_value: &Value) {
assert_fully_folded_with(h, |v| v == expected_value);
}
fn assert_fully_folded_with(h: &impl HugrView, check_value: impl Fn(&Value) -> bool) {
let mut node_count = 0;
for node in h.children(h.entrypoint()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if check_value(c.value()) => node_count += 1,
_ => panic!("unexpected op: {}\n{}", op, h.mermaid_string()),
}
}
assert_eq!(node_count, 4);
}
fn i2c(b: u64) -> Value {
Value::extension(ConstInt::new_u(5, b).unwrap())
}
fn f2c(f: f64) -> Value {
ConstF64::new(f).into()
}
#[rstest]
#[case(0.0, 0.0, 0.0)]
#[case(0.0, 1.0, 1.0)]
#[case(23.5, 435.5, 459.0)]
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
fn unwrap_float(pv: PartialValue<ValueHandle>) -> f64 {
let v: Value = pv.try_into_concrete(&float64_type()).unwrap();
v.get_custom_value::<ConstF64>().unwrap().value()
}
let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from);
let mut ctx = ConstFoldContext;
let v_a = partial_from_const(&ctx, n_a, &f2c(a));
let v_b = partial_from_const(&ctx, n_b, &f2c(b));
assert_eq!(unwrap_float(v_a.clone()), a);
assert_eq!(unwrap_float(v_b.clone()), b);
let mut outs = [PartialValue::Bottom];
let OpType::ExtensionOp(add_op) = OpType::from(FloatOps::fadd) else {
panic!()
};
ctx.interpret_leaf_op(n, &add_op, &[v_a, v_b], &mut outs);
assert_eq!(unwrap_float(outs[0].clone()), c);
}
fn noargfn(outputs: impl Into<TypeRow>) -> Signature {
inout_sig(type_row![], outputs)
}
#[test]
fn test_big() {
let sum_type = sum_with_error([INT_TYPES[5].clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)]));
let unpack = build
.add_dataflow_op(
UnpackTuple::new(vec![float64_type(), float64_type()].into()),
[tup],
)
.unwrap();
let sub = build
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();
let to_int = build
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs())
.unwrap();
let mut h = build.finish_hugr_with_outputs(to_int.outputs()).unwrap();
assert_eq!(h.entry_descendants().count(), 8);
constant_fold_pass(&mut h);
let expected = const_ok(i2c(2).clone(), [error_type()]);
assert_fully_folded(&h, &expected);
}
#[test]
#[ignore = "Waiting for `unwrap` operation"]
fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> {
use hugr_core::std_extensions::collections::list::{ListOp, ListValue};
let base_list: Value = ListValue::new(bool_t(), [Value::false_val()]).into();
let mut build = DFGBuilder::new(Signature::new(
type_row![],
vec![base_list.get_type().clone()],
))
.unwrap();
let list = build.add_load_const(base_list.clone());
let [list, maybe_elem] = build
.add_dataflow_op(
ListOp::pop.with_type(bool_t()).to_extension_op().unwrap(),
[list],
)?
.outputs_arr();
let elem = maybe_elem;
let [list] = build
.add_dataflow_op(
ListOp::push.with_type(bool_t()).to_extension_op().unwrap(),
[list, elem],
)?
.outputs_arr();
let mut h = build.finish_hugr_with_outputs([list])?;
constant_fold_pass(&mut h);
assert_fully_folded(&h, &base_list);
Ok(())
}
#[test]
fn test_fold_and() {
let mut build = DFGBuilder::new(noargfn([bool_t()])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::true_val());
let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_or() {
let mut build = DFGBuilder::new(noargfn([bool_t()])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::false_val());
let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_not() {
let mut build = DFGBuilder::new(noargfn([bool_t()])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn orphan_output() {
use hugr_core::ops::handle::NodeHandle;
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let true_wire = build.add_load_value(Value::true_val());
let orig_not = build.add_dataflow_op(LogicOp::Not, [true_wire]).unwrap();
let r = build
.add_dataflow_op(LogicOp::Or, [true_wire, orig_not.out_wire(0)])
.unwrap();
let or_node = r.node();
let parent = build.container_node();
let mut h = build.finish_hugr_with_outputs(r.outputs()).unwrap();
let new_not = h.add_node_with_parent(parent, LogicOp::Not);
h.connect(true_wire.node(), true_wire.source(), new_not, 0);
h.disconnect(or_node, IncomingPort::from(1));
h.connect(new_not, 0, or_node, 1);
h.remove_node(orig_not.node());
constant_fold_pass(&mut h);
assert_fully_folded(&h, &Value::true_val());
}
#[test]
fn test_folding_pass_issue_996() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0)));
let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap();
let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap();
let x4 = build
.add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs()))
.unwrap();
let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0)));
let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap();
let x7 = build
.add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs()))
.unwrap();
let mut h = build.finish_hugr_with_outputs(x7.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_const_fold_to_nonfinite() {
let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0)));
let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap();
let mut h0 = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h0);
assert_fully_folded_with(&h0, |v| {
v.get_custom_value::<ConstF64>().unwrap().value() == 1.0
});
assert_eq!(h0.entry_descendants().count(), 5);
let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0)));
let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap();
let mut h1 = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h1);
assert_eq!(h1.entry_descendants().count(), 8);
}
#[test]
fn test_fold_iwiden_u() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(4, 13).unwrap()));
let x1 = build
.add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 13).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_iwiden_s() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(4, -3).unwrap()));
let x1 = build
.add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, -3).unwrap());
assert_fully_folded(&h, &expected);
}
#[rstest]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)]
#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)]
#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)]
fn test_fold_inarrow<I: Copy, C: Into<Value>, E: std::fmt::Debug>(
#[case] mk_const: impl Fn(u8, I) -> Result<C, E>,
#[case] op_def: IntOpDef,
#[case] from_log_width: u8,
#[case] to_log_width: u8,
#[case] val: I,
#[case] succeeds: bool,
) {
use hugr_core::extension::prelude::const_ok;
let elem_type = INT_TYPES[to_log_width as usize].clone();
let sum_type = sum_with_error([elem_type.clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into());
let x1 = build
.add_dataflow_op(
op_def.with_two_log_widths(from_log_width, to_log_width),
[x0],
)
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
static INARROW_ERROR_VALUE: LazyLock<ConstError> = LazyLock::new(|| ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
});
let expected = if succeeds {
const_ok(mk_const(to_log_width, val).unwrap().into(), [error_type()])
} else {
INARROW_ERROR_VALUE.clone().as_either([elem_type])
};
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_itobool() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(0, 1).unwrap()));
let x1 = build
.add_dataflow_op(ConvertOpDef::itobool.without_log_width(), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ifrombool() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[0].clone()])).unwrap();
let x0 = build.add_load_const(Value::false_val());
let x1 = build
.add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(0, 0).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ieq() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(3, -1).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 255).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ine() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ilt_u() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ilt_s() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_igt_u() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_igt_s() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ile_u() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ile_s() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ige_u() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ige_s() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imax_u() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 11).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imax_s() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, 1).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imin_u() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 7).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imin_s() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, -2).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_iadd() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, -1).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_isub() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, -3).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ineg() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, 2).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imul() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 7).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, -14).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idivmod_checked_u() {
let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into();
let elem_type = Type::new_tuple(intpair);
let sum_type = sum_with_error([elem_type.clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::idivmod_checked_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = ConstError {
signal: 0,
message: "Division by zero".to_string(),
}
.as_either([elem_type]);
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idivmod_u() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[3].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(3, 20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap()));
let [x2, x3] = build
.add_dataflow_op(IntOpDef::idivmod_u.with_log_width(3), [x0, x1])
.unwrap()
.outputs_arr();
let x4 = build
.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [x2, x3])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x4.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(3, 8).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idivmod_checked_s() {
let intpair: TypeRowRV = vec![INT_TYPES[5].clone(), INT_TYPES[5].clone()].into();
let elem_type = Type::new_tuple(intpair);
let sum_type = sum_with_error([elem_type.clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::idivmod_checked_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = ConstError {
signal: 0,
message: "Division by zero".to_string(),
}
.as_either([elem_type]);
assert_fully_folded(&h, &expected);
}
#[rstest]
#[case(20, 3, 8)]
#[case(-20, 3, -6)]
#[case(-20, 4, -5)]
#[case(i64::MIN, 1, i64::MIN)]
#[case(i64::MIN, 2, -(1i64 << 62))]
#[case(i64::MIN, 1u64 << 63, -1)]
fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[6].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(6, a).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(6, b).unwrap()));
let [x2, x3] = build
.add_dataflow_op(IntOpDef::idivmod_s.with_log_width(6), [x0, x1])
.unwrap()
.outputs_arr();
let x4 = build
.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x4.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(6, c).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idiv_checked_u() {
let sum_type = sum_with_error([INT_TYPES[5].clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::idiv_checked_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = ConstError {
signal: 0,
message: "Division by zero".to_string(),
}
.as_either([INT_TYPES[5].clone()]);
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idiv_u() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::idiv_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 6).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imod_checked_u() {
let sum_type = sum_with_error([INT_TYPES[5].clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imod_checked_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = ConstError {
signal: 0,
message: "Division by zero".to_string(),
}
.as_either([INT_TYPES[5].clone()]);
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imod_u() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imod_u.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 2).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idiv_checked_s() {
let sum_type = sum_with_error([INT_TYPES[5].clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::idiv_checked_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = ConstError {
signal: 0,
message: "Division by zero".to_string(),
}
.as_either([INT_TYPES[5].clone()]);
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_idiv_s() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::idiv_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_s(5, -7).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imod_checked_s() {
let sum_type = sum_with_error([INT_TYPES[5].clone()]);
let mut build = DFGBuilder::new(noargfn(vec![sum_type.clone().into()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 0).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imod_checked_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = ConstError {
signal: 0,
message: "Division by zero".to_string(),
}
.as_either([INT_TYPES[5].clone()]);
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_imod_s() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::imod_s.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 1).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_iabs() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 2).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_iand() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 4).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ior() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 30).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ixor() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 26).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_inot() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ishl() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ishl.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 112).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_ishr() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ishr.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 1).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_irotl() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 61).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::irotl.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_irotr() {
let mut build = DFGBuilder::new(noargfn(vec![INT_TYPES[5].clone()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::irotr.with_log_width(5), [x0, x1])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap());
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_itostring_u() {
let mut build = DFGBuilder::new(noargfn(vec![string_type()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 17).unwrap()));
let x1 = build
.add_dataflow_op(ConvertOpDef::itostring_u.with_log_width(5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstString::new("17".into()));
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_itostring_s() {
let mut build = DFGBuilder::new(noargfn(vec![string_type()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -17).unwrap()));
let x1 = build
.add_dataflow_op(ConvertOpDef::itostring_s.with_log_width(5), [x0])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::extension(ConstString::new("-17".into()));
assert_fully_folded(&h, &expected);
}
#[test]
fn test_fold_int_ops() {
let mut build = DFGBuilder::new(noargfn(vec![bool_t()])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap()));
let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap()));
let x2 = build
.add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1])
.unwrap();
let x3 = build
.add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1])
.unwrap();
let x4 = build
.add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs()))
.unwrap();
let x5 = build.add_load_const(Value::extension(ConstInt::new_s(5, -10).unwrap()));
let x6 = build
.add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x5])
.unwrap();
let x7 = build
.add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs()))
.unwrap();
let mut h = build.finish_hugr_with_outputs(x7.outputs()).unwrap();
constant_fold_pass(&mut h);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
#[test]
fn test_via_part_unknown_tuple() {
let mut builder = DFGBuilder::new(endo_sig([INT_TYPES[3].clone()])).unwrap();
let [x] = builder.input_wires_arr();
let cst4 = builder.add_load_value(ConstInt::new_u(3, 4).unwrap());
let cst5 = builder.add_load_value(ConstInt::new_u(3, 5).unwrap());
let tuple_ty = TypeRow::from(vec![INT_TYPES[3].clone(); 3]);
let tup = builder
.add_dataflow_op(MakeTuple::new(tuple_ty.clone()), [cst4, x, cst5])
.unwrap();
let untup = builder
.add_dataflow_op(UnpackTuple::new(tuple_ty), tup.outputs())
.unwrap();
let [a, _b, c] = untup.outputs_arr();
let res = builder
.add_dataflow_op(IntOpDef::iadd.with_log_width(3), [a, c])
.unwrap();
let mut hugr = builder.finish_hugr_with_outputs(res.outputs()).unwrap();
constant_fold_pass(&mut hugr);
let mut expected_op_tags: HashSet<_, RandomState> = [
OpTag::Dfg,
OpTag::Input,
OpTag::Output,
OpTag::Const,
OpTag::LoadConst,
]
.map(|t| t.to_string())
.into_iter()
.collect();
for n in hugr.entry_descendants() {
let t = hugr.get_optype(n);
let removed = expected_op_tags.remove(&t.tag().to_string());
assert!(removed);
if let Some(c) = t.as_const() {
assert_eq!(c.value, ConstInt::new_u(3, 9).unwrap().into());
}
}
assert!(expected_op_tags.is_empty());
}
fn tail_loop_hugr(int_cst: ConstInt) -> Hugr {
let int_ty = int_cst.get_type();
let lw = int_cst.log_width();
let mut builder = DFGBuilder::new(inout_sig([bool_t()], [int_ty.clone()])).unwrap();
let [bool_w] = builder.input_wires_arr();
let lcst = builder.add_load_value(int_cst);
let tlb = builder
.tail_loop_builder([], [(int_ty, lcst)], type_row![])
.unwrap();
let [i] = tlb.input_wires_arr();
let [loop_out_w] = tlb.finish_with_outputs(bool_w, [i]).unwrap().outputs_arr();
let add = builder
.add_dataflow_op(IntOpDef::iadd.with_log_width(lw), [lcst, loop_out_w])
.unwrap();
builder.finish_hugr_with_outputs(add.outputs()).unwrap()
}
#[test]
fn test_tail_loop_unknown() {
let cst5 = ConstInt::new_u(3, 5).unwrap();
let mut h = tail_loop_hugr(cst5.clone());
constant_fold_pass(&mut h);
assert_eq!(h.entry_descendants().count(), 12);
let tl = h
.entry_descendants()
.filter(|n| h.get_optype(*n).is_tail_loop())
.exactly_one()
.ok()
.unwrap();
let mut dfg_nodes = Vec::new();
let mut loop_nodes = Vec::new();
for n in h.entry_descendants() {
if n == h.entrypoint() {
continue;
}
let p = h.get_parent(n).unwrap();
if p == h.entrypoint() {
dfg_nodes.push(n);
} else {
assert_eq!(p, tl);
loop_nodes.push(n);
}
}
let tag_string = |n: &Node| format!("{:?}", h.get_optype(*n).tag());
assert_eq!(
dfg_nodes
.iter()
.map(tag_string)
.sorted()
.collect::<Vec<_>>(),
vec![
"Const",
"Const",
"Input",
"LoadConst",
"LoadConst",
"Output",
"TailLoop"
]
);
assert_eq!(
loop_nodes.iter().map(tag_string).collect::<Vec<_>>(),
Vec::from(["Input", "Output", "Const", "LoadConst"])
);
let [loop_in, loop_out] = h.get_io(tl).unwrap();
assert!(h.input_neighbours(loop_in).next().is_none());
let (loop_cst, v) = loop_nodes
.into_iter()
.filter_map(|n| h.get_optype(n).as_const().map(|c| (n, c.value())))
.exactly_one()
.unwrap();
assert_eq!(v, &cst5.clone().into());
let loop_lcst = h.output_neighbours(loop_cst).exactly_one().ok().unwrap();
assert_eq!(h.get_parent(loop_lcst), Some(tl));
assert_eq!(
h.all_linked_inputs(loop_lcst).collect::<Vec<_>>(),
vec![(loop_out, IncomingPort::from(1))]
);
let [_, root_out] = h.get_io(h.entrypoint()).unwrap();
let mut cst5 = Some(cst5.into());
for n in dfg_nodes {
let Some(cst) = h.get_optype(n).as_const() else {
continue;
};
let lcst = h.output_neighbours(n).exactly_one().ok().unwrap();
let target = h.output_neighbours(lcst).exactly_one().ok().unwrap();
if Some(cst.value()) == cst5.as_ref() {
cst5 = None;
assert_eq!(target, tl);
} else {
assert_eq!(cst.value(), &ConstInt::new_u(3, 10).unwrap().into());
assert_eq!(target, root_out);
}
}
assert!(cst5.is_none()); }
#[test]
fn test_tail_loop_never_iterates() {
let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap());
ConstantFoldPass::default()
.with_inputs(h.entrypoint(), [(0, Value::true_val())]) .run(&mut h)
.unwrap();
assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into());
}
#[test]
fn test_tail_loop_increase_termination() {
let mut h = tail_loop_hugr(ConstInt::new_u(4, 6).unwrap());
ConstantFoldPass::default()
.allow_increase_termination()
.run(&mut h)
.unwrap();
assert_fully_folded(&h, &ConstInt::new_u(4, 12).unwrap().into());
}
fn cfg_hugr() -> Hugr {
let int_ty = INT_TYPES[4].clone();
let mut builder = DFGBuilder::new(inout_sig(vec![bool_t(); 2], [int_ty.clone()])).unwrap();
let [p, q] = builder.input_wires_arr();
let int_cst = builder.add_load_value(ConstInt::new_u(4, 1).unwrap());
let mut nested = builder
.dfg_builder_endo([(int_ty.clone(), int_cst)])
.unwrap();
let [i] = nested.input_wires_arr();
let mut cfg = nested
.cfg_builder([(int_ty.clone(), i)], [int_ty.clone()].into())
.unwrap();
let mut entry = cfg
.simple_entry_builder([int_ty.clone()].into(), 2)
.unwrap();
let [e_i] = entry.input_wires_arr();
let e_cst7 = entry.add_load_value(ConstInt::new_u(4, 7).unwrap());
let e_add = entry
.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [e_cst7, e_i])
.unwrap();
let entry = entry.finish_with_outputs(p, e_add.outputs()).unwrap();
let mut a = cfg
.simple_block_builder(endo_sig([int_ty.clone()]), 2)
.unwrap();
let [a_i] = a.input_wires_arr();
let a_cst3 = a.add_load_value(ConstInt::new_u(4, 3).unwrap());
let a_add = a
.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [a_cst3, a_i])
.unwrap();
let a = a.finish_with_outputs(q, a_add.outputs()).unwrap();
let x = cfg.exit_block();
let [tru, fals] = [1, 0];
cfg.branch(&entry, tru, &a).unwrap();
cfg.branch(&entry, fals, &x).unwrap();
cfg.branch(&a, tru, &entry).unwrap();
cfg.branch(&a, fals, &x).unwrap();
let cfg = cfg.finish_sub_container().unwrap();
let nested = nested.finish_with_outputs(cfg.outputs()).unwrap();
builder.finish_hugr_with_outputs(nested.outputs()).unwrap()
}
#[rstest]
#[case(&[(0,false)], true, false, Some(8))]
#[case(&[(0,true), (1,false)], true, true, Some(11))]
#[case(&[(1,false)], true, true, None)]
#[case(&[], false, false, None)]
fn test_cfg(
#[case] inputs: &[(usize, bool)],
#[case] fold_entry: bool,
#[case] fold_blk: bool,
#[case] fold_res: Option<u16>,
) {
let backup = cfg_hugr();
let mut hugr = backup.clone();
let pass = ConstantFoldPass::default().with_inputs(
hugr.entrypoint(),
inputs.iter().map(|(p, b)| (*p, Value::from_bool(*b))),
);
pass.run(&mut hugr).unwrap();
let nested = hugr
.children(hugr.entrypoint())
.filter(|n| hugr.get_optype(*n).is_dfg())
.exactly_one()
.ok()
.unwrap();
let cfg = hugr
.entry_descendants()
.filter(|n| hugr.get_optype(*n).is_cfg())
.exactly_one()
.ok()
.unwrap();
assert_eq!(hugr.get_parent(cfg), Some(nested));
let [entry, exit, a] = hugr.children(cfg).collect::<Vec<_>>().try_into().unwrap();
assert!(hugr.get_optype(exit).is_exit_block());
for (blk, is_folded, folded_cst, unfolded_cst) in
[(entry, fold_entry, 8, 7), (a, fold_blk, 11, 3)]
{
if is_folded {
assert_fully_folded(
&hugr.with_entrypoint(blk),
&ConstInt::new_u(4, folded_cst).unwrap().into(),
);
} else {
let mut expected_tags =
HashSet::from(["Input", "Output", "Leaf", "Const", "LoadConst"]);
for ch in hugr.children(blk) {
let tag = format!("{:?}", hugr.get_optype(ch).tag());
assert!(expected_tags.remove(tag.as_str()), "Not found: {tag}");
if let Some(cst) = hugr.get_optype(ch).as_const() {
assert_eq!(
cst.value(),
&ConstInt::new_u(4, unfolded_cst).unwrap().into()
);
} else if let Some(op) = hugr.get_optype(ch).as_extension_op() {
assert_eq!(op.unqualified_id(), "iadd");
}
}
}
}
let output_src = hugr
.input_neighbours(hugr.get_io(hugr.entrypoint()).unwrap()[1])
.exactly_one()
.ok()
.unwrap();
if let Some(res_int) = fold_res {
let res_v = ConstInt::new_u(4, res_int.into()).unwrap().into();
assert!(hugr.get_optype(output_src).is_load_constant());
let output_cst = hugr
.input_neighbours(output_src)
.exactly_one()
.ok()
.unwrap();
let cst = hugr.get_optype(output_cst).as_const().unwrap();
assert_eq!(cst.value(), &res_v);
let mut hugr2 = backup;
pass.allow_increase_termination().run(&mut hugr2).unwrap();
assert_fully_folded(&hugr2, &res_v);
} else {
assert_eq!(output_src, nested);
}
}
#[test]
fn test_module() -> Result<(), Box<dyn std::error::Error>> {
let mut mb = ModuleBuilder::new();
let c7 = mb.add_constant(Value::from(ConstInt::new_u(5, 7)?));
let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?));
let ad1 = mb.add_alias_declare("unused", TypeBound::Linear)?;
let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?;
let func_decl = mb.declare_vis(
"unused3",
Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]).into(),
Visibility::Public,
)?;
let mut main = mb.define_function_vis(
"main",
Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]),
Visibility::Public,
)?;
let lc7 = main.load_const(&c7);
let lc17 = main.load_const(&c17);
let [add] = main
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [lc7, lc17])?
.outputs_arr();
let main = main.finish_with_outputs([lc7, add])?;
let mut hugr = mb.finish_hugr()?;
constant_fold_pass(&mut hugr);
assert!(hugr.get_optype(hugr.entrypoint()).is_module());
assert_eq!(
hugr.children(hugr.entrypoint()).collect_vec(),
[
c7.node(),
ad1.node(),
ad2.node(),
func_decl.node(),
main.node()
]
);
assert_eq!(
get_child_tags(&hugr, main.node()),
HashMap::from([
(OpTag::Input, 1),
(OpTag::Output, 1),
(OpTag::Const, 1),
(OpTag::LoadConst, 2),
])
);
assert_eq!(
hugr.children(main.node())
.find_map(|n| hugr.get_optype(n).as_const()),
Some(&Const::new(ConstInt::new_u(5, 24).unwrap().into()))
);
Ok(())
}
#[rstest]
#[case::float_fsub(FloatOps::fsub)]
#[case::float_fadd(FloatOps::fadd)]
#[case::float_fmul(FloatOps::fmul)]
#[case::float_fdiv(FloatOps::fdiv)]
#[case::float_fneg(FloatOps::fneg)]
#[case::float_fabs(FloatOps::fabs)]
#[case::logic_and(LogicOp::And)]
#[case::logic_or(LogicOp::Or)]
#[case::logic_not(LogicOp::Not)]
#[case::int_iadd(IntOpDef::iadd.with_log_width(5))]
#[case::int_isub(IntOpDef::isub.with_log_width(5))]
#[case::int_ineg(IntOpDef::ineg.with_log_width(5))]
#[case::convert_trunc_u(ConvertOpDef::trunc_u.with_log_width(5))]
#[case::convert_trunc_s(ConvertOpDef::trunc_s.with_log_width(5))]
#[case::convert_convert_u(ConvertOpDef::convert_u.with_log_width(5))]
#[case::convert_convert_s(ConvertOpDef::convert_s.with_log_width(5))]
#[case::convert_itobool(ConvertOpDef::itobool.without_log_width())]
#[case::convert_ifrombool(ConvertOpDef::ifrombool.without_log_width())]
#[case::convert_itostring_u(ConvertOpDef::itostring_u.with_log_width(5))]
#[case::convert_itostring_s(ConvertOpDef::itostring_s.with_log_width(5))]
#[case::list_pop(ListOp::pop.with_type(bool_t()).to_extension_op().unwrap())]
#[case::list_push(ListOp::push.with_type(bool_t()).to_extension_op().unwrap())]
#[case::list_get(ListOp::get.with_type(bool_t()).to_extension_op().unwrap())]
#[case::list_set(ListOp::set.with_type(bool_t()).to_extension_op().unwrap())]
#[case::list_insert(ListOp::insert.with_type(bool_t()).to_extension_op().unwrap())]
#[case::list_length(ListOp::length.with_type(bool_t()).to_extension_op().unwrap())]
fn test_opaque_consts(#[case] op: impl Into<OpType>) {
let op = op.into();
let sig = op.dataflow_signature().unwrap();
let mut build = FunctionBuilder::new("fn", noargfn(sig.output.clone())).unwrap();
let inputs = sig
.input
.iter()
.enumerate()
.map(|(i, typ)| {
let opaque = CustomSerialized::new(typ.clone(), format!("opaque{i}").into());
build.add_load_const(Value::extension(opaque))
})
.collect_vec();
let x2 = build.add_dataflow_op(op.clone(), inputs).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap();
constant_fold_pass(&mut h);
assert!(h.entry_descendants().any(|n| h.get_optype(n) == &op));
}
fn get_child_tags(hugr: &Hugr, node: Node) -> HashMap<OpTag, usize> {
hugr.children(node)
.into_grouping_map_by(|n| hugr.get_optype(*n).tag())
.fold(0, |c, _, _| c + 1)
}
fn int_cst(v: u64) -> Value {
Value::from(ConstInt::new_u(5, v).unwrap())
}
fn two_funcs_hugr(entrypoint_is_main: Option<bool>) -> (Hugr, Node, Node) {
let mut mb = ModuleBuilder::new();
let mut callee = mb
.define_function("callee", Signature::new_endo([INT_TYPES[5].clone()]))
.unwrap();
let [inp] = callee.input_wires_arr();
let [lc7, lc11] = [7, 11].map(|v| callee.add_load_value(int_cst(v)));
let [add_csts] = callee
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [lc7, lc11])
.unwrap()
.outputs_arr();
let add_inp = callee
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [add_csts, inp])
.unwrap();
let callee = callee.finish_with_outputs(add_inp.outputs()).unwrap();
let mut main = mb
.define_function_vis(
"main",
Signature::new(type_row![], vec![INT_TYPES[5].clone()]),
Visibility::Public,
)
.unwrap();
let [lc3, lc5] = [3, 5].map(|v| main.add_load_value(int_cst(v)));
let add_csts = main
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [lc3, lc5])
.unwrap();
let call = main.call(callee.handle(), &[], add_csts.outputs()).unwrap();
let main = main.finish_with_outputs(call.outputs()).unwrap();
let mut hugr = mb.finish_hugr().unwrap();
if let Some(use_main) = entrypoint_is_main {
hugr.set_entrypoint(if use_main { main.node() } else { callee.node() });
}
(hugr, main.node(), callee.node())
}
#[rstest]
fn two_funcs_fully_folded(
#[values(Preserve::Public, Preserve::Entrypoint)] scope: impl Into<PassScope>,
#[values(Some(true), None)] entrypoint_is_main: Option<bool>,
) {
let (mut hugr, main, callee) = two_funcs_hugr(entrypoint_is_main);
ConstantFoldPass::default_with_scope(scope.into())
.run(&mut hugr)
.unwrap();
two_funcs_check_main_fully_folded(&hugr, main);
assert_eq!(
get_child_tags(&hugr, callee),
HashMap::from([
(OpTag::Input, 1),
(OpTag::Output, 1),
(OpTag::Const, 1),
(OpTag::LoadConst, 1)
])
);
let cst = hugr
.children(callee)
.filter_map(|n| hugr.get_optype(n).as_const())
.exactly_one()
.ok()
.unwrap();
assert_eq!(cst.value(), &int_cst(3 + 5 + 7 + 11));
}
fn two_funcs_check_main_fully_folded(hugr: &Hugr, main: Node) {
let main_tags = get_child_tags(hugr, main);
assert_eq!(
main_tags,
HashMap::from([
(OpTag::Input, 1),
(OpTag::Output, 1),
(OpTag::FnCall, 1),
(OpTag::Const, 2),
(OpTag::LoadConst, 2)
])
);
let output = hugr
.children(main)
.filter(|n| hugr.get_optype(*n).is_output())
.exactly_one()
.ok()
.unwrap();
let out_src = hugr.input_neighbours(output).exactly_one().ok().unwrap();
assert!(hugr.get_optype(out_src).is_load_constant());
let out_cst = hugr.input_neighbours(out_src).exactly_one().ok().unwrap();
assert_eq!(
hugr.get_optype(out_cst).as_const().unwrap().value(),
&int_cst(3 + 5 + 7 + 11)
);
let call = hugr
.children(main)
.filter(|n| hugr.get_optype(*n).is_call())
.exactly_one()
.ok()
.unwrap();
let (call_src, _) = hugr.single_linked_output(call, 0).unwrap();
assert!(hugr.get_optype(call_src).is_load_constant());
let call_cst = hugr.input_neighbours(call_src).exactly_one().ok().unwrap();
assert_eq!(
hugr.get_optype(call_cst).as_const().unwrap().value(),
&int_cst(3 + 5)
);
assert!(hugr.output_neighbours(call).next().is_none());
}
#[rstest]
#[case(Preserve::Public, Some(false))]
#[case(Preserve::Entrypoint, Some(false))]
#[case(Preserve::All, Some(false))]
#[case(Preserve::All, Some(true))]
#[case(Preserve::All, None)]
fn two_funcs_preserve_f(
#[case] scope: impl Into<PassScope>,
#[case] entrypoint_is_main: Option<bool>,
) {
let (mut hugr, main, callee) = two_funcs_hugr(entrypoint_is_main);
ConstantFoldPass::default_with_scope(scope.into())
.run(&mut hugr)
.unwrap();
two_funcs_check_f_respects_argument(&hugr, callee);
assert_eq!(
get_child_tags(&hugr, main),
HashMap::from([
(OpTag::Input, 1),
(OpTag::Output, 1),
(OpTag::FnCall, 1),
(OpTag::Const, 1),
(OpTag::LoadConst, 1)
])
);
let call = hugr
.children(main)
.filter(|n| hugr.get_optype(*n).is_call())
.exactly_one()
.ok()
.unwrap();
let (call_src, _) = hugr.single_linked_output(call, 0).unwrap();
assert!(hugr.get_optype(call_src).is_load_constant());
let call_cst = hugr.input_neighbours(call_src).exactly_one().ok().unwrap();
assert_eq!(
hugr.get_optype(call_cst).as_const().unwrap().value(),
&int_cst(3 + 5)
);
let call_out = hugr.output_neighbours(call).exactly_one().ok().unwrap();
assert!(hugr.get_optype(call_out).is_output());
}
fn two_funcs_check_f_respects_argument(hugr: &Hugr, callee: Node) {
assert_eq!(
get_child_tags(hugr, callee),
HashMap::from([
(OpTag::Input, 1),
(OpTag::Output, 1),
(OpTag::Const, 1),
(OpTag::LoadConst, 1),
(OpTag::Leaf, 1)
])
);
assert_eq!(
hugr.children(callee)
.filter_map(|n| hugr.get_optype(n).as_const())
.exactly_one()
.ok()
.unwrap()
.value(),
&int_cst(7 + 11)
);
}
#[rstest]
fn two_funcs_entrypoint(
#[values(PassScope::EntrypointFlat, PassScope::EntrypointRecursive)] scope: PassScope,
) {
let (backup, main, callee) = two_funcs_hugr(None);
let mut hugr = backup.clone();
ConstantFoldPass::default_with_scope(scope.clone())
.run(&mut hugr)
.unwrap();
assert_eq!(backup, hugr);
fn check_identical(hugr: &Hugr, backup: &Hugr, node: Node) {
assert_eq!(
hugr.descendants(node).collect_vec(),
backup.descendants(node).collect_vec()
);
for n in hugr.descendants(node) {
assert_eq!(hugr.get_optype(n), backup.get_optype(n));
assert_eq!(
hugr.node_inputs(n).collect_vec(),
backup.node_inputs(n).collect_vec()
);
for ip in hugr.node_inputs(n) {
assert_eq!(
hugr.linked_outputs(n, ip).collect_vec(),
backup.linked_outputs(n, ip).collect_vec()
);
}
assert_eq!(
hugr.children(n).collect_vec(),
backup.children(n).collect_vec()
);
}
}
let (mut hugr, _, _) = two_funcs_hugr(Some(true));
ConstantFoldPass::default_with_scope(scope.clone())
.run(&mut hugr)
.unwrap();
two_funcs_check_main_fully_folded(&hugr, main);
check_identical(&hugr, &backup, callee);
let (mut hugr, _, _) = two_funcs_hugr(Some(false));
ConstantFoldPass::default_with_scope(scope)
.run(&mut hugr)
.unwrap();
check_identical(&hugr, &backup, main);
two_funcs_check_f_respects_argument(&hugr, callee);
}