pub mod filter;
use std::iter::FusedIterator;
use std::marker::PhantomData;
use hugr::core::HugrNode;
use hugr::types::{EdgeKind, Type, TypeRow};
use hugr::{CircuitUnit, HugrView, IncomingPort, OutgoingPort};
use hugr::{Direction, Node, Port, Wire};
use crate::utils::type_is_linear;
use super::Circuit;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct LinearUnit(usize);
impl LinearUnit {
pub fn new(index: usize) -> Self {
Self(index)
}
pub fn index(&self) -> usize {
self.0
}
}
impl From<LinearUnit> for CircuitUnit {
fn from(lu: LinearUnit) -> Self {
CircuitUnit::Linear(lu.index())
}
}
impl TryFrom<CircuitUnit> for LinearUnit {
type Error = ();
fn try_from(cu: CircuitUnit) -> Result<Self, Self::Error> {
match cu {
CircuitUnit::Wire(_) => Err(()),
CircuitUnit::Linear(i) => Ok(LinearUnit(i)),
}
}
}
#[derive(Clone, Debug)]
pub struct Units<P, N = Node, UL = DefaultUnitLabeller> {
node: N,
types: TypeRow,
pos: usize,
linear_count: usize,
unit_labeller: UL,
_port: PhantomData<P>,
}
impl<N: HugrNode> Units<OutgoingPort, N, DefaultUnitLabeller> {
#[inline]
pub(super) fn new_circ_input<T: HugrView<Node = N>>(circuit: &Circuit<T>) -> Self {
Self::new_outgoing(circuit, circuit.input_node(), DefaultUnitLabeller)
}
}
impl<N: HugrNode, UL> Units<OutgoingPort, N, UL>
where
UL: UnitLabeller<N>,
{
#[inline]
pub(super) fn new_outgoing<T: HugrView<Node = N>>(
circuit: &Circuit<T>,
node: N,
unit_labeller: UL,
) -> Self {
Self::new_with_dir(circuit, node, Direction::Outgoing, unit_labeller)
}
}
impl<N: HugrNode, UL> Units<IncomingPort, N, UL>
where
UL: UnitLabeller<N>,
{
#[inline]
pub(super) fn new_incoming<T: HugrView<Node = N>>(
circuit: &Circuit<T>,
node: N,
unit_labeller: UL,
) -> Self {
Self::new_with_dir(circuit, node, Direction::Incoming, unit_labeller)
}
}
impl<P, N: HugrNode, UL> Units<P, N, UL>
where
P: Into<Port> + Copy,
UL: UnitLabeller<N>,
{
#[inline]
fn new_with_dir<T: HugrView<Node = N>>(
circuit: &Circuit<T>,
node: N,
direction: Direction,
unit_labeller: UL,
) -> Self {
Self {
node,
types: Self::init_types(circuit, node, direction),
pos: 0,
linear_count: 0,
unit_labeller,
_port: PhantomData,
}
}
fn init_types<T: HugrView>(
circuit: &Circuit<T>,
node: T::Node,
direction: Direction,
) -> TypeRow {
let hugr = circuit.hugr();
let optype = hugr.get_optype(node);
let sig = hugr.signature(node).unwrap_or_default().into_owned();
let mut types = match direction {
Direction::Outgoing => sig.output,
Direction::Incoming => sig.input,
};
if let Some(EdgeKind::Const(static_type)) = optype.static_port_kind(direction) {
types.to_mut().push(static_type);
};
if let Some(EdgeKind::Const(other)) = optype.other_port_kind(direction) {
types.to_mut().push(other);
}
types
}
#[inline]
fn make_value(&self, typ: &Type, port: P) -> Option<(CircuitUnit<N>, P, Type)> {
let unit = if type_is_linear(typ) {
let linear_unit =
self.unit_labeller
.assign_linear(self.node, port.into(), self.linear_count - 1);
CircuitUnit::Linear(linear_unit.index())
} else {
let wire = self.unit_labeller.assign_wire(self.node, port.into())?;
CircuitUnit::Wire(wire)
};
Some((unit, port, typ.clone()))
}
fn next_generic(&mut self) -> Option<(CircuitUnit<N>, P, Type)>
where
P: From<usize>,
{
loop {
let typ = self.types.get(self.pos)?;
let port = P::from(self.pos);
self.pos += 1;
if type_is_linear(typ) {
self.linear_count += 1;
}
if let Some(val) = self.make_value(typ, port) {
return Some(val);
}
}
}
}
impl<N: HugrNode, UL> Iterator for Units<OutgoingPort, N, UL>
where
UL: UnitLabeller<N>,
{
type Item = (CircuitUnit<N>, OutgoingPort, Type);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.next_generic()
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.types.len() - self.pos;
(len, Some(len))
}
}
impl<N: HugrNode, UL> Iterator for Units<IncomingPort, N, UL>
where
UL: UnitLabeller<N>,
{
type Item = (CircuitUnit<N>, IncomingPort, Type);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.next_generic()
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.types.len() - self.pos;
(0, Some(len))
}
}
impl<N: HugrNode, UL> FusedIterator for Units<OutgoingPort, N, UL> where UL: UnitLabeller<N> {}
impl<N: HugrNode, UL> FusedIterator for Units<IncomingPort, N, UL> where UL: UnitLabeller<N> {}
pub trait UnitLabeller<N> {
fn assign_linear(&self, node: N, port: Port, linear_count: usize) -> LinearUnit;
fn assign_wire(&self, node: N, port: Port) -> Option<Wire<N>>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct DefaultUnitLabeller;
impl<N: HugrNode> UnitLabeller<N> for DefaultUnitLabeller {
#[inline]
fn assign_linear(&self, _: N, _: Port, linear_count: usize) -> LinearUnit {
LinearUnit(linear_count)
}
#[inline]
fn assign_wire(&self, node: N, port: Port) -> Option<Wire<N>> {
let port = port.as_outgoing().ok()?;
Some(Wire::new(node, port))
}
}