use std::{
ffi::{CStr, CString},
mem::MaybeUninit,
os::raw::{c_char, c_int, c_void},
process,
};
use conv::ConvUtil;
#[cfg(not(msmpi))]
use crate::Tag;
use crate::{
attribute::CommAttribute,
datatype::traits::*,
ffi,
ffi::{MPI_Comm, MPI_Group},
raw::traits::*,
with_uninitialized, Count, IntArray,
};
mod cartesian;
pub mod traits {
pub use super::{AnyCommunicator, AsCommunicator, Communicator, Group};
}
pub use self::cartesian::*;
pub trait AsCommunicator {
type Out: Communicator;
fn as_communicator(&self) -> &Self::Out;
}
pub type Rank = c_int;
pub(crate) mod sealed;
pub struct SimpleCommunicator(pub(crate) sealed::CommunicatorHandle);
impl SimpleCommunicator {
pub fn world() -> SimpleCommunicator {
SimpleCommunicator(sealed::CommunicatorHandle::World)
}
pub fn self_comm() -> SimpleCommunicator {
SimpleCommunicator(sealed::CommunicatorHandle::SelfComm)
}
unsafe fn try_from_raw(raw: MPI_Comm) -> Option<SimpleCommunicator> {
let handle = sealed::CommunicatorHandle::try_from_raw(raw)?;
if let sealed::CommunicatorHandle::User(_) = handle {
Some(SimpleCommunicator(handle))
} else {
None
}
}
pub fn topology(&self) -> Topology {
unsafe {
let (_, topology) =
with_uninitialized(|topology| ffi::MPI_Topo_test(self.as_raw(), topology));
if topology == ffi::RSMPI_GRAPH {
Topology::Graph
} else if topology == ffi::RSMPI_CART {
Topology::Cartesian
} else if topology == ffi::RSMPI_DIST_GRAPH {
Topology::DistributedGraph
} else if topology == ffi::RSMPI_UNDEFINED {
Topology::Undefined
} else {
panic!("Unexpected Topology type!")
}
}
}
pub fn into_topology(self) -> IntoTopology {
match self.topology() {
Topology::Graph => unimplemented!(),
Topology::Cartesian => IntoTopology::Cartesian(CartesianCommunicator(self)),
Topology::DistributedGraph => unimplemented!(),
Topology::Undefined => IntoTopology::Undefined(self),
}
}
}
unsafe impl AsRaw for SimpleCommunicator {
type Raw = MPI_Comm;
fn as_raw(&self) -> Self::Raw {
self.0.as_raw()
}
}
impl sealed::AsHandle for SimpleCommunicator {
fn as_handle(&self) -> &sealed::CommunicatorHandle {
&self.0
}
}
impl FromRaw for SimpleCommunicator {
unsafe fn from_raw(handle: <Self as AsRaw>::Raw) -> Self {
let handle = sealed::CommunicatorHandle::simple_comm_from_raw(handle);
SimpleCommunicator(handle)
}
}
impl Communicator for SimpleCommunicator {
fn target_size(&self) -> Rank {
self.size()
}
}
impl AsCommunicator for SimpleCommunicator {
type Out = SimpleCommunicator;
fn as_communicator(&self) -> &Self::Out {
self
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum Topology {
Graph,
Cartesian,
DistributedGraph,
Undefined,
}
pub enum IntoTopology {
Graph(GraphCommunicator),
Cartesian(CartesianCommunicator),
DistributedGraph(DistributedGraphCommunicator),
Undefined(SimpleCommunicator),
}
pub struct InterCommunicator(pub(crate) sealed::CommunicatorHandle);
impl InterCommunicator {
pub unsafe fn try_from_raw(raw: MPI_Comm) -> Option<Self> {
sealed::CommunicatorHandle::try_from_raw(raw).and_then(|handle| match handle {
sealed::CommunicatorHandle::InterComm(_) => Some(InterCommunicator(handle)),
_ => None,
})
}
pub fn remote_size(&self) -> Rank {
let mut size = Rank::min_value();
unsafe {
ffi::MPI_Comm_remote_size(self.as_raw(), &mut size);
}
size
}
pub fn remote_group(&self) -> UserGroup {
unsafe {
let (_, g) = with_uninitialized(|g| {
ffi::MPI_Comm_remote_group(self.as_raw(), g);
});
UserGroup(g)
}
}
pub fn merge(&self, merge_order: MergeOrder) -> SimpleCommunicator {
unsafe {
SimpleCommunicator::try_from_raw(
with_uninitialized(|raw| {
ffi::MPI_Intercomm_merge(self.as_raw(), merge_order.as_raw(), raw)
})
.1,
)
}.expect("rspmi internal error: MPI implementation return MPI_COMM_NULL from MPI_Intercomm_merge()")
}
}
impl AsCommunicator for InterCommunicator {
type Out = InterCommunicator;
fn as_communicator(&self) -> &Self::Out {
self
}
}
unsafe impl AsRaw for InterCommunicator {
type Raw = MPI_Comm;
fn as_raw(&self) -> Self::Raw {
self.0.as_raw()
}
}
impl sealed::AsHandle for InterCommunicator {
fn as_handle(&self) -> &sealed::CommunicatorHandle {
&self.0
}
}
impl FromRaw for InterCommunicator {
unsafe fn from_raw(handle: <Self as AsRaw>::Raw) -> Self {
Self(sealed::CommunicatorHandle::inter_comm_from_raw(handle))
}
}
impl Communicator for InterCommunicator {
fn target_size(&self) -> Rank {
self.remote_size()
}
}
#[allow(missing_copy_implementations)]
pub struct GraphCommunicator;
#[allow(missing_copy_implementations)]
pub struct DistributedGraphCommunicator;
#[derive(Copy, Clone, Debug)]
pub struct Color(c_int);
impl Color {
pub fn undefined() -> Color {
Color(unsafe { ffi::RSMPI_UNDEFINED })
}
pub fn with_value(value: c_int) -> Color {
if value < 0 {
panic!("Value of color must be non-negative.")
}
Color(value)
}
fn as_raw(self) -> c_int {
self.0
}
}
pub type Key = c_int;
pub trait Communicator: sealed::AsHandle {
fn target_size(&self) -> Rank;
fn size(&self) -> Rank {
unsafe { with_uninitialized(|size| ffi::MPI_Comm_size(self.as_raw(), size)).1 }
}
fn rank(&self) -> Rank {
unsafe { with_uninitialized(|rank| ffi::MPI_Comm_rank(self.as_raw(), rank)).1 }
}
fn process_at_rank(&self, r: Rank) -> Process<'_> {
assert!(0 <= r && r < self.target_size());
Process::by_rank_unchecked(self, r)
}
fn any_process(&self) -> AnyProcess<'_> {
AnyProcess(self.as_handle())
}
fn this_process(&self) -> Process<'_> {
let rank = self.rank();
Process::by_rank_unchecked(self, rank)
}
fn compare(&self, other: &dyn Communicator) -> CommunicatorRelation {
unsafe {
with_uninitialized(|cmp| ffi::MPI_Comm_compare(self.as_raw(), other.as_raw(), cmp))
.1
.into()
}
}
fn duplicate(&self) -> SimpleCommunicator {
unsafe {
SimpleCommunicator::from_raw(
with_uninitialized(|newcomm| ffi::MPI_Comm_dup(self.as_raw(), newcomm)).1,
)
}
}
fn split_by_color(&self, color: Color) -> Option<SimpleCommunicator> {
self.split_by_color_with_key(color, Key::default())
}
fn split_by_color_with_key(&self, color: Color, key: Key) -> Option<SimpleCommunicator> {
unsafe {
SimpleCommunicator::try_from_raw(
with_uninitialized(|newcomm| {
ffi::MPI_Comm_split(self.as_raw(), color.as_raw(), key, newcomm)
})
.1,
)
}
}
fn split_shared(&self, key: c_int) -> SimpleCommunicator {
unsafe {
SimpleCommunicator::try_from_raw(
with_uninitialized(|newcomm| {
ffi::MPI_Comm_split_type(
self.as_raw(),
ffi::RSMPI_COMM_TYPE_SHARED,
key,
ffi::RSMPI_INFO_NULL,
newcomm,
)
})
.1,
).expect("rsmpi internal error: MPI implementation incorrectly returned MPI_COMM_NULL from MPI_Comm_split_type(..., MPI_COMM_TYPE_SHARED, ...)")
}
}
fn split_by_subgroup_collective(&self, group: &dyn Group) -> Option<SimpleCommunicator> {
unsafe {
SimpleCommunicator::try_from_raw(
with_uninitialized(|newcomm| {
ffi::MPI_Comm_create(self.as_raw(), group.as_raw(), newcomm)
})
.1,
)
}
}
#[cfg(not(msmpi))]
fn split_by_subgroup(&self, group: &dyn Group) -> Option<SimpleCommunicator> {
self.split_by_subgroup_with_tag(group, Tag::default())
}
#[cfg(not(msmpi))]
fn split_by_subgroup_with_tag(
&self,
group: &dyn Group,
tag: Tag,
) -> Option<SimpleCommunicator> {
unsafe {
SimpleCommunicator::try_from_raw(
with_uninitialized(|newcomm| {
ffi::MPI_Comm_create_group(self.as_raw(), group.as_raw(), tag, newcomm)
})
.1,
)
}
}
fn group(&self) -> UserGroup {
unsafe {
UserGroup(with_uninitialized(|group| ffi::MPI_Comm_group(self.as_raw(), group)).1)
}
}
fn abort(&self, errorcode: c_int) -> ! {
unsafe {
ffi::MPI_Abort(self.as_raw(), errorcode);
}
process::abort();
}
fn test_inter(&self) -> bool {
unsafe { comm_is_inter(self.as_raw()) }
}
fn set_name(&self, name: &str) {
let c_name = CString::new(name).expect("Failed to convert the Rust string to a C string");
unsafe {
ffi::MPI_Comm_set_name(self.as_raw(), c_name.as_ptr());
}
}
fn get_name(&self) -> String {
type BufType = [c_char; ffi::MPI_MAX_OBJECT_NAME as usize];
unsafe {
let mut buf = MaybeUninit::<BufType>::uninit();
let (_, _resultlen) = with_uninitialized(|resultlen| {
ffi::MPI_Comm_get_name(self.as_raw(), &mut (*buf.as_mut_ptr())[0], resultlen)
});
let buf_cstr = CStr::from_ptr(buf.assume_init().as_ptr());
buf_cstr.to_string_lossy().into_owned()
}
}
fn create_cartesian_communicator(
&self,
dims: &[Count],
periods: &[bool],
reorder: bool,
) -> Option<CartesianCommunicator> {
assert_eq!(
dims.len(),
periods.len(),
"dims and periods must be parallel, equal-sized arrays"
);
let periods: IntArray = periods.iter().map(|x| *x as i32).collect();
unsafe {
let mut comm_cart = ffi::RSMPI_COMM_NULL;
ffi::MPI_Cart_create(
self.as_raw(),
dims.count(),
dims.as_ptr(),
periods.as_ptr(),
reorder as Count,
&mut comm_cart,
);
CartesianCommunicator::try_from_raw(comm_cart)
}
}
fn cartesian_map(&self, dims: &[Count], periods: &[bool]) -> Option<Rank> {
assert_eq!(
dims.len(),
periods.len(),
"dims and periods must be parallel, equal-sized arrays"
);
let periods: IntArray = periods.iter().map(|x| *x as i32).collect();
unsafe {
let mut new_rank = ffi::MPI_UNDEFINED;
ffi::MPI_Cart_map(
self.as_raw(),
dims.count(),
dims.as_ptr(),
periods.as_ptr(),
&mut new_rank,
);
if new_rank == ffi::MPI_UNDEFINED {
None
} else {
Some(new_rank)
}
}
}
fn pack_size<Dt>(&self, incount: Count, datatype: &Dt) -> Count
where
Dt: Datatype,
Self: Sized,
{
unsafe {
with_uninitialized(|size| {
ffi::MPI_Pack_size(incount, datatype.as_raw(), self.as_raw(), size)
})
.1
}
}
fn pack<Buf>(&self, inbuf: &Buf) -> Vec<u8>
where
Buf: ?Sized + Buffer,
Self: Sized,
{
let inbuf_dt = inbuf.as_datatype();
let mut outbuf = vec![
0;
self.pack_size(inbuf.count(), &inbuf_dt)
.value_as::<usize>()
.expect("MPI_Pack_size returned a negative buffer size!")
];
let position = self.pack_into(inbuf, &mut outbuf[..], 0);
outbuf.resize(
position
.value_as()
.expect("MPI_Pack returned a negative position!"),
0,
);
outbuf
}
fn pack_into<Buf>(&self, inbuf: &Buf, outbuf: &mut [u8], position: Count) -> Count
where
Buf: ?Sized + Buffer,
Self: Sized,
{
let inbuf_dt = inbuf.as_datatype();
let mut position: Count = position;
unsafe {
ffi::MPI_Pack(
inbuf.pointer(),
inbuf.count(),
inbuf_dt.as_raw(),
outbuf.as_mut_ptr() as *mut _,
outbuf.count(),
&mut position,
self.as_raw(),
);
}
position
}
unsafe fn unpack_into<Buf>(&self, inbuf: &[u8], outbuf: &mut Buf, position: Count) -> Count
where
Buf: ?Sized + BufferMut,
Self: Sized,
{
let outbuf_dt = outbuf.as_datatype();
let mut position: Count = position;
ffi::MPI_Unpack(
inbuf.as_ptr() as *const _,
inbuf.count(),
&mut position,
outbuf.pointer_mut(),
outbuf.count(),
outbuf_dt.as_raw(),
self.as_raw(),
);
position
}
fn parent(&self) -> Option<InterCommunicator> {
unsafe {
let mut comm = ffi::RSMPI_COMM_NULL;
ffi::MPI_Comm_get_parent(&mut comm);
if comm == ffi::RSMPI_COMM_NULL {
return None;
}
Some(InterCommunicator::from_raw(comm))
}
}
}
pub trait AnyCommunicator: Communicator {
fn get_attr<A: CommAttribute>(&self) -> Option<&A>;
fn set_attr<A: CommAttribute>(&mut self, val: A);
}
impl<C: ?Sized + Communicator> AnyCommunicator for C {
fn get_attr<A: CommAttribute>(&self) -> Option<&A> {
let key = A::get_key();
let (val, flag) = unsafe {
let mut ptr: MaybeUninit<*mut A> = MaybeUninit::uninit();
let (_, flag) = with_uninitialized(|flag| {
ffi::MPI_Comm_get_attr(
self.as_raw(),
key.as_raw(),
ptr.as_mut_ptr() as *mut c_void,
flag,
)
});
(ptr.assume_init(), flag)
};
if flag == 0 {
None
} else {
unsafe { val.as_ref() }
}
}
fn set_attr<A: CommAttribute>(&mut self, val: A) {
let key = A::get_key();
let val = Box::new(val);
unsafe {
ffi::MPI_Comm_set_attr(
self.as_raw(),
key.as_raw(),
Box::into_raw(val) as *mut c_void,
)
};
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum CommunicatorRelation {
Identical,
Congruent,
Similar,
Unequal,
}
impl From<c_int> for CommunicatorRelation {
fn from(i: c_int) -> CommunicatorRelation {
if i == unsafe { ffi::RSMPI_IDENT } {
return CommunicatorRelation::Identical;
} else if i == unsafe { ffi::RSMPI_CONGRUENT } {
return CommunicatorRelation::Congruent;
} else if i == unsafe { ffi::RSMPI_SIMILAR } {
return CommunicatorRelation::Similar;
} else if i == unsafe { ffi::RSMPI_UNEQUAL } {
return CommunicatorRelation::Unequal;
}
panic!("Unknown communicator relation: {}", i)
}
}
#[derive(Copy, Clone)]
pub enum MergeOrder {
Low,
High,
}
impl MergeOrder {
fn as_raw(&self) -> c_int {
match self {
MergeOrder::Low => 0,
MergeOrder::High => 1,
}
}
}
#[derive(Copy, Clone)]
pub struct Process<'a> {
comm: AnyProcess<'a>,
rank: Rank,
}
impl<'a> Process<'a> {
#[allow(dead_code)]
fn by_rank<C: Communicator + ?Sized>(c: &'a C, r: Rank) -> Option<Self> {
if r != unsafe { ffi::RSMPI_PROC_NULL } {
Some(Process::by_rank_unchecked(c, r))
} else {
None
}
}
fn by_rank_unchecked<C: Communicator + ?Sized>(c: &'a C, r: Rank) -> Self {
Process {
comm: AnyProcess(c.as_handle()),
rank: r,
}
}
pub fn rank(&self) -> Rank {
self.rank
}
pub fn is_self(&self) -> bool {
self.as_communicator().rank() == self.rank
}
}
unsafe impl<'a> AsRaw for Process<'a> {
type Raw = MPI_Comm;
fn as_raw(&self) -> Self::Raw {
self.comm.as_raw()
}
}
impl<'a> sealed::AsHandle for Process<'a> {
fn as_handle(&self) -> &sealed::CommunicatorHandle {
self.comm.as_handle()
}
}
impl<'a> Communicator for Process<'a> {
fn target_size(&self) -> Rank {
self.size()
}
}
impl<'a> AsCommunicator for Process<'a> {
type Out = AnyProcess<'a>;
fn as_communicator(&self) -> &Self::Out {
&self.comm
}
}
#[derive(Copy, Clone)]
pub struct AnyProcess<'a>(&'a sealed::CommunicatorHandle);
unsafe impl<'a> AsRaw for AnyProcess<'a> {
type Raw = MPI_Comm;
fn as_raw(&self) -> Self::Raw {
self.0.as_raw()
}
}
impl<'a> sealed::AsHandle for AnyProcess<'a> {
fn as_handle(&self) -> &sealed::CommunicatorHandle {
self.0
}
}
impl<'a> Communicator for AnyProcess<'a> {
fn target_size(&self) -> Rank {
self.size()
}
}
impl<'a> AsCommunicator for AnyProcess<'a> {
type Out = Self;
fn as_communicator(&self) -> &Self::Out {
self
}
}
#[derive(Copy, Clone)]
pub struct SystemGroup(MPI_Group);
impl SystemGroup {
pub fn empty() -> SystemGroup {
SystemGroup(unsafe { ffi::RSMPI_GROUP_EMPTY })
}
}
unsafe impl AsRaw for SystemGroup {
type Raw = MPI_Group;
fn as_raw(&self) -> Self::Raw {
self.0
}
}
impl Group for SystemGroup {}
pub struct UserGroup(MPI_Group);
impl Drop for UserGroup {
fn drop(&mut self) {
unsafe {
ffi::MPI_Group_free(&mut self.0);
}
assert_eq!(self.0, unsafe { ffi::RSMPI_GROUP_NULL });
}
}
unsafe impl AsRaw for UserGroup {
type Raw = MPI_Group;
fn as_raw(&self) -> Self::Raw {
self.0
}
}
impl Group for UserGroup {}
pub trait Group: AsRaw<Raw = MPI_Group> {
fn union<G>(&self, other: &G) -> UserGroup
where
G: Group,
Self: Sized,
{
unsafe {
UserGroup(
with_uninitialized(|newgroup| {
ffi::MPI_Group_union(self.as_raw(), other.as_raw(), newgroup)
})
.1,
)
}
}
fn intersection<G>(&self, other: &G) -> UserGroup
where
G: Group,
Self: Sized,
{
unsafe {
UserGroup(
with_uninitialized(|newgroup| {
ffi::MPI_Group_intersection(self.as_raw(), other.as_raw(), newgroup)
})
.1,
)
}
}
fn difference<G>(&self, other: &G) -> UserGroup
where
G: Group,
Self: Sized,
{
unsafe {
UserGroup(
with_uninitialized(|newgroup| {
ffi::MPI_Group_difference(self.as_raw(), other.as_raw(), newgroup)
})
.1,
)
}
}
fn include(&self, ranks: &[Rank]) -> UserGroup {
unsafe {
UserGroup(
with_uninitialized(|newgroup| {
ffi::MPI_Group_incl(self.as_raw(), ranks.count(), ranks.as_ptr(), newgroup)
})
.1,
)
}
}
fn exclude(&self, ranks: &[Rank]) -> UserGroup {
unsafe {
UserGroup(
with_uninitialized(|newgroup| {
ffi::MPI_Group_excl(self.as_raw(), ranks.count(), ranks.as_ptr(), newgroup)
})
.1,
)
}
}
fn size(&self) -> Rank {
unsafe { with_uninitialized(|size| ffi::MPI_Group_size(self.as_raw(), size)).1 }
}
fn rank(&self) -> Option<Rank> {
unsafe {
let (_, rank) = with_uninitialized(|rank| ffi::MPI_Group_rank(self.as_raw(), rank));
if rank == ffi::RSMPI_UNDEFINED {
None
} else {
Some(rank)
}
}
}
fn translate_rank<G>(&self, rank: Rank, other: &G) -> Option<Rank>
where
G: Group,
Self: Sized,
{
unsafe {
let (_, translated) = with_uninitialized(|translated| {
ffi::MPI_Group_translate_ranks(self.as_raw(), 1, &rank, other.as_raw(), translated)
});
if translated == ffi::RSMPI_UNDEFINED {
None
} else {
Some(translated)
}
}
}
fn translate_ranks<G>(&self, ranks: &[Rank], other: &G) -> Vec<Option<Rank>>
where
G: Group,
Self: Sized,
{
ranks
.iter()
.map(|&r| self.translate_rank(r, other))
.collect()
}
fn compare<G>(&self, other: &G) -> GroupRelation
where
G: Group,
Self: Sized,
{
unsafe {
with_uninitialized(|relation| {
ffi::MPI_Group_compare(self.as_raw(), other.as_raw(), relation)
})
.1
.into()
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum GroupRelation {
Identical,
Similar,
Unequal,
}
impl From<c_int> for GroupRelation {
fn from(i: c_int) -> GroupRelation {
if i == unsafe { ffi::RSMPI_IDENT } {
return GroupRelation::Identical;
} else if i == unsafe { ffi::RSMPI_SIMILAR } {
return GroupRelation::Similar;
} else if i == unsafe { ffi::RSMPI_UNEQUAL } {
return GroupRelation::Unequal;
}
panic!("Unknown group relation: {}", i)
}
}
unsafe fn comm_is_inter(raw_comm: MPI_Comm) -> bool {
let mut flag = c_int::min_value();
unsafe {
ffi::MPI_Comm_test_inter(raw_comm, &mut flag);
}
flag != 0
}