use std::borrow::{Borrow, BorrowMut};
use std::ops::Range;
use bytemuck::{AnyBitPattern, TransparentWrapper};
use rapidhash::RapidHashMap;
use rayon::prelude::*;
use sux::bits::BitVec;
use sux::traits::{BitVecOps, BitVecOpsMut};
use crate::NodeId;
pub trait Labels {
type Label: ToOwned + ?Sized;
type Config;
fn new(num_nodes: usize, config: Self::Config) -> Self;
fn insert(
&mut self,
node: NodeId,
label: <Self::Label as ToOwned>::Owned,
) -> Option<<Self::Label as ToOwned>::Owned>;
fn get(&self, node: NodeId) -> Option<&Self::Label>;
fn contains_key(&self, node: NodeId) -> bool;
fn remove(&mut self, node: NodeId) -> Option<<Self::Label as ToOwned>::Owned>;
fn is_empty(&self) -> bool;
}
pub struct SparseLabels<Label: ToOwned + Sized> {
labels: RapidHashMap<NodeId, <Label as ToOwned>::Owned>,
}
impl<Label: ToOwned> Labels for SparseLabels<Label> {
type Label = Label;
type Config = ();
fn new(_num_nodes: usize, _config: Self::Config) -> Self {
Self {
labels: RapidHashMap::default(),
}
}
fn insert(
&mut self,
node: NodeId,
label: <Self::Label as ToOwned>::Owned,
) -> Option<<Self::Label as ToOwned>::Owned> {
self.labels.insert(node, label)
}
fn get(&self, node: NodeId) -> Option<&Self::Label> {
self.labels.get(&node).map(Borrow::borrow)
}
fn contains_key(&self, node: NodeId) -> bool {
self.labels.contains_key(&node)
}
fn remove(&mut self, node: NodeId) -> Option<<Self::Label as ToOwned>::Owned> {
self.labels.remove(&node)
}
fn is_empty(&self) -> bool {
self.labels.is_empty()
}
}
impl<Label: ToOwned> SparseLabels<Label> {
pub fn iter_labeled(&self) -> impl Iterator<Item = (NodeId, &Label)> {
self.labels
.iter()
.map(|(&node, label)| (node, label.borrow()))
}
pub fn into_iter_labeled(self) -> impl Iterator<Item = (NodeId, <Label as ToOwned>::Owned)> {
self.labels.into_iter()
}
}
pub struct DenseLabels<Label: Sized> {
labels: Box<[Label]>,
is_set: BitVec,
len: usize,
}
impl<Label: Default + Clone> Labels for DenseLabels<Label> {
type Label = Label;
type Config = ();
fn new(num_nodes: usize, _config: Self::Config) -> Self {
Self {
labels: vec![Label::default(); num_nodes].into(),
is_set: BitVec::new(num_nodes),
len: 0,
}
}
fn insert(
&mut self,
node: NodeId,
mut label: <Self::Label as ToOwned>::Owned,
) -> Option<Self::Label> {
let was_set = self.contains_key(node);
std::mem::swap(&mut self.labels[node], &mut label);
self.is_set.set(node, true);
if was_set {
Some(label)
} else {
self.len += 1;
None
}
}
fn get(&self, node: NodeId) -> Option<&Self::Label> {
if self.contains_key(node) {
Some(&self.labels[node])
} else {
None
}
}
fn contains_key(&self, node: NodeId) -> bool {
self.is_set.get(node)
}
fn remove(&mut self, node: NodeId) -> Option<Self::Label> {
if self.contains_key(node) {
self.len -= 1;
self.is_set.set(node, false);
let mut label = Label::default();
std::mem::swap(self.labels.get_mut(node)?, &mut label);
Some(label)
} else {
None
}
}
fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<Label: Default + Clone> DenseLabels<Label> {
pub fn iter(&self) -> impl Iterator<Item = Option<&Label>> {
self.is_set
.iter()
.zip(self.labels.iter())
.map(|(is_set, value)| if is_set { Some(value) } else { None })
}
pub fn iter_labeled(&self) -> impl Iterator<Item = (NodeId, &Label)> {
self.is_set
.iter()
.zip(self.labels.iter())
.enumerate()
.filter_map(|(node, (is_set, value))| if is_set { Some((node, value)) } else { None })
}
pub fn par_iter(&self) -> impl ParallelIterator<Item = Option<&Label>>
where
Label: Sync,
{
self.labels
.par_iter()
.enumerate()
.map(move |(node, value)| {
if self.is_set.get(node) {
Some(value)
} else {
None
}
})
}
pub fn par_iter_labeled(&self) -> impl ParallelIterator<Item = (NodeId, &Label)>
where
Label: Sync,
{
self.labels
.par_iter()
.enumerate()
.filter_map(move |(node, value)| {
if self.is_set.get(node) {
Some((node, value))
} else {
None
}
})
}
pub fn into_par_iter(self) -> impl ParallelIterator<Item = Option<Label>>
where
Label: Send,
{
let Self { labels, is_set, .. } = self;
Vec::from(labels)
.into_par_iter()
.enumerate()
.map(
move |(node, value)| {
if is_set.get(node) { Some(value) } else { None }
},
)
}
pub fn into_par_iter_labeled(&self) -> impl ParallelIterator<Item = (NodeId, &Label)>
where
Label: Sync,
{
let Self { labels, is_set, .. } = self;
labels
.into_par_iter()
.enumerate()
.filter_map(move |(node, value)| {
if is_set.get(node) {
Some((node, value))
} else {
None
}
})
}
}
pub trait StridableLabel: ToOwned<Owned: BorrowMut<Self>> {
type Word: Default + Copy;
fn from_stride(stride: &[Self::Word]) -> &Self;
fn swap_with_stride(&mut self, stride: &mut [Self::Word]);
}
#[derive(Debug, PartialEq, Eq, TransparentWrapper, derive_more::AsRef, derive_more::AsMut)]
#[repr(transparent)]
pub struct SliceLabel<Word: AnyBitPattern>(pub [Word]);
impl<Word: AnyBitPattern> ToOwned for SliceLabel<Word> {
type Owned = BoxLabel<Word>;
fn to_owned(&self) -> Self::Owned {
BoxLabel(self.0.to_owned().into())
}
}
#[derive(Debug, PartialEq, Eq, Default, TransparentWrapper)]
#[repr(transparent)]
pub struct BoxLabel<Word: AnyBitPattern>(pub Box<[Word]>);
impl<Word: AnyBitPattern> Borrow<SliceLabel<Word>> for BoxLabel<Word> {
fn borrow(&self) -> &SliceLabel<Word> {
SliceLabel::wrap_ref(Self::peel_ref(self))
}
}
impl<Word: AnyBitPattern> BorrowMut<SliceLabel<Word>> for BoxLabel<Word> {
fn borrow_mut(&mut self) -> &mut SliceLabel<Word> {
SliceLabel::wrap_mut(Self::peel_mut(self))
}
}
impl<Word: AnyBitPattern + Default> StridableLabel for SliceLabel<Word> {
type Word = Word;
fn from_stride(stride: &[Self::Word]) -> &Self {
Self::wrap_ref(stride)
}
fn swap_with_stride(&mut self, stride: &mut [Self::Word]) {
stride.swap_with_slice(TransparentWrapper::peel_mut(self))
}
}
pub struct StriddenLabelsConfig {
pub num_words: usize,
}
pub struct StriddenLabels<Label: StridableLabel + ?Sized> {
labels: Box<[Label::Word]>,
num_words: usize,
is_set: BitVec,
len: usize,
}
impl<Label: StridableLabel + ?Sized> Labels for StriddenLabels<Label> {
type Label = Label;
type Config = StriddenLabelsConfig;
fn new(num_nodes: usize, StriddenLabelsConfig { num_words }: Self::Config) -> Self {
Self {
labels: vec![Label::Word::default(); num_nodes * num_words].into(),
num_words,
is_set: BitVec::new(num_nodes),
len: 0,
}
}
fn insert(
&mut self,
node: NodeId,
mut label: <Self::Label as ToOwned>::Owned,
) -> Option<<Self::Label as ToOwned>::Owned> {
let was_set = self.contains_key(node);
let range = self.stride_range(node);
label.borrow_mut().swap_with_stride(&mut self.labels[range]);
self.is_set.set(node, true);
if was_set {
Some(label)
} else {
self.len += 1;
None
}
}
fn get(&self, node: NodeId) -> Option<&Self::Label> {
if self.contains_key(node) {
Some(Self::Label::from_stride(
&self.labels[self.stride_range(node)],
))
} else {
None
}
}
fn contains_key(&self, node: NodeId) -> bool {
self.is_set.get(node)
}
fn remove(&mut self, node: NodeId) -> Option<<Self::Label as ToOwned>::Owned> {
if self.contains_key(node) {
self.len -= 1;
self.is_set.set(node, false);
let range = self.stride_range(node);
let label = Label::from_stride(&self.labels[range.clone()]).to_owned();
self.labels[range].fill(Label::Word::default());
Some(label)
} else {
None
}
}
fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<Label: StridableLabel + ?Sized> StriddenLabels<Label> {
fn stride_range(&self, node: NodeId) -> Range<usize> {
(node * self.num_words)..(node + 1) * self.num_words
}
}
impl<Label: StridableLabel + ?Sized> StriddenLabels<Label> {
pub fn iter(&self) -> impl Iterator<Item = Option<&Label>> {
self.is_set
.iter()
.zip(self.labels.chunks(self.num_words))
.map(|(is_set, value)| {
if is_set {
Some(Label::from_stride(value))
} else {
None
}
})
}
pub fn iter_labeled(&self) -> impl Iterator<Item = (NodeId, &Label)> {
self.is_set
.iter()
.zip(self.labels.chunks(self.num_words))
.enumerate()
.filter_map(|(node, (is_set, value))| {
if is_set {
Some((node, Label::from_stride(value)))
} else {
None
}
})
}
pub fn par_iter(&self) -> impl ParallelIterator<Item = Option<&Label>>
where
Label: StridableLabel<Word: Sync> + Sync,
{
self.labels
.par_chunks(self.num_words)
.enumerate()
.map(move |(node, value)| {
if self.is_set.get(node) {
Some(Label::from_stride(value))
} else {
None
}
})
}
pub fn par_iter_labeled(&self) -> impl ParallelIterator<Item = (NodeId, &Label)>
where
Label: StridableLabel<Word: Sync> + Sync,
{
self.labels
.par_chunks(self.num_words)
.enumerate()
.filter_map(move |(node, value)| {
if self.is_set.get(node) {
Some((node, Label::from_stride(value)))
} else {
None
}
})
}
}