use std::cell::RefCell;
use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use itertools::Itertools;
use smallvec::SmallVec;
use crate::algorithms::{toposort, TopoSort};
use crate::boundary::Boundary;
use crate::{Direction, LinkView, NodeIndex, PortIndex, PortView, SecondaryMap, UnmanagedDenseMap};
use super::{ConvexChecker, CreateConvexChecker};
const MAX_LINES: usize = 8;
const MAX_LINES_ON_NODE: usize = 4;
#[derive(Debug, Clone, PartialEq)]
pub struct LineConvexChecker<G>
where
G: PortView,
{
graph: G,
node_to_pos: UnmanagedDenseMap<NodeIndex<G::NodeIndexBase>, LinePositions>,
lines: Vec<Vec<NodeIndex<G::NodeIndexBase>>>,
get_intervals_scratch_space: RefCell<SmallVec<[(LineIndex, LineIntervalWithCount); MAX_LINES]>>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
struct LinePositions {
line_indices: SmallVec<[LineIndex; MAX_LINES_ON_NODE]>,
position: Position,
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct LineIndex(pub u32);
impl LineIndex {
fn as_usize(self) -> usize {
self.0 as usize
}
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
pub struct Position(pub u32);
impl Position {
fn next(self) -> Self {
Self(self.0 + 1)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct LineIntervals(SmallVec<[(LineIndex, LineInterval); MAX_LINES]>);
impl LineIntervals {
pub fn get(&self, line: LineIndex) -> Option<LineInterval> {
let (_, interval) = self.iter().find(|&(l, _)| l == line)?;
Some(interval)
}
fn get_mut(&mut self, line: LineIndex) -> Option<&mut LineInterval> {
let (_, interval) = self.0.iter_mut().find(|(l, _)| *l == line)?;
Some(interval)
}
pub fn iter(&self) -> impl Iterator<Item = (LineIndex, LineInterval)> + '_ {
self.0.iter().copied()
}
pub fn values(&self) -> impl Iterator<Item = LineInterval> + '_ {
self.iter().map(|(_, interval)| interval)
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct LineInterval {
pub min: Position,
pub max: Position,
}
type LineIntervalWithCount = (LineInterval, usize);
impl<G: LinkView> LineConvexChecker<G> {
pub fn new(graph: G) -> Self {
let inputs = graph
.nodes_iter()
.filter(|&n| graph.input_neighbours(n).count() == 0);
let topsort: TopoSort<_> = toposort(&graph, inputs, Direction::Outgoing);
let mut extend_frontier = extend_line_ends_frontier(&graph);
let mut node_to_pos =
UnmanagedDenseMap::<_, LinePositions>::with_capacity(graph.node_count());
let mut lines = Vec::new();
for node in topsort {
let new_pos = extend_frontier(node);
for &line_index in &new_pos.line_indices {
if lines.len() <= line_index.as_usize() {
lines.extend(vec![vec![]; line_index.as_usize() - lines.len() + 1]);
}
lines[line_index.as_usize()].push(node);
}
node_to_pos.set(node, new_pos);
}
drop(extend_frontier);
Self {
graph,
node_to_pos,
lines,
get_intervals_scratch_space: RefCell::new(SmallVec::new()),
}
}
#[inline(always)]
pub fn is_node_convex(
&self,
nodes: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>,
) -> bool {
let Some(intervals) = self.get_intervals_from_nodes(nodes) else {
return false;
};
self.is_convex_by_intervals(&intervals)
}
pub fn is_convex_by_intervals(&self, intervals: &LineIntervals) -> bool {
let max_start_pos = intervals
.values()
.map(|LineInterval { min, .. }| min)
.max()
.unwrap();
let mut future_nodes = VecDeque::from_iter(intervals.iter().filter_map(
|(line, LineInterval { max, .. })| {
let ind = self.find_index(line, max).expect("max not on line");
self.lines[line.as_usize()].get(ind + 1).copied()
},
));
let mut visited = BTreeSet::new();
while let Some(node) = future_nodes.pop_front() {
if self.get_position(node) > max_start_pos {
continue;
}
if !visited.insert(node) {
continue; }
for &line in self.get_lines(node) {
if let Some(LineInterval { min, max, .. }) = intervals.get(line) {
let pos = self.get_position(node);
debug_assert!(
pos < min || pos > max,
"node cannot be in interval [min, max]"
);
if pos < min {
return false;
}
}
}
future_nodes.extend(self.graph.output_neighbours(node));
}
true
}
#[must_use]
pub fn try_extend_intervals(
&self,
intervals: &mut LineIntervals,
node: NodeIndex<G::NodeIndexBase>,
) -> bool {
let old_intervals = intervals.clone();
let pos = self.get_position(node);
let lines = self.get_lines(node);
for &line in lines {
if let Some(interval) = intervals.get_mut(line) {
if pos < interval.min {
if self.line_positions_from(pos, line).nth(1) != Some(interval.min) {
*intervals = old_intervals;
return false;
}
interval.min = pos;
} else if pos > interval.max {
if self.line_positions_from(interval.max, line).nth(1) != Some(pos) {
*intervals = old_intervals;
return false;
}
interval.max = pos;
}
} else {
intervals
.0
.push((line, LineInterval { min: pos, max: pos }));
}
}
true
}
#[inline(always)]
pub fn get_lines(&self, node: NodeIndex<G::NodeIndexBase>) -> &[LineIndex] {
&self.node_to_pos.get(node).line_indices
}
#[inline(always)]
pub fn get_position(&self, node: NodeIndex<G::NodeIndexBase>) -> Position {
self.node_to_pos.get(node).position
}
pub fn get_intervals_from_nodes(
&self,
nodes: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>,
) -> Option<LineIntervals> {
let nodes = nodes.into_iter();
let mut line_to_pos = self.get_intervals_scratch_space.borrow_mut();
line_to_pos.clear();
for node in nodes {
let pos = self.get_position(node);
for &l in self.get_lines(node) {
small_map_add(&mut line_to_pos, l, pos);
}
}
let mut intervals = LineIntervals::default();
for &(l, (interval, count)) in line_to_pos.iter() {
let min_index = self.find_index(l, interval.min).expect("min on line");
let max_index = min_index + count - 1;
let &max_node = self.lines[l.as_usize()]
.get(max_index)
.expect("count <= number of nodes in interval");
if self.get_position(max_node) != interval.max {
return None;
}
intervals.0.push((l, interval));
}
Some(intervals)
}
#[inline(always)]
pub fn get_intervals_from_boundary_ports(
&self,
ports: impl IntoIterator<Item = PortIndex<G::PortIndexBase>>,
) -> Option<LineIntervals> {
let boundary = Boundary::from_ports(&self.graph, ports);
self.get_intervals_from_boundary(&boundary)
}
#[inline(always)]
pub fn get_intervals_from_boundary(
&self,
boundary: &Boundary<G::PortIndexBase>,
) -> Option<LineIntervals> {
let nodes = boundary.internal_nodes(&self.graph);
self.get_intervals_from_nodes(nodes)
}
pub fn nodes_in_intervals<'a>(
&'a self,
intervals: &'a LineIntervals,
) -> impl Iterator<Item = NodeIndex<G::NodeIndexBase>> + 'a {
intervals
.iter()
.map(|(line, interval)| self.line_nodes_between(&interval, line))
.kmerge_by(|&n1, &n2| (self.get_position(n1), n1) < (self.get_position(n2), n2))
.dedup()
}
pub fn shrink_to_fit(&mut self) {
let mut line_to_pos = self.get_intervals_scratch_space.borrow_mut();
line_to_pos.shrink_to_fit();
}
pub fn lines_at_port(&self, port: PortIndex<G::PortIndexBase>) -> &[LineIndex] {
let node = self.graph.port_node(port).expect("valid port");
let dir = self.graph.port_direction(port).expect("valid port");
let port_offset = self.graph.port_offset(port).expect("valid port").index();
let mut links_per_port = self
.graph
.ports(node, dir)
.map(|p| self.graph.port_links(p).count());
let start_pos = (&mut links_per_port).take(port_offset).sum::<usize>();
let end_pos = start_pos + links_per_port.next().expect("valid offset");
let lines = self.get_lines(node);
&lines[start_pos..end_pos]
}
#[inline(always)]
fn line_nodes_from(
&self,
start_pos: Position,
line_index: LineIndex,
) -> impl Iterator<Item = NodeIndex<G::NodeIndexBase>> + '_ {
let start = self
.find_index(line_index, start_pos)
.expect("start not on line");
let line = &self.lines[line_index.as_usize()];
line[start..].iter().copied()
}
#[inline(always)]
fn line_nodes_between(
&self,
&LineInterval { min, max }: &LineInterval,
line_index: LineIndex,
) -> impl Iterator<Item = NodeIndex<G::NodeIndexBase>> + '_ {
self.line_nodes_from(min, line_index)
.take_while(move |&n| self.get_position(n) <= max)
}
#[inline(always)]
fn line_positions_from(
&self,
start_pos: Position,
line_index: LineIndex,
) -> impl Iterator<Item = Position> + '_ {
self.line_nodes_from(start_pos, line_index)
.map(|n| self.get_position(n))
}
fn find_index(&self, line: LineIndex, Position(pos): Position) -> Option<usize> {
let line = &self.lines[line.as_usize()];
if line.is_empty() {
return None;
}
let mut low = 0;
let mut high = line.len() - 1;
let Position(low_pos) = self.get_position(line[low]);
let Position(high_pos) = self.get_position(line[high]);
if low_pos == pos {
return Some(low);
} else if high_pos == pos {
return Some(high);
} else if low_pos > pos || high_pos < pos {
return None;
}
loop {
let Position(low_pos) = self.get_position(line[low]);
let Position(high_pos) = self.get_position(line[high]);
let alpha = (pos - low_pos) as f64 / (high_pos - low_pos) as f64;
let mut guess = low + (alpha * (high - low) as f64).round() as usize;
if guess == low {
guess += 1;
} else if guess == high {
guess -= 1;
}
let Position(guess_pos) = self.get_position(line[guess]);
match guess_pos.cmp(&pos) {
Ordering::Equal => return Some(guess),
Ordering::Less => low = guess,
Ordering::Greater => high = guess,
}
}
}
}
#[inline(always)]
fn small_map_add(
small_map: &mut SmallVec<[(LineIndex, LineIntervalWithCount); 8]>,
key: LineIndex,
value: Position,
) {
let new_interval = |line: LineIndex, value: Position| {
let min = value;
let max = value;
(line, (LineInterval { min, max }, 0))
};
let ind = small_map
.iter()
.position(|&(l, _)| l == key)
.unwrap_or_else(|| {
small_map.push(new_interval(key, value));
small_map.len() - 1
});
let (LineInterval { min, max }, count) = &mut small_map[ind].1;
if *min > value {
*min = value;
}
if *max < value {
*max = value;
}
*count += 1;
}
fn extend_line_ends_frontier<G>(
graph: &G,
) -> impl FnMut(NodeIndex<G::NodeIndexBase>) -> LinePositions + '_
where
G: LinkView,
{
let mut frontier: BTreeMap<PortIndex<G::PortIndexBase>, LinePositions> = BTreeMap::new();
fn pop_frontier<P: crate::index::IndexBase>(
frontier: &mut BTreeMap<PortIndex<P>, LinePositions>,
port: PortIndex<P>,
) -> Option<(LineIndex, Position)> {
let positions = frontier.get_mut(&port)?;
let Some(line_index) = positions.line_indices.pop() else {
frontier.remove(&port);
return None;
};
let position = positions.position;
Some((line_index, position))
}
fn push_frontier<P: crate::index::IndexBase>(
frontier: &mut BTreeMap<PortIndex<P>, LinePositions>,
port: PortIndex<P>,
line_index: LineIndex,
position: Position,
) {
let entry = frontier.entry(port).or_default();
entry.line_indices.push(line_index);
entry.position = position;
}
let mut n_lines = 0;
let mut create_new_line = move || {
let new_line = LineIndex(n_lines);
n_lines += 1;
new_line
};
move |node: NodeIndex<G::NodeIndexBase>| {
let prev_outgoing_ports = graph.inputs(node).flat_map(|ip| graph.port_links(ip));
let mut max_pos: Option<Position> = None;
let mut lines = SmallVec::with_capacity(graph.num_inputs(node));
for (_, out_port) in prev_outgoing_ports {
let (line_index, position) =
pop_frontier(&mut frontier, out_port.into()).expect("unknown frontier port");
lines.push(line_index);
max_pos = max_pos.map(|p| p.max(position)).or(Some(position));
}
let position = max_pos.map(|p| p.next()).unwrap_or_default();
let n_in_lines = lines.len();
let mut free_line = 0;
for out_port in graph.outputs(node) {
for _ in 0..graph.port_links(out_port).count() {
if free_line < n_in_lines {
let line_index = lines[free_line];
free_line += 1;
push_frontier(&mut frontier, out_port, line_index, position);
} else {
let new_line = create_new_line();
push_frontier(&mut frontier, out_port, new_line, position);
lines.push(new_line);
}
}
}
if lines.is_empty() {
lines.push(create_new_line());
}
LinePositions {
line_indices: lines,
position,
}
}
}
impl<G: LinkView> ConvexChecker for LineConvexChecker<G> {
type NodeIndexBase = G::NodeIndexBase;
type PortIndexBase = G::PortIndexBase;
fn is_convex(
&self,
nodes: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>,
inputs: impl IntoIterator<Item = PortIndex<G::PortIndexBase>>,
outputs: impl IntoIterator<Item = PortIndex<G::PortIndexBase>>,
) -> bool {
let pre_outputs: BTreeSet<_> = outputs
.into_iter()
.filter_map(|p| Some(self.graph.port_link(p)?.into()))
.collect();
if inputs.into_iter().any(|p| pre_outputs.contains(&p)) {
return false;
}
self.is_node_convex(nodes)
}
}
impl<G: LinkView> CreateConvexChecker<G> for LineConvexChecker<G> {
fn new_convex_checker(graph: G) -> Self {
Self::new(graph)
}
fn graph(&self) -> &G {
&self.graph
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::boundary::HasBoundary;
use crate::view::Subgraph;
use crate::{LinkMut, MultiPortGraph, NodeIndex, PortMut, PortView};
use rstest::{fixture, rstest};
#[fixture]
fn two_lines_ish_graph() -> (MultiPortGraph, [NodeIndex; 11]) {
let mut graph = MultiPortGraph::new();
let nodes: Vec<NodeIndex> = (0..11).map(|_| graph.add_node(1, 1)).collect();
let mut link = |n1, n2| graph.link_nodes(nodes[n1], 0, nodes[n2], 0).unwrap();
link(0, 1);
for i in (2..7).step_by(2) {
link(1, i);
link(i, i + 1);
}
link(8, 9);
link(9, 10);
(graph, nodes.try_into().unwrap())
}
#[fixture]
fn two_lines_merging_graph() -> (MultiPortGraph, [NodeIndex; 3]) {
let mut graph = MultiPortGraph::new();
let nodes: Vec<NodeIndex> = (0..3).map(|_| graph.add_node(1, 1)).collect();
let mut link = |n1, n2| graph.link_nodes(nodes[n1], 0, nodes[n2], 0).unwrap();
link(0, 2);
link(1, 2);
(graph, nodes.try_into().unwrap())
}
#[rstest]
fn test_line_partition(two_lines_ish_graph: (MultiPortGraph, [NodeIndex; 11])) {
let (graph, nodes) = two_lines_ish_graph;
let checker = LineConvexChecker::new(graph);
let node_n_is_at_position = |n: NodeIndex, (line_index, position): (usize, usize)| {
assert_eq!(
checker.get_position(n),
Position(position as u32),
"{n:?} is at position {:?}",
Position(position as u32)
);
assert!(
checker.get_lines(n).contains(&LineIndex(line_index as u32)),
"{n:?} is on line {:?}",
LineIndex(line_index as u32)
);
};
let line0 = vec![nodes[0], nodes[1], nodes[6], nodes[7]];
let line1 = vec![nodes[8], nodes[9], nodes[10]];
let line2 = vec![nodes[1], nodes[4], nodes[5]];
let line3 = vec![nodes[1], nodes[2], nodes[3]];
for (&n, pos) in line0.iter().zip(0..=3) {
node_n_is_at_position(n, (0, pos));
}
for (&n, pos) in line1.iter().zip(0..=2) {
node_n_is_at_position(n, (1, pos));
}
for (&n, pos) in line2.iter().zip(1..=3) {
node_n_is_at_position(n, (2, pos));
}
for (&n, pos) in line3.iter().zip(1..=3) {
node_n_is_at_position(n, (3, pos));
}
assert_eq!(checker.lines, [line0, line1, line2, line3]);
}
#[rstest]
fn test_line_partition_merging(two_lines_merging_graph: (MultiPortGraph, [NodeIndex; 3])) {
let (graph, nodes) = two_lines_merging_graph;
let checker = LineConvexChecker::new(graph);
let line0 = vec![nodes[0], nodes[2]];
let line1 = vec![nodes[1], nodes[2]];
assert_eq!(checker.lines, [line0, line1]);
}
#[rstest]
fn test_try_extend_intervals(two_lines_ish_graph: (MultiPortGraph, [NodeIndex; 11])) {
let (graph, nodes) = two_lines_ish_graph;
let checker = LineConvexChecker::new(graph);
let subgraph = (1..=4).map(|i| nodes[i]);
let intervals = checker.get_intervals_from_nodes(subgraph.clone()).unwrap();
let mut extended_intervals = LineIntervals::default();
for node in subgraph {
assert!(checker.try_extend_intervals(&mut extended_intervals, node));
}
assert_eq!(intervals, extended_intervals);
}
#[test]
fn test_get_intervals_convex() {
let (g, [i1, i2, i3, n1, n2, o1, o2]) = super::super::tests::graph();
let checker = LineConvexChecker::new(g.clone());
let convex_node_sets: &[&[NodeIndex]] = &[
&[i1, i2, i3],
&[i1, n2],
&[i1, n2, o1, n1],
&[i1, n2, o2, n1],
&[i1, i3, n2],
];
for nodes in convex_node_sets {
let mut intervals = checker
.get_intervals_from_nodes(nodes.iter().copied())
.unwrap();
let subgraph = Subgraph::with_nodes(&g, nodes.iter().copied());
let boundary = subgraph.port_boundary();
let mut intervals2 = checker.get_intervals_from_boundary(&boundary).unwrap();
intervals.0.sort_by_key(|&(l, _)| l);
intervals2.0.sort_by_key(|&(l, _)| l);
assert_eq!(intervals, intervals2);
}
}
#[test]
fn test_nodes_in_intervals() {
let (g, [i1, i2, i3, n1, n2, o1, o2]) = super::super::tests::graph();
let checker = LineConvexChecker::new(g.clone());
let convex_node_sets: &[&[NodeIndex]] = &[
&[i1, i2, i3],
&[i1, n2],
&[i1, n2, o1, n1],
&[i1, n2, o2, n1],
&[i1, i3, n2],
];
for nodes in convex_node_sets {
let intervals = checker
.get_intervals_from_nodes(nodes.iter().copied())
.unwrap();
let nodes_in_intervals = checker.nodes_in_intervals(&intervals);
let nodes_sorted = {
let mut nodes_sorted = nodes.to_vec();
nodes_sorted.sort_by_key(|&n| (checker.get_position(n), n));
nodes_sorted
};
assert_eq!(nodes_in_intervals.collect_vec(), nodes_sorted);
}
}
#[test]
fn test_lines_at_port() {
let (g, nodes) = super::super::tests::graph();
let checker = LineConvexChecker::new(g.clone());
for n in nodes {
let lines = checker.get_lines(n);
for dir in Direction::BOTH {
for (i, p) in g.ports(n, dir).enumerate() {
let &line_at_port = checker.lines_at_port(p).iter().exactly_one().unwrap();
assert_eq!(line_at_port, lines[i]);
}
}
}
}
#[test]
fn test_lines_at_port_multigraph() {
const NUM_LINKS: usize = 3;
let mut g: MultiPortGraph = MultiPortGraph::new();
let out = g.add_node(0, 1);
let in_ = g.add_node(1, 0);
for _ in 0..NUM_LINKS {
g.link_nodes(out, 0, in_, 0).unwrap();
}
let checker = LineConvexChecker::new(g.clone());
for n in [out, in_] {
let lines = checker.get_lines(n);
assert_eq!(lines.len(), NUM_LINKS);
let p = g.all_ports(n).exactly_one().ok().unwrap();
let lines_at_port = checker.lines_at_port(p);
assert_eq!(lines_at_port, lines);
}
}
}