use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::task::{Context, Poll};
use rdma_io_sys::ibverbs::*;
use crate::Result;
use crate::comp_channel::CompletionChannel;
use crate::cq::CompletionQueue;
use crate::wc::WorkCompletion;
const ACK_BATCH_SIZE: u32 = 16;
pub trait CqNotifier: Send + Sync {
fn readable(&self) -> Pin<Box<dyn Future<Output = io::Result<()>> + Send + '_>>;
fn poll_readable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum CqPollState {
#[default]
Idle,
WaitingFd,
}
pub struct AsyncCq {
cq: Arc<CompletionQueue>,
channel: CompletionChannel,
notifier: Box<dyn CqNotifier>,
unacked_events: AtomicU32,
}
impl AsyncCq {
pub fn new(
cq: Arc<CompletionQueue>,
channel: CompletionChannel,
notifier: Box<dyn CqNotifier>,
) -> Self {
Self {
cq,
channel,
notifier,
unacked_events: AtomicU32::new(0),
}
}
#[cfg(feature = "tokio")]
pub fn create_tokio(ctx: Arc<crate::device::Context>, depth: i32) -> crate::Result<Self> {
let ch = CompletionChannel::new(&ctx)?;
let cq = CompletionQueue::with_comp_channel(ctx, depth, &ch)?;
let notifier =
crate::tokio_notifier::TokioCqNotifier::new(ch.fd()).map_err(crate::Error::Verbs)?;
Ok(Self::new(cq, ch, Box::new(notifier)))
}
pub async fn poll(&self, wc_buf: &mut [WorkCompletion]) -> Result<usize> {
loop {
self.cq.req_notify(false)?;
let n = self.cq.poll(wc_buf)?;
if n > 0 {
return Ok(n);
}
self.notifier
.readable()
.await
.map_err(crate::Error::Verbs)?;
self.drain_channel_events()?;
}
}
pub async fn poll_wr_id(&self, expected: u64) -> Result<WorkCompletion> {
let mut wc = [WorkCompletion::default(); 4];
loop {
let n = self.poll(&mut wc).await?;
for item in &wc[..n] {
if item.wr_id() == expected {
return Ok(*item);
}
}
}
}
pub fn cq(&self) -> &Arc<CompletionQueue> {
&self.cq
}
pub fn poll_completions(
&self,
cx: &mut Context<'_>,
state: &mut CqPollState,
wc_buf: &mut [WorkCompletion],
) -> Poll<Result<usize>> {
loop {
if *state == CqPollState::WaitingFd {
match self.notifier.poll_readable(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(crate::Error::Verbs(e))),
Poll::Ready(Ok(())) => {
self.drain_channel_events()?;
*state = CqPollState::Idle;
}
}
}
self.cq.req_notify(false)?;
let n = self.cq.poll(wc_buf)?;
if n > 0 {
return Poll::Ready(Ok(n));
}
match self.notifier.poll_readable(cx) {
Poll::Pending => {
*state = CqPollState::WaitingFd;
return Poll::Pending;
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(crate::Error::Verbs(e))),
Poll::Ready(Ok(())) => {
self.drain_channel_events()?;
}
}
}
}
fn drain_channel_events(&self) -> Result<()> {
loop {
match self.channel.get_cq_event() {
Ok(_) => self.ack_event(),
Err(crate::Error::Verbs(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(());
}
Err(e) => return Err(e),
}
}
}
fn ack_event(&self) {
let prev = self.unacked_events.fetch_add(1, Ordering::Relaxed);
if prev + 1 >= ACK_BATCH_SIZE {
let unacked = self.unacked_events.swap(0, Ordering::Relaxed);
if unacked > 0 {
unsafe {
ibv_ack_cq_events(self.cq.as_raw(), unacked);
}
}
}
}
}
impl Drop for AsyncCq {
fn drop(&mut self) {
while self.channel.get_cq_event().is_ok() {
self.unacked_events
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let unacked = self.unacked_events.load(Ordering::Relaxed);
if unacked > 0 {
unsafe {
ibv_ack_cq_events(self.cq.as_raw(), unacked);
}
}
}
}