use crate::edge_id;
use paste::paste;
use std::{collections::HashMap, fmt::Debug};
macro_rules! impl_prim {
($node_bits:ty, $node_id:ty, $num:expr) => {
paste! {
#[doc = "Graph implementation using `" $node_bits "` as the node bits storage."]
#[doc = "Number of nodes must be equal or lower than " $num "."]
#[doc = $num "."]
#[derive(Debug, Clone)]
pub struct [< Graph $num >] {
pub nodes: [<Nodes $num>],
pub edges: HashMap<($node_id, $node_id), $node_bits>,
}
impl [< Graph $num >] {
#[doc = "Create a new Graph" $num " with the given number of nodes."]
#[doc = "Number of nodes must be equal or lower than " $num "."]
#[doc = $num "."]
pub fn builder(nodes_len: usize) -> [<Graph $num Builder>] {
debug_assert!(nodes_len <= $num, "Number of nodes must be equal or lower than {}", $num);
[<Graph $num Builder>]::new(nodes_len.min($num))
}
pub fn into_builder(self) -> [<Graph $num Builder>] {
[<Graph $num Builder>] {
nodes: self.nodes,
edge_masks: [<Edges $num>] { inner: self.edges.iter().map(|(k, _)| (*k, 0)).collect() },
edges: [<Edges $num>] { inner: self.edges },
}
}
#[inline]
pub fn neighbor_to(&self, curr: $node_id, dest: $node_id) -> Option<$node_id> {
self.neighbors_to(curr, dest).next()
}
#[inline]
pub fn neighbor_to_with(
&self,
curr: $node_id,
dest: $node_id,
f: impl Fn($node_id) -> bool,
) -> Option<$node_id> {
self.neighbors_to(curr, dest).find(|&n| f(n))
}
#[inline]
pub fn neighbors_to(&self, curr: $node_id, dest: $node_id) -> [<NextNodesIter $num>]<'_> {
[<NextNodesIter $num>] {
graph: self,
neighbors: self.nodes.neighbors(curr),
curr,
dest,
}
}
#[inline]
pub fn path_to(&self, curr: $node_id, dest: $node_id) -> [<PathIter $num>]<'_> {
[<PathIter $num>] {
map: self,
curr,
dest,
init: false,
}
}
#[inline]
pub fn path_exists(&self, curr: $node_id, dest: $node_id) -> bool {
self.neighbor_to(curr, dest).is_some()
}
#[inline]
pub fn neighbors(&self, node: $node_id) -> impl Iterator<Item = $node_id> + '_ {
self.nodes.neighbors(node)
}
#[inline]
pub fn nodes_len(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn edges_len(&self) -> usize {
self.edges.len()
}
}
#[derive(Debug)]
pub struct [<PathIter $num>]<'a> {
map: &'a [<Graph $num>],
curr: $node_id,
dest: $node_id,
init: bool,
}
impl Iterator for [<PathIter $num>]<'_> {
type Item = $node_id;
fn next(&mut self) -> Option<Self::Item> {
if !self.init {
self.init = true;
return Some(self.curr);
}
let Some(next) = self.map.neighbor_to(self.curr, self.dest) else {
return None;
};
self.curr = next;
Some(next)
}
}
#[derive(Debug)]
pub struct [<NextNodesIter $num>]<'a> {
graph: &'a [<Graph $num>],
curr: $node_id,
dest: $node_id,
neighbors: [<NodeBits $num Iter>],
}
impl Iterator for [<NextNodesIter $num>]<'_> {
type Item = $node_id;
fn next(&mut self) -> Option<Self::Item> {
if self.curr == self.dest {
return None;
}
while let Some(neighbor) = self.neighbors.next() {
let bit = self.graph.edges.get(&edge_id(self.curr, neighbor))? & 1 << self.dest > 0;
let bit = if self.curr > neighbor { !bit } else { bit };
if bit {
return Some(neighbor);
}
}
None
}
}
#[doc = "Builder for [Graph" $num "]"]
#[derive(Debug, Clone)]
pub struct [<Graph $num Builder>] {
pub nodes: [<Nodes $num>],
pub edges: [<Edges $num>],
pub edge_masks: [<Edges $num>],
}
impl [<Graph $num Builder>] {
#[doc = "Create a new [Graph" $num "] with the given number of nodes."]
#[doc = "Number of nodes must be equal or lower than " $num "."]
#[doc = $num "."]
pub fn new(nodes_len: usize) -> Self {
Self {
nodes: [<Nodes $num>]::new(nodes_len),
edges: [<Edges $num>]::new(),
edge_masks: [<Edges $num>]::new(),
}
}
pub fn resize(&mut self, new_len: u8) {
let should_truncate = new_len < self.nodes.len() as u8;
self.nodes.resize(new_len as usize);
if should_truncate {
self.edges.truncate(new_len);
self.edge_masks.truncate(new_len);
}
}
pub fn connect(&mut self, a: $node_id, b: $node_id) {
if !self.nodes.connect(a, b) {
return;
}
let a_bit = 1 << a;
let b_bit = 1 << b;
let mut val = b_bit;
if a > b {
val = a_bit;
}
let ab = edge_id(a, b);
self.edges.insert(ab, val);
self.edge_masks.insert(ab, a_bit | b_bit);
}
pub fn disconnect(&mut self, a: $node_id, b: $node_id) {
if self.nodes.disconnect(a, b) {
return;
}
let ab = edge_id(a, b);
if self.edges.inner.remove(&ab).is_some() {
self.edge_masks.inner.remove(&ab);
}
}
#[doc = "and returns [Graph" $num "]."]
pub fn build(self) -> [< Graph $num >] {
let Self {
nodes,
mut edges,
mut edge_masks,
} = self;
let mut neighbors_at_depth: Vec<($node_bits, $node_bits)> =
nodes.inner.iter().enumerate().map(|(i, e)| (*e, 1 << i)).collect();
let mut active_neighbors_mask: $node_bits = 0;
let mut done_mask: $node_bits = 0;
let mut upserts: Vec<($node_bits, $node_bits, $node_bits)> = Vec::new();
let last_node_bit = 1 << (nodes.inner.len() - 1);
let full_mask: $node_bits = last_node_bit | (last_node_bit - 1);
for (a, a_neighbors) in &nodes {
let a_neighbors_len = a_neighbors.len() as usize;
upserts.fill((0, 0, 0));
if upserts.len() < a_neighbors_len {
upserts.resize(a_neighbors_len, (0, 0, 0));
}
for (i, b) in a_neighbors.enumerate() {
let b_bit = 1 << b;
let mut val = b_bit;
if a > b {
val = 0;
}
for (j, c) in a_neighbors.clone().enumerate() {
if i == j {
continue;
}
let upsert = if (a > b) == (a > c) {
!val & b_bit
} else {
val & b_bit
};
let vals = &mut upserts[j];
vals.0 |= upsert;
vals.1 |= b_bit;
}
}
for (i, b) in a_neighbors.enumerate() {
let ab = edge_id(a, b);
let (upsert, computed, _) = upserts[i];
if computed != 0 {
if upsert != 0 {
edges.insert(ab, upsert);
}
edge_masks.insert(ab, computed);
}
}
}
'outer: while done_mask != full_mask {
for a in [<node_bits_ $num _iter>](full_mask ^ done_mask) {
let a_bit = 1 << a;
let a_neighbors = nodes.neighbors(a);
let a_neighbors_len = a_neighbors.len() as usize;
upserts.fill((0, 0, 0));
if upserts.len() < a_neighbors_len {
upserts.resize(a_neighbors_len, (0, 0, 0));
}
let mut a_active_neighbors_mask = 0;
let mut all_edges_done = true;
for (i, b) in a_neighbors.enumerate() {
let mask = edge_masks.get(edge_id(a, b)).unwrap();
upserts[i].2 = mask;
if mask != full_mask {
all_edges_done = false;
}
}
if all_edges_done {
done_mask |= a_bit;
continue;
}
for (i, b) in a_neighbors.enumerate() {
let neighbors_mask = neighbors_at_depth.get(b as usize).unwrap().0 & !a_bit;
if neighbors_mask == 0 {
continue;
}
a_active_neighbors_mask |= 1 << b;
let ab = edge_id(a, b);
let val = edges.get(ab).unwrap();
for (j, c) in a_neighbors.enumerate() {
if i == j {
continue;
}
let mask_ac = upserts[j].2;
if mask_ac == full_mask {
continue;
}
all_edges_done = false;
let compute_mask = neighbors_mask & !mask_ac;
if compute_mask == 0 {
continue;
}
let upsert = if (a > b) == (a > c) { !val } else { val } & compute_mask;
let vals = &mut upserts[j];
vals.0 |= upsert;
vals.1 |= compute_mask;
}
}
if all_edges_done || a_active_neighbors_mask == 0 {
done_mask |= a_bit;
} else {
for (i, b) in a_neighbors.enumerate() {
let ab = edge_id(a, b);
let (upsert, computed, _) = upserts[i];
if computed != 0 {
if upsert != 0 {
edges.insert(ab, upsert);
}
edge_masks.insert(ab, computed);
}
}
}
if done_mask == full_mask {
break 'outer;
}
active_neighbors_mask |= a_active_neighbors_mask;
}
for a in [<node_bits_ $num _iter>](active_neighbors_mask) {
let a_usize = a as usize;
let (a_neighbors_at_depth, mut prev_neighbors) = neighbors_at_depth[a_usize];
if a_neighbors_at_depth == 0 {
continue;
}
let mut new_neighbors = 0;
for b in [<node_bits_ $num _iter>](a_neighbors_at_depth) {
new_neighbors |= nodes.neighbors(b).node_bits;
}
prev_neighbors |= a_neighbors_at_depth;
new_neighbors &= !prev_neighbors;
neighbors_at_depth[a_usize] = (new_neighbors, prev_neighbors);
}
active_neighbors_mask = 0;
}
[< Graph $num >] {
nodes,
edges: edges.inner,
}
}
}
#[doc = "value: " $node_bits " with neighbors' bit locations set to `true`"]
#[derive(Debug, Clone)]
pub struct [<Nodes $num>] {
pub inner: Vec<$node_bits>,
}
impl [<Nodes $num>] {
pub fn new(nodes_len: usize) -> Self {
Self {
inner: vec![0; nodes_len],
}
}
#[inline]
pub fn neighbors(&self, node: $node_id) -> [<NodeBits $num Iter>] {
[<node_bits_ $num _iter>](self.inner[node as usize])
}
pub fn connect(&mut self, a: $node_id, b: $node_id) -> bool {
if a == b {
return false;
}
let b_bit = 1 << b;
self.inner[a as usize] |= b_bit;
self.inner[b as usize] |= 1 << a;
true
}
pub fn disconnect(&mut self, a: $node_id, b: $node_id) -> bool {
if a == b {
return false;
}
let b_bit = 1 << b;
self.inner[a as usize] &= !b_bit;
self.inner[b as usize] &= !(1 << a);
true
}
#[inline]
pub fn edge_count(&self, node: $node_id) -> u32 {
self.inner[node as usize].count_ones()
}
#[inline]
pub fn len(&self) -> usize {
self.inner.len()
}
#[inline]
pub fn resize(&mut self, new_len: usize) {
self.inner.resize(new_len, 0);
}
}
#[derive(Debug, Clone)]
pub struct [<Edges $num>] {
inner: HashMap<($node_id, $node_id), $node_bits>,
}
impl [<Edges $num>] {
fn new() -> Self {
Self {
inner: HashMap::new(),
}
}
#[inline]
pub fn get(&self, edge_id: ($node_id, $node_id)) -> Option<$node_bits> {
self.inner.get(&edge_id).cloned()
}
#[inline]
pub fn insert(&mut self, edge_id: ($node_id, $node_id), val: $node_bits) {
if let Some(edge) = self.inner.get_mut(&edge_id) {
*edge |= val;
} else {
self.inner.insert(edge_id, val);
}
}
pub fn truncate(&mut self, nodes_len: u8) {
let keys_to_remove = self
.inner
.keys()
.filter(|&(a, b)| *a >= nodes_len || *b >= nodes_len)
.cloned()
.collect::<Vec<_>>();
for key in keys_to_remove {
self.inner.remove(&key);
}
for edge in self.inner.values_mut() {
*edge &= (1 << nodes_len) - 1;
}
}
}
impl<'a> IntoIterator for &'a [<Nodes $num>] {
type Item = ($node_id, [<NodeBits $num Iter>]);
type IntoIter = [<Neighbors $num Iter>]<'a>;
fn into_iter(self) -> Self::IntoIter {
[<Neighbors $num Iter>] {
neighbors: self,
node: 0,
}
}
}
pub struct [<Neighbors $num Iter>]<'a> {
neighbors: &'a [<Nodes $num>],
node: $node_id,
}
impl<'a> Iterator for [<Neighbors $num Iter>]<'a> {
type Item = ($node_id, [<NodeBits $num Iter>]);
fn next(&mut self) -> Option<Self::Item> {
let node = self.node;
if node as usize >= self.neighbors.len() {
return None;
}
self.node += 1;
self.neighbors
.inner
.get(node as usize)
.map(|connected| (node, [<node_bits_ $num _iter>](*connected)))
}
}
fn [<node_bits_ $num _iter>](node_bits: $node_bits) -> [<NodeBits $num Iter>] {
[<NodeBits $num Iter>] { node_bits }
}
#[derive(Clone, Copy)]
pub struct [<NodeBits $num Iter>] {
node_bits: $node_bits,
}
impl Debug for [<NodeBits $num Iter>] {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:016b}", self.node_bits)
}
}
impl [<NodeBits $num Iter>] {
pub fn without(self, node: $node_id) -> Self {
Self {
node_bits: self.node_bits & !(1 << node),
}
}
#[inline]
pub fn len(&self) -> u32 {
self.node_bits.count_ones()
}
}
impl Iterator for [<NodeBits $num Iter>] {
type Item = $node_id;
fn next(&mut self) -> Option<Self::Item> {
if self.node_bits == 0 {
return None;
}
let node = self.node_bits.trailing_zeros();
self.node_bits &= !(1 << node);
Some(node as $node_id)
}
}
}
};
}
impl_prim!(u16, u8, 16);
impl_prim!(u32, u8, 32);
impl_prim!(u64, u8, 64);
impl_prim!(u128, u8, 128);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_16() {
pub const NODES_X_LEN: usize = 4;
pub const NODES_Y_LEN: usize = 4;
pub const NODES_LEN: usize = NODES_X_LEN * NODES_Y_LEN;
let mut builder = Graph16Builder::new(NODES_LEN);
for y in 0..NODES_Y_LEN {
for x in 0..NODES_X_LEN {
let node_id = y * NODES_X_LEN + x;
if x > 0 {
let a = (node_id - 1) as u8;
let b = node_id as u8;
builder.connect(a, b);
}
if y > 0 {
let a = node_id as u8;
let b = (node_id - NODES_X_LEN) as u8;
builder.connect(a, b);
}
}
}
let now = std::time::Instant::now();
let _graph = builder.build();
println!("Time: {:?}", now.elapsed());
}
#[test]
fn test_graph_32() {
pub const NODES_X_LEN: usize = 4;
pub const NODES_Y_LEN: usize = 8;
pub const NODES_LEN: usize = NODES_X_LEN * NODES_Y_LEN;
let mut builder = Graph32Builder::new(NODES_LEN);
for y in 0..NODES_Y_LEN {
for x in 0..NODES_X_LEN {
let node_id = y * NODES_X_LEN + x;
if x > 0 {
let a = (node_id - 1) as u8;
let b = node_id as u8;
builder.connect(a, b);
}
if y > 0 {
let a = node_id as u8;
let b = (node_id - NODES_X_LEN) as u8;
builder.connect(a, b);
}
}
}
let now = std::time::Instant::now();
let _graph = builder.build();
println!("Time: {:?}", now.elapsed());
}
#[test]
fn test_graph_64() {
pub const NODES_X_LEN: usize = 8;
pub const NODES_Y_LEN: usize = 8;
pub const NODES_LEN: usize = NODES_X_LEN * NODES_Y_LEN;
let mut builder = Graph64Builder::new(NODES_LEN);
for y in 0..NODES_Y_LEN {
for x in 0..NODES_X_LEN {
let node_id = y * NODES_X_LEN + x;
if x > 0 {
let a = (node_id - 1) as u8;
let b = node_id as u8;
builder.connect(a, b);
}
if y > 0 {
let a = node_id as u8;
let b = (node_id - NODES_X_LEN) as u8;
builder.connect(a, b);
}
}
}
let now = std::time::Instant::now();
let _graph = builder.build();
println!("Time: {:?}", now.elapsed());
}
#[test]
fn test_graph_128() {
pub const NODES_X_LEN: usize = 8;
pub const NODES_Y_LEN: usize = 16;
pub const NODES_LEN: usize = NODES_X_LEN * NODES_Y_LEN;
let mut builder = Graph128Builder::new(NODES_LEN);
for y in 0..NODES_Y_LEN {
for x in 0..NODES_X_LEN {
let node_id = y * NODES_X_LEN + x;
if x > 0 {
let a = (node_id - 1) as u8;
let b = node_id as u8;
builder.connect(a, b);
}
if y > 0 {
let a = node_id as u8;
let b = (node_id - NODES_X_LEN) as u8;
builder.connect(a, b);
}
}
}
let now = std::time::Instant::now();
let _graph = builder.build();
println!("Time: {:?}", now.elapsed());
}
}