use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::iter::FusedIterator;
use hugr::hugr::NodeMetadata;
use hugr::ops::{OpTag, OpTrait};
use hugr::{HugrView, IncomingPort, OutgoingPort};
use hugr_core::hugr::internal::{HugrInternals, PortgraphNodeMap};
use itertools::Either::{self, Left, Right};
use itertools::{EitherOrBoth, Itertools};
use petgraph::visit as pv;
use portgraph::PortView;
use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units};
use super::Circuit;
pub use hugr::ops::OpType;
pub use hugr::types::{EdgeKind, Type, TypeRow};
pub use hugr::{CircuitUnit, Direction, Node, Port, PortIndex, Wire};
pub struct Command<'circ, T: HugrView> {
circ: &'circ Circuit<T>,
node: Node,
input_linear_units: Vec<LinearUnit>,
output_linear_units: Vec<LinearUnit>,
}
impl<'circ, T: HugrView<Node = Node>> Command<'circ, T> {
#[inline]
pub fn node(&self) -> Node {
self.node
}
#[inline]
pub fn optype(&self) -> &'circ OpType {
self.circ.hugr().get_optype(self.node)
}
#[inline]
pub fn units(
&self,
direction: Direction,
) -> impl Iterator<Item = (CircuitUnit, Port, Type)> + '_ {
match direction {
Direction::Incoming => Either::Left(self.inputs().map(|(u, p, t)| (u, p.into(), t))),
Direction::Outgoing => Either::Right(self.outputs().map(|(u, p, t)| (u, p.into(), t))),
}
}
#[inline]
pub fn linear_units(
&self,
direction: Direction,
) -> impl Iterator<Item = (LinearUnit, Port, Type)> + '_ {
match direction {
Direction::Incoming => {
Either::Left(self.linear_inputs().map(|(u, p, t)| (u, p.into(), t)))
}
Direction::Outgoing => {
Either::Right(self.linear_outputs().map(|(u, p, t)| (u, p.into(), t)))
}
}
}
#[inline]
pub fn input_qubits(&self) -> impl Iterator<Item = (LinearUnit, IncomingPort, Type)> + '_ {
self.inputs().filter_map(filter::filter_qubit)
}
#[inline]
pub fn output_qubits(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_ {
self.outputs().filter_map(filter::filter_qubit)
}
#[inline]
pub fn outputs(&self) -> Units<OutgoingPort, Node, &'_ Self> {
Units::new_outgoing(self.circ, self.node, self)
}
#[inline]
pub fn linear_outputs(&self) -> impl Iterator<Item = (LinearUnit, OutgoingPort, Type)> + '_ {
self.outputs().filter_map(filter::filter_linear)
}
#[inline]
pub fn output_wires(&self) -> impl Iterator<Item = (CircuitUnit, Wire)> + '_ {
self.outputs().filter_map(move |(unit, port, _typ)| {
let w = self.assign_wire(self.node, port.into())?;
Some((unit, w))
})
}
#[inline]
pub fn inputs(&self) -> Units<IncomingPort, Node, &'_ Self> {
Units::new_incoming(self.circ, self.node, self)
}
#[inline]
pub fn linear_inputs(&self) -> impl Iterator<Item = (LinearUnit, IncomingPort, Type)> + '_ {
self.inputs().filter_map(filter::filter_linear)
}
#[inline]
pub fn input_wires(&self) -> impl IntoIterator<Item = (CircuitUnit, Wire)> + '_ {
self.inputs().filter_map(move |(unit, port, _typ)| {
let w = self.assign_wire(self.node, port.into())?;
Some((unit, w))
})
}
#[inline]
pub fn input_count(&self) -> usize {
self.optype().value_input_count() + self.optype().static_input_port().is_some() as usize
}
#[inline]
pub fn output_count(&self) -> usize {
self.optype().value_output_count() + self.optype().static_output_port().is_some() as usize
}
#[inline]
pub fn linear_unit_port(&self, unit: LinearUnit, direction: Direction) -> Option<Port> {
self.linear_units(direction)
.find(|(cu, _, _)| *cu == unit)
.map(|(_, port, _)| port)
}
#[inline]
pub fn is_linear_port(&self, port: Port) -> bool {
self.optype()
.port_kind(port)
.is_some_and(|kind| kind.is_linear())
}
#[inline]
pub fn metadata(&self, key: impl AsRef<str>) -> Option<&NodeMetadata> {
self.circ.hugr().get_metadata(self.node, key)
}
}
impl<T: HugrView<Node = Node>> UnitLabeller<Node> for &Command<'_, T> {
#[inline]
fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit {
let units = match port.direction() {
Direction::Incoming => &self.input_linear_units,
Direction::Outgoing => &self.output_linear_units,
};
*units.get(port.index()).unwrap_or_else(|| {
panic!(
"Could not assign a linear unit to port {port:?} of node {:?}",
self.node
)
})
}
#[inline]
fn assign_wire(&self, node: Node, port: Port) -> Option<Wire> {
match port.as_directed() {
Left(to_port) => {
let (from, from_port) = self.circ.hugr().linked_outputs(node, to_port).next()?;
Some(Wire::new(from, from_port))
}
Right(from_port) => Some(Wire::new(node, from_port)),
}
}
}
impl<T: HugrView<Node = Node>> std::fmt::Debug for Command<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Command")
.field("circuit name", &self.circ.name())
.field("node", &self.node)
.field("input_linear_units", &self.input_linear_units)
.field("output_linear_units", &self.output_linear_units)
.finish()
}
}
impl<T: HugrView> PartialEq for Command<'_, T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.input_linear_units == other.input_linear_units
&& self.output_linear_units == other.output_linear_units
}
}
impl<T: HugrView> Eq for Command<'_, T> {}
impl<T: HugrView> Clone for Command<'_, T> {
fn clone(&self) -> Self {
Self {
circ: self.circ,
node: self.node,
input_linear_units: self.input_linear_units.clone(),
output_linear_units: self.output_linear_units.clone(),
}
}
}
impl<T: HugrView> std::hash::Hash for Command<'_, T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node.hash(state);
self.input_linear_units.hash(state);
self.output_linear_units.hash(state);
}
}
type NodeWalker<'circ, T> = pv::Topo<
portgraph::NodeIndex,
<portgraph::view::FlatRegion<'circ, <T as HugrInternals>::RegionPortgraph<'circ>> as petgraph::visit::Visitable>::Map,
>;
#[derive(Clone)]
pub struct CommandIterator<'circ, T: HugrView> {
circ: &'circ Circuit<T>,
region: portgraph::view::FlatRegion<'circ, T::RegionPortgraph<'circ>>,
region_node_map: T::RegionPortgraphNodes,
nodes: NodeWalker<'circ, T>,
wire_unit: HashMap<Wire, usize>,
max_remaining: usize,
delayed_consts: HashSet<Node>,
delayed_consumers: HashMap<Node, usize>,
delayed_node: Option<Node>,
}
impl<'circ, T: HugrView<Node = Node>> CommandIterator<'circ, T> {
pub(super) fn new(circ: &'circ Circuit<T>) -> Self {
let wire_unit = circ
.linear_units()
.map(|(linear_unit, port, _)| (Wire::new(circ.input_node(), port), linear_unit.index()))
.collect();
let (region, region_node_map) = circ.hugr().region_portgraph(circ.parent());
let node_count = region.node_count();
let nodes = pv::Topo::new(®ion);
Self {
circ,
region,
region_node_map,
nodes,
wire_unit,
max_remaining: node_count - 2,
delayed_consts: HashSet::new(),
delayed_consumers: HashMap::new(),
delayed_node: None,
}
}
fn next_node(&mut self) -> Option<Node> {
let node = self.delayed_node.take().or_else(|| {
let pg_node = self.nodes.next(&self.region)?;
Some(self.region_node_map.from_portgraph(pg_node))
})?;
if node == self.circ.parent() {
return self.next_node();
}
let tag = self.circ.hugr().get_optype(node).tag();
if tag == OpTag::Const || tag == OpTag::LoadConst {
self.delayed_consts.insert(node);
for consumer in self.circ.hugr().output_neighbours(node) {
*self.delayed_consumers.entry(consumer).or_default() += 1;
}
return self.next_node();
}
match self.delayed_consumers.contains_key(&node) {
true => {
let delayed = self.next_delayed_node(node);
self.delayed_consts.remove(&delayed);
for consumer in self.circ.hugr().output_neighbours(delayed) {
let Entry::Occupied(mut entry) = self.delayed_consumers.entry(consumer) else {
panic!("Delayed node consumer was not in delayed_consumers. Delayed node: {delayed:?}, consumer: {consumer:?}.");
};
*entry.get_mut() -= 1;
if *entry.get() == 0 {
entry.remove();
}
}
self.delayed_node = Some(node);
Some(delayed)
}
false => Some(node),
}
}
fn next_delayed_node(&mut self, consumer: Node) -> Node {
let Some(delayed_pred) = self
.circ
.hugr()
.input_neighbours(consumer)
.find(|k| self.delayed_consts.contains(k))
else {
panic!("Could not find a delayed predecessor for node {consumer:?}.");
};
match self.delayed_consumers.contains_key(&delayed_pred) {
true => self.next_delayed_node(delayed_pred),
false => delayed_pred,
}
}
fn process_node(&mut self, node: Node) -> Option<(Vec<LinearUnit>, Vec<LinearUnit>)> {
if node == self.circ.parent() {
return None;
}
let tag = self.circ.hugr().get_optype(node).tag();
if tag == OpTag::Input || tag == OpTag::Output {
return None;
}
let mut input_linear_units = Vec::new();
let mut output_linear_units = Vec::new();
let input_units = Units::new_incoming(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear);
let output_units = Units::new_outgoing(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear);
for ports in input_units.zip_longest(output_units) {
let mut terminate_input =
|port: IncomingPort, wire_unit: &mut HashMap<Wire, usize>| -> Option<usize> {
let linear_id = self.circ.hugr().single_linked_output(node, port).and_then(
|(wire_node, wire_port)| wire_unit.remove(&Wire::new(wire_node, wire_port)),
)?;
input_linear_units.push(LinearUnit::new(linear_id));
Some(linear_id)
};
let mut register_output =
|unit: usize, port: OutgoingPort, wire_unit: &mut HashMap<Wire, usize>| {
let wire = Wire::new(node, port);
wire_unit.insert(wire, unit);
output_linear_units.push(LinearUnit::new(unit));
};
match ports {
EitherOrBoth::Right((_, out_port, _)) => {
let new_id = self.wire_unit.len();
register_output(new_id, out_port, &mut self.wire_unit);
}
EitherOrBoth::Left((_, in_port, _)) => {
terminate_input(in_port, &mut self.wire_unit);
}
EitherOrBoth::Both((_, in_port, _), (_, out_port, _)) => {
if let Some(linear_id) = terminate_input(in_port, &mut self.wire_unit) {
register_output(linear_id, out_port, &mut self.wire_unit);
}
}
}
}
Some((input_linear_units, output_linear_units))
}
}
impl<'circ, T: HugrView<Node = Node>> Iterator for CommandIterator<'circ, T> {
type Item = Command<'circ, T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
let node = self.next_node()?;
if let Some((input_linear_units, output_linear_units)) = self.process_node(node) {
self.max_remaining -= 1;
return Some(Command {
circ: self.circ,
node,
input_linear_units,
output_linear_units,
});
}
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.max_remaining))
}
}
impl<T: HugrView<Node = Node>> FusedIterator for CommandIterator<'_, T> {}
impl<T: HugrView<Node = Node>> std::fmt::Debug for CommandIterator<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommandIterator")
.field("circuit name", &self.circ.name())
.field("wire_unit", &self.wire_unit)
.field("max_remaining", &self.max_remaining)
.finish()
}
}
#[cfg(test)]
mod test {
use hugr::builder::{Container, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::qb_t;
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::handle::NodeHandle;
use hugr::ops::Value;
use hugr::types::Signature;
use itertools::Itertools;
use rstest::{fixture, rstest};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use crate::extension::rotation::ConstRotation;
use crate::utils::build_simple_circuit;
use crate::Tk2Op;
use super::*;
macro_rules! assert_eq_iter {
($iterable:expr, $expected:expr $(,)?) => {
assert_eq!($iterable.collect_vec(), $expected.into_iter().collect_vec());
};
}
#[fixture]
fn simple_circuit() -> Circuit {
build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::T, [1])?;
Ok(())
})
.unwrap()
}
#[fixture]
fn simple_module() -> Circuit {
build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::T, [1])?;
Ok(())
})
.unwrap()
}
#[fixture]
fn module_with_circuits() -> Circuit {
let mut module = simple_module();
let other_circ = simple_circuit();
let hugr = module.hugr_mut();
hugr.insert_hugr(hugr.module_root(), other_circ.into_hugr());
return module;
}
#[rstest]
#[case::dfg_rooted(simple_circuit())]
#[case::module_rooted(simple_module())]
#[case::complex_module_rooted(module_with_circuits())]
fn iterate_commands_simple(#[case] circ: Circuit) {
assert_eq!(CommandIterator::new(&circ).count(), 3);
let tk2op_name = |op: Tk2Op| op.exposed_name();
let mut commands = CommandIterator::new(&circ);
assert_eq!(commands.size_hint(), (0, Some(3)));
let hadamard = commands.next().unwrap();
assert_eq!(hadamard.optype().to_string(), tk2op_name(Tk2Op::H));
assert_eq_iter!(
hadamard.inputs().map(|(u, _, _)| u),
[CircuitUnit::Linear(0)],
);
assert_eq_iter!(
hadamard.outputs().map(|(u, _, _)| u),
[CircuitUnit::Linear(0)],
);
let cx = commands.next().unwrap();
assert_eq!(cx.optype().to_string(), tk2op_name(Tk2Op::CX));
assert_eq_iter!(
cx.inputs().map(|(unit, _, _)| unit),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)],
);
assert_eq_iter!(
cx.outputs().map(|(unit, _, _)| unit),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)],
);
let t = commands.next().unwrap();
assert_eq!(t.optype().to_string(), tk2op_name(Tk2Op::T));
assert_eq_iter!(
t.inputs().map(|(unit, _, _)| unit),
[CircuitUnit::Linear(1)],
);
assert_eq_iter!(
t.outputs().map(|(unit, _, _)| unit),
[CircuitUnit::Linear(1)],
);
assert_eq!(commands.next(), None);
}
#[test]
fn commands_nonlinear() {
let qb_row = vec![qb_t(); 1];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row)).unwrap();
let [q_in] = h.input_wires_arr();
let constant = h.add_constant(Value::extension(ConstRotation::PI_2));
let loaded_const = h.load_const(&constant);
let rz = h.add_dataflow_op(Tk2Op::Rz, [q_in, loaded_const]).unwrap();
let circ: Circuit = h.finish_hugr_with_outputs(rz.outputs()).unwrap().into();
assert_eq!(CommandIterator::new(&circ).count(), 3);
let mut commands = CommandIterator::new(&circ);
let const_cmd = commands.next().unwrap();
assert_eq!(const_cmd.optype().to_string(), "const:custom:a(Ï€*0.5)");
assert_eq_iter!(const_cmd.inputs().map(|(u, _, _)| u), [],);
assert_eq_iter!(
const_cmd.outputs().map(|(u, _, _)| u),
[CircuitUnit::Wire(Wire::new(constant.node(), 0))],
);
let load_const_cmd = commands.next().unwrap();
let load_const_node = load_const_cmd.node();
assert!(load_const_cmd.optype().is_load_constant());
assert_eq_iter!(
load_const_cmd.inputs().map(|(u, _, _)| u),
[CircuitUnit::Wire(Wire::new(constant.node(), 0))],
);
assert_eq_iter!(
load_const_cmd.outputs().map(|(u, _, _)| u),
[CircuitUnit::Wire(Wire::new(load_const_node, 0))],
);
let rz_cmd = commands.next().unwrap();
assert_eq!(rz_cmd.optype().cast(), Some(Tk2Op::Rz));
assert_eq_iter!(
rz_cmd.inputs().map(|(u, _, _)| u),
[
CircuitUnit::Linear(0),
CircuitUnit::Wire(Wire::new(load_const_node, 0))
],
);
assert_eq_iter!(
rz_cmd.outputs().map(|(u, _, _)| u),
[CircuitUnit::Linear(0)],
);
}
#[test]
fn alloc_free() -> Result<(), Box<dyn std::error::Error>> {
let qb_row = vec![qb_t(); 1];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row))?;
let [q_in] = h.input_wires_arr();
let alloc = h.add_dataflow_op(Tk2Op::QAlloc, [])?;
let [q_new] = alloc.outputs_arr();
let cx = h.add_dataflow_op(Tk2Op::CX, [q_in, q_new])?;
let [q_in, q_new] = cx.outputs_arr();
let free = h.add_dataflow_op(Tk2Op::QFree, [q_in])?;
let circ: Circuit = h.finish_hugr_with_outputs([q_new])?.into();
let mut cmds = circ.commands();
let alloc_cmd = cmds.next().unwrap();
assert_eq!(alloc_cmd.node(), alloc.node());
assert_eq!(
alloc_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[]
);
assert_eq!(
alloc_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(1)]
);
let cx_cmd = cmds.next().unwrap();
assert_eq!(cx_cmd.node(), cx.node());
assert_eq!(
cx_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)]
);
assert_eq!(
cx_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)]
);
let free_cmd = cmds.next().unwrap();
assert_eq!(free_cmd.node(), free.node());
assert_eq!(
free_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0)]
);
assert_eq!(
free_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[]
);
Ok(())
}
#[test]
fn test_impls() -> Result<(), Box<dyn std::error::Error>> {
let qb_row = vec![qb_t(); 1];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), vec![]))?;
let [q_in] = h.input_wires_arr();
h.add_dataflow_op(Tk2Op::QFree, [q_in])?;
let circ: Circuit = h.finish_hugr_with_outputs([])?.into();
let cmd1 = circ.commands().next().unwrap();
let cmd2 = circ.commands().next().unwrap();
assert_eq!(cmd1, cmd2);
let mut hasher1 = DefaultHasher::new();
cmd1.hash(&mut hasher1);
let mut hasher2 = DefaultHasher::new();
cmd2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());
Ok(())
}
}