use crate::error::{LinalgError, LinalgResult};
use super::{MPIConfig, MPIDatatype, MPIReduceOp};
use std::collections::HashMap;
use std::ffi::{c_int, c_void};
use std::sync::{Arc, Mutex, RwLock};
#[derive(Debug)]
pub struct MPICommunicator {
comm_handle: MPICommHandle,
rank: i32,
size: i32,
derived_types: HashMap<String, super::datatypes::MPIDatatype>,
persistent_requests: HashMap<String, MPIPersistentRequest>,
active_operations: Arc<RwLock<HashMap<String, MPIRequest>>>,
comm_stats: Arc<Mutex<MPICommStats>>,
}
#[derive(Debug)]
pub struct MPICommHandle {
handle: *mut c_void,
}
unsafe impl Send for MPICommHandle {}
unsafe impl Sync for MPICommHandle {}
#[derive(Debug)]
pub struct MPIPersistentRequest {
request_handle: *mut c_void,
operation_type: PersistentOperationType,
buffer_info: BufferInfo,
is_active: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum PersistentOperationType {
Send,
Recv,
Bcast,
Allreduce,
Allgather,
Scatter,
Gather,
}
#[derive(Debug, Clone)]
pub struct BufferInfo {
buffer_ptr: *mut c_void,
buffersize: usize,
element_count: usize,
datatype: String,
}
#[derive(Debug)]
pub struct MPIRequest {
request_handle: *mut c_void,
operation_id: String,
start_time: std::time::Instant,
expected_bytes: usize,
operation_type: RequestOperationType,
}
#[derive(Debug, Clone, Copy)]
pub enum RequestOperationType {
PointToPoint,
Collective,
RMA,
IO,
}
#[derive(Debug, Default, Clone)]
pub struct MPICommStats {
pub messages_sent: usize,
pub messages_received: usize,
pub bytes_sent: usize,
pub bytes_received: usize,
pub avg_latency: f64,
pub peak_bandwidth: f64,
pub efficiency: f64,
pub error_count: usize,
}
#[derive(Debug, Default, Clone)]
pub struct MPIStatus {
pub source: i32,
pub tag: i32,
pub error: i32,
pub count: usize,
}
extern "C" {
fn mpi_init(argc: *mut c_int, argv: *mut *mut *mut i8) -> c_int;
fn mpi_initialized(flag: *mut c_int) -> c_int;
fn mpi_comm_world() -> *mut c_void;
fn mpi_comm_rank(comm: *mut c_void) -> c_int;
fn mpi_commsize(comm: *mut c_void) -> c_int;
fn mpi_isend(buf: *const c_void, count: usize, datatype: c_int, dest: c_int, tag: c_int, comm: *mut c_void) -> *mut c_void;
fn mpi_irecv(buf: *mut c_void, count: usize, datatype: c_int, source: c_int, tag: c_int, comm: *mut c_void) -> *mut c_void;
fn mpi_wait(request: *mut c_void, status: *mut c_void) -> c_int;
fn mpi_bcast(buffer: *mut c_void, count: usize, datatype: c_int, root: c_int, comm: *mut c_void) -> c_int;
fn mpi_allreduce(sendbuf: *const c_void, recvbuf: *mut c_void, count: usize, datatype: c_int, op: c_int, comm: *mut c_void) -> c_int;
fn mpi_gather(sendbuf: *const c_void, sendcount: usize, sendtype: c_int, recvbuf: *mut c_void, recvcount: usize, recvtype: c_int, root: c_int, comm: *mut c_void) -> c_int;
fn mpi_scatter(sendbuf: *const c_void, sendcount: usize, sendtype: c_int, recvbuf: *mut c_void, recvcount: usize, recvtype: c_int, root: c_int, comm: *mut c_void) -> c_int;
fn mpi_barrier(comm: *mut c_void) -> c_int;
fn mpi_finalize() -> c_int;
}
impl MPICommunicator {
pub fn new(config: &MPIConfig) -> LinalgResult<Self> {
unsafe {
let mut flag: c_int = 0;
if mpi_initialized(&mut flag) != 0 || flag == 0 {
let mut argc = 0;
let argv: *mut *mut i8 = std::ptr::null_mut();
if mpi_init(&mut argc, &mut argv) != 0 {
return Err(LinalgError::InitializationError(
"Failed to initialize MPI".to_string()
));
}
}
}
let comm_handle = MPICommHandle {
handle: unsafe { mpi_comm_world() },
};
let rank = unsafe { mpi_comm_rank(comm_handle.handle) };
let size = unsafe { mpi_commsize(comm_handle.handle) };
if rank < 0 || size <= 0 {
return Err(LinalgError::InitializationError(
"Invalid MPI rank or size".to_string()
));
}
Ok(Self {
comm_handle,
rank,
size,
derived_types: HashMap::new(),
persistent_requests: HashMap::new(),
active_operations: Arc::new(RwLock::new(HashMap::new())),
comm_stats: Arc::new(Mutex::new(MPICommStats::default())),
})
}
pub fn rank(&self) -> i32 {
self.rank
}
pub fn size(&self) -> i32 {
self.size
}
pub fn handle(&self) -> &MPICommHandle {
&self.comm_handle
}
pub fn barrier(&self) -> LinalgResult<()> {
unsafe {
let result = mpi_barrier(self.comm_handle.handle);
if result != 0 {
return Err(LinalgError::CommunicationError(
format!("MPI barrier failed with code {}", result)
));
}
}
Ok(())
}
pub fn get_stats(&self) -> MPICommStats {
self.comm_stats.lock().expect("Operation failed").clone()
}
pub fn reset_stats(&self) {
let mut stats = self.comm_stats.lock().expect("Operation failed");
*stats = MPICommStats::default();
}
pub fn active_operations_count(&self) -> usize {
self.active_operations.read().expect("Operation failed").len()
}
pub fn has_active_operations(&self) -> bool {
!self.active_operations.read().expect("Operation failed").is_empty()
}
pub fn register_datatype(&mut self, name: String, datatype: super::datatypes::MPIDatatype) {
self.derived_types.insert(name, datatype);
}
pub fn get_datatype(&self, name: &str) -> Option<&super::datatypes::MPIDatatype> {
self.derived_types.get(name)
}
pub fn add_persistent_request(&mut self, name: String, request: MPIPersistentRequest) {
self.persistent_requests.insert(name, request);
}
pub fn get_persistent_request(&self, name: &str) -> Option<&MPIPersistentRequest> {
self.persistent_requests.get(name)
}
fn update_stats(&self, operation_type: RequestOperationType, bytes: usize, elapsed: f64) {
let mut stats = self.comm_stats.lock().expect("Operation failed");
match operation_type {
RequestOperationType::PointToPoint => {
stats.messages_sent += 1;
stats.bytes_sent += bytes;
},
RequestOperationType::Collective => {
stats.messages_sent += 1;
stats.bytes_sent += bytes;
},
_ => {}
}
let total_messages = stats.messages_sent + stats.messages_received;
if total_messages > 0 {
stats.avg_latency = (stats.avg_latency * (total_messages - 1) as f64 + elapsed) / total_messages as f64;
} else {
stats.avg_latency = elapsed;
}
let bandwidth = bytes as f64 / elapsed;
if bandwidth > stats.peak_bandwidth {
stats.peak_bandwidth = bandwidth;
}
}
}
impl MPICommHandle {
pub fn raw_handle(&self) -> *mut c_void {
self.handle
}
pub fn is_valid(&self) -> bool {
!self.handle.is_null()
}
}
impl MPIPersistentRequest {
pub fn new(
request_handle: *mut c_void,
operation_type: PersistentOperationType,
buffer_info: BufferInfo,
) -> Self {
Self {
request_handle,
operation_type,
buffer_info,
is_active: false,
}
}
pub fn is_active(&self) -> bool {
self.is_active
}
pub fn activate(&mut self) {
self.is_active = true;
}
pub fn deactivate(&mut self) {
self.is_active = false;
}
pub fn operation_type(&self) -> PersistentOperationType {
self.operation_type
}
pub fn buffer_info(&self) -> &BufferInfo {
&self.buffer_info
}
}
impl BufferInfo {
pub fn new(buffer_ptr: *mut c_void, buffersize: usize, element_count: usize, datatype: String) -> Self {
Self {
buffer_ptr,
buffersize,
element_count,
datatype,
}
}
pub fn buffersize(&self) -> usize {
self.buffersize
}
pub fn element_count(&self) -> usize {
self.element_count
}
pub fn datatype(&self) -> &str {
&self.datatype
}
pub fn buffer_ptr(&self) -> *mut c_void {
self.buffer_ptr
}
}
impl MPIRequest {
pub fn new(
request_handle: *mut c_void,
operation_id: String,
expected_bytes: usize,
operation_type: RequestOperationType,
) -> Self {
Self {
request_handle,
operation_id,
start_time: std::time::Instant::now(),
expected_bytes,
operation_type,
}
}
pub fn operation_id(&self) -> &str {
&self.operation_id
}
pub fn start_time(&self) -> std::time::Instant {
self.start_time
}
pub fn expected_bytes(&self) -> usize {
self.expected_bytes
}
pub fn operation_type(&self) -> RequestOperationType {
self.operation_type
}
pub fn request_handle(&self) -> *mut c_void {
self.request_handle
}
}
unsafe impl Send for MPIPersistentRequest {}
unsafe impl Sync for MPIPersistentRequest {}
unsafe impl Send for MPIRequest {}
unsafe impl Sync for MPIRequest {}
unsafe impl Send for BufferInfo {}
unsafe impl Sync for BufferInfo {}