#![warn(missing_docs)]
#![warn(clippy::all)]
use std::ffi::{c_char, CString};
mod comm;
mod datatype;
mod datatype_builder;
pub mod doc;
mod error;
mod ffi;
mod group;
mod info;
mod op;
mod persistent;
mod request;
#[cfg(feature = "numa")]
pub mod slurm;
mod status;
mod topology;
#[cfg(feature = "rma")]
mod window;
pub use comm::{Communicator, SplitType};
#[cfg(feature = "rma")]
pub use datatype::AtomicMpiDatatype;
pub use datatype::{
BytePermutable, DatatypeTag, DoubleInt, FloatInt, Int2, LongDoubleInt, LongInt, MpiDatatype,
MpiIndexedDatatype, ShortInt,
};
pub use datatype_builder::{CustomDatatype, StructField};
pub use error::{Error, MpiErrorClass, Result};
pub use group::{Group, GroupComparison, RankRange};
pub use info::Info;
pub use op::UserOp;
pub use persistent::PersistentRequest;
pub use request::Request;
pub use status::Status;
#[cfg(feature = "numa")]
pub use topology::SlurmInfo;
pub use topology::{HostEntry, TopologyInfo};
#[cfg(feature = "rma")]
pub use window::{
LockAllGuard, LockGuard, LockType, PendingFetchResult, SharedWindow, Win, WinFenceAssert,
WinKind, WinLockAllGuard, WinLockGuard, WinPscwAssert,
};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Mutex, OnceLock};
static MPI_INITIALIZED: AtomicBool = AtomicBool::new(false);
static ATTACHED_BUFFER: Mutex<Option<Box<[u8]>>> = Mutex::new(None);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(i32)]
pub enum ThreadLevel {
Single = 0,
Funneled = 1,
Serialized = 2,
Multiple = 3,
}
#[cfg_attr(not(feature = "rma"), doc = "```compile_fail")]
#[cfg_attr(
not(feature = "rma"),
doc = "// This must not compile without --features rma."
)]
#[cfg_attr(not(feature = "rma"), doc = "let _ = ferrompi::ReduceOp::Replace;")]
#[cfg_attr(not(feature = "rma"), doc = "```")]
#[cfg_attr(feature = "rma", doc = "```no_run")]
#[cfg_attr(
feature = "rma",
doc = "// With --features rma, ReduceOp::Replace is available."
)]
#[cfg_attr(feature = "rma", doc = "let _ = ferrompi::ReduceOp::Replace;")]
#[cfg_attr(feature = "rma", doc = "```")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum ReduceOp {
Sum = 0,
Max = 1,
Min = 2,
Prod = 3,
BitwiseOr = 4,
BitwiseAnd = 5,
BitwiseXor = 6,
LogicalOr = 7,
LogicalAnd = 8,
LogicalXor = 9,
MaxLoc = 10,
MinLoc = 11,
#[cfg(feature = "rma")]
Replace = 12,
#[cfg(feature = "rma")]
NoOp = 13,
}
pub struct Mpi {
thread_level: ThreadLevel,
_marker: PhantomData<*const ()>,
}
impl Mpi {
pub fn init() -> Result<Self> {
Self::init_thread(ThreadLevel::Single)
}
pub fn init_thread(required: ThreadLevel) -> Result<Self> {
if MPI_INITIALIZED.swap(true, Ordering::SeqCst) {
return Err(Error::AlreadyInitialized);
}
let mut provided: i32 = 0;
let ret = unsafe { ffi::ferrompi_init_thread(required as i32, &mut provided) };
if ret != 0 {
MPI_INITIALIZED.store(false, Ordering::SeqCst);
return Err(Error::Mpi {
class: MpiErrorClass::Raw(ret),
code: ret,
message: format!("MPI_Init_thread failed with code {ret}"),
operation: Some("init_thread"),
});
}
let thread_level = match provided {
0 => ThreadLevel::Single,
1 => ThreadLevel::Funneled,
2 => ThreadLevel::Serialized,
_ => ThreadLevel::Multiple,
};
Ok(Mpi {
thread_level,
_marker: PhantomData,
})
}
pub fn thread_level(&self) -> ThreadLevel {
self.thread_level
}
pub fn world(&self) -> Communicator {
Communicator::world()
}
pub fn wtime() -> f64 {
unsafe { ffi::ferrompi_wtime() }
}
pub fn library_version() -> Result<String> {
let mut buf = [0u8; 8192];
let mut len: i32 = 0;
let ret = unsafe {
ffi::ferrompi_get_library_version(buf.as_mut_ptr().cast::<c_char>(), &mut len)
};
Error::check_with_op(ret, "get_library_version")?;
let len = (len.max(0) as usize).min(buf.len());
let s = std::str::from_utf8(&buf[..len])
.map_err(|_| Error::Internal("Invalid UTF-8 in library version string".into()))?;
Ok(s.trim_end().to_string())
}
pub fn version() -> Result<String> {
let mut buf = [0u8; 256];
let mut len: i32 = 0;
let ret = unsafe { ffi::ferrompi_get_version(buf.as_mut_ptr().cast::<c_char>(), &mut len) };
if ret != 0 {
return Err(Error::from_code_with_op(ret, "get_version"));
}
let len = (len.max(0) as usize).min(buf.len());
let s = std::str::from_utf8(&buf[..len])
.map_err(|_| Error::Internal("Invalid UTF-8 in version string".into()))?;
Ok(s.to_string())
}
pub fn is_initialized() -> bool {
let mut flag: i32 = 0;
unsafe { ffi::ferrompi_initialized(&mut flag) };
flag != 0
}
pub fn is_finalized() -> bool {
let mut flag: i32 = 0;
unsafe { ffi::ferrompi_finalized(&mut flag) };
flag != 0
}
fn supports_create_from_group() -> bool {
static SUPPORTED: OnceLock<bool> = OnceLock::new();
if let Some(&cached) = SUPPORTED.get() {
return cached;
}
let Ok(v) = Mpi::version() else {
return false;
};
let major: u32 = v
.split_whitespace()
.nth(1)
.and_then(|tok| tok.split('.').next())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let supported = major >= 4;
let _ = SUPPORTED.set(supported);
supported
}
pub fn create_from_group(&self, group: &group::Group, stringtag: &str) -> Result<Communicator> {
let c_tag = CString::new(stringtag)
.map_err(|_| Error::Internal("stringtag contains null byte".into()))?;
if !Self::supports_create_from_group() {
return Err(Error::NotSupported(
"MPI_Comm_create_from_group".to_string(),
));
}
let mut new_handle: i32 = -1;
let ret = unsafe {
ffi::ferrompi_comm_create_from_group(group.handle, c_tag.as_ptr(), &mut new_handle)
};
Error::check_with_op(ret, "comm_create_from_group")?;
Communicator::from_handle(new_handle)
}
pub fn buffer_attach(&self, buffer: Box<[u8]>) -> Result<()> {
let mut guard = ATTACHED_BUFFER
.lock()
.map_err(|_| Error::Internal("ATTACHED_BUFFER mutex poisoned".into()))?;
if guard.is_some() {
return Err(Error::InvalidOp);
}
if buffer.len() > i32::MAX as usize {
return Err(Error::InvalidBuffer);
}
let ptr = buffer.as_ptr() as *mut std::ffi::c_void;
let size = buffer.len() as i64;
*guard = Some(buffer);
let ret = unsafe { ffi::ferrompi_buffer_attach(ptr, size) };
if ret != 0 {
guard.take();
return Err(Error::from_code_with_op(ret, "buffer_attach"));
}
Ok(())
}
pub fn buffer_detach(&self) -> Result<Box<[u8]>> {
let mut guard = ATTACHED_BUFFER
.lock()
.map_err(|_| Error::Internal("ATTACHED_BUFFER mutex poisoned".into()))?;
if guard.is_none() {
return Err(Error::InvalidOp);
}
let mut out_ptr: *mut std::ffi::c_void = std::ptr::null_mut();
let mut out_size: i64 = 0;
let ret = unsafe { ffi::ferrompi_buffer_detach(&mut out_ptr, &mut out_size) };
if ret != 0 {
return Err(Error::from_code_with_op(ret, "buffer_detach"));
}
let buf = guard.take().expect("guard was Some; take() must succeed");
Ok(buf)
}
}
impl Drop for Mpi {
fn drop(&mut self) {
if MPI_INITIALIZED.load(Ordering::SeqCst) {
unsafe {
ffi::ferrompi_finalize();
}
MPI_INITIALIZED.store(false, Ordering::SeqCst);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn thread_level_ordering() {
assert!(ThreadLevel::Single < ThreadLevel::Funneled);
assert!(ThreadLevel::Funneled < ThreadLevel::Serialized);
assert!(ThreadLevel::Serialized < ThreadLevel::Multiple);
}
#[test]
fn thread_level_equality() {
assert_eq!(ThreadLevel::Single, ThreadLevel::Single);
assert_eq!(ThreadLevel::Funneled, ThreadLevel::Funneled);
assert_eq!(ThreadLevel::Serialized, ThreadLevel::Serialized);
assert_eq!(ThreadLevel::Multiple, ThreadLevel::Multiple);
assert_ne!(ThreadLevel::Single, ThreadLevel::Multiple);
assert_ne!(ThreadLevel::Funneled, ThreadLevel::Serialized);
}
#[test]
fn thread_level_repr_values() {
assert_eq!(ThreadLevel::Single as i32, 0);
assert_eq!(ThreadLevel::Funneled as i32, 1);
assert_eq!(ThreadLevel::Serialized as i32, 2);
assert_eq!(ThreadLevel::Multiple as i32, 3);
}
#[test]
fn thread_level_debug_clone() {
let level = ThreadLevel::Funneled;
let cloned = level;
assert_eq!(format!("{cloned:?}"), "Funneled");
assert_eq!(format!("{:?}", ThreadLevel::Single), "Single");
assert_eq!(format!("{:?}", ThreadLevel::Serialized), "Serialized");
assert_eq!(format!("{:?}", ThreadLevel::Multiple), "Multiple");
}
#[test]
fn reduce_op_repr_values() {
let ops = [
(ReduceOp::Sum, 0),
(ReduceOp::Max, 1),
(ReduceOp::Min, 2),
(ReduceOp::Prod, 3),
(ReduceOp::BitwiseOr, 4),
(ReduceOp::BitwiseAnd, 5),
(ReduceOp::BitwiseXor, 6),
(ReduceOp::LogicalOr, 7),
(ReduceOp::LogicalAnd, 8),
(ReduceOp::LogicalXor, 9),
(ReduceOp::MaxLoc, 10),
(ReduceOp::MinLoc, 11),
];
for (op, expected) in ops {
assert_eq!(op as i32, expected);
}
#[cfg(feature = "rma")]
{
assert_eq!(ReduceOp::Replace as i32, 12);
assert_eq!(ReduceOp::NoOp as i32, 13);
}
}
#[test]
fn reduce_op_equality() {
assert_eq!(ReduceOp::Sum, ReduceOp::Sum);
assert_eq!(ReduceOp::Max, ReduceOp::Max);
assert_eq!(ReduceOp::Min, ReduceOp::Min);
assert_eq!(ReduceOp::Prod, ReduceOp::Prod);
assert_eq!(ReduceOp::BitwiseOr, ReduceOp::BitwiseOr);
assert_eq!(ReduceOp::BitwiseAnd, ReduceOp::BitwiseAnd);
assert_eq!(ReduceOp::BitwiseXor, ReduceOp::BitwiseXor);
assert_eq!(ReduceOp::LogicalOr, ReduceOp::LogicalOr);
assert_eq!(ReduceOp::LogicalAnd, ReduceOp::LogicalAnd);
assert_eq!(ReduceOp::LogicalXor, ReduceOp::LogicalXor);
assert_eq!(ReduceOp::MaxLoc, ReduceOp::MaxLoc);
assert_eq!(ReduceOp::MinLoc, ReduceOp::MinLoc);
assert_ne!(ReduceOp::Sum, ReduceOp::Max);
assert_ne!(ReduceOp::Min, ReduceOp::Prod);
assert_ne!(ReduceOp::Sum, ReduceOp::Prod);
assert_ne!(ReduceOp::BitwiseOr, ReduceOp::BitwiseAnd);
assert_ne!(ReduceOp::LogicalOr, ReduceOp::LogicalAnd);
assert_ne!(ReduceOp::Sum, ReduceOp::BitwiseOr);
assert_ne!(ReduceOp::MaxLoc, ReduceOp::MinLoc);
assert_ne!(ReduceOp::MaxLoc, ReduceOp::Max);
}
#[test]
fn reduce_op_debug_clone() {
let op = ReduceOp::Sum;
let cloned = op;
assert_eq!(format!("{cloned:?}"), "Sum");
assert_eq!(format!("{:?}", ReduceOp::Max), "Max");
assert_eq!(format!("{:?}", ReduceOp::Min), "Min");
assert_eq!(format!("{:?}", ReduceOp::Prod), "Prod");
assert_eq!(format!("{:?}", ReduceOp::BitwiseOr), "BitwiseOr");
assert_eq!(format!("{:?}", ReduceOp::BitwiseAnd), "BitwiseAnd");
assert_eq!(format!("{:?}", ReduceOp::BitwiseXor), "BitwiseXor");
assert_eq!(format!("{:?}", ReduceOp::LogicalOr), "LogicalOr");
assert_eq!(format!("{:?}", ReduceOp::LogicalAnd), "LogicalAnd");
assert_eq!(format!("{:?}", ReduceOp::LogicalXor), "LogicalXor");
assert_eq!(format!("{:?}", ReduceOp::MaxLoc), "MaxLoc");
assert_eq!(format!("{:?}", ReduceOp::MinLoc), "MinLoc");
}
#[test]
fn reduce_op_all_variants_match_c_switch() {
let variants = [
(ReduceOp::Sum, 0i32),
(ReduceOp::Max, 1),
(ReduceOp::Min, 2),
(ReduceOp::Prod, 3),
(ReduceOp::BitwiseOr, 4),
(ReduceOp::BitwiseAnd, 5),
(ReduceOp::BitwiseXor, 6),
(ReduceOp::LogicalOr, 7),
(ReduceOp::LogicalAnd, 8),
(ReduceOp::LogicalXor, 9),
(ReduceOp::MaxLoc, 10),
(ReduceOp::MinLoc, 11),
];
for (op, expected) in variants {
assert_eq!(op as i32, expected);
}
#[cfg(feature = "rma")]
{
assert_eq!(ReduceOp::Replace as i32, 12);
assert_eq!(ReduceOp::NoOp as i32, 13);
}
}
#[cfg(feature = "rma")]
#[test]
fn replace_noop_discriminants() {
assert_eq!(ReduceOp::Replace as i32, 12);
assert_eq!(ReduceOp::NoOp as i32, 13);
}
#[test]
fn buffer_attach_invalid_buffer_variant_exists() {
let err = Error::InvalidBuffer;
assert!(matches!(err, Error::InvalidBuffer));
}
#[test]
fn buffer_attach_signature_compiles() {
fn _check(mpi: &Mpi, buf: Box<[u8]>) -> Result<()> {
mpi.buffer_attach(buf)
}
fn _check_detach(mpi: &Mpi) -> Result<Box<[u8]>> {
mpi.buffer_detach()
}
}
#[test]
fn buffer_attach_double_attach_returns_invalid_op() {
{
let mut g = ATTACHED_BUFFER.lock().unwrap();
if g.is_none() {
*g = Some(vec![0u8; 4].into_boxed_slice());
}
}
let mpi = Mpi {
thread_level: ThreadLevel::Single,
_marker: PhantomData,
};
let buf2 = vec![0u8; 8].into_boxed_slice();
let result = mpi.buffer_attach(buf2);
assert!(
matches!(result, Err(Error::InvalidOp)),
"expected Err(InvalidOp) on double attach, got: {result:?}"
);
ATTACHED_BUFFER.lock().unwrap().take();
}
#[test]
fn buffer_detach_without_attach_returns_invalid_op() {
{
let mut g = ATTACHED_BUFFER.lock().unwrap();
*g = None;
}
let mpi = Mpi {
thread_level: ThreadLevel::Single,
_marker: PhantomData,
};
let result = mpi.buffer_detach();
assert!(
matches!(result, Err(Error::InvalidOp)),
"expected Err(InvalidOp) on detach without attach, got: {result:?}"
);
}
#[test]
fn supports_create_from_group_does_not_cache_probe_failures() {
let _: fn() -> bool = Mpi::supports_create_from_group;
}
#[test]
fn create_from_group_null_byte_in_tag() {
let mpi = Mpi {
thread_level: ThreadLevel::Single,
_marker: PhantomData,
};
let g = group::Group { handle: 0 };
let result = mpi.create_from_group(&g, "bad\0tag");
match result {
Err(Error::Internal(msg)) => {
assert!(
msg.contains("null byte"),
"expected 'null byte' in error message, got: {msg}"
);
}
Ok(_) => panic!("expected Err(Error::Internal(_)), got Ok(_)"),
Err(e) => panic!("expected Err(Error::Internal(_)), got Err({e})"),
}
}
}