use crate::channel::pending_work::PendingWork;
use crate::channel::{Channel, TransportError, TransportResult};
use crate::ibverbs::error::IbvResult;
use crate::ibverbs::protection_domain::ProtectionDomain;
use crate::ibverbs::work::{
ReadWorkRequest, ReceiveWorkRequest, SendWorkRequest, WorkSuccess, WriteWorkRequest,
};
use std::cell::RefCell;
use std::marker::PhantomData;
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::rc::Rc;
use thiserror::Error;
impl Channel {
pub fn scope<'env, F, T, E>(&'env mut self, f: F) -> Result<T, ScopeError<E>>
where
F: for<'scope> FnOnce(&mut PollingScope<'scope, 'env, Channel>) -> Result<T, E>,
{
PollingScope::run(self, f)
}
pub fn manual_scope<'env, F, T, E>(&'env mut self, f: F) -> Result<T, E>
where
F: for<'scope> FnOnce(&mut PollingScope<'scope, 'env, Channel>) -> Result<T, E>,
{
PollingScope::run_manual(self, f)
}
}
impl<'scope, 'env> PollingScope<'scope, 'env, Channel> {
pub fn pd(&self) -> &ProtectionDomain {
self.inner.pd()
}
}
pub type ScopeResult<T> = Result<T, ScopeError>;
#[derive(Debug, Error)]
pub enum ScopeError<E = TransportError> {
#[error("Closure error: {0}")]
ClosureError(#[from] E),
#[error("Auto poll error: {0:?}")]
AutoPollError(Vec<TransportError>),
}
impl<'a, 'b, C> PollingScope<'a, 'b, C> {
pub(crate) fn run<'env, F, T, E>(inner: &'env mut C, f: F) -> Result<T, ScopeError<E>>
where
F: for<'scope> FnOnce(&mut PollingScope<'scope, 'env, C>) -> Result<T, E>,
{
let mut scope = PollingScope::new(inner);
let scope_result = catch_unwind(AssertUnwindSafe(|| f(&mut scope)));
let auto_poll_result = scope.auto_poll();
match scope_result {
Ok(closure_result) => match closure_result {
Err(closure_error) => Err(ScopeError::ClosureError(closure_error)),
Ok(closure_output) => match auto_poll_result {
Ok(_) => Ok(closure_output),
Err(error) => Err(ScopeError::AutoPollError(error)),
},
},
Err(panic) => resume_unwind(panic),
}
}
pub(crate) fn run_manual<'env, F, T, E>(inner: &'env mut C, f: F) -> Result<T, E>
where
F: for<'scope> FnOnce(&mut PollingScope<'scope, 'env, C>) -> Result<T, E>,
{
let mut scope = PollingScope::new(inner);
let scope_result = catch_unwind(AssertUnwindSafe(|| f(&mut scope)));
let auto_poll_result = scope.auto_poll();
match scope_result {
Ok(closure_result) => {
let closure_output = closure_result?;
match auto_poll_result {
Ok(AutoPollSuccess::NoPendingWorks) => Ok(closure_output),
Ok(AutoPollSuccess::PendingWorksSucceeded) | Err(_) => {
panic!("Unpolled wrs in PollingScope::run_manual")
}
}
}
Err(panic) => resume_unwind(panic),
}
}
}
pub struct PollingScope<'scope, 'env: 'scope, C> {
pub(crate) inner: &'env mut C,
wrs: Vec<ScopedPendingWork<'scope>>,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}
impl<'scope, 'env, C> PollingScope<'scope, 'env, C> {
pub(super) fn new(inner: &'env mut C) -> Self {
PollingScope {
inner,
wrs: vec![],
scope: PhantomData,
env: PhantomData,
}
}
fn auto_poll(self) -> AutoPollResult {
let mut auto_polled = false;
let mut transport_errors = Vec::new();
for wr in self.wrs {
let mut wr = wr.inner.borrow_mut();
if !wr.user_polled_to_completion {
auto_polled = true; if let Err(transport_error) = wr.wr.spin_poll() {
transport_errors.push(transport_error);
}
}
}
if !auto_polled {
Ok(AutoPollSuccess::NoPendingWorks)
} else {
if transport_errors.is_empty() {
Ok(AutoPollSuccess::PendingWorksSucceeded)
} else {
Err(transport_errors)
}
}
}
}
type AutoPollResult = Result<AutoPollSuccess, Vec<TransportError>>;
enum AutoPollSuccess {
NoPendingWorks,
PendingWorksSucceeded,
}
impl<'scope, 'env, C> PollingScope<'scope, 'env, C> {
pub(crate) fn channel_post_send<F>(
&mut self,
channel_selector: F,
wr: SendWorkRequest<'_, 'env>,
) -> IbvResult<ScopedPendingWork<'scope>>
where
F: FnOnce(&mut C) -> IbvResult<&mut Channel>,
{
let channel = channel_selector(self.inner)?;
let wr = ScopedPendingWork::new(unsafe { channel.send_unpolled(wr)? });
self.wrs.push(wr.clone());
Ok(wr)
}
pub(crate) fn channel_post_receive<F>(
&mut self,
channel_selector: F,
wr: ReceiveWorkRequest<'_, 'env>,
) -> IbvResult<ScopedPendingWork<'scope>>
where
F: FnOnce(&mut C) -> IbvResult<&mut Channel>,
{
let channel = channel_selector(self.inner)?;
let wr = ScopedPendingWork::new(unsafe { channel.receive_unpolled(wr)? });
self.wrs.push(wr.clone());
Ok(wr)
}
pub(crate) fn channel_post_write<F>(
&mut self,
channel_selector: F,
wr: WriteWorkRequest<'_, 'env>,
) -> IbvResult<ScopedPendingWork<'scope>>
where
F: FnOnce(&mut C) -> IbvResult<&mut Channel>,
{
let channel = channel_selector(self.inner)?;
let wr = ScopedPendingWork::new(unsafe { channel.write_unpolled(wr)? });
self.wrs.push(wr.clone());
Ok(wr)
}
pub(crate) fn channel_post_read<F>(
&mut self,
channel_selector: F,
wr: ReadWorkRequest<'_, 'env>,
) -> IbvResult<ScopedPendingWork<'scope>>
where
F: FnOnce(&mut C) -> IbvResult<&mut Channel>,
{
let channel = channel_selector(self.inner)?;
let wr = ScopedPendingWork::new(unsafe { channel.read_unpolled(wr)? });
self.wrs.push(wr.clone());
Ok(wr)
}
}
#[derive(Debug, Clone)]
pub struct ScopedPendingWork<'scope> {
inner: Rc<RefCell<ScopedPendingWorkInner<'scope>>>,
env: PhantomData<&'scope mut &'scope ()>,
}
#[derive(Debug)]
struct ScopedPendingWorkInner<'scope> {
user_polled_to_completion: bool,
wr: PendingWork<'scope>,
}
impl<'scope> ScopedPendingWork<'scope> {
fn new(wr: PendingWork<'scope>) -> Self {
ScopedPendingWork {
inner: Rc::new(RefCell::new(ScopedPendingWorkInner {
user_polled_to_completion: false,
wr,
})),
env: PhantomData,
}
}
pub fn poll(&self) -> Option<TransportResult<WorkSuccess>> {
let mut wr = self.inner.borrow_mut();
let poll = wr.wr.poll()?;
wr.user_polled_to_completion = true;
Some(poll)
}
pub fn spin_poll(&self) -> TransportResult<WorkSuccess> {
let mut wr = self.inner.borrow_mut();
let poll = wr.wr.spin_poll();
wr.user_polled_to_completion = true;
poll
}
}