use rayon::scope;
#[derive(Clone)]
pub struct RayonComm;
impl Default for RayonComm {
fn default() -> Self {
Self::new()
}
}
impl RayonComm {
pub fn new() -> Self {
#[cfg(feature = "rayon")]
{
crate::parallel::threads::init_global_rayon_pool(1);
}
RayonComm
}
pub fn congruent(&self, other: &RayonComm) -> bool {
super::Comm::size(self) == super::Comm::size(other)
}
}
impl super::Comm for RayonComm {
type Vec = Vec<f64>;
type Request<'a> = ();
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
crate::parallel::threads::current_rayon_threads()
}
fn barrier(&self) {
scope(|_| {});
}
fn scatter<T: Clone>(&self, global: &[T], out: &mut [T], root: usize) {
let n = out.len();
let start = root * n;
out.clone_from_slice(&global[start..start + n]);
}
fn gather<T: Clone>(&self, local: &[T], out: &mut Vec<T>, _root: usize) {
out.clear();
out.extend_from_slice(local);
}
fn all_reduce(&self, x: f64) -> f64 {
x }
fn all_reduce_f64(&self, local: f64) -> f64 {
local }
fn split(&self, _color: i32, _key: i32) -> super::UniverseComm {
super::UniverseComm::Rayon(RayonComm::new()) }
fn irecv_from<'a>(&'a self, _buf: &'a mut [f64], _src: i32) -> Self::Request<'a> {}
fn isend_to<'a>(&'a self, _buf: &'a [f64], _dest: i32) -> Self::Request<'a> {}
fn irecv_from_u64<'a>(&'a self, _buf: &'a mut [u64], _src: i32) -> Self::Request<'a> {}
fn isend_to_u64<'a>(&'a self, _buf: &'a [u64], _dest: i32) -> Self::Request<'a> {}
fn wait_all<'a>(&self, _reqs: &mut [Self::Request<'a>]) {}
}