use std::cell::Cell;
use std::convert::TryInto;
use std::marker::PhantomData;
use std::mem::{self, MaybeUninit};
use std::ptr;
use crate::ffi;
use crate::ffi::{MPI_Request, MPI_Status};
use crate::point_to_point::Status;
use crate::raw::traits::AsRaw;
use crate::with_uninitialized;
fn is_null(request: MPI_Request) -> bool {
request == unsafe { ffi::RSMPI_REQUEST_NULL }
}
#[must_use]
#[derive(Debug)]
pub struct Request<'a, S: Scope<'a> = StaticScope> {
request: MPI_Request,
scope: S,
phantom: PhantomData<Cell<&'a ()>>,
}
unsafe impl<'a, S: Scope<'a>> AsRaw for Request<'a, S> {
type Raw = MPI_Request;
fn as_raw(&self) -> Self::Raw {
self.request
}
}
impl<'a, S: Scope<'a>> Drop for Request<'a, S> {
fn drop(&mut self) {
panic!("request was dropped without being completed");
}
}
pub fn wait_any<'a, S: Scope<'a>>(requests: &mut Vec<Request<'a, S>>) -> Option<(usize, Status)> {
let mut mpi_requests: Vec<_> = requests.iter().map(AsRaw::as_raw).collect();
let mut index: i32 = crate::mpi_sys::MPI_UNDEFINED;
let size: i32 = mpi_requests
.len()
.try_into()
.expect("Error while casting usize to i32");
let status;
unsafe {
status = Status::from_raw(
with_uninitialized(|s| {
ffi::MPI_Waitany(size, mpi_requests.as_mut_ptr(), &mut index, s);
s
})
.1,
);
}
if index == crate::mpi_sys::MPI_UNDEFINED {
None
} else {
let u_index: usize = index.try_into().expect("Error while casting i32 to usize");
assert!(is_null(mpi_requests[u_index]));
let r = requests.remove(u_index);
unsafe {
r.into_raw();
}
Some((u_index, status))
}
}
impl<'a, S: Scope<'a>> Request<'a, S> {
#[allow(clippy::default_trait_access)]
pub unsafe fn from_raw(request: MPI_Request, scope: S) -> Self {
debug_assert!(!is_null(request));
scope.register();
Self {
request,
scope,
phantom: Default::default(),
}
}
pub unsafe fn into_raw(self) -> (MPI_Request, S) {
let request = ptr::read(&self.request);
let scope = ptr::read(&self.scope);
let _ = ptr::read(&self.phantom);
mem::forget(self);
scope.unregister();
(request, scope)
}
fn wait_with(self, status: *mut MPI_Status) {
unsafe {
let (mut request, _) = self.into_raw();
ffi::MPI_Wait(&mut request, status);
assert!(is_null(request)); }
}
pub fn wait(self) -> Status {
unsafe { Status::from_raw(with_uninitialized(|status| self.wait_with(status)).1) }
}
pub fn wait_without_status(self) {
self.wait_with(unsafe { ffi::RSMPI_STATUS_IGNORE as *mut _ });
}
pub fn test(self) -> Result<Status, Self> {
unsafe {
let mut status = MaybeUninit::uninit();
let mut request = self.as_raw();
let (_, flag) =
with_uninitialized(|flag| ffi::MPI_Test(&mut request, flag, status.as_mut_ptr()));
if flag == 0 {
Err(self)
} else {
assert!(is_null(request)); self.into_raw();
Ok(Status::from_raw(status.assume_init()))
}
}
}
pub fn cancel(&self) {
let mut request = self.as_raw();
unsafe {
ffi::MPI_Cancel(&mut request);
}
}
pub fn shrink_scope_to<'b, S2>(self, scope: S2) -> Request<'b, S2>
where
'a: 'b,
S2: Scope<'b>,
{
unsafe {
let (request, _) = self.into_raw();
Request::from_raw(request, scope)
}
}
}
#[derive(Debug)]
pub struct WaitGuard<'a, S: Scope<'a> = StaticScope>(Option<Request<'a, S>>);
impl<'a, S: Scope<'a>> Drop for WaitGuard<'a, S> {
fn drop(&mut self) {
self.0.take().expect("invalid WaitGuard").wait();
}
}
unsafe impl<'a, S: Scope<'a>> AsRaw for WaitGuard<'a, S> {
type Raw = MPI_Request;
fn as_raw(&self) -> Self::Raw {
self.0.as_ref().expect("invalid WaitGuard").as_raw()
}
}
impl<'a, S: Scope<'a>> From<WaitGuard<'a, S>> for Request<'a, S> {
fn from(mut guard: WaitGuard<'a, S>) -> Self {
guard.0.take().expect("invalid WaitGuard")
}
}
impl<'a, S: Scope<'a>> From<Request<'a, S>> for WaitGuard<'a, S> {
fn from(req: Request<'a, S>) -> Self {
WaitGuard(Some(req))
}
}
impl<'a, S: Scope<'a>> WaitGuard<'a, S> {
fn cancel(&self) {
if let Some(ref req) = self.0 {
req.cancel();
}
}
}
#[derive(Debug)]
pub struct CancelGuard<'a, S: Scope<'a> = StaticScope>(WaitGuard<'a, S>);
impl<'a, S: Scope<'a>> Drop for CancelGuard<'a, S> {
fn drop(&mut self) {
self.0.cancel();
}
}
impl<'a, S: Scope<'a>> From<CancelGuard<'a, S>> for WaitGuard<'a, S> {
fn from(guard: CancelGuard<'a, S>) -> Self {
unsafe {
let inner = ptr::read(&guard.0);
mem::forget(guard);
inner
}
}
}
impl<'a, S: Scope<'a>> From<WaitGuard<'a, S>> for CancelGuard<'a, S> {
fn from(guard: WaitGuard<'a, S>) -> Self {
CancelGuard(guard)
}
}
impl<'a, S: Scope<'a>> From<Request<'a, S>> for CancelGuard<'a, S> {
fn from(req: Request<'a, S>) -> Self {
CancelGuard(WaitGuard::from(req))
}
}
pub unsafe trait Scope<'a> {
fn register(&self);
unsafe fn unregister(&self);
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct StaticScope;
unsafe impl Scope<'static> for StaticScope {
fn register(&self) {}
unsafe fn unregister(&self) {}
}
#[derive(Debug)]
pub struct LocalScope<'a> {
num_requests: Cell<usize>,
phantom: PhantomData<Cell<&'a ()>>, }
#[cold]
fn abort_on_unhandled_request() {
let _droppable = std::panic::catch_unwind(|| {
panic!("at least one request was dropped without being completed");
});
std::process::abort();
}
impl<'a> Drop for LocalScope<'a> {
fn drop(&mut self) {
if self.num_requests.get() != 0 {
abort_on_unhandled_request();
}
}
}
unsafe impl<'a, 'b> Scope<'a> for &'b LocalScope<'a> {
fn register(&self) {
self.num_requests.set(self.num_requests.get() + 1)
}
unsafe fn unregister(&self) {
self.num_requests.set(
self.num_requests
.get()
.checked_sub(1)
.expect("unregister has been called more times than register"),
)
}
}
#[allow(clippy::default_trait_access)]
pub fn scope<'a, F, R>(f: F) -> R
where
F: FnOnce(&LocalScope<'a>) -> R,
{
f(&LocalScope {
num_requests: Default::default(),
phantom: Default::default(),
})
}