#![cfg_attr(not(test), forbid(unsafe_code))]
#![warn(
clippy::cargo,
clippy::pedantic,
clippy::nursery,
clippy::missing_docs_in_private_items
)]
#![deny(missing_docs)]
#[cfg(feature = "rayon")]
use rayon_crate as rayon;
use std::cmp::Ordering;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::io::Cursor;
use std::iter::FromIterator;
use std::num::NonZeroUsize;
use std::ops::BitAnd;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
#[cfg(feature = "rayon")]
use rayon::iter::{
FromParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelExtend,
ParallelIterator,
};
#[cfg(feature = "rayon")]
use rayon::slice::ParallelSliceMut;
use dashmap::mapref::multiple::RefMulti;
use dashmap::mapref::one::RefMut;
use dashmap::{DashMap, DashSet};
use murmur3::murmur3_x64_128;
mod error;
pub use dashmap;
pub use error::*;
#[derive(Clone, Debug)]
pub struct NodeSelection<ExclusionTags: Hash + Eq, Metadata> {
nodes: DashMap<NodeId, Node<ExclusionTags, Metadata>>,
}
impl<ExclusionTags, Metadata> NodeSelection<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq,
{
#[inline]
#[must_use]
pub fn get(&self, item: &[u8]) -> Option<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
self.get_internal(item, None)
}
#[inline]
#[must_use]
pub fn get_with_exclusions(
&self,
item: &[u8],
tags: &DashSet<ExclusionTags>,
) -> Option<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
self.get_internal(item, Some(tags))
}
fn get_internal(
&self,
item: &[u8],
tags: Option<&DashSet<ExclusionTags>>,
) -> Option<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
self.nodes
.iter()
.filter(|entry| {
!entry.value().exclusions.iter().any(|exclusions| {
tags.as_ref()
.map(|set| set.contains(&exclusions))
.unwrap_or_default()
})
})
.max_by(|left, right| f64_total_ordering(left.score(item), right.score(item)))
}
#[inline]
#[must_use]
pub fn get_n(
&self,
item: &[u8],
n: usize,
) -> Vec<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
self.get_n_internal(item, n, None)
}
#[inline]
#[must_use]
pub fn get_n_with_exclusions(
&self,
item: &[u8],
n: usize,
tags: &DashSet<ExclusionTags>,
) -> Vec<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
self.get_n_internal(item, n, Some(tags))
}
fn get_n_internal(
&self,
item: &[u8],
n: usize,
tags: Option<&DashSet<ExclusionTags>>,
) -> Vec<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
let mut nodes = self
.nodes
.iter()
.filter(|entry| {
!entry.value().exclusions.iter().any(|exclusions| {
tags.as_ref()
.map(|set| set.contains(&exclusions))
.unwrap_or_default()
})
})
.collect::<Vec<_>>();
nodes.sort_unstable_by(|left, right| {
f64_total_ordering(left.score(item), right.score(item))
});
nodes.truncate(n);
nodes
}
}
impl<ExclusionTags, Metadata> NodeSelection<ExclusionTags, Metadata>
where
ExclusionTags: Send + Sync + Hash + Eq,
Metadata: Send + Sync,
{
#[must_use]
#[cfg(feature = "rayon")]
pub fn par_get(
&self,
item: &[u8],
tags: Option<&DashSet<ExclusionTags>>,
) -> Option<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
self.nodes
.par_iter()
.filter(|node| {
!node.exclusions.par_iter().any(|exclusions| {
tags.as_ref()
.map(|set| set.contains(&exclusions))
.unwrap_or_default()
})
})
.max_by(|left, right| f64_total_ordering(left.score(item), right.score(item)))
}
#[must_use]
#[cfg(feature = "rayon")]
pub fn par_get_n(
&self,
item: &[u8],
n: usize,
tags: Option<&DashSet<ExclusionTags>>,
) -> Vec<RefMulti<NodeId, Node<ExclusionTags, Metadata>>> {
let mut nodes = self
.nodes
.par_iter()
.filter(|entry| {
!entry.value().exclusions.iter().any(|exclusions| {
tags.as_ref()
.map(|set| set.contains(&exclusions))
.unwrap_or_default()
})
})
.collect::<Vec<_>>();
nodes.par_sort_unstable_by(|left, right| {
f64_total_ordering(left.score(item), right.score(item))
});
nodes.truncate(n);
nodes
}
}
#[derive(Clone, Debug)]
pub struct BitNodeSelection<ExclusionTags: Hash + Eq + BitAnd, Metadata> {
nodes: DashMap<NodeId, BitNode<ExclusionTags, Metadata>>,
}
impl<ExclusionTags, Metadata> BitNodeSelection<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + BitAnd<Output = ExclusionTags> + Copy + Default,
{
#[inline]
#[must_use]
pub fn get(&self, item: &[u8]) -> Option<RefMulti<NodeId, BitNode<ExclusionTags, Metadata>>> {
self.get_with_exclusions(item, ExclusionTags::default())
}
#[inline]
#[must_use]
pub fn get_with_exclusions(
&self,
item: &[u8],
tags: ExclusionTags,
) -> Option<RefMulti<NodeId, BitNode<ExclusionTags, Metadata>>> {
self.nodes
.iter()
.filter(|entry| entry.value().exclusions & tags == ExclusionTags::default())
.max_by(|left, right| f64_total_ordering(left.score(item), right.score(item)))
}
#[inline]
#[must_use]
pub fn get_n(
&self,
item: &[u8],
n: usize,
) -> Vec<RefMulti<NodeId, BitNode<ExclusionTags, Metadata>>> {
self.get_n_with_exclusions(item, n, ExclusionTags::default())
}
#[inline]
#[must_use]
pub fn get_n_with_exclusions(
&self,
item: &[u8],
n: usize,
tags: ExclusionTags,
) -> Vec<RefMulti<NodeId, BitNode<ExclusionTags, Metadata>>> {
let mut nodes = self
.nodes
.iter()
.filter(|entry| entry.value().exclusions & tags == ExclusionTags::default())
.collect::<Vec<_>>();
nodes.sort_unstable_by(|left, right| {
f64_total_ordering(left.score(item), right.score(item))
});
nodes.truncate(n);
nodes
}
}
impl<ExclusionTags, Metadata> BitNodeSelection<ExclusionTags, Metadata>
where
ExclusionTags: Send + Sync + Hash + Eq + BitAnd<Output = ExclusionTags> + Default + Copy,
Metadata: Send + Sync,
{
#[cfg(feature = "rayon")]
pub fn par_get(
&self,
item: &[u8],
tags: ExclusionTags,
) -> Option<RefMulti<NodeId, BitNode<ExclusionTags, Metadata>>> {
self.nodes
.par_iter()
.filter(|node| node.exclusions & tags == ExclusionTags::default())
.max_by(|left, right| f64_total_ordering(left.score(item), right.score(item)))
}
#[must_use]
#[cfg(feature = "rayon")]
pub fn par_get_n(
&self,
item: &[u8],
n: usize,
tags: ExclusionTags,
) -> Vec<RefMulti<NodeId, BitNode<ExclusionTags, Metadata>>> {
let mut nodes = self
.nodes
.par_iter()
.filter(|node| node.exclusions & tags == ExclusionTags::default())
.collect::<Vec<_>>();
nodes.par_sort_unstable_by(|left, right| {
f64_total_ordering(left.score(item), right.score(item))
});
nodes.truncate(n);
nodes
}
}
macro_rules! impl_node_selection {
($struct_name:ident, $node:ident $(: $bounds:path )?) => {
impl<ExclusionTags, Metadata> $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
{
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[inline]
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
nodes: DashMap::with_capacity(capacity),
}
}
#[inline]
#[must_use]
pub fn add(&self, node: $node<ExclusionTags, Metadata>) -> NodeId {
let id = NodeId::new_opaque();
self.add_with_id(id, node);
id
}
#[inline]
pub fn add_with_id(&self, id: NodeId, node: $node<ExclusionTags, Metadata>) {
if self.nodes.insert(id, node).is_some() {
panic!("Node with duplicate id added: {}", id);
}
}
#[inline]
pub fn try_add(
&self,
node: $node<ExclusionTags, Metadata>,
) -> Result<NodeId, DuplicateIdError> {
let id = NodeId::new_opaque();
self.try_add_with_id(id, node)?;
Ok(id)
}
#[inline]
pub fn try_add_with_id(
&self,
id: NodeId,
node: $node<ExclusionTags, Metadata>,
) -> Result<(), DuplicateIdError> {
if self.nodes.contains_key(&id) {
Err(DuplicateIdError(id))
} else {
self.nodes.insert(id, node);
Ok(())
}
}
#[inline]
#[must_use]
pub fn get_mut(
&self,
id: NodeId,
) -> Option<RefMut<NodeId, $node<ExclusionTags, Metadata>>> {
self.nodes.get_mut(&id)
}
#[inline]
#[must_use]
pub fn remove(&self, id: NodeId) -> Option<$node<ExclusionTags, Metadata>> {
self.nodes.remove(&id).map(|(_id, node)| node)
}
#[inline]
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.nodes.contains_key(&id)
}
}
impl<ExclusionTags, Metadata> Default for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
{
#[inline]
fn default() -> Self {
Self {
nodes: DashMap::new(),
}
}
}
impl<ExclusionTags, Metadata> Extend<(NodeId, $node<ExclusionTags, Metadata>)>
for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
{
#[inline]
fn extend<T: IntoIterator<Item = (NodeId, $node<ExclusionTags, Metadata>)>>(
&mut self,
iter: T,
) {
self.nodes.extend(iter);
}
}
impl<ExclusionTags, Metadata> FromIterator<(NodeId, $node<ExclusionTags, Metadata>)>
for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
{
#[inline]
fn from_iter<T: IntoIterator<Item = (NodeId, $node<ExclusionTags, Metadata>)>>(
iter: T,
) -> Self {
Self {
nodes: DashMap::from_iter(iter),
}
}
}
#[cfg(feature = "rayon")]
impl<ExclusionTags, Metadata> FromParallelIterator<(NodeId, $node<ExclusionTags, Metadata>)>
for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Send + Sync + Hash + Eq + $($bounds)*,
Metadata: Send + Sync,
{
#[inline]
fn from_par_iter<I>(into_iter: I) -> Self
where
I: IntoParallelIterator<Item = (NodeId, $node<ExclusionTags, Metadata>)>,
{
Self {
nodes: DashMap::from_par_iter(into_iter),
}
}
}
impl<ExclusionTags, Metadata> IntoIterator for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
{
type Item = (NodeId, $node<ExclusionTags, Metadata>);
type IntoIter =
<DashMap<NodeId, $node<ExclusionTags, Metadata>> as IntoIterator>::IntoIter;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.nodes.into_iter()
}
}
#[cfg(feature = "rayon")]
impl<ExclusionTags, Metadata> IntoParallelIterator for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Send + Sync + Hash + Eq + $($bounds)*,
Metadata: Send + Sync,
{
type Item = <Self as IntoIterator>::Item;
type Iter =
<DashMap<NodeId, $node<ExclusionTags, Metadata>> as IntoParallelIterator>::Iter;
#[inline]
fn into_par_iter(self) -> <Self as IntoParallelIterator>::Iter {
self.nodes.into_par_iter()
}
}
#[cfg(feature = "rayon")]
impl<ExclusionTags, Metadata> ParallelExtend<(NodeId, $node<ExclusionTags, Metadata>)>
for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Send + Sync + Hash + Eq + $($bounds)*,
Metadata: Send + Sync,
{
#[inline]
fn par_extend<I>(&mut self, extendable: I)
where
I: IntoParallelIterator<Item = (NodeId, $node<ExclusionTags, Metadata>)>,
{
self.nodes.par_extend(extendable);
}
}
};
}
impl_node_selection!(NodeSelection, Node);
impl_node_selection!(BitNodeSelection, BitNode: BitAnd);
macro_rules! impl_node {
($struct_name:ident, $excludes_type:ty $(: $bounds:path )?) => {
impl<ExclusionTags, Metadata> $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
{
#[inline]
#[must_use]
pub fn with_exclusions(
weight: NonZeroUsize,
metadata: Metadata,
exclusions: $excludes_type,
) -> Self {
Self {
seed: rand::random(),
weight,
exclusions,
metadata,
}
}
#[inline]
#[must_use]
pub fn from_parts<Rng: rand::Rng>(
rng: &mut Rng,
weight: NonZeroUsize,
exclusions: $excludes_type,
metadata: Metadata,
) -> Self {
Self {
weight,
seed: rng.gen(),
exclusions,
metadata,
}
}
fn score(&self, item: &[u8]) -> f64 {
#[allow(clippy::cast_possible_truncation)]
let hash =
Hash64(murmur3_x64_128(&mut Cursor::new(item), self.seed).unwrap() as u64)
.as_normalized_float();
let score = 1.0 / -hash.ln();
#[allow(clippy::cast_precision_loss)]
{
self.weight.get() as f64 * score
}
}
#[inline]
pub fn set_weight(&mut self, weight: NonZeroUsize) {
self.weight = weight;
}
#[inline]
pub const fn data(&self) -> &Metadata {
&self.metadata
}
}
impl<ExclusionTags, Metadata> Hash for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
Metadata: Hash,
{
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.weight.hash(state);
self.seed.hash(state);
self.metadata.hash(state);
}
}
impl<ExclusionTags, Metadata> PartialEq for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
Metadata: PartialEq,
{
#[inline]
fn eq(&self, other: &Self) -> bool {
self.weight == other.weight
&& self.seed == other.seed
&& self.metadata == other.metadata
}
}
impl<ExclusionTags, Metadata> Eq for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
Metadata: Eq,
{
}
impl<ExclusionTags, Metadata> PartialOrd for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
Metadata: PartialOrd,
{
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.metadata.partial_cmp(&other.metadata) {
None | Some(Ordering::Equal) => match self.weight.cmp(&other.weight) {
Ordering::Equal => Some(self.seed.cmp(&other.seed)),
cmp => Some(cmp),
},
cmp => cmp,
}
}
}
impl<ExclusionTags, Metadata> Ord for $struct_name<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + $($bounds)*,
Metadata: Ord,
{
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
match self.metadata.cmp(&other.metadata) {
Ordering::Equal => match self.weight.cmp(&other.weight) {
Ordering::Equal => self.seed.cmp(&other.seed),
cmp => cmp,
},
cmp => cmp,
}
}
}
};
}
#[derive(Clone, Debug)]
pub struct Node<ExclusionTags: Hash + Eq, Metadata> {
weight: NonZeroUsize,
seed: u32,
exclusions: DashSet<ExclusionTags>,
metadata: Metadata,
}
impl<ExclusionTags, Metadata> Node<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq,
{
#[inline]
#[must_use]
pub fn new(weight: NonZeroUsize, metadata: Metadata) -> Self {
Self {
seed: rand::random(),
weight,
exclusions: DashSet::new(),
metadata,
}
}
}
impl<ExclusionTags, Metadata> Node<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq,
Metadata: Default,
{
#[inline]
#[must_use]
pub fn with_default(weight: NonZeroUsize) -> Self {
Self::new(weight, Metadata::default())
}
}
impl_node!(Node, DashSet<ExclusionTags>);
#[derive(Clone, Debug)]
pub struct BitNode<ExclusionTags: Hash + Eq + BitAnd, Metadata> {
weight: NonZeroUsize,
seed: u32,
exclusions: ExclusionTags,
metadata: Metadata,
}
impl<ExclusionTags, Metadata> BitNode<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + BitAnd + Default,
{
#[inline]
#[must_use]
pub fn new(weight: NonZeroUsize, metadata: Metadata) -> Self {
Self {
seed: rand::random(),
weight,
exclusions: Default::default(),
metadata,
}
}
}
impl<ExclusionTags, Metadata> BitNode<ExclusionTags, Metadata>
where
ExclusionTags: Hash + Eq + BitAnd + Default,
Metadata: Default,
{
#[inline]
#[must_use]
pub fn with_default(weight: NonZeroUsize) -> Self {
Self::new(weight, Metadata::default())
}
}
impl_node!(BitNode, ExclusionTags: BitAnd);
struct Hash64(u64);
impl Hash64 {
fn as_normalized_float(&self) -> f64 {
const FIFTY_THREE_ONES: u64 = u64::MAX >> (u64::BITS - 53);
let fifty_three_zeros: f64 = f64::from_bits((1_u64) << 53);
f64::from_bits(self.0 & FIFTY_THREE_ONES) / fifty_three_zeros
}
}
static NODE_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct NodeId(usize);
impl NodeId {
#[inline]
#[must_use]
pub const fn new(id: usize) -> Self {
Self(id)
}
#[inline]
#[must_use]
pub fn new_opaque() -> Self {
Self::new(NODE_ID_COUNTER.fetch_add(1, AtomicOrdering::Relaxed))
}
}
impl Display for NodeId {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[cfg(test)]
mod hash64 {
use crate::Hash64;
#[test]
fn normalized_float() {
assert_eq!(
Hash64(0x1234567890abcdef).as_normalized_float(),
0.6355555368049276
);
assert_eq!(
Hash64(0xffffffffffffff).as_normalized_float(),
0.9999999999999999
);
}
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)]
fn f64_total_ordering(left: f64, right: f64) -> Ordering {
let mut left = left.to_bits() as i64;
let mut right = right.to_bits() as i64;
left ^= (((left >> 63) as u64) >> 1) as i64;
right ^= (((right >> 63) as u64) >> 1) as i64;
left.cmp(&right)
}
#[cfg(test)]
mod node_selection {
use super::*;
#[should_panic]
#[test]
fn duplicate_id_will_panic() {
let selector = NodeSelection::<(), ()>::new();
let id = unsafe { NonZeroUsize::new_unchecked(1) };
selector.add_with_id(NodeId::new(0), Node::with_default(id));
selector.add_with_id(NodeId::new(0), Node::with_default(id));
}
}
#[cfg(test)]
mod node_selection_no_exclusions {
use std::collections::BTreeMap;
use super::*;
#[test]
fn sanity_check_weighted() {
let node_selection = NodeSelection::<(), ()>::new();
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(Node::with_default(NonZeroUsize::new_unchecked(100)), 0);
map.insert(Node::with_default(NonZeroUsize::new_unchecked(200)), 0);
map.insert(Node::with_default(NonZeroUsize::new_unchecked(300)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection.get(&(rand::random::<f64>()).to_le_bytes());
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = node.weight.get() as usize;
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
}
}
#[test]
fn sanity_check_unweighted() {
let node_selection = NodeSelection::<(), ()>::new();
let mut nodes: BTreeMap<_, usize> = {
let weight = unsafe { NonZeroUsize::new_unchecked(1) };
let mut map = BTreeMap::new();
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_default(weight), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection.get(&(rand::random::<f64>()).to_le_bytes());
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
let range = (200 - 50)..(200 + 50);
for counts in nodes.values() {
assert!(range.contains(counts));
}
}
#[test]
fn sanity_check_unweighted_n() {
let node_selection = NodeSelection::<(), ()>::new();
let mut nodes: BTreeMap<_, usize> = {
let weight = unsafe { NonZeroUsize::new_unchecked(1) };
let mut map = BTreeMap::new();
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_default(weight), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let nodes_selected = node_selection.get_n(&(rand::random::<f64>()).to_le_bytes(), 2);
for node in nodes_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
let range = (400 - 30)..(400 + 30);
for counts in nodes.values() {
assert!(range.contains(counts));
}
}
}
#[cfg(test)]
mod node_selection_exclusions {
use std::collections::BTreeMap;
use std::iter::FromIterator;
use super::*;
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
enum Exclusions {
A,
}
#[test]
fn sanity_check_weighted() {
let node_selection = NodeSelection::<Exclusions, ()>::new();
let exclusions = DashSet::from_iter([Exclusions::A]);
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(Node::with_default(NonZeroUsize::new_unchecked(100)), 0);
map.insert(
Node::with_exclusions(NonZeroUsize::new_unchecked(200), (), exclusions.clone()),
0,
);
map.insert(Node::with_default(NonZeroUsize::new_unchecked(300)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection
.get_with_exclusions(&(rand::random::<f64>()).to_le_bytes(), &exclusions);
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = node.weight.get() as usize * 3 / 2;
if node.exclusions.is_empty() {
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
#[test]
fn sanity_check_unweighted() {
let node_selection = NodeSelection::<Exclusions, ()>::new();
let exclusions = DashSet::from_iter([Exclusions::A]);
let mut nodes: BTreeMap<_, usize> = {
let mut map = BTreeMap::new();
let weight = unsafe { NonZeroUsize::new_unchecked(1) };
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_exclusions(weight, (), exclusions.clone()), 0);
map.insert(Node::with_default(weight), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection
.get_with_exclusions(&(rand::random::<f64>()).to_le_bytes(), &exclusions);
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
let range = (300 - 50)..(300 + 50);
for (node, counts) in nodes {
if node.exclusions.is_empty() {
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
#[test]
fn sanity_check_unweighted_n() {
let node_selection = NodeSelection::<Exclusions, ()>::new();
let exclusions = DashSet::from_iter([Exclusions::A]);
let mut nodes: BTreeMap<_, usize> = {
let mut map = BTreeMap::new();
let weight = unsafe { NonZeroUsize::new_unchecked(1) };
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_exclusions(weight, (), exclusions.clone()), 0);
map.insert(Node::with_default(weight), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection.get_n_with_exclusions(
&(rand::random::<f64>()).to_le_bytes(),
2,
&exclusions,
);
for node in node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
let range = (600 - 50)..(600 + 50);
for (node, counts) in nodes {
if node.exclusions.is_empty() {
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
}
#[cfg(test)]
mod bit_node_selection {
use std::collections::BTreeMap;
use super::*;
#[test]
fn exclusions() {
let node_selection = BitNodeSelection::<u8, ()>::new();
let exclusions = 0b01;
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(100)), 0);
map.insert(
BitNode::with_exclusions(NonZeroUsize::new_unchecked(200), (), exclusions),
0,
);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(300)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection
.get_with_exclusions(&(rand::random::<f64>()).to_le_bytes(), exclusions);
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = node.weight.get() as usize * 3 / 2;
if node.exclusions == 0 {
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
#[test]
fn no_exclusions() {
let node_selection = BitNodeSelection::<u8, ()>::new();
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(100)), 0);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(200)), 0);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(300)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection.get(&(rand::random::<f64>()).to_le_bytes());
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = node.weight.get() as usize;
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
}
}
#[test]
fn get_n() {
let node_selection = BitNodeSelection::<u8, ()>::new();
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(200)), 0);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(200)), 0);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(200)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection.get_n(&(rand::random::<f64>()).to_le_bytes(), 2);
for node in node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = node.weight.get() as usize * 2;
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
}
}
}
#[cfg(all(test, feature = "rayon"))]
mod par_tests {
use std::collections::BTreeMap;
use super::*;
#[test]
fn bit_node_selection_get() {
let node_selection = BitNodeSelection::<u8, ()>::new();
let exclusions = 0b01;
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(100)), 0);
map.insert(
BitNode::with_exclusions(NonZeroUsize::new_unchecked(200), (), exclusions),
0,
);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(300)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected =
node_selection.par_get(&(rand::random::<f64>()).to_le_bytes(), exclusions);
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = node.weight.get() as usize * 3 / 2;
if node.exclusions == 0 {
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
#[test]
fn bit_node_selection_get_n() {
let node_selection = BitNodeSelection::<u8, ()>::new();
let exclusions = 0b01;
let mut nodes: BTreeMap<_, usize> = unsafe {
let mut map = BTreeMap::new();
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(100)), 0);
map.insert(
BitNode::with_exclusions(NonZeroUsize::new_unchecked(200), (), exclusions),
0,
);
map.insert(BitNode::with_default(NonZeroUsize::new_unchecked(300)), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected =
node_selection.par_get_n(&(rand::random::<f64>()).to_le_bytes(), 2, exclusions);
for node in node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
for (node, counts) in nodes {
let anchor = 600 as usize;
if node.exclusions == 0 {
let range = (anchor - 50)..(anchor + 50);
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
enum Exclusions {
A,
}
#[test]
fn node_selection_get() {
let node_selection = NodeSelection::<Exclusions, ()>::new();
let exclusions = DashSet::from_iter([Exclusions::A]);
let mut nodes: BTreeMap<_, usize> = {
let mut map = BTreeMap::new();
let weight = unsafe { NonZeroUsize::new_unchecked(1) };
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_exclusions(weight, (), exclusions.clone()), 0);
map.insert(Node::with_default(weight), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected =
node_selection.par_get(&(rand::random::<f64>()).to_le_bytes(), Some(&exclusions));
if let Some(node) = node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
let range = (300 - 50)..(300 + 50);
for (node, counts) in nodes {
if node.exclusions.is_empty() {
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
#[test]
fn node_selection_get_n() {
let node_selection = NodeSelection::<Exclusions, ()>::new();
let exclusions = DashSet::from_iter([Exclusions::A]);
let mut nodes: BTreeMap<_, usize> = {
let mut map = BTreeMap::new();
let weight = unsafe { NonZeroUsize::new_unchecked(1) };
map.insert(Node::with_default(weight), 0);
map.insert(Node::with_exclusions(weight, (), exclusions.clone()), 0);
map.insert(Node::with_default(weight), 0);
map
};
for node in nodes.keys() {
let _ = node_selection.add(node.clone());
}
for _ in 0..600 {
let node_selected = node_selection.par_get_n(
&(rand::random::<f64>()).to_le_bytes(),
2,
Some(&exclusions),
);
for node in node_selected {
let node = nodes.get_mut(node.value()).unwrap();
*node += 1;
}
}
let range = (600 - 50)..(600 + 50);
for (node, counts) in nodes {
if node.exclusions.is_empty() {
assert!(range.contains(&counts));
} else {
assert_eq!(counts, 0);
}
}
}
}