use crate::rdma::{RdmaConnection, RdmaMemoryPool};
use futures::future::BoxFuture;
use nexar::PeerConnection;
#[cfg(feature = "gpudirect")]
use nexar::error::NexarError;
use nexar::error::Result;
use nexar::transport::BulkTransport;
use std::sync::Arc;
pub(crate) struct RdmaStateHolder(pub Arc<RdmaState>);
pub(crate) struct RdmaState {
pub conn: tokio::sync::Mutex<RdmaConnection>,
pub pool: Arc<RdmaMemoryPool>,
}
fn get_rdma(peer: &PeerConnection) -> Result<Option<Arc<RdmaState>>> {
Ok(peer
.extension::<RdmaStateHolder>()?
.map(|holder| Arc::clone(&holder.0)))
}
pub trait PeerConnectionRdmaExt {
fn set_rdma(&self, rdma_conn: RdmaConnection, pool: Arc<RdmaMemoryPool>);
fn send_raw_rdma(&self, data: &[u8]) -> impl std::future::Future<Output = Result<()>> + Send;
}
struct RdmaBulkTransport(Arc<RdmaState>);
impl BulkTransport for RdmaBulkTransport {
fn send_bulk<'a>(&'a self, data: &'a [u8]) -> BoxFuture<'a, Result<()>> {
let rdma = Arc::clone(&self.0);
Box::pin(async move { send_via_rdma(rdma, data).await })
}
fn recv_bulk<'a>(&'a self, expected_size: usize) -> BoxFuture<'a, Result<Vec<u8>>> {
let rdma = Arc::clone(&self.0);
Box::pin(async move {
let mut pooled = rdma.pool.checkout()?;
let mut conn = rdma.conn.lock().await;
conn.recv_async(pooled.mr_mut(), 0).await?;
Ok(pooled[..expected_size].to_vec())
})
}
}
impl PeerConnectionRdmaExt for PeerConnection {
fn set_rdma(&self, rdma_conn: RdmaConnection, pool: Arc<RdmaMemoryPool>) {
let state = Arc::new(RdmaState {
conn: tokio::sync::Mutex::new(rdma_conn),
pool,
});
let _ = self.add_extension(RdmaStateHolder(Arc::clone(&state)));
let bulk: Arc<dyn BulkTransport> = Arc::new(RdmaBulkTransport(state));
let _ = self.add_extension(bulk);
}
async fn send_raw_rdma(&self, data: &[u8]) -> Result<()> {
if let Some(rdma) = get_rdma(self)? {
return send_via_rdma(rdma, data).await;
}
self.send_raw(data).await
}
}
async fn send_via_rdma(rdma: Arc<RdmaState>, data: &[u8]) -> Result<()> {
let mut pooled = rdma.pool.checkout()?;
let len = data.len();
pooled[..len].copy_from_slice(data);
let mut conn = rdma.conn.lock().await;
conn.send_async(pooled.mr_mut(), 0).await
}
#[cfg(feature = "gpudirect")]
mod gpudirect_ext {
use super::*;
use crate::gpudirect::{GpuDirectPool, GpuDirectQp, PooledGpuMr};
pub(crate) struct GpuDirectStateHolder(pub Arc<GpuDirectState>);
pub(crate) struct GpuDirectState {
pub qp: tokio::sync::Mutex<GpuDirectQp>,
pub pool: Arc<GpuDirectPool>,
}
fn get_gpudirect(peer: &PeerConnection) -> Result<Option<Arc<GpuDirectState>>> {
Ok(peer
.extension::<GpuDirectStateHolder>()?
.map(|holder| Arc::clone(&holder.0)))
}
pub trait PeerConnectionGpuDirectExt: PeerConnectionRdmaExt {
fn set_gpudirect(&self, qp: GpuDirectQp, pool: Arc<GpuDirectPool>);
fn send_raw_gpu(
&self,
gpu_ptr: u64,
size: usize,
) -> impl std::future::Future<Output = Result<()>> + Send;
fn recv_raw_gpu(
&self,
gpu_ptr: u64,
size: usize,
) -> impl std::future::Future<Output = Result<()>> + Send;
}
impl PeerConnectionGpuDirectExt for PeerConnection {
fn set_gpudirect(&self, qp: GpuDirectQp, pool: Arc<GpuDirectPool>) {
let _ = self.add_extension(GpuDirectStateHolder(Arc::new(GpuDirectState {
qp: tokio::sync::Mutex::new(qp),
pool,
})));
}
async fn send_raw_gpu(&self, gpu_ptr: u64, size: usize) -> Result<()> {
if let Some(gd) = get_gpudirect(self)? {
if let Some(pooled) = gd.pool.checkout() {
let mr_size = pooled.mr().size();
let mr_gpu_ptr = pooled.mr().gpu_ptr();
if mr_size >= size {
if mr_gpu_ptr != gpu_ptr {
unsafe {
cudarc::driver::result::memcpy_dtod_sync(mr_gpu_ptr, gpu_ptr, size)
.map_err(|e| {
NexarError::device(format!(
"GPUDirect D2D copy failed: {e}"
))
})?;
}
}
return send_via_gpudirect(Arc::clone(&gd), pooled).await;
}
let mut offset = 0usize;
while offset < size {
let chunk = std::cmp::min(mr_size, size - offset);
unsafe {
cudarc::driver::result::memcpy_dtod_sync(
mr_gpu_ptr,
gpu_ptr + offset as u64,
chunk,
)
.map_err(|e| {
NexarError::device(format!(
"GPUDirect D2D copy (chunk at offset {offset}) failed: {e}"
))
})?;
}
send_via_gpudirect_sized(Arc::clone(&gd), &pooled, chunk).await?;
offset += chunk;
}
return Ok(());
}
}
let host_data = crate::gpudirect::stage_gpu_to_host(gpu_ptr, size)?;
self.send_raw_rdma(&host_data).await
}
async fn recv_raw_gpu(&self, gpu_ptr: u64, size: usize) -> Result<()> {
if let Some(gd) = get_gpudirect(self)? {
if let Some(pooled) = gd.pool.checkout() {
let mr_size = pooled.mr().size();
let mr_gpu_ptr = pooled.mr().gpu_ptr();
if mr_size >= size {
recv_via_gpudirect(Arc::clone(&gd), pooled).await?;
if mr_gpu_ptr != gpu_ptr {
unsafe {
cudarc::driver::result::memcpy_dtod_sync(gpu_ptr, mr_gpu_ptr, size)
.map_err(|e| {
NexarError::device(format!(
"GPUDirect D2D copy failed: {e}"
))
})?;
}
}
return Ok(());
}
let mut offset = 0usize;
while offset < size {
let chunk = std::cmp::min(mr_size, size - offset);
recv_via_gpudirect_sized(Arc::clone(&gd), &pooled, chunk).await?;
unsafe {
cudarc::driver::result::memcpy_dtod_sync(
gpu_ptr + offset as u64,
mr_gpu_ptr,
chunk,
)
.map_err(|e| {
NexarError::device(format!(
"GPUDirect D2D copy (chunk at offset {offset}) failed: {e}"
))
})?;
}
offset += chunk;
}
return Ok(());
}
}
Err(NexarError::device(
"GPUDirect recv_raw_gpu: no suitable GPUDirect MR available; \
use recv_bytes() + stage_host_to_gpu() at the application layer",
))
}
}
async fn send_via_gpudirect(gd: Arc<GpuDirectState>, pooled: PooledGpuMr) -> Result<()> {
tokio::task::spawn_blocking(move || {
let qp = gd
.qp
.lock()
.map_err(|e| NexarError::device(format!("GPUDirect lock poisoned: {e}")))?;
qp.send(pooled.mr(), 0)
})
.await
.map_err(|e| NexarError::device(format!("GPUDirect spawn_blocking: {e}")))?
}
async fn send_via_gpudirect_sized(
gd: Arc<GpuDirectState>,
pooled: &PooledGpuMr,
_chunk_size: usize,
) -> Result<()> {
let mr_ptr = pooled.mr() as *const _ as usize;
tokio::task::spawn_blocking(move || {
let mr = unsafe { &*(mr_ptr as *const crate::gpudirect::GpuMr) };
let qp = gd
.qp
.lock()
.map_err(|e| NexarError::device(format!("GPUDirect lock poisoned: {e}")))?;
qp.send(mr, 0)
})
.await
.map_err(|e| NexarError::device(format!("GPUDirect spawn_blocking: {e}")))?
}
async fn recv_via_gpudirect_sized(
gd: Arc<GpuDirectState>,
pooled: &PooledGpuMr,
_chunk_size: usize,
) -> Result<()> {
let mr_ptr = pooled.mr() as *const _ as usize;
tokio::task::spawn_blocking(move || {
let mr = unsafe { &*(mr_ptr as *const crate::gpudirect::GpuMr) };
let qp = gd
.qp
.lock()
.map_err(|e| NexarError::device(format!("GPUDirect lock poisoned: {e}")))?;
qp.recv(mr, 0)
})
.await
.map_err(|e| NexarError::device(format!("GPUDirect spawn_blocking: {e}")))?
}
async fn recv_via_gpudirect(gd: Arc<GpuDirectState>, pooled: PooledGpuMr) -> Result<()> {
tokio::task::spawn_blocking(move || {
let qp = gd
.qp
.lock()
.map_err(|e| NexarError::device(format!("GPUDirect lock poisoned: {e}")))?;
qp.recv(pooled.mr(), 0)
})
.await
.map_err(|e| NexarError::device(format!("GPUDirect spawn_blocking: {e}")))?
}
}
#[cfg(feature = "gpudirect")]
pub use gpudirect_ext::PeerConnectionGpuDirectExt;