use crate::{
critical::{CriticalPoint, CriticalPointKey},
hash::SliceMap,
progress::ProgressBar,
};
use std::thread;
pub fn parallel_map_reduce<State, Item, Init, Map, Reduce>(
items: &[Item],
init_state: Init,
map_op: Map,
reduce_op: Reduce,
threads: usize,
progress_bar: Box<dyn ProgressBar>,
) -> State
where
State: Send,
Init: Fn() -> State + Sync + Send,
Map: Fn(&mut State, &Item) + Sync + Send + Copy,
Reduce: Fn(&mut State, State) + Sync + Send + Copy,
Item: Sync + Clone,
{
let total_length = items.len();
let chunk_size = (total_length / threads) + (total_length % threads).min(1);
thread::scope(|s| {
let mut handles = Vec::with_capacity(threads);
items.chunks(chunk_size).for_each(|chunk| {
let mut local_state = init_state();
handles.push(s.spawn(|| {
chunk.iter().for_each(|index| {
map_op(&mut local_state, index);
progress_bar.tick();
});
local_state
}));
});
let mut handle_iter = handles.into_iter();
let mut global_state = match handle_iter.next() {
Some(h) => h.join().unwrap(),
None => panic!(""),
};
for handle in handle_iter {
reduce_op(&mut global_state, handle.join().unwrap());
}
global_state
})
}
pub fn parallel_prune<F>(
critical_points: &[CriticalPoint],
density: &[f64],
validator: F,
threads: usize,
progress_bar: Box<dyn ProgressBar>,
) -> Vec<CriticalPoint>
where
F: Fn(&CriticalPoint) -> bool + Sync + Send + Copy,
{
let final_map = parallel_map_reduce(
critical_points,
SliceMap::default,
|local_state, cp| {
if validator(cp) {
let key = CriticalPointKey::from_cp(cp.clone());
local_state
.entry(key)
.and_modify(|existing: &mut CriticalPoint| {
let rho_new = density[cp.position as usize];
let rho_old = density[existing.position as usize];
if rho_new > rho_old {
*existing = cp.clone();
}
})
.or_insert(cp.clone());
}
},
|global_state, local_state| {
local_state.into_iter().for_each(|(key, cp)| {
global_state
.entry(key)
.and_modify(|existing: &mut CriticalPoint| {
let rho_new = density[cp.position as usize];
let rho_old = density[existing.position as usize];
if rho_new > rho_old {
*existing = cp.clone();
}
})
.or_insert(cp.clone());
});
},
threads,
progress_bar,
);
final_map
.into_iter()
.map(|(k, v)| CriticalPoint::new(v.position, v.kind, k.into_box()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::critical::{CriticalPoint, CriticalPointKind};
use crate::progress::HiddenBar;
use crate::voxel_map::EncodedAtom;
fn create_cp(pos: isize, atoms: &[u32]) -> CriticalPoint {
let atoms_enc = atoms
.iter()
.map(|&id| EncodedAtom::new_zero_image(id))
.collect::<Vec<_>>()
.into_boxed_slice();
CriticalPoint::new(pos, CriticalPointKind::Bond, atoms_enc)
}
#[test]
fn test_parallel_map_reduce_sum() {
let data: Vec<i32> = (1..=100).collect();
let threads = 4;
let pbar = Box::new(HiddenBar {});
let sum = parallel_map_reduce(
&data,
|| 0, |acc, item| *acc += item, |global, local| *global += local, threads,
pbar,
);
assert_eq!(sum, 5050);
}
#[test]
fn test_parallel_map_reduce_chunking() {
let data = vec![1, 2, 3];
let threads = 8;
let pbar = Box::new(HiddenBar {});
let sum = parallel_map_reduce(
&data,
|| 0,
|acc, item| *acc += item,
|global, local| *global += local,
threads,
pbar,
);
assert_eq!(sum, 6);
}
#[test]
fn test_parallel_prune_logic() {
let cp1 = create_cp(10, &[1, 2]);
let cp2 = create_cp(20, &[1, 2]);
let cps = vec![cp1, cp2];
let mut density = vec![0.0; 30];
density[10] = 1.0;
density[20] = 5.0;
let threads = 2;
let pbar = Box::new(HiddenBar {});
let pruned = parallel_prune(&cps, &density, |_| true, threads, pbar);
assert_eq!(pruned.len(), 1);
assert_eq!(pruned[0].position, 20); }
#[test]
fn test_parallel_prune_filtering() {
let cp1 = create_cp(10, &[1, 2]); let mut cp2 = create_cp(20, &[3, 4]);
cp2.kind = CriticalPointKind::Ring;
let cps = vec![cp1.clone(), cp2];
let density = vec![0.0; 30];
let threads = 1;
let pbar = Box::new(HiddenBar {});
let pruned = parallel_prune(
&cps,
&density,
|cp| matches!(cp.kind, CriticalPointKind::Bond), threads,
pbar,
);
assert_eq!(pruned.len(), 1);
assert_eq!(pruned[0].kind, CriticalPointKind::Bond);
}
}