use std::{
cell::Cell,
fmt,
marker::PhantomData,
mem::{self, MaybeUninit},
os::raw::c_int,
ptr,
};
use crate::{
ffi,
ffi::{MPI_Request, MPI_Status},
point_to_point::Status,
raw::traits::*,
with_uninitialized,
};
fn is_null(request: MPI_Request) -> bool {
request == unsafe { ffi::RSMPI_REQUEST_NULL }
}
#[must_use]
pub struct Request<'a, D: ?Sized, S: Scope<'a> = StaticScope> {
request: MPI_Request,
data: &'a D,
scope: S,
phantom: PhantomData<Cell<&'a ()>>,
}
impl<'a, D: ?Sized, S: Scope<'a>> fmt::Debug for Request<'a, D, S>
where
D: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter
.debug_struct("Request")
.field("request", &self.request)
.field("data", &self.data)
.finish()
}
}
unsafe impl<'a, D: ?Sized, S: Scope<'a>> AsRaw for Request<'a, D, S> {
type Raw = MPI_Request;
fn as_raw(&self) -> Self::Raw {
self.request
}
}
impl<'a, D: ?Sized, S: Scope<'a>> Drop for Request<'a, D, S> {
fn drop(&mut self) {
panic!("request was dropped without being completed");
}
}
pub fn wait_any<'a, D, S: Scope<'a>>(
requests: &mut Vec<Request<'a, D, S>>,
) -> Option<(usize, Status)> {
let mut mpi_requests: Vec<_> = requests.iter().map(|r| r.as_raw()).collect();
let mut index: i32 = mpi_sys::MPI_UNDEFINED;
let size: i32 = mpi_requests
.len()
.try_into()
.expect("Error while casting usize to i32");
let status;
unsafe {
status = Status::from_raw(
with_uninitialized(|s| {
ffi::MPI_Waitany(size, mpi_requests.as_mut_ptr(), &mut index, s);
s
})
.1,
);
}
if index != mpi_sys::MPI_UNDEFINED {
let u_index: usize = index.try_into().expect("Error while casting i32 to usize");
assert!(is_null(mpi_requests[u_index]));
let r = requests.remove(u_index);
unsafe {
r.into_raw();
}
Some((u_index, status))
} else {
None
}
}
impl<'a, D: ?Sized, S: Scope<'a>> Request<'a, D, S> {
pub unsafe fn from_raw(request: MPI_Request, data: &'a D, scope: S) -> Self {
debug_assert!(!is_null(request));
scope.register();
Self {
request,
data,
scope,
phantom: Default::default(),
}
}
pub unsafe fn into_raw(self) -> (MPI_Request, &'a D, S) {
let request = ptr::read(&self.request);
let data = ptr::read(&self.data);
let scope = ptr::read(&self.scope);
let _ = ptr::read(&self.phantom);
mem::forget(self);
scope.unregister();
(request, data, scope)
}
fn wait_with(self, status: *mut MPI_Status) -> &'a D {
unsafe {
let (mut request, data, _) = self.into_raw();
ffi::MPI_Wait(&mut request, status);
assert!(is_null(request)); data
}
}
pub fn wait_for_data(self) -> &'a D {
self.wait_with(unsafe { ffi::RSMPI_STATUS_IGNORE })
}
pub fn wait(self) -> Status {
unsafe { Status::from_raw(with_uninitialized(|status| self.wait_with(status)).1) }
}
pub fn wait_without_status(self) {
self.wait_with(unsafe { ffi::RSMPI_STATUS_IGNORE });
}
pub fn test(self) -> Result<Status, Self> {
unsafe {
let mut status = MaybeUninit::uninit();
let mut request = self.as_raw();
let (_, flag) =
with_uninitialized(|flag| ffi::MPI_Test(&mut request, flag, status.as_mut_ptr()));
if flag != 0 {
assert!(is_null(request)); let (_, _data, _) = self.into_raw();
Ok(Status::from_raw(status.assume_init()))
} else {
Err(self)
}
}
}
pub fn test_with_data(self) -> Result<(Status, &'a D), Self> {
unsafe {
let mut status = MaybeUninit::uninit();
let mut request = self.as_raw();
let (_, flag) =
with_uninitialized(|flag| ffi::MPI_Test(&mut request, flag, status.as_mut_ptr()));
if flag != 0 {
assert!(is_null(request)); let (_, data, _) = self.into_raw();
Ok((Status::from_raw(status.assume_init()), data))
} else {
Err(self)
}
}
}
pub fn cancel(&self) {
let mut request = self.as_raw();
unsafe {
ffi::MPI_Cancel(&mut request);
}
}
pub fn shrink_scope_to<'b, S2>(self, scope: S2) -> Request<'b, D, S2>
where
'a: 'b,
S2: Scope<'b>,
{
unsafe {
let (request, data, _) = self.into_raw();
Request::from_raw(request, data, scope)
}
}
}
#[derive(Debug)]
pub struct WaitGuard<'a, D: ?Sized, S: Scope<'a> = StaticScope>(Option<Request<'a, D, S>>);
impl<'a, D: ?Sized, S: Scope<'a>> Drop for WaitGuard<'a, D, S> {
fn drop(&mut self) {
self.0.take().expect("invalid WaitGuard").wait();
}
}
unsafe impl<'a, D: ?Sized, S: Scope<'a>> AsRaw for WaitGuard<'a, D, S> {
type Raw = MPI_Request;
fn as_raw(&self) -> Self::Raw {
self.0.as_ref().expect("invalid WaitGuard").as_raw()
}
}
impl<'a, D: ?Sized, S: Scope<'a>> From<WaitGuard<'a, D, S>> for Request<'a, D, S> {
fn from(mut guard: WaitGuard<'a, D, S>) -> Self {
guard.0.take().expect("invalid WaitGuard")
}
}
impl<'a, D: ?Sized, S: Scope<'a>> From<Request<'a, D, S>> for WaitGuard<'a, D, S> {
fn from(req: Request<'a, D, S>) -> Self {
WaitGuard(Some(req))
}
}
impl<'a, D: ?Sized, S: Scope<'a>> WaitGuard<'a, D, S> {
fn cancel(&self) {
if let Some(ref req) = self.0 {
req.cancel();
}
}
}
#[derive(Debug)]
pub struct CancelGuard<'a, D: ?Sized, S: Scope<'a> = StaticScope>(WaitGuard<'a, D, S>);
impl<'a, D: ?Sized, S: Scope<'a>> Drop for CancelGuard<'a, D, S> {
fn drop(&mut self) {
self.0.cancel();
}
}
impl<'a, D: ?Sized, S: Scope<'a>> From<CancelGuard<'a, D, S>> for WaitGuard<'a, D, S> {
fn from(guard: CancelGuard<'a, D, S>) -> Self {
unsafe {
let inner = ptr::read(&guard.0);
mem::forget(guard);
inner
}
}
}
impl<'a, D: ?Sized, S: Scope<'a>> From<WaitGuard<'a, D, S>> for CancelGuard<'a, D, S> {
fn from(guard: WaitGuard<'a, D, S>) -> Self {
CancelGuard(guard)
}
}
impl<'a, D: ?Sized, S: Scope<'a>> From<Request<'a, D, S>> for CancelGuard<'a, D, S> {
fn from(req: Request<'a, D, S>) -> Self {
CancelGuard(WaitGuard::from(req))
}
}
pub unsafe trait Scope<'a> {
fn register(&self);
unsafe fn unregister(&self);
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct StaticScope;
unsafe impl Scope<'static> for StaticScope {
fn register(&self) {}
unsafe fn unregister(&self) {}
}
#[derive(Debug)]
pub struct LocalScope<'a> {
num_requests: Cell<usize>,
phantom: PhantomData<Cell<&'a ()>>, }
#[cold]
fn abort_on_unhandled_request() {
let _ = std::panic::catch_unwind(|| {
panic!("at least one request was dropped without being completed");
});
std::process::abort();
}
impl<'a> Drop for LocalScope<'a> {
fn drop(&mut self) {
if self.num_requests.get() != 0 {
abort_on_unhandled_request();
}
}
}
unsafe impl<'a, 'b> Scope<'a> for &'b LocalScope<'a> {
fn register(&self) {
self.num_requests.set(self.num_requests.get() + 1)
}
unsafe fn unregister(&self) {
self.num_requests.set(
self.num_requests
.get()
.checked_sub(1)
.expect("unregister has been called more times than register"),
)
}
}
pub fn scope<'a, F, R>(f: F) -> R
where
F: FnOnce(&LocalScope<'a>) -> R,
{
f(&LocalScope {
num_requests: Default::default(),
phantom: Default::default(),
})
}
pub fn multiple_scope<'a, F, R, D>(reserve: usize, f: F) -> R
where
D: 'a + ?Sized,
F: FnOnce(&LocalScope<'a>, &mut RequestCollection<'a, D>) -> R,
{
f(
&LocalScope {
num_requests: Default::default(),
phantom: Default::default(),
},
&mut RequestCollection::new(reserve),
)
}
pub struct RequestCollection<'a, D: ?Sized> {
requests: Vec<MPI_Request>,
data: Vec<Option<&'a D>>,
statuses: Vec<MaybeUninit<MPI_Status>>,
indices: Vec<c_int>,
}
impl<'a, D: ?Sized> RequestCollection<'a, D> {
fn new(reserve: usize) -> RequestCollection<'a, D> {
let mut requests = vec![];
let mut data = vec![];
let mut statuses = vec![];
let mut indices = vec![];
requests.reserve(reserve);
data.reserve(reserve);
statuses.reserve(reserve);
indices.reserve(reserve);
RequestCollection {
requests,
data,
statuses,
indices,
}
}
pub fn incomplete(&self) -> usize {
self.data
.iter()
.map(|data| if data.is_some() { 1 } else { 0 })
.sum()
}
pub fn add<S>(&mut self, req: Request<'a, D, S>) -> usize
where
S: Scope<'a>,
{
let i = self.requests.len();
let (req, data, _) = unsafe { req.into_raw() };
self.requests.push(req);
self.data.push(Some(data));
self.statuses.push(MaybeUninit::<MPI_Status>::uninit());
self.indices.push(0);
i
}
pub fn wait_any(&mut self) -> Option<(usize, Status, &'a D)> {
let mut i: c_int = 0;
let (_res, status) = unsafe {
let count = self.requests.len() as c_int;
with_uninitialized(|status| {
ffi::MPI_Waitany(count, self.requests.as_mut_ptr(), &mut i, status)
})
};
let i: usize = i.try_into().expect("could not cast c_int to usize");
self.data[i]
.take()
.map(|data| (i, Status::from_raw(status), data))
}
pub fn wait_some(&mut self, result: &mut Vec<(usize, Status, &'a D)>) {
result.clear();
let mut count = 0;
unsafe {
let n = self.requests.len() as c_int;
ffi::MPI_Waitsome(
n,
self.requests.as_mut_ptr(),
&mut count,
self.indices.as_mut_ptr(),
self.statuses.as_mut_ptr() as *mut MPI_Status,
);
};
let count: usize = count.try_into().expect("could not cast c_int to usize");
result.reserve(count);
for i in 0..count {
let idx: usize = self.indices[i]
.try_into()
.expect("could not cast c_int to usize");
assert!(is_null(self.requests[idx]));
if let Some(data) = self.data[idx].take() {
let status = unsafe { self.statuses[idx].assume_init() };
let status = Status::from_raw(status);
result.push((idx, status, data));
}
}
}
pub fn wait_all(&mut self, result: &mut Vec<(usize, Status, &'a D)>) {
let _res = unsafe {
ffi::MPI_Waitall(
self.requests
.len()
.try_into()
.expect("could not cast usize to c_int"),
self.requests.as_mut_ptr(),
self.statuses.as_mut_ptr() as *mut MPI_Status,
)
};
result.clear();
result.reserve(self.requests.len());
for i in 0..self.requests.len() {
if let Some(data) = self.data[i].take() {
let status = unsafe { self.statuses[i].assume_init() };
let status = Status::from_raw(status);
result.push((i, status, data));
}
}
}
pub fn test_any(&mut self) -> Option<(usize, Status, &'a D)> {
let n = self.requests.len() as c_int;
let mut i = 0;
let mut flag = 0;
let (_, status) = unsafe {
with_uninitialized(|status| {
ffi::MPI_Testany(n, self.requests.as_mut_ptr(), &mut i, &mut flag, status)
})
};
if flag != 0 {
let i: usize = i.try_into().expect("could not cast c_int to usize");
assert!(is_null(self.requests[i]));
self.data[i]
.take()
.map(|data| (i, Status::from_raw(status), data))
} else {
None
}
}
pub fn test_some(&mut self, result: &mut Vec<(usize, Status, &'a D)>) {
result.clear();
let n = self.requests.len() as c_int;
let mut count = 0;
unsafe {
ffi::MPI_Testsome(
n,
self.requests.as_mut_ptr(),
&mut count,
self.indices.as_mut_ptr(),
self.statuses.as_mut_ptr() as *mut MPI_Status,
);
}
let count: usize = count.try_into().expect("could not cast c_int to usize");
result.reserve(count);
for i in 0..count {
let idx: usize = self.indices[i]
.try_into()
.expect("could not cast c_int to usize");
assert!(is_null(self.requests[idx]));
let status = unsafe { self.statuses[idx].assume_init() };
if let Some(data) = self.data[idx].take() {
result.push((idx, Status::from_raw(status), data));
}
}
}
pub fn test_all(&mut self, result: &mut Vec<(usize, Status, &'a D)>) -> bool {
let n = self.requests.len() as c_int;
let mut flag = 0;
unsafe {
ffi::MPI_Testall(
n,
self.requests.as_mut_ptr(),
&mut flag,
self.statuses.as_mut_ptr() as *mut MPI_Status,
);
}
result.clear();
result.reserve(self.requests.len());
if flag != 0 {
for i in 0..self.requests.len() {
if let Some(data) = self.data[i].take() {
let status = unsafe { self.statuses[i].assume_init() };
result.push((i, Status::from_raw(status), data));
}
}
true
} else {
false
}
}
}
impl<'a, D: ?Sized> Drop for RequestCollection<'a, D> {
fn drop(&mut self) {
if !self.data.iter().all(|c| c.is_none()) {
panic!("some requests have not completed");
}
}
}