use std::{
any::Any,
collections::HashMap,
error::Error,
fmt::{self, Display, Formatter},
hash::Hash,
iter::Iterator,
sync::{mpsc, Arc, Mutex},
vec::Vec,
};
use crossbeam::thread;
use rayon::iter::IntoParallelIterator;
use rayon::prelude::ParallelIterator;
#[cfg(feature = "serde-serialize")]
use serde::{Deserialize, Serialize};
use super::lattice::{LatticeCyclic, LatticeElementToIndex};
#[derive(Debug)]
#[non_exhaustive]
pub enum ThreadAnyError {
ThreadNumberIncorrect,
Panic(Vec<Box<dyn Any + Send + 'static>>),
}
impl Display for ThreadAnyError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::ThreadNumberIncorrect => write!(f, "number of thread is incorrect"),
Self::Panic(any) => {
let n = any.len();
if n == 0 {
write!(f, "0 thread panicked")?;
}
else if n == 1 {
write!(f, "a thread panicked with")?;
}
else {
write!(f, "{} threads panicked with [", n)?;
}
for (index, element_any) in any.iter().enumerate() {
if let Some(string) = element_any.downcast_ref::<String>() {
write!(f, "\"{}\"", string)?;
}
else if let Some(string) = element_any.downcast_ref::<&str>() {
write!(f, "\"{}\"", string)?;
}
else {
write!(f, "{:?}", element_any)?;
}
if index < any.len() - 1 {
write!(f, " ,")?;
}
else if n > 1 {
write!(f, "]")?;
}
}
Ok(())
}
}
}
}
impl Error for ThreadAnyError {}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
#[non_exhaustive]
pub enum ThreadError {
ThreadNumberIncorrect,
Panic(Vec<Option<String>>),
}
impl Display for ThreadError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::ThreadNumberIncorrect => write!(f, "number of thread is incorrect"),
Self::Panic(strings) => {
let n = strings.len();
if n == 0 {
write!(f, "0 thread panicked")?;
}
else if n == 1 {
write!(f, "a thread panicked with")?;
}
else {
write!(f, "{} threads panicked with [", n)?;
}
for (index, string) in strings.iter().enumerate() {
if let Some(string) = string {
write!(f, "\"{}\"", string)?;
}
else {
write!(f, "None")?;
}
if index < strings.len() - 1 {
write!(f, " ,")?;
}
else if n > 1 {
write!(f, "]")?;
}
}
Ok(())
}
}
}
}
impl Error for ThreadError {}
impl From<ThreadAnyError> for ThreadError {
#[allow(clippy::manual_map)] fn from(f: ThreadAnyError) -> Self {
match f {
ThreadAnyError::ThreadNumberIncorrect => Self::ThreadNumberIncorrect,
ThreadAnyError::Panic(any) => Self::Panic(
any.iter()
.map(|element| {
if let Some(string) = element.downcast_ref::<String>() {
Some(string.clone())
}
else if let Some(string) = element.downcast_ref::<&str>() {
Some(string.to_string())
}
else {
None
}
})
.collect(),
),
}
}
}
impl From<ThreadError> for ThreadAnyError {
fn from(f: ThreadError) -> Self {
match f {
ThreadError::ThreadNumberIncorrect => Self::ThreadNumberIncorrect,
ThreadError::Panic(strings) => Self::Panic(
strings
.iter()
.map(|string| -> Box<dyn Any + Send + 'static> {
if let Some(string) = string {
Box::new(string.clone())
}
else {
Box::new("".to_string())
}
})
.collect(),
),
}
}
}
pub fn run_pool_parallel<Key, Data, CommonData, F>(
iter: impl Iterator<Item = Key> + Send,
common_data: &CommonData,
closure: &F,
number_of_thread: usize,
capacity: usize,
) -> Result<HashMap<Key, Data>, ThreadAnyError>
where
CommonData: Sync,
Key: Eq + Hash + Send + Clone + Sync,
Data: Send,
F: Sync + Clone + Fn(&Key, &CommonData) -> Data,
{
run_pool_parallel_with_initializations_mutable(
iter,
common_data,
&|_, key, common| closure(key, common),
&|| (),
number_of_thread,
capacity,
)
}
#[allow(clippy::needless_return)] #[allow(clippy::semicolon_if_nothing_returned)] pub fn run_pool_parallel_with_initializations_mutable<Key, Data, CommonData, InitData, F, FInit>(
iter: impl Iterator<Item = Key> + Send,
common_data: &CommonData,
closure: &F,
closure_init: FInit,
number_of_thread: usize,
capacity: usize,
) -> Result<HashMap<Key, Data>, ThreadAnyError>
where
CommonData: Sync,
Key: Eq + Hash + Send + Clone + Sync,
Data: Send,
F: Sync + Clone + Fn(&mut InitData, &Key, &CommonData) -> Data,
FInit: Send + Clone + FnOnce() -> InitData,
{
if number_of_thread == 0 {
return Err(ThreadAnyError::ThreadNumberIncorrect);
}
else if number_of_thread == 1 {
let mut hash_map = HashMap::<Key, Data>::with_capacity(capacity);
let mut init_data = closure_init();
for i in iter {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
hash_map.insert(i.clone(), closure(&mut init_data, &i, common_data))
}))
.map_err(|err| ThreadAnyError::Panic(vec![err]))?;
}
return Ok(hash_map);
}
else {
let result = thread::scope(|s| {
let mutex_iter = Arc::new(Mutex::new(iter));
let mut threads = Vec::with_capacity(number_of_thread);
let (result_tx, result_rx) = mpsc::channel::<(Key, Data)>();
for _ in 0..number_of_thread {
let iter_clone = Arc::clone(&mutex_iter);
let transmitter = result_tx.clone();
let closure_init_clone = closure_init.clone();
let handel = s.spawn(move |_| {
let mut init_data = closure_init_clone();
loop {
let val = iter_clone.lock().unwrap().next();
match val {
Some(i) => transmitter
.send((i.clone(), closure(&mut init_data, &i, common_data)))
.unwrap(),
None => break,
}
}
});
threads.push(handel);
}
drop(result_tx);
let mut hash_map = HashMap::<Key, Data>::with_capacity(capacity);
for message in result_rx {
let (key, data) = message;
hash_map.insert(key, data);
}
let panics = threads
.into_iter()
.map(|handel| handel.join())
.filter_map(|res| res.err())
.collect::<Vec<_>>();
if !panics.is_empty() {
return Err(ThreadAnyError::Panic(panics));
}
Ok(hash_map)
})
.unwrap_or_else(|err| {
if err
.downcast_ref::<Vec<Box<dyn Any + 'static + Send>>>()
.is_some()
{
unreachable!("a failing handle is not joined")
}
unreachable!("main thread panicked")
});
return result;
}
}
pub fn run_pool_parallel_vec<Key, Data, CommonData, F, const D: usize>(
iter: impl Iterator<Item = Key> + Send,
common_data: &CommonData,
closure: &F,
number_of_thread: usize,
capacity: usize,
l: &LatticeCyclic<D>,
default_data: &Data,
) -> Result<Vec<Data>, ThreadAnyError>
where
CommonData: Sync,
Key: Eq + Send + Clone + Sync + LatticeElementToIndex<D>,
Data: Send + Clone,
F: Sync + Clone + Fn(&Key, &CommonData) -> Data,
{
run_pool_parallel_vec_with_initializations_mutable(
iter,
common_data,
&|_, key, common| closure(key, common),
&|| (),
number_of_thread,
capacity,
l,
default_data,
)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::needless_return)] #[allow(clippy::semicolon_if_nothing_returned)] pub fn run_pool_parallel_vec_with_initializations_mutable<
Key,
Data,
CommonData,
InitData,
F,
FInit,
const D: usize,
>(
iter: impl Iterator<Item = Key> + Send,
common_data: &CommonData,
closure: &F,
closure_init: FInit,
number_of_thread: usize,
capacity: usize,
l: &LatticeCyclic<D>,
default_data: &Data,
) -> Result<Vec<Data>, ThreadAnyError>
where
CommonData: Sync,
Key: Eq + Send + Clone + Sync,
Data: Send + Clone,
F: Sync + Clone + Fn(&mut InitData, &Key, &CommonData) -> Data,
FInit: Send + Clone + FnOnce() -> InitData,
Key: LatticeElementToIndex<D>,
{
if number_of_thread == 0 {
return Err(ThreadAnyError::ThreadNumberIncorrect);
}
else if number_of_thread == 1 {
let mut vec = Vec::<Data>::with_capacity(capacity);
let mut init_data = closure_init();
for i in iter {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
insert_in_vec(
&mut vec,
i.clone().to_index(l),
closure(&mut init_data, &i, common_data),
default_data,
);
}))
.map_err(|err| ThreadAnyError::Panic(vec![err]))?;
}
return Ok(vec);
}
else {
let result = thread::scope(|s| {
let mutex_iter = Arc::new(Mutex::new(iter));
let mut threads = Vec::with_capacity(number_of_thread);
let (result_tx, result_rx) = mpsc::channel::<(Key, Data)>();
for _ in 0..number_of_thread {
let iter_clone = Arc::clone(&mutex_iter);
let transmitter = result_tx.clone();
let closure_init_clone = closure_init.clone();
let handel = s.spawn(move |_| {
let mut init_data = closure_init_clone();
loop {
let val = iter_clone.lock().unwrap().next();
match val {
Some(i) => transmitter
.send((i.clone(), closure(&mut init_data, &i, common_data)))
.unwrap(),
None => break,
}
}
});
threads.push(handel);
}
drop(result_tx);
let mut vec = Vec::<Data>::with_capacity(capacity);
for message in result_rx {
let (key, data) = message;
insert_in_vec(&mut vec, key.to_index(l), data, default_data);
}
let panics = threads
.into_iter()
.map(|handel| handel.join())
.filter_map(|res| res.err())
.collect::<Vec<_>>();
if !panics.is_empty() {
return Err(ThreadAnyError::Panic(panics));
}
Ok(vec)
})
.unwrap_or_else(|err| {
if err
.downcast_ref::<Vec<Box<dyn Any + 'static + Send>>>()
.is_some()
{
unreachable!("a failing handle is not joined")
}
unreachable!("main thread panicked")
});
return result;
}
}
pub fn insert_in_vec<Data>(vec: &mut Vec<Data>, pos: usize, data: Data, default_data: &Data)
where
Data: Clone,
{
if pos < vec.len() {
vec[pos] = data;
}
else {
for _ in vec.len()..pos {
vec.push(default_data.clone());
}
vec.push(data);
}
}
pub fn run_pool_parallel_rayon<Key, Data, CommonData, F>(
iter: impl Iterator<Item = Key> + Send,
common_data: &CommonData,
closure: F,
) -> Vec<Data>
where
CommonData: Sync,
Key: Eq + Send,
Data: Send,
F: Sync + Fn(&Key, &CommonData) -> Data,
{
iter.collect::<Vec<Key>>()
.into_par_iter()
.map(|el| closure(&el, common_data))
.collect()
}
#[cfg(test)]
mod test {
use std::error::Error;
use super::*;
use crate::error::ImplementationError;
#[test]
fn thread_error() {
assert_eq!(
format!("{}", ThreadAnyError::ThreadNumberIncorrect),
"number of thread is incorrect"
);
assert!(
format!("{}", ThreadAnyError::Panic(vec![Box::new(())])).contains("a thread panicked")
);
assert!(
format!("{}", ThreadAnyError::Panic(vec![Box::new("message 1")])).contains("message 1")
);
assert!(format!("{}", ThreadAnyError::Panic(vec![])).contains("0 thread panicked"));
assert!(ThreadAnyError::ThreadNumberIncorrect.source().is_none());
assert!(ThreadAnyError::Panic(vec![Box::new(())]).source().is_none());
assert!(
ThreadAnyError::Panic(vec![Box::new(ImplementationError::Unreachable)])
.source()
.is_none()
);
assert!(ThreadAnyError::Panic(vec![Box::new("test")])
.source()
.is_none());
assert_eq!(
format!("{}", ThreadError::ThreadNumberIncorrect),
"number of thread is incorrect"
);
assert!(format!("{}", ThreadError::Panic(vec![None])).contains("a thread panicked"));
assert!(format!("{}", ThreadError::Panic(vec![None, None])).contains("2 threads panicked"));
assert!(format!(
"{}",
ThreadError::Panic(vec![Some("message 1".to_string())])
)
.contains("message 1"));
assert!(format!("{}", ThreadError::Panic(vec![])).contains("0 thread panicked"));
assert!(ThreadError::ThreadNumberIncorrect.source().is_none());
assert!(ThreadError::Panic(vec![None]).source().is_none());
assert!(ThreadError::Panic(vec![Some("".to_string())])
.source()
.is_none());
assert!(ThreadError::Panic(vec![Some("test".to_string())])
.source()
.is_none());
let error = ThreadAnyError::Panic(vec![
Box::new(()),
Box::new("t1"),
Box::new("t2".to_string()),
]);
let error2 = ThreadAnyError::Panic(vec![
Box::new(""),
Box::new("t1".to_string()),
Box::new("t2".to_string()),
]);
let error3 = ThreadError::Panic(vec![None, Some("t1".to_string()), Some("t2".to_string())]);
assert_eq!(ThreadError::from(error), error3);
assert_eq!(ThreadAnyError::from(error3).to_string(), error2.to_string());
let error = ThreadAnyError::ThreadNumberIncorrect;
let error2 = ThreadError::ThreadNumberIncorrect;
assert_eq!(ThreadError::from(error), error2);
let error = ThreadAnyError::ThreadNumberIncorrect;
assert_eq!(ThreadAnyError::from(error2).to_string(), error.to_string());
}
}