use std::sync::RwLock;
#[cfg(feature = "distributed")]
use crossbeam_queue::SegQueue;
#[cfg(feature = "distributed")]
pub struct ConcurrentVisitedPool {
pools: dashmap::DashMap<usize, SegQueue<Vec<bool>>>,
min_pool_size: usize,
max_array_size: usize,
}
#[cfg(feature = "distributed")]
impl ConcurrentVisitedPool {
pub fn new() -> Self {
Self {
pools: dashmap::DashMap::new(),
min_pool_size: 8,
max_array_size: 10_000_000,
}
}
pub fn get(&self, size: usize) -> Vec<bool> {
let pool_opt = self.pools.get(&size);
if let Some(queue) = pool_opt {
let q: &SegQueue<Vec<bool>> = queue.value();
if let Some(mut visited) = q.pop() {
visited.fill(false);
return visited;
}
}
vec![false; size]
}
pub fn return_pool(&self, mut visited: Vec<bool>) {
let size = visited.len();
if size > self.max_array_size {
return;
}
visited.fill(false);
let queue = self.pools.entry(size).or_insert_with(|| {
SegQueue::new()
});
if queue.len() < self.min_pool_size * 2 {
queue.push(visited);
}
}
}
#[cfg(feature = "distributed")]
impl Default for ConcurrentVisitedPool {
fn default() -> Self {
Self::new()
}
}
pub struct VisitedPool {
pool: RwLock<Vec<Vec<bool>>>,
min_pool_size: usize,
max_array_size: usize,
}
impl VisitedPool {
pub fn with_config(min_pool_size: usize, max_array_size: usize) -> Self {
Self {
pool: RwLock::new(Vec::with_capacity(min_pool_size)),
min_pool_size,
max_array_size,
}
}
pub fn new() -> Self {
Self::with_config(4, 10_000_000)
}
pub fn get(&self, size: usize) -> Vec<bool> {
let pool = self.pool.read().expect("VisitedPool RwLock poisoned: internal bug");
for i in 0..pool.len() {
if pool[i].len() == size {
let mut reused = pool[i].clone();
reused.fill(false);
return reused;
}
}
for i in 0..pool.len() {
if pool[i].len() >= size && pool[i].len() <= size * 2 {
let mut reused = pool[i].clone();
reused.truncate(size);
reused.fill(false);
return reused;
}
}
drop(pool);
vec![false; size]
}
pub fn return_pool(&self, mut visited: Vec<bool>) {
if visited.len() > self.max_array_size {
return;
}
visited.fill(false);
let mut pool = self.pool.write().expect("VisitedPool RwLock poisoned: internal bug");
if pool.len() >= self.min_pool_size * 2 {
pool.remove(0);
}
pool.push(visited);
}
}
impl Default for VisitedPool {
fn default() -> Self {
Self::new()
}
}
pub struct ThreadLocalVisitedPool {
_private: (),
}
thread_local! {
static VISITED_POOL: std::cell::RefCell<Vec<Vec<bool>>> = const { std::cell::RefCell::new(Vec::new()) };
}
impl ThreadLocalVisitedPool {
pub fn new() -> Self {
Self { _private: () }
}
pub fn get(&self, size: usize) -> Vec<bool> {
VISITED_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
for i in 0..pool.len() {
if pool[i].len() == size {
let mut reused = pool.swap_remove(i);
reused.fill(false);
return reused;
}
}
for i in 0..pool.len() {
if pool[i].len() >= size && pool[i].len() <= size * 2 {
let mut reused = pool.swap_remove(i);
reused.truncate(size);
reused.fill(false);
return reused;
}
}
vec![false; size]
})
}
pub fn return_pool(&self, mut visited: Vec<bool>) {
const MAX_POOL_SIZE: usize = 8;
const MAX_ARRAY_SIZE: usize = 10_000_000;
if visited.len() > MAX_ARRAY_SIZE {
return;
}
visited.fill(false);
VISITED_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if pool.len() < MAX_POOL_SIZE {
pool.push(visited);
}
});
}
pub fn with<F, R>(&self, size: usize, f: F) -> R
where
F: FnOnce(&mut [bool]) -> R,
{
let mut visited = self.get(size);
let result = f(&mut visited);
self.return_pool(visited);
result
}
}
impl Default for ThreadLocalVisitedPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visited_pool_basic() {
let pool = VisitedPool::new();
let visited1 = pool.get(100);
assert_eq!(visited1.len(), 100);
assert!(!visited1[0]);
pool.return_pool(visited1);
let visited2 = pool.get(100);
assert_eq!(visited2.len(), 100);
assert!(!visited2[0]); }
#[test]
fn test_visited_pool_reuse() {
let pool = VisitedPool::new();
let mut visited = pool.get(100);
visited[50] = true;
pool.return_pool(visited);
let visited2 = pool.get(100);
assert!(!visited2[50]); }
#[test]
fn test_visited_pool_different_sizes() {
let pool = VisitedPool::new();
let visited1 = pool.get(100);
pool.return_pool(visited1);
let visited2 = pool.get(200);
assert_eq!(visited2.len(), 200);
pool.return_pool(visited2);
let visited3 = pool.get(100);
assert_eq!(visited3.len(), 100);
}
#[test]
fn test_thread_local_pool() {
let pool = ThreadLocalVisitedPool::new();
let mut visited = pool.get(100);
visited[50] = true;
pool.return_pool(visited);
let visited2 = pool.get(100);
assert!(!visited2[50]);
}
#[test]
fn test_thread_local_with() {
let pool = ThreadLocalVisitedPool::new();
let count = pool.with(100, |visited| {
visited[10] = true;
visited[20] = true;
visited.iter().filter(|&&v| v).count()
});
assert_eq!(count, 2);
let visited = pool.get(100);
assert!(!visited[10]);
assert!(!visited[20]);
}
#[test]
#[cfg(feature = "distributed")]
fn test_concurrent_pool_basic() {
let pool = ConcurrentVisitedPool::new();
let visited1 = pool.get(100);
assert_eq!(visited1.len(), 100);
assert!(!visited1[0]);
pool.return_pool(visited1);
let visited2 = pool.get(100);
assert_eq!(visited2.len(), 100);
assert!(!visited2[0]); }
#[test]
#[cfg(feature = "distributed")]
fn test_concurrent_pool_reuse() {
let pool = ConcurrentVisitedPool::new();
let mut visited = pool.get(100);
visited[50] = true;
pool.return_pool(visited);
let visited2 = pool.get(100);
assert!(!visited2[50]); }
#[test]
#[cfg(feature = "distributed")]
fn test_concurrent_pool_parallel() {
use rayon::prelude::*;
let pool = ConcurrentVisitedPool::new();
let results: Vec<usize> = (0..10)
.into_par_iter()
.map(|_| {
let mut visited = pool.get(100);
visited[0] = true;
let count = visited.iter().filter(|&&v| v).count();
pool.return_pool(visited);
count
})
.collect();
for result in results {
assert_eq!(result, 1);
}
}
}