use crate::ffi;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MpiErrorClass {
Success,
Buffer,
Count,
Type,
Tag,
Comm,
Rank,
Request,
Root,
Group,
Op,
Topology,
Dims,
Arg,
Unknown,
Truncate,
Other,
Intern,
InStatus,
Pending,
Win,
Info,
File,
Raw(i32),
}
impl MpiErrorClass {
pub fn from_raw(class: i32) -> Self {
match class {
0 => MpiErrorClass::Success,
1 => MpiErrorClass::Buffer,
2 => MpiErrorClass::Count,
3 => MpiErrorClass::Type,
4 => MpiErrorClass::Tag,
5 => MpiErrorClass::Comm,
6 => MpiErrorClass::Rank,
7 => MpiErrorClass::Request,
8 => MpiErrorClass::Root,
9 => MpiErrorClass::Group,
10 => MpiErrorClass::Op,
11 => MpiErrorClass::Topology,
12 => MpiErrorClass::Dims,
13 => MpiErrorClass::Arg,
14 => MpiErrorClass::Unknown,
15 => MpiErrorClass::Truncate,
16 => MpiErrorClass::Other,
17 => MpiErrorClass::Intern,
18 => MpiErrorClass::InStatus,
19 => MpiErrorClass::Pending,
27 => MpiErrorClass::File,
28 => MpiErrorClass::Info,
45 => MpiErrorClass::Win,
other => MpiErrorClass::Raw(other),
}
}
}
impl std::fmt::Display for MpiErrorClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MpiErrorClass::Success => write!(f, "SUCCESS"),
MpiErrorClass::Buffer => write!(f, "ERR_BUFFER"),
MpiErrorClass::Count => write!(f, "ERR_COUNT"),
MpiErrorClass::Type => write!(f, "ERR_TYPE"),
MpiErrorClass::Tag => write!(f, "ERR_TAG"),
MpiErrorClass::Comm => write!(f, "ERR_COMM"),
MpiErrorClass::Rank => write!(f, "ERR_RANK"),
MpiErrorClass::Request => write!(f, "ERR_REQUEST"),
MpiErrorClass::Root => write!(f, "ERR_ROOT"),
MpiErrorClass::Group => write!(f, "ERR_GROUP"),
MpiErrorClass::Op => write!(f, "ERR_OP"),
MpiErrorClass::Topology => write!(f, "ERR_TOPOLOGY"),
MpiErrorClass::Dims => write!(f, "ERR_DIMS"),
MpiErrorClass::Arg => write!(f, "ERR_ARG"),
MpiErrorClass::Unknown => write!(f, "ERR_UNKNOWN"),
MpiErrorClass::Truncate => write!(f, "ERR_TRUNCATE"),
MpiErrorClass::Other => write!(f, "ERR_OTHER"),
MpiErrorClass::Intern => write!(f, "ERR_INTERN"),
MpiErrorClass::InStatus => write!(f, "ERR_IN_STATUS"),
MpiErrorClass::Pending => write!(f, "ERR_PENDING"),
MpiErrorClass::Win => write!(f, "ERR_WIN"),
MpiErrorClass::Info => write!(f, "ERR_INFO"),
MpiErrorClass::File => write!(f, "ERR_FILE"),
MpiErrorClass::Raw(c) => write!(f, "ERR_CLASS({c})"),
}
}
}
#[derive(Error, Debug)]
pub enum Error {
#[error("MPI has already been initialized")]
AlreadyInitialized,
#[error("MPI error: {message} (class={class}, code={code})")]
Mpi {
class: MpiErrorClass,
code: i32,
message: String,
},
#[error("Invalid buffer")]
InvalidBuffer,
#[error("Operation not supported: {0}")]
NotSupported(String),
#[error("Internal error: {0}")]
Internal(String),
}
impl Error {
pub fn from_code(code: i32) -> Self {
assert!(code != 0, "from_code called with success code 0");
let mut class: i32 = 0;
let mut msg_buf = [0u8; 512];
let mut msg_len: i32 = 0;
let ret = unsafe {
ffi::ferrompi_error_info(
code,
&mut class,
msg_buf.as_mut_ptr().cast::<std::ffi::c_char>(),
&mut msg_len,
)
};
if ret == 0 {
let len = msg_len.max(0) as usize;
let message = std::str::from_utf8(&msg_buf[..len])
.unwrap_or("unknown error")
.to_string();
Error::Mpi {
class: MpiErrorClass::from_raw(class),
code,
message,
}
} else {
Error::Mpi {
class: MpiErrorClass::Raw(code),
code,
message: format!("MPI error code {code}"),
}
}
}
pub fn check(code: i32) -> Result<()> {
if code == 0 {
Ok(())
} else {
Err(Error::from_code(code))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn check_success_returns_ok() {
assert!(Error::check(0).is_ok());
}
#[test]
fn error_class_from_known_values() {
assert_eq!(MpiErrorClass::from_raw(0), MpiErrorClass::Success);
assert_eq!(MpiErrorClass::from_raw(1), MpiErrorClass::Buffer);
assert_eq!(MpiErrorClass::from_raw(2), MpiErrorClass::Count);
assert_eq!(MpiErrorClass::from_raw(3), MpiErrorClass::Type);
assert_eq!(MpiErrorClass::from_raw(4), MpiErrorClass::Tag);
assert_eq!(MpiErrorClass::from_raw(5), MpiErrorClass::Comm);
assert_eq!(MpiErrorClass::from_raw(6), MpiErrorClass::Rank);
assert_eq!(MpiErrorClass::from_raw(7), MpiErrorClass::Request);
assert_eq!(MpiErrorClass::from_raw(8), MpiErrorClass::Root);
assert_eq!(MpiErrorClass::from_raw(9), MpiErrorClass::Group);
assert_eq!(MpiErrorClass::from_raw(10), MpiErrorClass::Op);
assert_eq!(MpiErrorClass::from_raw(11), MpiErrorClass::Topology);
assert_eq!(MpiErrorClass::from_raw(12), MpiErrorClass::Dims);
assert_eq!(MpiErrorClass::from_raw(13), MpiErrorClass::Arg);
assert_eq!(MpiErrorClass::from_raw(14), MpiErrorClass::Unknown);
assert_eq!(MpiErrorClass::from_raw(15), MpiErrorClass::Truncate);
assert_eq!(MpiErrorClass::from_raw(16), MpiErrorClass::Other);
assert_eq!(MpiErrorClass::from_raw(17), MpiErrorClass::Intern);
assert_eq!(MpiErrorClass::from_raw(18), MpiErrorClass::InStatus);
assert_eq!(MpiErrorClass::from_raw(19), MpiErrorClass::Pending);
assert_eq!(MpiErrorClass::from_raw(27), MpiErrorClass::File);
assert_eq!(MpiErrorClass::from_raw(28), MpiErrorClass::Info);
assert_eq!(MpiErrorClass::from_raw(45), MpiErrorClass::Win);
}
#[test]
fn error_class_unknown_raw_value() {
assert_eq!(MpiErrorClass::from_raw(999), MpiErrorClass::Raw(999));
assert_eq!(MpiErrorClass::from_raw(-1), MpiErrorClass::Raw(-1));
}
#[test]
fn error_class_display_formats() {
assert_eq!(format!("{}", MpiErrorClass::Success), "SUCCESS");
assert_eq!(format!("{}", MpiErrorClass::Buffer), "ERR_BUFFER");
assert_eq!(format!("{}", MpiErrorClass::Comm), "ERR_COMM");
assert_eq!(format!("{}", MpiErrorClass::Rank), "ERR_RANK");
assert_eq!(format!("{}", MpiErrorClass::Raw(42)), "ERR_CLASS(42)");
}
#[test]
fn error_display_formats_correctly() {
let err = Error::InvalidBuffer;
assert_eq!(format!("{err}"), "Invalid buffer");
let err = Error::AlreadyInitialized;
assert_eq!(format!("{err}"), "MPI has already been initialized");
let err = Error::NotSupported("persistent collectives".to_string());
assert_eq!(
format!("{err}"),
"Operation not supported: persistent collectives"
);
let err = Error::Internal("test failure".to_string());
assert_eq!(format!("{err}"), "Internal error: test failure");
let err = Error::Mpi {
class: MpiErrorClass::Rank,
code: 6,
message: "invalid rank".to_string(),
};
assert_eq!(
format!("{err}"),
"MPI error: invalid rank (class=ERR_RANK, code=6)"
);
}
#[test]
#[allow(clippy::clone_on_copy)] fn error_class_hash_and_clone() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(MpiErrorClass::Success);
set.insert(MpiErrorClass::Buffer);
set.insert(MpiErrorClass::Raw(42));
set.insert(MpiErrorClass::Raw(42)); assert_eq!(set.len(), 3);
assert!(set.contains(&MpiErrorClass::Success));
assert!(set.contains(&MpiErrorClass::Buffer));
assert!(set.contains(&MpiErrorClass::Raw(42)));
assert!(!set.contains(&MpiErrorClass::Comm));
let original = MpiErrorClass::Comm;
let cloned = original.clone();
assert_eq!(cloned, MpiErrorClass::Comm);
assert_eq!(original, cloned);
let raw_original = MpiErrorClass::Raw(77);
let raw_cloned = raw_original.clone();
assert_eq!(raw_cloned, MpiErrorClass::Raw(77));
}
#[test]
fn error_class_display_all_variants() {
let cases = [
(MpiErrorClass::Success, "SUCCESS"),
(MpiErrorClass::Buffer, "ERR_BUFFER"),
(MpiErrorClass::Count, "ERR_COUNT"),
(MpiErrorClass::Type, "ERR_TYPE"),
(MpiErrorClass::Tag, "ERR_TAG"),
(MpiErrorClass::Comm, "ERR_COMM"),
(MpiErrorClass::Rank, "ERR_RANK"),
(MpiErrorClass::Request, "ERR_REQUEST"),
(MpiErrorClass::Root, "ERR_ROOT"),
(MpiErrorClass::Group, "ERR_GROUP"),
(MpiErrorClass::Op, "ERR_OP"),
(MpiErrorClass::Topology, "ERR_TOPOLOGY"),
(MpiErrorClass::Dims, "ERR_DIMS"),
(MpiErrorClass::Arg, "ERR_ARG"),
(MpiErrorClass::Unknown, "ERR_UNKNOWN"),
(MpiErrorClass::Truncate, "ERR_TRUNCATE"),
(MpiErrorClass::Other, "ERR_OTHER"),
(MpiErrorClass::Intern, "ERR_INTERN"),
(MpiErrorClass::InStatus, "ERR_IN_STATUS"),
(MpiErrorClass::Pending, "ERR_PENDING"),
(MpiErrorClass::Win, "ERR_WIN"),
(MpiErrorClass::Info, "ERR_INFO"),
(MpiErrorClass::File, "ERR_FILE"),
(MpiErrorClass::Raw(100), "ERR_CLASS(100)"),
];
for (class, expected) in &cases {
assert_eq!(
format!("{class}"),
*expected,
"Display mismatch for {class:?}"
);
}
}
#[test]
fn error_debug_format() {
let err = Error::InvalidBuffer;
let debug = format!("{err:?}");
assert!(
debug.contains("InvalidBuffer"),
"Debug output should contain 'InvalidBuffer', got: {debug}"
);
let mpi_err = Error::Mpi {
class: MpiErrorClass::Arg,
code: 13,
message: "invalid argument".to_string(),
};
let debug = format!("{mpi_err:?}");
assert!(
debug.contains("Mpi"),
"Debug output should contain 'Mpi', got: {debug}"
);
assert!(
debug.contains("Arg"),
"Debug output should contain 'Arg', got: {debug}"
);
let err = Error::AlreadyInitialized;
let debug = format!("{err:?}");
assert!(debug.contains("AlreadyInitialized"));
let err = Error::NotSupported("test op".to_string());
let debug = format!("{err:?}");
assert!(debug.contains("NotSupported"));
let err = Error::Internal("internal msg".to_string());
let debug = format!("{err:?}");
assert!(debug.contains("Internal"));
}
#[test]
fn error_mpi_fields_accessible() {
let err = Error::Mpi {
class: MpiErrorClass::Topology,
code: 11,
message: "invalid topology".to_string(),
};
if let Error::Mpi {
class,
code,
message,
} = &err
{
assert_eq!(*class, MpiErrorClass::Topology);
assert_eq!(*code, 11);
assert_eq!(message, "invalid topology");
} else {
panic!("Expected Error::Mpi variant");
}
let display = format!("{err}");
assert!(display.contains("invalid topology"));
assert!(display.contains("ERR_TOPOLOGY"));
assert!(display.contains("11"));
}
}