use std::mem::ManuallyDrop;
use std::net::SocketAddr;
use std::os::unix::io::RawFd;
use std::sync::Arc;
use tokio::io::unix::AsyncFd;
use crate::Result;
use crate::cm::{CmEventType, CmId, ConnParam, EventChannel, PortSpace};
use crate::cq::CompletionQueue;
use crate::device::Context;
use crate::pd::ProtectionDomain;
use crate::qp::QpInitAttr;
pub(crate) struct AsyncEventChannel {
async_fd: AsyncFd<RawFd>,
}
impl AsyncEventChannel {
pub(crate) fn new(ch: &EventChannel) -> Result<Self> {
let async_fd = AsyncFd::new(ch.fd()).map_err(crate::Error::Verbs)?;
Ok(Self { async_fd })
}
pub(crate) async fn get_event(&self, ch: &EventChannel) -> Result<crate::cm::CmEvent> {
loop {
let mut guard = self
.async_fd
.readable()
.await
.map_err(crate::Error::Verbs)?;
match ch.try_get_event() {
Ok(ev) => return Ok(ev),
Err(crate::Error::WouldBlock) => {
guard.clear_ready();
continue;
}
Err(e) => return Err(e),
}
}
}
pub(crate) async fn expect_event(
&self,
ch: &EventChannel,
expected: CmEventType,
) -> Result<()> {
let ev = self.get_event(ch).await?;
let actual = ev.event_type();
if actual != expected {
ev.ack();
return Err(crate::Error::InvalidArg(format!(
"expected {expected:?}, got {actual:?}"
)));
}
ev.ack();
Ok(())
}
#[allow(dead_code)]
pub(crate) fn poll_expect_event(
&self,
cx: &mut std::task::Context<'_>,
ch: &EventChannel,
expected: CmEventType,
) -> std::task::Poll<Result<()>> {
loop {
let mut guard = match self.async_fd.poll_read_ready(cx) {
std::task::Poll::Ready(Ok(g)) => g,
std::task::Poll::Ready(Err(e)) => {
return std::task::Poll::Ready(Err(crate::Error::Verbs(e)));
}
std::task::Poll::Pending => return std::task::Poll::Pending,
};
match ch.try_get_event() {
Ok(ev) => {
let actual = ev.event_type();
ev.ack();
if actual != expected {
return std::task::Poll::Ready(Err(crate::Error::InvalidArg(format!(
"expected {expected:?}, got {actual:?}"
))));
}
return std::task::Poll::Ready(Ok(()));
}
Err(crate::Error::WouldBlock) => {
guard.clear_ready();
continue;
}
Err(e) => return std::task::Poll::Ready(Err(e)),
}
}
}
}
pub struct AsyncCmId {
cm_id: ManuallyDrop<CmId>,
event_channel: ManuallyDrop<EventChannel>,
}
unsafe impl Send for AsyncCmId {}
impl Drop for AsyncCmId {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.cm_id);
ManuallyDrop::drop(&mut self.event_channel);
}
}
}
impl AsyncCmId {
pub fn new(port_space: PortSpace) -> Result<Self> {
let ch = EventChannel::new()?;
ch.set_nonblocking()?;
let cm_id = CmId::new(&ch, port_space)?;
Ok(Self {
cm_id: ManuallyDrop::new(cm_id),
event_channel: ManuallyDrop::new(ch),
})
}
pub async fn resolve_addr(
&self,
src: Option<&SocketAddr>,
dst: &SocketAddr,
timeout_ms: i32,
) -> Result<()> {
let async_ch = AsyncEventChannel::new(&self.event_channel)?;
self.cm_id.resolve_addr(src, dst, timeout_ms)?;
async_ch
.expect_event(&self.event_channel, CmEventType::AddrResolved)
.await
}
pub async fn resolve_route(&self, timeout_ms: i32) -> Result<()> {
let async_ch = AsyncEventChannel::new(&self.event_channel)?;
self.cm_id.resolve_route(timeout_ms)?;
async_ch
.expect_event(&self.event_channel, CmEventType::RouteResolved)
.await
}
pub async fn connect(&self, param: &ConnParam) -> Result<()> {
let async_ch = AsyncEventChannel::new(&self.event_channel)?;
self.cm_id.connect(param)?;
async_ch
.expect_event(&self.event_channel, CmEventType::Established)
.await
}
pub async fn connect_to(addr: &SocketAddr) -> Result<Self> {
let cm = Self::new(PortSpace::Tcp)?;
cm.resolve_addr(None, addr, 2000).await?;
cm.resolve_route(2000).await?;
cm.connect(&ConnParam::default()).await?;
Ok(cm)
}
pub fn cm_id(&self) -> &CmId {
&self.cm_id
}
pub fn event_channel(&self) -> &EventChannel {
&self.event_channel
}
pub fn verbs_context(&self) -> Option<Arc<Context>> {
self.cm_id.verbs_context()
}
pub fn alloc_pd(&self) -> Result<Arc<ProtectionDomain>> {
self.cm_id.alloc_pd()
}
pub fn create_qp_with_cq(
&self,
pd: &Arc<ProtectionDomain>,
init_attr: &QpInitAttr,
send_cq: Option<&Arc<CompletionQueue>>,
recv_cq: Option<&Arc<CompletionQueue>>,
) -> Result<crate::cm::CmQueuePair> {
self.cm_id
.create_qp_with_cq(pd, init_attr, send_cq, recv_cq)
}
pub fn qp_raw(&self) -> *mut rdma_io_sys::ibverbs::ibv_qp {
self.cm_id.qp_raw()
}
pub fn disconnect(&self) -> Result<()> {
self.cm_id.disconnect()
}
pub async fn disconnect_async(&self) -> Result<()> {
let async_ch = AsyncEventChannel::new(&self.event_channel)?;
self.cm_id.disconnect()?;
async_ch
.expect_event(&self.event_channel, CmEventType::Disconnected)
.await
}
pub async fn next_event(&self) -> Result<crate::cm::CmEvent> {
let async_ch = AsyncEventChannel::new(&self.event_channel)?;
async_ch.get_event(&self.event_channel).await
}
pub fn into_parts(self) -> (EventChannel, CmId) {
let mut this = ManuallyDrop::new(self);
unsafe {
let cm_id = ManuallyDrop::take(&mut this.cm_id);
let event_channel = ManuallyDrop::take(&mut this.event_channel);
(event_channel, cm_id)
}
}
}
pub struct AsyncCmListener {
_cm_id: ManuallyDrop<CmId>,
event_channel: ManuallyDrop<EventChannel>,
async_ch: AsyncEventChannel,
pending_requests: std::sync::Mutex<std::collections::VecDeque<CmId>>,
}
unsafe impl Send for AsyncCmListener {}
unsafe impl Sync for AsyncCmListener {}
impl Drop for AsyncCmListener {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self._cm_id);
ManuallyDrop::drop(&mut self.event_channel);
}
}
}
impl AsyncCmListener {
pub fn bind(addr: &SocketAddr) -> Result<Self> {
Self::bind_with_backlog(addr, 128)
}
pub fn bind_with_backlog(addr: &SocketAddr, backlog: i32) -> Result<Self> {
let ch = EventChannel::new()?;
ch.set_nonblocking()?;
let async_ch = AsyncEventChannel::new(&ch)?;
let cm_id = CmId::new(&ch, PortSpace::Tcp)?;
cm_id.listen(addr, backlog)?;
Ok(Self {
_cm_id: ManuallyDrop::new(cm_id),
event_channel: ManuallyDrop::new(ch),
async_ch,
pending_requests: std::sync::Mutex::new(std::collections::VecDeque::new()),
})
}
pub fn local_addr(&self) -> Option<std::net::SocketAddr> {
self._cm_id.local_addr()
}
pub async fn accept(&self) -> Result<AsyncCmId> {
self.accept_with_param(&ConnParam::default()).await
}
pub async fn accept_with_param(&self, param: &ConnParam) -> Result<AsyncCmId> {
let conn_id = self.get_request().await?;
conn_id.accept(param)?;
loop {
let ev = self.async_ch.get_event(&self.event_channel).await?;
let etype = ev.event_type();
match etype {
CmEventType::Established => {
ev.ack();
break;
}
CmEventType::ConnectRequest => {
let stashed_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
ev.ack();
self.pending_requests.lock().unwrap().push_back(stashed_id);
}
_ => {
ev.ack();
return Err(crate::Error::InvalidArg(format!(
"expected Established, got {etype:?}"
)));
}
}
}
let conn_ch = EventChannel::new()?;
conn_ch.set_nonblocking()?;
conn_id.migrate(&conn_ch)?;
Ok(AsyncCmId {
cm_id: ManuallyDrop::new(conn_id),
event_channel: ManuallyDrop::new(conn_ch),
})
}
pub async fn get_request(&self) -> Result<CmId> {
if let Some(conn_id) = self.pending_requests.lock().unwrap().pop_front() {
return Ok(conn_id);
}
let ev = self.async_ch.get_event(&self.event_channel).await?;
let etype = ev.event_type();
if etype != CmEventType::ConnectRequest {
ev.ack();
return Err(crate::Error::InvalidArg(format!(
"expected ConnectRequest, got {etype:?}"
)));
}
let conn_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
ev.ack();
Ok(conn_id)
}
pub fn poll_get_request(
&self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<CmId>> {
if let Some(conn_id) = self.pending_requests.lock().unwrap().pop_front() {
return std::task::Poll::Ready(Ok(conn_id));
}
loop {
let mut guard = match self.async_ch.async_fd.poll_read_ready(cx) {
std::task::Poll::Ready(Ok(g)) => g,
std::task::Poll::Ready(Err(e)) => {
return std::task::Poll::Ready(Err(crate::Error::Verbs(e)));
}
std::task::Poll::Pending => return std::task::Poll::Pending,
};
match self.event_channel.try_get_event() {
Ok(ev) => {
let etype = ev.event_type();
if etype != CmEventType::ConnectRequest {
ev.ack();
return std::task::Poll::Ready(Err(crate::Error::InvalidArg(format!(
"expected ConnectRequest, got {etype:?}"
))));
}
let conn_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
ev.ack();
return std::task::Poll::Ready(Ok(conn_id));
}
Err(crate::Error::WouldBlock) => {
guard.clear_ready();
continue;
}
Err(e) => return std::task::Poll::Ready(Err(e)),
}
}
}
pub async fn complete_accept(&self, conn_id: CmId, param: &ConnParam) -> Result<AsyncCmId> {
conn_id.accept(param)?;
loop {
let ev = self.async_ch.get_event(&self.event_channel).await?;
let etype = ev.event_type();
match etype {
CmEventType::Established => {
ev.ack();
break;
}
CmEventType::ConnectRequest => {
let stashed_id = unsafe { CmId::from_raw(ev.cm_id_raw(), true) };
ev.ack();
self.pending_requests.lock().unwrap().push_back(stashed_id);
}
_ => {
ev.ack();
return Err(crate::Error::InvalidArg(format!(
"expected Established, got {etype:?}"
)));
}
}
}
let conn_ch = EventChannel::new()?;
conn_ch.set_nonblocking()?;
conn_id.migrate(&conn_ch)?;
Ok(AsyncCmId {
cm_id: ManuallyDrop::new(conn_id),
event_channel: ManuallyDrop::new(conn_ch),
})
}
pub async fn next_event(&self) -> Result<crate::cm::CmEvent> {
self.async_ch.get_event(&self.event_channel).await
}
}