use crate::error::{QuantumLogError, Result};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Once;
static MPI_CHECKED: Once = Once::new();
static MPI_AVAILABLE: AtomicBool = AtomicBool::new(false);
#[allow(dead_code)]
const MPI_SUCCESS: i32 = 0;
#[allow(dead_code)]
const MPI_ERR_OTHER: i32 = 15;
#[allow(dead_code)]
type MpiInitializedFn = unsafe extern "C" fn(*mut i32) -> i32;
#[allow(dead_code)]
type MpiCommRankFn = unsafe extern "C" fn(i32, *mut i32) -> i32;
#[cfg(feature = "dynamic_mpi")]
use once_cell::sync::OnceCell;
#[cfg(feature = "dynamic_mpi")]
struct MpiDynamic {
#[allow(dead_code)]
lib: libloading::Library,
initialized: MpiInitializedFn,
comm_rank: MpiCommRankFn,
}
#[cfg(feature = "dynamic_mpi")]
static MPI_DYNAMIC: OnceCell<MpiDynamic> = OnceCell::new();
#[cfg(all(feature = "mpi_support", not(feature = "dynamic_mpi")))]
extern "C" {
fn MPI_Initialized(flag: *mut i32) -> i32;
fn MPI_Comm_rank(comm: i32, rank: *mut i32) -> i32;
}
#[allow(dead_code)]
const MPI_COMM_WORLD: i32 = 0x44000000;
pub fn is_mpi_available() -> bool {
MPI_CHECKED.call_once(|| {
let available = check_mpi_availability();
MPI_AVAILABLE.store(available, Ordering::Relaxed);
});
MPI_AVAILABLE.load(Ordering::Relaxed)
}
fn check_mpi_availability() -> bool {
#[cfg(feature = "dynamic_mpi")]
{
load_mpi_library().is_ok()
}
#[cfg(all(feature = "mpi_support", not(feature = "dynamic_mpi")))]
{
check_mpi_initialized().unwrap_or(false)
}
#[cfg(all(not(feature = "mpi_support"), not(feature = "dynamic_mpi")))]
{
false
}
}
#[cfg(feature = "dynamic_mpi")]
fn load_mpi_library() -> Result<()> {
if MPI_DYNAMIC.get().is_some() {
return Ok(());
}
let lib_names = [
"libmpi.so", "libmpi.so.12", "libmpi.so.40", "mpi.dll", "libmpi.dylib", ];
let mut last_error = None;
for lib_name in &lib_names {
match unsafe { libloading::Library::new(lib_name) } {
Ok(lib) => {
let resolved: std::result::Result<(MpiInitializedFn, MpiCommRankFn), libloading::Error> = {
let init_res: std::result::Result<libloading::Symbol<MpiInitializedFn>, libloading::Error> =
unsafe { lib.get(b"MPI_Initialized") };
let rank_res: std::result::Result<libloading::Symbol<MpiCommRankFn>, libloading::Error> =
unsafe { lib.get(b"MPI_Comm_rank") };
match (init_res, rank_res) {
(Ok(init_fn), Ok(rank_fn)) => Ok((*init_fn, *rank_fn)),
(Err(e), _) | (_, Err(e)) => Err(e),
}
};
match resolved {
Ok((init_ptr, rank_ptr)) => {
let dynamic = MpiDynamic {
lib,
initialized: init_ptr,
comm_rank: rank_ptr,
};
let _ = MPI_DYNAMIC.set(dynamic);
return Ok(());
}
Err(e) => {
last_error = Some(QuantumLogError::LibLoadingError { source: e });
}
}
}
Err(e) => {
last_error = Some(QuantumLogError::LibLoadingError { source: e });
}
}
}
Err(last_error.unwrap_or_else(|| QuantumLogError::mpi("无法加载任何 MPI 库")))
}
fn check_mpi_initialized() -> Result<bool> {
#[cfg(feature = "dynamic_mpi")]
{
if let Some(dynlib) = MPI_DYNAMIC.get() {
let mut flag: i32 = 0;
let result = unsafe { (dynlib.initialized)(&mut flag as *mut i32) };
if result == MPI_SUCCESS {
Ok(flag != 0)
} else {
Err(QuantumLogError::mpi(format!(
"MPI_Initialized 调用失败,错误码: {}",
result
)))
}
} else {
Err(QuantumLogError::mpi("MPI_Initialized 函数未加载"))
}
}
#[cfg(all(feature = "mpi_support", not(feature = "dynamic_mpi")))]
{
unsafe {
let mut flag: i32 = 0;
let result = MPI_Initialized(&mut flag as *mut i32);
if result == MPI_SUCCESS {
Ok(flag != 0)
} else {
Err(QuantumLogError::mpi(format!(
"MPI_Initialized 调用失败,错误码: {}",
result
)))
}
}
}
#[cfg(all(not(feature = "mpi_support"), not(feature = "dynamic_mpi")))]
{
Ok(false)
}
}
pub fn get_mpi_rank() -> Option<i32> {
if !is_mpi_available() {
return None;
}
match check_mpi_initialized() {
Ok(true) => get_mpi_rank_internal().ok(),
Ok(false) => None, Err(_) => None, }
}
fn get_mpi_rank_internal() -> Result<i32> {
#[cfg(feature = "dynamic_mpi")]
{
if let Some(dynlib) = MPI_DYNAMIC.get() {
let mut rank: i32 = -1;
let result = unsafe { (dynlib.comm_rank)(MPI_COMM_WORLD, &mut rank as *mut i32) };
if result == MPI_SUCCESS {
Ok(rank)
} else {
Err(QuantumLogError::mpi(format!(
"MPI_Comm_rank 调用失败,错误码: {}",
result
)))
}
} else {
Err(QuantumLogError::mpi("MPI_Comm_rank 函数未加载"))
}
}
#[cfg(all(feature = "mpi_support", not(feature = "dynamic_mpi")))]
{
unsafe {
let mut rank: i32 = -1;
let result = MPI_Comm_rank(MPI_COMM_WORLD, &mut rank as *mut i32);
if result == MPI_SUCCESS {
Ok(rank)
} else {
Err(QuantumLogError::mpi(format!(
"MPI_Comm_rank 调用失败,错误码: {}",
result
)))
}
}
}
#[cfg(all(not(feature = "mpi_support"), not(feature = "dynamic_mpi")))]
{
Err(QuantumLogError::mpi("MPI 支持未启用"))
}
}
pub fn get_mpi_rank_string() -> String {
match get_mpi_rank() {
Some(rank) => rank.to_string(),
None => "N/A".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mpi_availability_check() {
let _available = is_mpi_available();
let available1 = is_mpi_available();
let available2 = is_mpi_available();
assert_eq!(available1, available2);
}
#[test]
fn test_mpi_rank_when_unavailable() {
if !is_mpi_available() {
assert_eq!(get_mpi_rank(), None);
assert_eq!(get_mpi_rank_string(), "N/A");
}
}
#[test]
fn test_mpi_rank_string_format() {
let rank_str = get_mpi_rank_string();
assert!(rank_str == "N/A" || rank_str.parse::<i32>().is_ok());
}
#[test]
fn test_multiple_rank_calls() {
let rank1 = get_mpi_rank();
let rank2 = get_mpi_rank();
assert_eq!(rank1, rank2);
}
#[test]
fn test_rank_string_consistency() {
let rank = get_mpi_rank();
let rank_str = get_mpi_rank_string();
match rank {
Some(r) => assert_eq!(rank_str, r.to_string()),
None => assert_eq!(rank_str, "N/A"),
}
}
}