use parking_lot::RwLock;
use rayon::{Scope, ScopeFifo};
use std::collections::VecDeque;
use std::iter::FusedIterator;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
pub trait TreeNode {
fn children(&self) -> &[Box<Self>];
}
pub trait TreeNodeMut: TreeNode {
fn children_mut(&mut self) -> &mut [Box<Self>];
}
pub trait VisitableTree: TreeNode {
fn dfs_iter(&self) -> DfsIter<'_, Self> {
DfsIter::new(self)
}
fn bfs_iter(&self) -> BfsIter<'_, Self> {
BfsIter::new(self)
}
}
pub trait MutVisitableTree: TreeNodeMut {
fn visit_mut_dfs<F: FnMut(&mut Self)>(&mut self, mut visitor: F) {
let mut stack = Vec::new();
stack.push(self);
while let Some(current_node) = stack.pop() {
visitor(current_node);
stack.extend(
current_node
.children_mut()
.iter_mut()
.rev()
.map(DerefMut::deref_mut),
);
}
}
fn visit_mut_bfs<F: FnMut(&mut Self)>(&mut self, mut visitor: F) {
let mut queue_down = VecDeque::new();
queue_down.push_back(self);
while let Some(current_node) = queue_down.pop_front() {
visitor(current_node);
queue_down.extend(
current_node
.children_mut()
.iter_mut()
.map(DerefMut::deref_mut),
);
}
}
}
pub struct DfsIter<'a, T: ?Sized> {
stack: Vec<&'a T>,
}
impl<'a, T: ?Sized> DfsIter<'a, T> {
fn new(start: &'a T) -> Self {
Self { stack: vec![start] }
}
}
impl<'a, T: TreeNode + ?Sized> Iterator for DfsIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(current_node) = self.stack.pop() {
self.stack
.extend(current_node.children().iter().rev().map(Deref::deref));
Some(current_node)
} else {
None
}
}
}
impl<T: TreeNode + ?Sized> FusedIterator for DfsIter<'_, T> {}
pub struct BfsIter<'a, T: ?Sized> {
queue: VecDeque<&'a T>,
}
impl<'a, T: ?Sized> BfsIter<'a, T> {
fn new(start: &'a T) -> Self {
Self {
queue: vec![start].into(),
}
}
}
impl<'a, T: TreeNode + ?Sized> Iterator for BfsIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if let Some(current_node) = self.queue.pop_front() {
self.queue
.extend(current_node.children().iter().rev().map(Deref::deref));
Some(current_node)
} else {
None
}
}
}
impl<T: TreeNode + ?Sized> FusedIterator for BfsIter<'_, T> {}
pub trait ParVisitableTree: TreeNode {
fn par_visit_bfs<F>(&self, visitor: F)
where
Self: Sync,
F: Fn(&Self) + Sync,
{
fn par_visit_bfs_impl<'scope, T, F>(
node: &'scope T,
s: &ScopeFifo<'scope>,
visitor: &'scope F,
) where
T: TreeNode + Sync + ?Sized,
F: Fn(&T) + Sync,
{
s.spawn_fifo(move |_| visitor(node));
for child in node.children().iter().map(Deref::deref) {
s.spawn_fifo(move |s| par_visit_bfs_impl(child, s, visitor));
}
}
let v = &visitor;
rayon::scope_fifo(move |s| par_visit_bfs_impl(self, s, v));
}
fn try_par_visit_bfs<E, F>(&self, visitor: F) -> Result<(), E>
where
Self: Sync,
E: Send + Sync,
F: Fn(&Self) -> Result<(), E> + Sync,
{
let error = Arc::new(RwLock::new(Ok(())));
fn try_par_visit_bfs_impl<'scope, T, E, F>(
node: &'scope T,
s: &ScopeFifo<'scope>,
error: Arc<RwLock<Result<(), E>>>,
visitor: &'scope F,
) where
T: TreeNode + Sync + ?Sized,
E: Send + Sync + 'scope,
F: Fn(&T) -> Result<(), E> + Sync,
{
if error.read().is_err() {
return;
}
{
let error = error.clone();
s.spawn_fifo(move |_| {
if error.read().is_ok() {
let res = visitor(node);
if res.is_err() {
let mut error_guard = error.write();
if !error_guard.is_err() {
*error_guard = res;
}
}
}
});
}
for child in node.children().iter().map(Deref::deref) {
let error = error.clone();
s.spawn_fifo(move |s| try_par_visit_bfs_impl(child, s, error, visitor));
}
}
{
let v = &visitor;
let e = error.clone();
rayon::scope_fifo(move |s| try_par_visit_bfs_impl(self, s, e, v));
}
if !error.read().is_ok() {
match Arc::try_unwrap(error) {
Ok(e) => e.into_inner(),
Err(_) => panic!("Unable to unwrap Arc that stores error of tree visitation"),
}
} else {
Ok(())
}
}
}
pub trait ParMutVisitableTree: TreeNodeMut {
fn par_visit_mut_bfs<F>(&mut self, visitor: F)
where
Self: Send + Sync,
F: Fn(&mut Self) + Sync,
{
fn par_visit_mut_bfs_impl<'scope, T, F>(
node: &'scope mut T,
s: &ScopeFifo<'scope>,
visitor: &'scope F,
) where
T: TreeNodeMut + Send + Sync + ?Sized,
F: Fn(&mut T) + Sync,
{
visitor(node);
for child in node.children_mut().iter_mut().map(DerefMut::deref_mut) {
s.spawn_fifo(move |s| par_visit_mut_bfs_impl(child, s, visitor));
}
}
let v = &visitor;
rayon::scope_fifo(move |s| par_visit_mut_bfs_impl(self, s, v));
}
fn par_visit_mut_dfs_post<F>(&mut self, visitor: F)
where
Self: Send + Sync,
F: Fn(&mut Self) + Sync,
{
fn par_visit_mut_dfs_post_impl<'scope, T, F>(
node: &'scope mut T,
_s: &Scope<'scope>,
visitor: &'scope F,
) where
T: TreeNodeMut + Send + Sync + ?Sized,
F: Fn(&mut T) + Sync,
{
rayon::scope(|s| {
for child in node.children_mut().iter_mut().map(DerefMut::deref_mut) {
s.spawn(move |s| par_visit_mut_dfs_post_impl(child, s, visitor));
}
});
visitor(node);
}
let v = &visitor;
rayon::scope(move |s| par_visit_mut_dfs_post_impl(self, s, v));
}
fn try_par_visit_mut_dfs_post<E, F>(&mut self, visitor: F) -> Result<(), E>
where
Self: Send + Sync,
E: Send + Sync,
F: Fn(&mut Self) -> Result<(), E> + Sync,
{
let error = Arc::new(RwLock::new(Ok(())));
fn try_par_visit_mut_dfs_post_impl<'scope, T, E, F>(
node: &'scope mut T,
_s: &Scope<'scope>,
error: Arc<RwLock<Result<(), E>>>,
visitor: &'scope F,
) where
T: TreeNodeMut + Send + Sync + ?Sized,
E: Send + Sync,
F: Fn(&mut T) -> Result<(), E> + Sync,
{
if error.read().is_err() {
return;
}
rayon::scope(|s| {
for child in node.children_mut().iter_mut().map(DerefMut::deref_mut) {
let error = error.clone();
s.spawn(move |s| try_par_visit_mut_dfs_post_impl(child, s, error, visitor));
}
});
if error.read().is_ok() {
let res = visitor(node);
if res.is_err() {
let mut error_guard = error.write();
if !error_guard.is_err() {
*error_guard = res;
}
}
}
}
{
let v = &visitor;
let e = error.clone();
rayon::scope(move |s| try_par_visit_mut_dfs_post_impl(self, s, e, v));
}
if !error.read().is_ok() {
match Arc::try_unwrap(error) {
Ok(e) => e.into_inner(),
Err(_) => panic!("Unable to unwrap Arc that stores error of tree visitation"),
}
} else {
Ok(())
}
}
}
impl<T: TreeNode> VisitableTree for T {}
impl<T: TreeNodeMut> MutVisitableTree for T {}
impl<T: TreeNode + Send + Sync> ParVisitableTree for T {}
impl<T: TreeNodeMut + Send + Sync> ParMutVisitableTree for T {}