use super::{DirWire, ModifierResolver, ModifierResolverErrors};
use hugr::{
IncomingPort, OutgoingPort,
builder::Dataflow,
core::HugrNode,
extension::simple_op::MakeExtensionOp,
hugr::hugrmut::HugrMut,
ops::OpType,
std_extensions::collections::{
array::{
ArrayKind, ArrayOp, GenericArrayOp,
GenericArrayOpDef::{self, *},
},
borrow_array::{
BArrayFromArray, BArrayOp, BArrayToArray, BArrayUnsafeOp, BArrayUnsafeOpDef,
},
},
};
impl<N: HugrNode> ModifierResolver<N> {
pub(super) fn modify_array_op(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
optype: &OpType,
new_dfg: &mut impl Dataflow,
) -> Result<bool, ModifierResolverErrors<N>> {
if let Some(op) = ArrayOp::from_optype(optype) {
self.generic_modify_array_op(h, n, op, new_dfg)?;
} else if let Some(op) = BArrayOp::from_optype(optype) {
self.generic_modify_array_op(h, n, op, new_dfg)?;
} else if let Some(op) = BArrayUnsafeOp::from_optype(optype) {
return self.modify_borrow_array_unsafe_op(h, n, op, new_dfg);
} else {
return Ok(false);
}
Ok(true)
}
fn modify_borrow_array_unsafe_op(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
op: BArrayUnsafeOp,
new_dfg: &mut impl Dataflow,
) -> Result<bool, ModifierResolverErrors<N>> {
if !self.modifiers().dagger || !self.qubit_finder.contains_element_type(&op.elem_ty) {
self.add_node_no_modification(h, n, op, new_dfg)?;
return Ok(true);
}
let new_op_def = match op.def {
BArrayUnsafeOpDef::borrow => BArrayUnsafeOpDef::r#return,
BArrayUnsafeOpDef::r#return => BArrayUnsafeOpDef::borrow,
_ => return Ok(false),
};
let node = new_dfg.add_child_node(new_op_def.to_concrete(op.elem_ty, op.size));
self.map_insert(
(n, IncomingPort::from(0)).into(),
(node, OutgoingPort::from(0)).into(),
)?;
self.map_insert(
(n, IncomingPort::from(1)).into(),
(node, IncomingPort::from(1)).into(),
)?;
self.map_insert(
(n, OutgoingPort::from(0)).into(),
(node, IncomingPort::from(0)).into(),
)?;
match op.def {
BArrayUnsafeOpDef::borrow => {
self.map_insert(
(n, OutgoingPort::from(1)).into(),
(node, IncomingPort::from(2)).into(),
)?;
}
BArrayUnsafeOpDef::r#return => {
self.map_insert(
(n, IncomingPort::from(2)).into(),
(node, OutgoingPort::from(1)).into(),
)?;
}
_ => unreachable!(),
}
Ok(true)
}
fn generic_modify_array_op<AK: ArrayKind>(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
op: GenericArrayOp<AK>,
new_dfg: &mut impl Dataflow,
) -> Result<(), ModifierResolverErrors<N>> {
let op_def = &op.def;
if !self.modifiers().dagger || !self.qubit_finder.contains_element_type(&op.elem_ty) {
self.add_node_no_modification(h, n, op, new_dfg)?;
} else {
let new_op_def: GenericArrayOpDef<AK> = match op_def {
swap => swap,
new_array => unpack,
unpack => new_array,
get => get,
set => set,
pop_left => pop_left,
pop_right => pop_right,
discard_empty => discard_empty,
_ => {
return Err(ModifierResolverErrors::unresolvable(
n,
format!("Cannot modify array operation {op_def:?} under dagger"),
OpType::from(op.clone()),
));
}
};
let new_op = new_op_def.to_concrete(op.elem_ty, op.size);
let node = new_dfg.add_child_node(new_op);
for port in h.all_node_ports(n) {
let wire = DirWire::new(node, port).reverse();
self.map_insert(DirWire(n, port), wire)?;
}
}
Ok(())
}
pub(super) fn try_array_convert(
&mut self,
h: &impl HugrMut<Node = N>,
n: N,
optype: &OpType,
new_dfg: &mut impl Dataflow,
) -> Result<bool, ModifierResolverErrors<N>> {
if !self.modifiers().dagger {
self.add_node_no_modification(h, n, optype.clone(), new_dfg)?;
return Ok(true);
}
let Some(op) = optype.as_extension_op() else {
return Ok(false);
};
let node = if let Ok(op) = BArrayToArray::from_extension_op(op) {
if !self.qubit_finder.contains_element_type(&op.elem_ty) {
self.add_node_no_modification(h, n, optype.clone(), new_dfg)?;
return Ok(true);
}
new_dfg.add_child_node(BArrayFromArray::new(op.elem_ty, op.size))
} else if let Ok(op) = BArrayFromArray::from_extension_op(op) {
if !self.qubit_finder.contains_element_type(&op.elem_ty) {
self.add_node_no_modification(h, n, optype.clone(), new_dfg)?;
return Ok(true);
}
new_dfg.add_child_node(BArrayToArray::new(op.elem_ty, op.size))
} else {
return Ok(false);
};
for port in h.all_node_ports(n) {
let wire = DirWire::new(node, port).reverse();
self.map_insert(DirWire(n, port), wire)?;
}
Ok(true)
}
}