use core::panic;
use std::cmp::Ordering;
use derive_more::derive::{Display, Error};
use hugr::{Direction, HugrView, IncomingPort, OutgoingPort, Port, core::HugrNode};
use itertools::{Either, Itertools};
use crate::resource::{Position, ResourceId, ResourceScope};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Interval<N> {
Singleton {
start_or_end: (N, Port),
},
Span {
start: (N, OutgoingPort),
end: (N, IncomingPort),
},
}
impl<N: HugrNode> Interval<N> {
pub fn new_singleton(
resource_id: ResourceId,
node: N,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Option<Self> {
let in_port = scope.get_port(node, resource_id, Direction::Incoming);
let out_port = scope.get_port(node, resource_id, Direction::Outgoing);
let port = in_port.or(out_port)?;
Some(Self::Singleton {
start_or_end: (node, port),
})
}
fn new_span(extrema: [(N, Port); 2]) -> Self {
let mut start = None;
let mut end = None;
for (node, port) in extrema {
match port.as_directed() {
Either::Left(incoming) => {
if end.replace((node, incoming)).is_some() {
panic!("multiple incoming ports in span extrema");
}
}
Either::Right(outgoing) => {
if start.replace((node, outgoing)).is_some() {
panic!("multiple outgoing ports in span extrema");
}
}
}
}
debug_assert!(
start.map(|s| s.0) != end.map(|e| e.0),
"start and end nodes must differ"
);
Self::Span {
start: start.expect("exactly one outgoing port in span extrema"),
end: end.expect("exactly one incoming port in span extrema"),
}
}
pub fn new(
resource_id: ResourceId,
start_node: N,
end_node: N,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Self {
Self::try_new(resource_id, start_node, end_node, scope).unwrap()
}
pub fn try_new(
resource_id: ResourceId,
start_node: N,
end_node: N,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Result<Self, InvalidInterval<N>> {
let start_pos = scope
.get_position(start_node)
.ok_or(InvalidInterval::NotOnResourcePath(start_node))?;
let end_pos = scope
.get_position(end_node)
.ok_or(InvalidInterval::NotOnResourcePath(end_node))?;
if start_pos > end_pos {
return Err(InvalidInterval::StartAfterEnd(
start_node,
end_node,
resource_id,
));
}
if start_node == end_node {
Interval::new_singleton(resource_id, start_node, scope)
.ok_or(InvalidInterval::NotOnResourcePath(start_node))
} else {
let start_port = scope
.get_port(start_node, resource_id, Direction::Outgoing)
.ok_or(InvalidInterval::NotOnResourcePath(start_node))?;
let end_port = scope
.get_port(end_node, resource_id, Direction::Incoming)
.ok_or(InvalidInterval::NotOnResourcePath(end_node))?;
Ok(Self::new_span([
(start_node, start_port),
(end_node, end_port),
]))
}
}
pub fn resource_id(&self, scope: &ResourceScope<impl HugrView<Node = N>>) -> ResourceId {
let (node, port) = self.any_port();
scope
.get_resource_id(node, port)
.expect("interval port is a resource port in scope")
}
#[inline]
pub fn outgoing_boundary_port(
&self,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Option<(N, OutgoingPort)> {
self.boundary_port(Direction::Outgoing, scope)
.map(|(n, p)| (n, p.as_outgoing().expect("outgoing port")))
}
#[inline]
pub fn incoming_boundary_port(
&self,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Option<(N, IncomingPort)> {
self.boundary_port(Direction::Incoming, scope)
.map(|(n, p)| (n, p.as_incoming().expect("incoming port")))
}
pub fn boundary_port(
&self,
direction: Direction,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Option<(N, Port)> {
let (node, port) = match (*self, direction) {
(Interval::Singleton { start_or_end }, _) => start_or_end,
(Interval::Span { start, .. }, Direction::Incoming) => (start.0, start.1.into()),
(Interval::Span { end, .. }, Direction::Outgoing) => (end.0, end.1.into()),
};
let resource_id = scope
.get_resource_id(node, port)
.expect("interval port is a resource port in scope");
scope
.get_port(node, resource_id, direction)
.map(|port| (node, port))
}
pub fn start_node(&self) -> N {
match *self {
Interval::Singleton {
start_or_end: (node, _),
} => node,
Interval::Span {
start: (start_node, _),
..
} => start_node,
}
}
pub fn end_node(&self) -> N {
match *self {
Interval::Singleton {
start_or_end: (node, _),
} => node,
Interval::Span {
end: (end_node, _), ..
} => end_node,
}
}
pub fn try_extend(
&mut self,
node: N,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Result<Option<Direction>, InvalidInterval<N>> {
let Some(pos) = scope.get_position(node) else {
return Err(InvalidInterval::NotOnResourcePath(node));
};
match self.position_in_interval(pos, scope) {
Ordering::Equal => {
return Ok(None);
}
Ordering::Less => {
let (in_node, in_port) = self
.incoming_boundary_port(scope)
.ok_or(InvalidInterval::NotContiguous(node))?;
let (prev_node, _) = scope
.hugr()
.single_linked_output(in_node, in_port)
.ok_or(InvalidInterval::NotContiguous(node))?;
if prev_node != node {
return Err(InvalidInterval::NotContiguous(node));
}
}
Ordering::Greater => {
let (out_node, out_port) = self
.outgoing_boundary_port(scope)
.ok_or(InvalidInterval::NotContiguous(node))?;
let (next_node, _) = scope
.hugr()
.single_linked_input(out_node, out_port)
.ok_or(InvalidInterval::NotContiguous(node))?;
if next_node != node {
return Err(InvalidInterval::NotContiguous(node));
}
}
};
Ok(self.add_node_unchecked(node, scope))
}
pub(crate) fn add_node_unchecked(
&mut self,
node: N,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Option<Direction> {
let pos = scope
.get_position(node)
.expect("node must be on resource path");
let resource_id = self.resource_id(scope);
let extension_dir = match self.position_in_interval(pos, scope) {
Ordering::Less => Direction::Incoming,
Ordering::Greater => Direction::Outgoing,
Ordering::Equal => {
return None;
}
};
let new_extrema_node = node;
let new_extrema_port = scope
.get_port(new_extrema_node, resource_id, extension_dir.reverse())
.expect("node is on interval resource path");
let existing_extrema = match (*self, extension_dir) {
(Interval::Span { end: (n, p), .. }, Direction::Incoming) => (n, p.into()),
(Interval::Span { start: (n, p), .. }, Direction::Outgoing) => (n, p.into()),
(Interval::Singleton { start_or_end }, dir) => {
ensure_direction_resource_port(start_or_end, dir, scope)
.expect("not a resource path end")
}
};
debug_assert!(new_extrema_node != existing_extrema.0,);
*self = Self::new_span([(new_extrema_node, new_extrema_port), existing_extrema]);
Some(extension_dir)
}
fn position_in_interval(
&self,
pos: Position,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Ordering {
if pos < self.start_pos(scope) {
Ordering::Less
} else if pos > self.end_pos(scope) {
Ordering::Greater
} else {
Ordering::Equal
}
}
fn any_port(&self) -> (N, Port) {
match *self {
Interval::Singleton {
start_or_end: node_port,
} => node_port,
Interval::Span {
start: (out_node, out_port),
..
} => (out_node, out_port.into()),
}
}
#[inline]
fn start_pos(&self, scope: &ResourceScope<impl HugrView<Node = N>>) -> Position {
let start_node = self.start_node();
scope
.get_position(start_node)
.expect("valid interval start node")
}
#[inline]
fn end_pos(&self, scope: &ResourceScope<impl HugrView<Node = N>>) -> Position {
let end_node = self.end_node();
scope
.get_position(end_node)
.expect("valid interval end node")
}
}
fn ensure_direction_resource_port<N: HugrNode>(
(node, port): (N, Port),
dir: Direction,
scope: &ResourceScope<impl HugrView<Node = N>>,
) -> Option<(N, Port)> {
if dir == port.direction() {
Some((node, port))
} else {
let resource_id = scope
.get_resource_id(node, port)
.expect("interval port is a resource port in scope");
let port = scope.get_port(node, resource_id, dir)?;
Some((node, port))
}
}
#[derive(Debug, Clone, PartialEq, Display, Error)]
pub enum InvalidInterval<N> {
#[display("node {_0:?} is not contiguous with the interval")]
NotContiguous(N),
#[display("node {_0:?} is not on the interval's resource path")]
NotOnResourcePath(N),
#[display("start node {_0:?} is after end node {_1:?} on resource path {_2:?}")]
StartAfterEnd(N, N, ResourceId),
}
impl<H: HugrView> ResourceScope<H> {
pub fn nodes_in_interval(
&self,
interval: Interval<H::Node>,
) -> impl Iterator<Item = H::Node> + '_ {
let start_node = interval.start_node();
let end_node = interval.end_node();
let resource_id = interval.resource_id(self);
self.resource_path_iter(resource_id, start_node, Direction::Outgoing)
.take_while_inclusive(move |&node| node != end_node)
}
}
#[cfg(test)]
mod tests {
use super::{ResourceScope, *};
use std::ops::RangeInclusive;
use crate::{Circuit, resource::tests::cx_circuit};
use itertools::Itertools;
use rstest::{fixture, rstest};
#[test]
fn test_nodes_in_interval() {
let circ = cx_circuit(5);
let subgraph = Circuit::from(&circ).subgraph().unwrap();
let cx_nodes = subgraph.nodes().to_owned();
let scope = ResourceScope::new(&circ, subgraph);
assert_eq!(cx_nodes.len(), 5);
let pos_interval = 1usize..4;
for resource_id in [0, 1].map(ResourceId::new) {
let interval = Interval::new(
resource_id,
cx_nodes[pos_interval.start],
cx_nodes[pos_interval.end - 1],
&scope,
);
assert_eq!(
Interval::new(
resource_id,
cx_nodes[pos_interval.start],
cx_nodes[pos_interval.end - 1],
&scope
),
interval
);
assert_eq!(
scope.nodes_in_interval(interval).collect_vec(),
cx_nodes[pos_interval.clone()]
);
}
}
#[fixture]
fn cx_circuit_scope() -> ResourceScope {
let circ = cx_circuit(5);
ResourceScope::from_circuit(Circuit::from(circ))
}
#[rstest]
#[case::extend_left(
1,
Some(Direction::Incoming),
1..=3,
)]
#[case::extend_right(
4,
Some(Direction::Outgoing),
2..=4,
)]
#[case::node_already_in_interval_start(
2,
None,
2..=3,
)]
#[case::node_already_in_interval_end(
3,
None,
2..=3,
)]
fn test_try_extend_success(
cx_circuit_scope: ResourceScope,
#[case] node_to_extend: usize,
#[case] expected_direction: Option<Direction>,
#[case] expected_range: RangeInclusive<usize>,
) {
let cx_nodes = cx_circuit_scope.nodes();
let mut interval = Interval::new(
ResourceId::new(0),
cx_nodes[2],
cx_nodes[3],
&cx_circuit_scope,
);
let result = interval
.try_extend(cx_nodes[node_to_extend], &cx_circuit_scope)
.unwrap();
assert_eq!(result, expected_direction);
assert_eq!(interval.start_node(), cx_nodes[*expected_range.start()]);
assert_eq!(interval.end_node(), cx_nodes[*expected_range.end()]);
}
#[rstest]
#[case::extend_left(
1,
Some(Direction::Incoming),
1..=2,
)]
#[case::extend_right(
3,
Some(Direction::Outgoing),
2..=3,
)]
#[case::node_already_in_interval_start(
2,
None,
2..=2,
)]
fn test_try_extend_singleton(
cx_circuit_scope: ResourceScope,
#[case] node_to_extend: usize,
#[case] expected_direction: Option<Direction>,
#[case] expected_range: RangeInclusive<usize>,
) {
let cx_nodes = cx_circuit_scope.nodes();
let mut interval = Interval::new(
ResourceId::new(0),
cx_nodes[2],
cx_nodes[2],
&cx_circuit_scope,
);
assert_eq!(
interval,
Interval::new_singleton(ResourceId::new(0), cx_nodes[2], &cx_circuit_scope).unwrap()
);
let result = interval
.try_extend(cx_nodes[node_to_extend], &cx_circuit_scope)
.unwrap();
assert_eq!(result, expected_direction);
assert_eq!(interval.start_node(), cx_nodes[*expected_range.start()]);
assert_eq!(interval.end_node(), cx_nodes[*expected_range.end()]);
}
#[rstest]
fn test_try_extend_error(cx_circuit_scope: ResourceScope) {
let cx_nodes = cx_circuit_scope.nodes();
let mut interval = Interval::new(
ResourceId::new(0),
cx_nodes[2],
cx_nodes[3],
&cx_circuit_scope,
);
let result = interval
.try_extend(cx_nodes[0], &cx_circuit_scope)
.unwrap_err();
assert_eq!(result, InvalidInterval::NotContiguous(cx_nodes[0]));
assert_eq!(interval.start_node(), cx_nodes[2]);
assert_eq!(interval.end_node(), cx_nodes[3]);
}
#[rstest]
fn interval_is_scope_independent(cx_circuit_scope: ResourceScope) {
let other_scope = {
let mut tmp_scope = cx_circuit_scope.clone();
tmp_scope.map_positions(|pos| pos.increment());
tmp_scope
};
let scope = cx_circuit_scope;
let cx_nodes = scope.nodes();
let interval = Interval::new(ResourceId::new(0), cx_nodes[2], cx_nodes[4], &scope);
assert_eq!(
scope.nodes_in_interval(interval).collect_vec(),
other_scope.nodes_in_interval(interval).collect_vec()
);
assert_eq!(
interval.incoming_boundary_port(&scope),
interval.incoming_boundary_port(&other_scope)
);
assert_eq!(
interval.outgoing_boundary_port(&scope),
interval.outgoing_boundary_port(&other_scope)
);
}
}