use fixedbitset::FixedBitSet;
use hashbrown::HashSet;
use itertools::Itertools;
use std::{hash::Hash, iter, usize};
use crate::prelude::*;
pub struct PartiteHypergraph<N, E, const K: usize> {
inner: UndirectedHypergraph<N, E>,
partition_map: Vec<FixedBitSet>,
}
impl_graph_basics_wrapper!(
PartiteHypergraph<N, E, K>,
UndirectedHypergraph<N, E>,
false,
|const K: usize|
);
impl_weights_wrapper!(
PartiteHypergraph<N, E, K>,
|const K: usize|
);
impl<N, E, const K: usize> PartiteHypergraph<N, E, K>
where
N: Clone + Eq + Hash,
E: Clone + Eq + Hash,
{
pub fn new() -> Self {
Self {
inner: UndirectedHypergraph::new(),
partition_map: vec![],
}
}
pub fn add_node(&mut self, weight: N, partition: usize) -> usize {
assert!(partition < K);
let mut one_shot = FixedBitSet::with_capacity(K);
unsafe { one_shot.insert_unchecked(partition) };
self.partition_map.push(one_shot);
self.inner.add_node(weight)
}
pub fn add_edge(&mut self, weight: E, indices: Vec<usize>) -> Result<usize, HypergraphErrors> {
if indices.len() < 2 {
return Err(HypergraphErrors::EdgeTooSmall);
}
if !indices.iter().tuple_combinations().all(|(a, b)| {
self.partition_map[*a]
.intersection(&self.partition_map[*b])
.count()
== 0
}) {
return Err(HypergraphErrors::InvariantViolation {
err: "Nodes that share a partition".to_string(),
});
}
return self.inner.add_edge(weight, indices);
}
pub fn add_nodes(&mut self, weights: impl Iterator<Item = (N, usize)>) {
weights.for_each(|(x, p)| {
self.add_node(x, p);
});
}
pub fn add_edges(
&mut self,
edges: impl Iterator<Item = (E, Vec<usize>)>,
) -> Result<(), HypergraphErrors> {
for (w, s) in edges {
self.add_edge(w, s)?;
}
Ok(())
}
pub fn remove_node(&mut self, node_index: usize) -> Option<UndirectedNode<N>> {
self.inner.remove_node(node_index)
}
pub fn remove_nodes(&mut self, node_indices: Vec<usize>) -> Vec<UndirectedNode<N>> {
self.inner.remove_nodes(node_indices)
}
pub fn remove_edge(&mut self, edge_index: usize) -> Option<UndirectedEdge<E>> {
self.inner.remove_edge(edge_index)
}
pub fn remove_edges(&mut self, edge_indices: Vec<usize>) -> Vec<UndirectedEdge<E>> {
self.inner.remove_edges(edge_indices)
}
pub fn get_neighbours(&self, node_index: usize) -> Option<HashSet<&usize>> {
self.inner.get_neighbours(node_index)
}
pub fn get_edges(&self, node_index: usize) -> Option<&Vec<usize>> {
self.inner.get_incident_edges(node_index)
}
}
fn all_the_bits<const K: usize>() -> FixedBitSet {
FixedBitSet::with_capacity_and_blocks(
K,
iter::repeat_n(usize::MAX, 1 + K / usize::BITS as usize).collect::<Vec<_>>(),
)
}
pub struct UniformPartiteHypergraph<N, E, const K: usize, const ORDER: usize> {
inner: undirected_uniform::UniformHypergraph<N, E, ORDER>,
partition_map: Vec<FixedBitSet>,
}
impl_graph_basics_wrapper!(
UniformPartiteHypergraph<N, E, K, ORDER>,
undirected_uniform::UniformHypergraph<N, E, ORDER>,
false,
|const K: usize, const ORDER: usize|
);
impl<N, E, const K: usize, const ORDER: usize> UniformPartiteHypergraph<N, E, K, ORDER>
where
N: Clone + Eq + Hash,
E: Clone + Eq + Hash,
{
pub fn new() -> Self {
Self {
inner: undirected_uniform::UniformHypergraph::new(),
partition_map: vec![],
}
}
pub fn add_node(&mut self, weight: N, partition: usize) -> usize {
assert!(partition < K);
let mut one_shot = FixedBitSet::with_capacity(K);
unsafe { one_shot.insert_unchecked(partition) };
self.partition_map.push(one_shot);
self.inner.add_node(weight)
}
pub fn add_edge(
&mut self,
weight: E,
indices: [usize; ORDER],
) -> Result<usize, HypergraphErrors> {
if indices.len() < 2 {
return Err(HypergraphErrors::EdgeTooSmall);
}
let oneh_intersection = indices
.iter()
.map(|x| &self.partition_map[*x])
.fold(all_the_bits::<K>(), |acc, x| acc.intersection(&x).collect());
if !oneh_intersection.is_clear() {
return Err(HypergraphErrors::InvariantViolation {
err: "Nodes that share a partition".to_string(),
});
}
return self.inner.add_edge(weight, indices);
}
pub fn add_nodes(&mut self, weights: impl Iterator<Item = (N, usize)>) {
weights.for_each(|(x, p)| {
self.add_node(x, p);
});
}
pub fn add_edges(
&mut self,
edges: impl Iterator<Item = (E, [usize; ORDER])>,
) -> Result<(), HypergraphErrors> {
for (w, s) in edges {
self.add_edge(w, s)?;
}
Ok(())
}
pub fn remove_node(&mut self, node_index: usize) -> HypergraphResult<UndirectedNode<N>> {
self.inner.remove_node(node_index)
}
pub fn remove_nodes(&mut self, node_indices: Vec<usize>) -> Vec<UndirectedNode<N>> {
self.inner.remove_nodes(node_indices)
}
pub fn remove_edge(
&mut self,
edge_index: usize,
) -> HypergraphResult<undirected_uniform::UniformEdge<E, ORDER>> {
self.inner.remove_edge(edge_index)
}
pub fn remove_edges(
&mut self,
edge_indices: Vec<usize>,
) -> Vec<undirected_uniform::UniformEdge<E, ORDER>> {
self.inner.remove_edges(edge_indices)
}
pub fn get_neighbours(&self, node_index: usize) -> Option<HashSet<&usize>> {
self.inner.get_neighbours(node_index)
}
pub fn get_edges(&self, node_index: usize) -> Option<&Vec<usize>> {
self.inner.get_incident_edges(node_index)
}
}