use crate::passes::{ComposablePass, PassScope, WithScope};
use derive_more::{Display, Error};
use hugr::extension::prelude::ConstUsize;
use hugr::extension::simple_op::MakeExtensionOp;
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::{OpTag, OpTrait, Value};
use hugr::std_extensions::arithmetic::int_types::ConstInt;
use hugr::std_extensions::collections::borrow_array::{BArrayUnsafeOpDef, BORROW_ARRAY_TYPENAME};
use hugr::types::{EdgeKind, Type};
use hugr::{HugrView, IncomingPort, Node, OutgoingPort, Wire};
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Clone, Debug, Default)]
pub struct BorrowSquashPass {
scope: PassScope,
}
impl WithScope for BorrowSquashPass {
fn with_scope(mut self, scope: impl Into<crate::passes::PassScope>) -> Self {
self.scope = scope.into();
self
}
}
impl<H: HugrMut<Node = Node>> ComposablePass<H> for BorrowSquashPass {
type Error = BorrowSquashError;
type Result = Vec<(Node, Node)>;
fn run(&self, hugr: &mut H) -> Result<Vec<(Node, Node)>, BorrowSquashError> {
let mut regions = VecDeque::from_iter(self.scope.root(hugr));
let mut results = Vec::new();
let mut seen = HashSet::new();
let mut op_queue = VecDeque::new();
while let Some(region) = regions.pop_front() {
let is_dataflow_region = OpTag::DataflowParent >= hugr.get_optype(region).tag();
seen.clear();
op_queue.clear();
for child in hugr.children(region) {
if is_dataflow_region && hugr.in_value_types(child).next().is_none() {
op_queue.extend(all_outs(hugr, child));
}
if self.scope.recursive() && hugr.children(child).next().is_some() {
regions.push_back(child);
}
}
while let Some(start) = op_queue.pop_front() {
if !seen.insert(start) {
continue;
}
let elided = borrow_squash_traversal(hugr, &mut op_queue, start, true);
results.extend(elided);
}
}
Ok(results)
}
}
#[derive(Clone, Debug, Display, Error, PartialEq)]
#[non_exhaustive]
pub enum BorrowSquashError {}
#[derive(Clone, Debug)]
struct BorrowFromPorts {
inc: IncomingPort,
out: OutgoingPort,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum BRAction {
Borrow(OutgoingPort),
Return(IncomingPort),
}
#[derive(Debug, Clone)]
struct BorrowReturnPorts {
action: BRAction,
elem_index: IncomingPort,
borrow_from: BorrowFromPorts,
}
fn is_borrow_return<H: HugrView>(node: H::Node, hugr: &H) -> Option<BorrowReturnPorts> {
let op = hugr.get_optype(node);
let ext_op = op.as_extension_op()?;
match BArrayUnsafeOpDef::from_extension_op(ext_op) {
Ok(BArrayUnsafeOpDef::borrow) => {
let sig = op.dataflow_signature().unwrap();
let counts = (sig.input_count(), sig.output_count());
assert_eq!(counts, (2, 2), "Borrow node has incorrect signature");
Some(BorrowReturnPorts {
action: BRAction::Borrow(OutgoingPort::from(1)),
borrow_from: BorrowFromPorts {
inc: IncomingPort::from(0),
out: OutgoingPort::from(0),
},
elem_index: IncomingPort::from(1),
})
}
Ok(BArrayUnsafeOpDef::r#return) => {
let sig = op.dataflow_signature().unwrap();
let counts = (sig.input_count(), sig.output_count());
assert_eq!(counts, (3, 1), "Return node has incorrect signature");
Some(BorrowReturnPorts {
action: BRAction::Return(IncomingPort::from(2)),
borrow_from: BorrowFromPorts {
inc: IncomingPort::from(0),
out: OutgoingPort::from(0),
},
elem_index: IncomingPort::from(1),
})
}
_ => None,
}
}
fn is_borrow_array(ty: &Type) -> bool {
ty.as_extension()
.is_some_and(|ext| ext.name() == &BORROW_ARRAY_TYPENAME)
}
pub fn borrow_squash_array<H: HugrMut<Node = Node>>(
hugr: &mut H,
start: Wire,
recurse: bool,
) -> Vec<(Node, Node)> {
borrow_squash_traversal(hugr, &mut Vec::new(), start, recurse)
}
fn borrow_squash_traversal<H: HugrMut<Node = Node>>(
hugr: &mut H,
candidates: &mut impl Extend<Wire>,
start: Wire,
recurse: bool,
) -> Vec<(Node, Node)> {
let array_ty = wire_type(hugr, start);
if !is_borrow_array(&array_ty) {
for (n, _) in hugr.linked_inputs(start.node(), start.source()) {
if is_borrow_return(n, hugr).is_none() {
candidates.extend(all_outs(hugr, n));
}
}
return vec![];
};
struct Borrow(Node, OutgoingPort, BorrowFromPorts);
struct Return(Node, IncomingPort, BorrowFromPorts);
let mut borrowed: HashMap<u64, (Borrow, Option<Return>)> = HashMap::new(); let mut rb_elisions: Vec<(Return, Borrow)> = Vec::new();
let mut array = start;
while let Some((node, index, action, borrow_from)) =
next_array_op(hugr, candidates, array, &array_ty)
{
array = Wire::new(node, borrow_from.out);
match (action, borrowed.entry(index)) {
(BRAction::Borrow(borrowed_out), Entry::Vacant(ve)) => {
ve.insert((Borrow(node, borrowed_out, borrow_from), None));
}
(BRAction::Return(inc), Entry::Occupied(mut oe)) => {
let (_, ret) = &mut oe.get_mut();
if ret.replace(Return(node, inc, borrow_from)).is_some() {
oe.remove_entry(); }
}
(BRAction::Borrow(borrowed_out), Entry::Occupied(mut oe)) => {
let (_, prev_return) = oe.get_mut();
match prev_return.take() {
Some(prev_return) => {
rb_elisions.push((prev_return, Borrow(node, borrowed_out, borrow_from)));
}
None => {
oe.remove_entry(); }
};
}
(BRAction::Return(_), Entry::Vacant(_)) => (), }
}
fn elide_node<H: HugrMut<Node = Node>>(hugr: &mut H, n: Node, ports: &BorrowFromPorts) {
let in_array = hugr
.single_linked_output(n, ports.inc)
.expect("array is linear");
let out_array = hugr
.single_linked_input(n, ports.out)
.expect("array is linear");
hugr.connect(in_array.0, in_array.1, out_array.0, out_array.1);
hugr.remove_node(n);
}
let mut elided: Vec<_> = rb_elisions
.into_iter()
.map(|(ret, bor)| {
let src = hugr.single_linked_output(ret.0, ret.1).expect("input");
for tgt in hugr.linked_inputs(bor.0, bor.1).collect::<Vec<_>>() {
hugr.connect(src.0, src.1, tgt.0, tgt.1);
}
elide_node(hugr, bor.0, &bor.2);
elide_node(hugr, ret.0, &ret.2);
(ret.0, bor.0)
})
.collect();
if recurse {
for (bor, _) in borrowed.into_values() {
let start = Wire::new(bor.0, bor.1);
let new_elided = borrow_squash_traversal(hugr, candidates, start, true);
elided.extend(new_elided);
}
}
elided
}
fn next_array_op(
hugr: &impl HugrView<Node = Node>,
candidates: &mut impl Extend<Wire>,
array: Wire,
array_ty: &Type,
) -> Option<(Node, u64, BRAction, BorrowFromPorts)> {
let (node, inport) = hugr
.single_linked_input(array.node(), array.source())
.expect("array is linear");
let Some(is_br) = is_borrow_return(node, hugr) else {
candidates.extend(all_outs(hugr, node));
return None;
};
if inport != is_br.borrow_from.inc {
if let BRAction::Borrow(_) = is_br.action {
panic!("Array fed into unexpected port of borrow")
}
if let BRAction::Return(rv) = is_br.action {
assert_eq!(rv, inport); }
return None;
}
assert_eq!(
*array_ty,
wire_type(hugr, Wire::new(node, is_br.borrow_from.out))
);
let Some(idx) = find_const(hugr, node, is_br.elem_index) else {
candidates.extend(all_outs(hugr, node));
return None;
};
Some((node, idx, is_br.action, is_br.borrow_from))
}
fn wire_type(h: &impl HugrView<Node = Node>, w: Wire) -> Type {
let Some(EdgeKind::Value(ty)) = h.get_optype(w.node()).port_kind(w.source()) else {
panic!("Invalid wire {w}")
};
ty.clone()
}
fn all_outs(h: &impl HugrView<Node = Node>, n: Node) -> impl Iterator<Item = Wire> + '_ {
h.out_value_types(n).map(move |(p, _)| Wire::new(n, p))
}
fn find_const<H: HugrView>(hugr: &H, n: H::Node, inp: IncomingPort) -> Option<u64> {
let (load_const, _) = hugr.single_linked_output(n, inp).expect("dataflow input");
if !hugr.get_optype(load_const).is_load_constant() {
return None;
}
let const_op = hugr
.single_linked_output(load_const, 0)
.and_then(|(n, _)| hugr.get_optype(n).as_const())
.expect("LoadConstant input is constant");
if let Value::Extension { e } = &const_op.value {
if let Some(c) = e.value().downcast_ref::<ConstUsize>() {
return Some(c.value());
}
if let Some(c) = e.value().downcast_ref::<ConstInt>() {
return Some(c.value_u());
}
}
panic!("Unexpected index {:?}", const_op.value)
}
#[cfg(test)]
mod test {
use std::{collections::BTreeSet, io::BufReader};
use super::{BorrowSquashPass, find_const};
use crate::extension::REGISTRY;
use crate::passes::{ComposablePass, const_fold::ConstantFoldPass};
use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr, FunctionBuilder, endo_sig};
use hugr::extension::prelude::{ConstUsize, qb_t, usize_t};
use hugr::extension::simple_op::MakeExtensionOp;
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::{OpTrait, handle::NodeHandle};
use hugr::std_extensions::collections::{
array::ArrayKind,
borrow_array::{BArrayOpBuilder, BArrayUnsafeOpDef, BorrowArray},
};
use hugr::types::Signature;
use hugr::{Hugr, HugrView};
use itertools::Itertools;
use rstest::{fixture, rstest};
#[rstest]
fn simple() {
let mut dfb = DFGBuilder::new(endo_sig([BorrowArray::ty(3, qb_t())])).unwrap();
let [arr] = dfb.input_wires_arr();
let idx = dfb.add_load_value(ConstUsize::new(1));
let (arr, q) = dfb.add_borrow_array_borrow(qb_t(), 3, arr, idx).unwrap();
let arr2 = dfb.add_borrow_array_return(qb_t(), 3, arr, idx, q).unwrap();
let (arr3, q2) = dfb.add_borrow_array_borrow(qb_t(), 3, arr2, idx).unwrap();
let arr4 = dfb
.add_borrow_array_return(qb_t(), 3, arr3, idx, q2)
.unwrap();
let mut h = dfb.finish_hugr_with_outputs([arr4]).unwrap();
h.validate().unwrap();
let mut h2 = h.clone();
h2.remove_node(arr2.node());
h2.remove_node(arr3.node());
h2.connect(q.node(), q.source(), arr4.node(), 2);
h2.connect(q.node(), 0, arr4.node(), 0);
h2.validate().unwrap();
let r = BorrowSquashPass::default().run(&mut h).unwrap();
assert_eq!(r, vec![(arr2.node(), arr3.node())]);
if h != h2 {
assert_eq!(h.nodes().collect_vec(), h2.nodes().collect_vec());
for n in h.nodes() {
assert_eq!(h.get_parent(n), h2.get_parent(n));
assert_eq!(h.get_optype(n), h2.get_optype(n));
for p in h.all_node_ports(n) {
let ins_h = h.linked_ports(n, p).collect_vec();
let ins_h2 = h2.linked_ports(n, p).collect_vec();
assert_eq!(ins_h, ins_h2);
}
}
}
}
#[fixture]
fn ranges_array() -> Hugr {
let reader = BufReader::new(
include_bytes!("../../../test_files/guppy_optimization/ranges/ranges.flat.array.hugr")
.as_slice(),
);
Hugr::load(reader, Some(®ISTRY)).unwrap()
}
#[rstest]
fn test_borrow_squash(ranges_array: Hugr) {
let counts = |h: &Hugr| {
let mut brs = vec![(0, 0); 4];
for n in find_borrows(h) {
brs[get_index(h, n) as usize].0 += 1;
}
for n in find_returns(h) {
brs[get_index(h, n) as usize].1 += 1;
}
brs
};
let f = ranges_array
.children(ranges_array.module_root())
.find(|n| {
ranges_array
.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.func_name() == "f")
})
.unwrap();
let mut h = ranges_array;
h.set_entrypoint(f);
ConstantFoldPass::default().run(&mut h).unwrap();
assert_eq!(counts(&h), vec![(4, 4), (6, 6), (6, 6), (4, 4)]);
let orig_num_nodes = h.num_nodes();
let res = BorrowSquashPass::default().run(&mut h).unwrap();
h.validate().unwrap();
let expected_elisions = 16;
assert_eq!(res.len(), expected_elisions);
assert_eq!(h.num_nodes(), orig_num_nodes - 2 * expected_elisions);
assert_eq!(counts(&h), vec![(1, 1); 4]);
}
fn find_borrows<H: HugrView>(h: &H) -> impl Iterator<Item = H::Node> + '_ {
h.entry_descendants().filter(|n| {
h.get_optype(*n).as_extension_op().is_some_and(|eop| {
BArrayUnsafeOpDef::from_extension_op(eop) == Ok(BArrayUnsafeOpDef::borrow)
})
})
}
fn find_returns<H: HugrView>(h: &H) -> impl Iterator<Item = H::Node> + '_ {
h.entry_descendants().filter(|n| {
h.get_optype(*n).as_extension_op().is_some_and(|eop| {
BArrayUnsafeOpDef::from_extension_op(eop) == Ok(BArrayUnsafeOpDef::r#return)
})
})
}
fn get_index<H: HugrView>(h: &H, n: H::Node) -> u64 {
find_const(h, n, 1.into()).unwrap()
}
#[rstest]
fn test_nested_array() {
let inner_array_type = BorrowArray::ty(5, qb_t());
let outer_array_type = BorrowArray::ty(3, inner_array_type.clone());
let reader = BufReader::new(
include_bytes!("../../../test_files/guppy_optimization/nested_array/nested_array.hugr")
.as_slice(),
);
let mut h = Hugr::load(reader, Some(®ISTRY)).unwrap();
let array_func = h
.children(h.module_root())
.find(|n| {
h.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.func_name() == "main")
})
.unwrap();
h.set_entrypoint(array_func);
ConstantFoldPass::default().run(&mut h).unwrap();
for nodes in [find_borrows(&h).collect_vec(), find_returns(&h).collect()] {
let mut outer_count = 0;
for node in nodes {
let expected_array_type = match get_index(&h, node) {
0 => {
outer_count += 1;
&outer_array_type
}
1 | 2 => &inner_array_type,
idx => panic!("Unexpected index {idx}"),
};
assert_eq!(
h.get_optype(node)
.dataflow_signature()
.unwrap()
.input_types()[0],
*expected_array_type
);
}
assert_eq!(outer_count, 8);
}
let [cx1, cx2] = h
.nodes()
.filter(|n| {
h.get_optype(*n)
.as_extension_op()
.is_some_and(|eop| eop.qualified_id() == "tket.quantum.CX")
})
.collect_array()
.unwrap();
for cx in [cx1, cx2] {
assert!(
BTreeSet::from_iter(find_returns(&h))
.is_superset(&h.output_neighbours(cx).collect())
);
assert!(
BTreeSet::from_iter(find_borrows(&h))
.is_superset(&h.input_neighbours(cx).collect())
);
}
let res = BorrowSquashPass::default().run(&mut h).unwrap();
h.validate().unwrap();
assert_eq!(res.len(), 9);
assert_eq!(
find_borrows(&h)
.map(|n| get_index(&h, n))
.sorted()
.collect_vec(),
[0, 1, 2]
);
assert_eq!(
find_returns(&h)
.map(|n| get_index(&h, n))
.sorted()
.collect_vec(),
[0, 1, 2]
);
for cx in [cx1, cx2] {
assert!(
h.get_optype(cx)
.as_extension_op()
.is_some_and(|eop| eop.qualified_id() == "tket.quantum.CX")
);
}
assert!(h.output_neighbours(cx1).all(|n| n == cx2));
}
#[rstest]
fn test_dynamic(#[values(0, 1, 2)] dyn_pos: usize) {
let ins = vec![usize_t(), BorrowArray::ty(10, usize_t())];
let mut outs = vec![usize_t(); 3];
outs.extend(ins.clone());
let mut fb = FunctionBuilder::new("test", Signature::new(ins, outs)).unwrap();
let [i, arr] = fb.input_wires_arr();
let borrow = |fb: &mut FunctionBuilder<Hugr>, arr, i| {
let op = BArrayUnsafeOpDef::borrow.to_concrete(usize_t(), 10);
let h = fb.add_dataflow_op(op, [arr, i]).unwrap();
(h.node(), h.outputs_arr::<2>())
};
let return_ = |fb: &mut FunctionBuilder<Hugr>, arr, i, val| {
let op = BArrayUnsafeOpDef::r#return.to_concrete(usize_t(), 10);
let h = fb.add_dataflow_op(op, [arr, i, val]).unwrap();
(h.node(), h.outputs_arr::<1>()[0])
};
let [one, two, three, four] = [1, 2, 3, 4].map(|i| fb.add_load_value(ConstUsize::new(i)));
let (arr, xi) = if dyn_pos == 0 {
let (_, [arr, xi]) = borrow(&mut fb, arr, i);
(arr, Some(xi))
} else {
(arr, None)
};
let (_, [arr, x1]) = borrow(&mut fb, arr, one);
let (ret1, arr) = return_(&mut fb, arr, one, x1);
let (rebo1, [arr, x1]) = borrow(&mut fb, arr, one);
let (_, [arr, x2]) = borrow(&mut fb, arr, two);
let (_, arr) = return_(&mut fb, arr, two, x2);
let (_, [arr, x3]) = borrow(&mut fb, arr, three);
let (arr, xi) = if dyn_pos == 0 {
let (_, arr) = return_(&mut fb, arr, i, xi.unwrap());
(arr, None)
} else {
assert!(xi.is_none());
let (_, [arr, xi]) = borrow(&mut fb, arr, i);
if dyn_pos == 1 {
(arr, Some(xi))
} else {
let (_, arr) = return_(&mut fb, arr, i, xi);
(arr, None)
}
};
let (_, [arr, x2]) = borrow(&mut fb, arr, two);
let (_, arr) = return_(&mut fb, arr, three, x3);
let (_, [arr, x3]) = borrow(&mut fb, arr, three);
let (_, [arr, x4]) = borrow(&mut fb, arr, four);
let (ret4, arr) = return_(&mut fb, arr, four, x4);
let (rebo4, [arr, x4]) = borrow(&mut fb, arr, four);
let arr = xi.map_or(arr, |xi| return_(&mut fb, arr, i, xi).1);
let mut h = fb.finish_hugr_with_outputs([x1, x2, x3, x4, arr]).unwrap();
let res = BorrowSquashPass::default().run(&mut h).unwrap();
assert_eq!(res, [(ret1, rebo1), (ret4, rebo4)]);
h.validate().unwrap();
}
}