use std::sync::OnceLock;
use crate::ffi;
fn impl_error_classes() -> (i32, i32, i32) {
static CLASSES: OnceLock<(i32, i32, i32)> = OnceLock::new();
*CLASSES.get_or_init(|| unsafe {
(
ffi::ferrompi_err_file(),
ffi::ferrompi_err_info(),
ffi::ferrompi_err_win(),
)
})
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MpiErrorClass {
Success,
Buffer,
Count,
Type,
Tag,
Comm,
Rank,
Request,
Root,
Group,
Op,
Topology,
Dims,
Arg,
Unknown,
Truncate,
Other,
Intern,
InStatus,
Pending,
Win,
Info,
File,
Raw(i32),
}
impl MpiErrorClass {
pub fn from_raw(class: i32) -> Self {
match class {
0 => MpiErrorClass::Success,
1 => MpiErrorClass::Buffer,
2 => MpiErrorClass::Count,
3 => MpiErrorClass::Type,
4 => MpiErrorClass::Tag,
5 => MpiErrorClass::Comm,
6 => MpiErrorClass::Rank,
7 => MpiErrorClass::Request,
8 => MpiErrorClass::Root,
9 => MpiErrorClass::Group,
10 => MpiErrorClass::Op,
11 => MpiErrorClass::Topology,
12 => MpiErrorClass::Dims,
13 => MpiErrorClass::Arg,
14 => MpiErrorClass::Unknown,
15 => MpiErrorClass::Truncate,
16 => MpiErrorClass::Other,
17 => MpiErrorClass::Intern,
18 => MpiErrorClass::InStatus,
19 => MpiErrorClass::Pending,
other => {
let (err_file, err_info, err_win) = impl_error_classes();
if other == err_file {
MpiErrorClass::File
} else if other == err_info {
MpiErrorClass::Info
} else if other == err_win {
MpiErrorClass::Win
} else {
MpiErrorClass::Raw(other)
}
}
}
}
}
impl std::fmt::Display for MpiErrorClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MpiErrorClass::Success => write!(f, "SUCCESS"),
MpiErrorClass::Buffer => write!(f, "ERR_BUFFER"),
MpiErrorClass::Count => write!(f, "ERR_COUNT"),
MpiErrorClass::Type => write!(f, "ERR_TYPE"),
MpiErrorClass::Tag => write!(f, "ERR_TAG"),
MpiErrorClass::Comm => write!(f, "ERR_COMM"),
MpiErrorClass::Rank => write!(f, "ERR_RANK"),
MpiErrorClass::Request => write!(f, "ERR_REQUEST"),
MpiErrorClass::Root => write!(f, "ERR_ROOT"),
MpiErrorClass::Group => write!(f, "ERR_GROUP"),
MpiErrorClass::Op => write!(f, "ERR_OP"),
MpiErrorClass::Topology => write!(f, "ERR_TOPOLOGY"),
MpiErrorClass::Dims => write!(f, "ERR_DIMS"),
MpiErrorClass::Arg => write!(f, "ERR_ARG"),
MpiErrorClass::Unknown => write!(f, "ERR_UNKNOWN"),
MpiErrorClass::Truncate => write!(f, "ERR_TRUNCATE"),
MpiErrorClass::Other => write!(f, "ERR_OTHER"),
MpiErrorClass::Intern => write!(f, "ERR_INTERN"),
MpiErrorClass::InStatus => write!(f, "ERR_IN_STATUS"),
MpiErrorClass::Pending => write!(f, "ERR_PENDING"),
MpiErrorClass::Win => write!(f, "ERR_WIN"),
MpiErrorClass::Info => write!(f, "ERR_INFO"),
MpiErrorClass::File => write!(f, "ERR_FILE"),
MpiErrorClass::Raw(c) => write!(f, "ERR_CLASS({c})"),
}
}
}
#[derive(Debug)]
pub enum Error {
AlreadyInitialized,
Mpi {
class: MpiErrorClass,
code: i32,
message: String,
operation: Option<&'static str>,
},
InvalidBuffer,
InvalidOp,
NotSupported(String),
Internal(String),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::AlreadyInitialized => write!(f, "MPI has already been initialized"),
Error::Mpi {
class,
code,
message,
operation: Some(op),
} => write!(
f,
"MPI error in {op}: {message} (class={class}, code={code})"
),
Error::Mpi {
class,
code,
message,
operation: None,
} => write!(f, "MPI error: {message} (class={class}, code={code})"),
Error::InvalidBuffer => write!(f, "Invalid buffer"),
Error::InvalidOp => write!(f, "Invalid reduction operation for this method"),
Error::NotSupported(s) => write!(f, "Operation not supported: {s}"),
Error::Internal(s) => write!(f, "Internal error: {s}"),
}
}
}
impl std::error::Error for Error {}
impl Error {
pub fn from_code(code: i32) -> Self {
assert!(code != 0, "from_code called with success code 0");
let mut class: i32 = 0;
let mut msg_buf = [0u8; 512];
let mut msg_len: i32 = 0;
let ret = unsafe {
ffi::ferrompi_error_info(
code,
&mut class,
msg_buf.as_mut_ptr().cast::<std::ffi::c_char>(),
&mut msg_len,
)
};
if ret == 0 {
let len = msg_len.max(0) as usize;
let message = std::str::from_utf8(&msg_buf[..len])
.unwrap_or("unknown error")
.to_string();
Error::Mpi {
class: MpiErrorClass::from_raw(class),
code,
message,
operation: None,
}
} else {
Error::Mpi {
class: MpiErrorClass::Raw(code),
code,
message: format!("MPI error code {code}"),
operation: None,
}
}
}
pub fn from_code_with_op(code: i32, operation: &'static str) -> Self {
match Error::from_code(code) {
Error::Mpi {
class,
code,
message,
operation: _,
} => Error::Mpi {
class,
code,
message,
operation: Some(operation),
},
other => other,
}
}
pub fn check(code: i32) -> Result<()> {
if code == 0 {
Ok(())
} else {
Err(Error::from_code(code))
}
}
pub fn check_with_op(code: i32, operation: &'static str) -> Result<()> {
if code == 0 {
Ok(())
} else {
Err(Error::from_code_with_op(code, operation))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_success_returns_ok() {
assert!(Error::check(0).is_ok());
}
#[test]
fn error_class_from_known_values() {
assert_eq!(MpiErrorClass::from_raw(0), MpiErrorClass::Success);
assert_eq!(MpiErrorClass::from_raw(1), MpiErrorClass::Buffer);
assert_eq!(MpiErrorClass::from_raw(2), MpiErrorClass::Count);
assert_eq!(MpiErrorClass::from_raw(3), MpiErrorClass::Type);
assert_eq!(MpiErrorClass::from_raw(4), MpiErrorClass::Tag);
assert_eq!(MpiErrorClass::from_raw(5), MpiErrorClass::Comm);
assert_eq!(MpiErrorClass::from_raw(6), MpiErrorClass::Rank);
assert_eq!(MpiErrorClass::from_raw(7), MpiErrorClass::Request);
assert_eq!(MpiErrorClass::from_raw(8), MpiErrorClass::Root);
assert_eq!(MpiErrorClass::from_raw(9), MpiErrorClass::Group);
assert_eq!(MpiErrorClass::from_raw(10), MpiErrorClass::Op);
assert_eq!(MpiErrorClass::from_raw(11), MpiErrorClass::Topology);
assert_eq!(MpiErrorClass::from_raw(12), MpiErrorClass::Dims);
assert_eq!(MpiErrorClass::from_raw(13), MpiErrorClass::Arg);
assert_eq!(MpiErrorClass::from_raw(14), MpiErrorClass::Unknown);
assert_eq!(MpiErrorClass::from_raw(15), MpiErrorClass::Truncate);
assert_eq!(MpiErrorClass::from_raw(16), MpiErrorClass::Other);
assert_eq!(MpiErrorClass::from_raw(17), MpiErrorClass::Intern);
assert_eq!(MpiErrorClass::from_raw(18), MpiErrorClass::InStatus);
assert_eq!(MpiErrorClass::from_raw(19), MpiErrorClass::Pending);
}
#[test]
fn error_class_unknown_raw_value() {
assert_eq!(MpiErrorClass::from_raw(999), MpiErrorClass::Raw(999));
assert_eq!(MpiErrorClass::from_raw(-1), MpiErrorClass::Raw(-1));
}
#[test]
fn error_class_display_formats() {
assert_eq!(format!("{}", MpiErrorClass::Success), "SUCCESS");
assert_eq!(format!("{}", MpiErrorClass::Buffer), "ERR_BUFFER");
assert_eq!(format!("{}", MpiErrorClass::Comm), "ERR_COMM");
assert_eq!(format!("{}", MpiErrorClass::Rank), "ERR_RANK");
assert_eq!(format!("{}", MpiErrorClass::Raw(42)), "ERR_CLASS(42)");
}
#[test]
fn error_display_formats_correctly() {
let err = Error::InvalidBuffer;
assert_eq!(format!("{err}"), "Invalid buffer");
let err = Error::AlreadyInitialized;
assert_eq!(format!("{err}"), "MPI has already been initialized");
let err = Error::NotSupported("persistent collectives".to_string());
assert_eq!(
format!("{err}"),
"Operation not supported: persistent collectives"
);
let err = Error::Internal("test failure".to_string());
assert_eq!(format!("{err}"), "Internal error: test failure");
let err = Error::Mpi {
class: MpiErrorClass::Rank,
code: 6,
message: "invalid rank".to_string(),
operation: None,
};
assert_eq!(
format!("{err}"),
"MPI error: invalid rank (class=ERR_RANK, code=6)"
);
}
#[test]
#[allow(clippy::clone_on_copy)] fn error_class_hash_and_clone() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(MpiErrorClass::Success);
set.insert(MpiErrorClass::Buffer);
set.insert(MpiErrorClass::Raw(42));
set.insert(MpiErrorClass::Raw(42)); assert_eq!(set.len(), 3);
assert!(set.contains(&MpiErrorClass::Success));
assert!(set.contains(&MpiErrorClass::Buffer));
assert!(set.contains(&MpiErrorClass::Raw(42)));
assert!(!set.contains(&MpiErrorClass::Comm));
let original = MpiErrorClass::Comm;
let cloned = original.clone();
assert_eq!(cloned, MpiErrorClass::Comm);
assert_eq!(original, cloned);
let raw_original = MpiErrorClass::Raw(77);
let raw_cloned = raw_original.clone();
assert_eq!(raw_cloned, MpiErrorClass::Raw(77));
}
#[test]
fn error_class_display_all_variants() {
let cases = [
(MpiErrorClass::Success, "SUCCESS"),
(MpiErrorClass::Buffer, "ERR_BUFFER"),
(MpiErrorClass::Count, "ERR_COUNT"),
(MpiErrorClass::Type, "ERR_TYPE"),
(MpiErrorClass::Tag, "ERR_TAG"),
(MpiErrorClass::Comm, "ERR_COMM"),
(MpiErrorClass::Rank, "ERR_RANK"),
(MpiErrorClass::Request, "ERR_REQUEST"),
(MpiErrorClass::Root, "ERR_ROOT"),
(MpiErrorClass::Group, "ERR_GROUP"),
(MpiErrorClass::Op, "ERR_OP"),
(MpiErrorClass::Topology, "ERR_TOPOLOGY"),
(MpiErrorClass::Dims, "ERR_DIMS"),
(MpiErrorClass::Arg, "ERR_ARG"),
(MpiErrorClass::Unknown, "ERR_UNKNOWN"),
(MpiErrorClass::Truncate, "ERR_TRUNCATE"),
(MpiErrorClass::Other, "ERR_OTHER"),
(MpiErrorClass::Intern, "ERR_INTERN"),
(MpiErrorClass::InStatus, "ERR_IN_STATUS"),
(MpiErrorClass::Pending, "ERR_PENDING"),
(MpiErrorClass::Win, "ERR_WIN"),
(MpiErrorClass::Info, "ERR_INFO"),
(MpiErrorClass::File, "ERR_FILE"),
(MpiErrorClass::Raw(100), "ERR_CLASS(100)"),
];
for (class, expected) in &cases {
assert_eq!(
format!("{class}"),
*expected,
"Display mismatch for {class:?}"
);
}
}
#[test]
fn error_debug_format() {
let err = Error::InvalidBuffer;
let debug = format!("{err:?}");
assert!(
debug.contains("InvalidBuffer"),
"Debug output should contain 'InvalidBuffer', got: {debug}"
);
let mpi_err = Error::Mpi {
class: MpiErrorClass::Arg,
code: 13,
message: "invalid argument".to_string(),
operation: None,
};
let debug = format!("{mpi_err:?}");
assert!(
debug.contains("Mpi"),
"Debug output should contain 'Mpi', got: {debug}"
);
assert!(
debug.contains("Arg"),
"Debug output should contain 'Arg', got: {debug}"
);
let err = Error::AlreadyInitialized;
let debug = format!("{err:?}");
assert!(debug.contains("AlreadyInitialized"));
let err = Error::NotSupported("test op".to_string());
let debug = format!("{err:?}");
assert!(debug.contains("NotSupported"));
let err = Error::Internal("internal msg".to_string());
let debug = format!("{err:?}");
assert!(debug.contains("Internal"));
}
#[test]
fn error_mpi_fields_accessible() {
let err = Error::Mpi {
class: MpiErrorClass::Topology,
code: 11,
message: "invalid topology".to_string(),
operation: None,
};
if let Error::Mpi {
class,
code,
message,
operation,
} = &err
{
assert_eq!(*class, MpiErrorClass::Topology);
assert_eq!(*code, 11);
assert_eq!(message, "invalid topology");
assert_eq!(*operation, None);
} else {
panic!("Expected Error::Mpi variant");
}
let display = format!("{err}");
assert!(display.contains("invalid topology"));
assert!(display.contains("ERR_TOPOLOGY"));
assert!(display.contains("11"));
}
#[test]
fn error_mpi_display_with_operation_some() {
let err = Error::Mpi {
class: MpiErrorClass::Rank,
code: 6,
message: "invalid rank".to_string(),
operation: Some("allreduce"),
};
assert_eq!(
format!("{err}"),
"MPI error in allreduce: invalid rank (class=ERR_RANK, code=6)"
);
}
#[test]
fn error_mpi_display_with_operation_none() {
let err = Error::Mpi {
class: MpiErrorClass::Rank,
code: 6,
message: "invalid rank".to_string(),
operation: None,
};
assert_eq!(
format!("{err}"),
"MPI error: invalid rank (class=ERR_RANK, code=6)"
);
}
#[test]
fn from_code_with_op_sets_operation_field() {
let err = Error::Mpi {
class: MpiErrorClass::Comm,
code: 5,
message: "invalid communicator".to_string(),
operation: Some("broadcast"),
};
if let Error::Mpi { operation, .. } = &err {
assert_eq!(*operation, Some("broadcast"));
} else {
panic!("Expected Error::Mpi variant");
}
let display = format!("{err}");
assert_eq!(
display,
"MPI error in broadcast: invalid communicator (class=ERR_COMM, code=5)"
);
}
}