use crate::*;
use rayon::{ThreadPool, prelude::*};
#[derive(Debug, Clone)]
pub struct Frontier<'a, T> {
data: Vec<Vec<T>>,
threads: Option<&'a ThreadPool>,
}
impl<T> AsRef<[Vec<T>]> for Frontier<'_, T> {
fn as_ref(&self) -> &[Vec<T>] {
self.data.as_ref()
}
}
impl<T> AsMut<[Vec<T>]> for Frontier<'_, T> {
fn as_mut(&mut self) -> &mut [Vec<T>] {
self.data.as_mut()
}
}
impl<T> PartialEq for Frontier<'_, T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.len() == other.len() && self.iter().zip(other.iter()).all(|(a, b)| a.eq(b))
}
}
impl<T> From<Vec<T>> for Frontier<'_, T> {
fn from(value: Vec<T>) -> Self {
let mut frontier = Frontier::default();
frontier.data[0] = value;
frontier
}
}
impl<'a, T> From<Frontier<'a, T>> for Vec<Vec<T>> {
fn from(val: Frontier<'a, T>) -> Self {
val.data
}
}
impl<'a, T> From<Frontier<'a, T>> for Vec<T>
where
T: Clone,
{
fn from(val: Frontier<'a, T>) -> Self {
val.concat()
}
}
impl<T> core::default::Default for Frontier<'_, T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone> Frontier<'_, T> {
#[inline]
pub fn concat(&self) -> Vec<T> {
self.data.concat()
}
}
impl<'a, T> Frontier<'a, T> {
#[inline]
pub fn new() -> Self {
let n_threads = Frontier::<T>::system_number_of_threads();
Frontier {
data: (0..n_threads).map(|_| Vec::new()).collect::<Vec<_>>(),
threads: None,
}
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
let n_threads = Frontier::<T>::system_number_of_threads();
Frontier {
data: (0..n_threads)
.map(|_| Vec::with_capacity(capacity / n_threads))
.collect::<Vec<_>>(),
threads: None,
}
}
#[inline]
pub fn with_threads(thread_pool: &'a ThreadPool, capacity: Option<usize>) -> Self {
let n_threads = thread_pool.current_num_threads();
Frontier {
data: (0..n_threads)
.map(|_| Vec::with_capacity(capacity.unwrap_or(0) / n_threads))
.collect::<Vec<_>>(),
threads: Some(thread_pool),
}
}
#[inline(always)]
fn get_current_thread_index(&self) -> usize {
if let Some(thread_pool) = self.threads {
if let Some(index) = thread_pool.current_thread_index() {
index
} else {
if rayon::current_thread_index().is_some() {
panic!("Parallel frontier called from external thread pool")
} else {
0
}
}
} else {
rayon::current_thread_index().unwrap_or(0)
}
}
#[inline]
pub unsafe fn push_on_thread(&self, element: T, thread_id: usize) {
unsafe { (*((&self.data[thread_id]) as *const Vec<T> as *mut Vec<T>)).push(element) };
}
#[inline]
pub fn push(&self, element: T) {
unsafe {
self.push_on_thread(element, self.get_current_thread_index());
};
}
#[inline]
pub fn pop(&self) -> Option<T> {
unsafe { self.pop_from_thread(self.get_current_thread_index()) }
}
#[inline]
pub unsafe fn pop_from_thread(&self, thread_id: usize) -> Option<T> {
unsafe { (*((&self.data[thread_id]) as *const Vec<T> as *mut Vec<T>)).pop() }
}
#[inline]
pub fn number_of_threads(&self) -> usize {
self.data.len()
}
#[inline]
pub fn system_number_of_threads() -> usize {
rayon::current_num_threads().max(1)
}
#[inline]
pub fn len(&self) -> usize {
self.data.iter().map(|v| v.len()).sum()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn clear(&mut self) {
self.data.iter_mut().for_each(|v| v.clear());
}
#[inline]
pub fn shrink_to_fit(&mut self) {
self.data.iter_mut().for_each(|v| v.shrink_to_fit());
}
#[inline]
pub fn iter(&self) -> FrontierIter<'_, T> {
FrontierIter::new(self)
}
#[inline]
pub fn iter_vectors(&self) -> impl Iterator<Item = &Vec<T>> + '_ {
self.data.iter()
}
#[inline]
pub fn vector_sizes(&self) -> Vec<usize> {
self.data.iter().map(|v| v.len()).collect::<Vec<_>>()
}
#[inline]
pub fn par_iter(&self) -> FrontierParIter<'_, T> {
FrontierParIter::new(self)
}
}
impl<T> Frontier<'_, T>
where
T: Send + Sync,
{
#[inline]
pub fn par_iter_vectors(&self) -> impl IndexedParallelIterator<Item = &Vec<T>> + '_ {
self.data.par_iter()
}
#[inline]
pub fn par_iter_vectors_mut(
&mut self,
) -> impl IndexedParallelIterator<Item = &mut Vec<T>> + '_ {
self.data.par_iter_mut()
}
#[inline]
pub fn into_par_iter_vectors(self) -> impl IndexedParallelIterator<Item = Vec<T>> {
self.data.into_par_iter()
}
}