use crate::ibverbs::access_config::AccessFlags;
use crate::ibverbs::error::{IbvError, IbvResult};
use crate::ibverbs::memory::{
GatherElement, RemoteMemoryRegion, ScatterElement, ScatterGatherElementError,
};
use crate::ibverbs::protection_domain::ProtectionDomain;
use ibverbs_sys::*;
use std::ffi::c_void;
use std::io;
#[doc(alias = "ibv_mr")]
#[doc(alias = "ibv_reg_mr")]
pub struct MemoryRegion {
pd: ProtectionDomain,
mr: *mut ibv_mr,
}
unsafe impl Sync for MemoryRegion {}
unsafe impl Send for MemoryRegion {}
impl Drop for MemoryRegion {
fn drop(&mut self) {
log::debug!("MemoryRegion deregistered");
let errno = unsafe { ibv_dereg_mr(self.mr) };
if errno != 0 {
let error = IbvError::from_errno_with_msg(errno, "Failed to deregister memory region");
log::error!("{error}");
}
}
}
impl std::fmt::Debug for MemoryRegion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryRegion")
.field("address", &(unsafe { (*self.mr).addr }))
.field("length", &(unsafe { (*self.mr).length }))
.field("handle", &(unsafe { (*self.mr).handle }))
.field("lkey", &(unsafe { (*self.mr).lkey }))
.field("rkey", &(unsafe { (*self.mr).rkey }))
.field("pd", &self.pd)
.finish()
}
}
impl MemoryRegion {
pub unsafe fn register_mr_with_access(
pd: &ProtectionDomain,
address: *mut u8,
length: usize,
access_flags: AccessFlags,
) -> IbvResult<MemoryRegion> {
#[allow(clippy::cast_possible_wrap)]
let mr = unsafe {
ibv_reg_mr(
pd.inner.pd,
address as *mut c_void,
length,
access_flags.code() as i32,
)
};
if mr.is_null() {
Err(IbvError::from_errno_with_msg(
io::Error::last_os_error()
.raw_os_error()
.expect("ibv_reg_mr should set errno on error"),
"Failed to register memory region",
))
} else {
log::debug!("MemoryRegion registered");
Ok(MemoryRegion { pd: pd.clone(), mr })
}
}
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn register_local_mr(
pd: &ProtectionDomain,
address: *mut u8,
length: usize,
) -> IbvResult<MemoryRegion> {
unsafe {
Self::register_mr_with_access(
pd,
address,
length,
AccessFlags::new().with_local_write(),
)
}
}
pub unsafe fn register_shared_mr(
pd: &ProtectionDomain,
address: *mut u8,
length: usize,
) -> IbvResult<MemoryRegion> {
unsafe {
Self::register_mr_with_access(
pd,
address,
length,
AccessFlags::new()
.with_local_write()
.with_remote_read()
.with_remote_write(),
)
}
}
pub unsafe fn register_dmabuf_mr_with_access(
pd: &ProtectionDomain,
fd: i32,
offset: u64,
length: usize,
iova: u64,
access_flags: AccessFlags,
) -> IbvResult<MemoryRegion> {
#[allow(clippy::cast_possible_wrap)]
let mr = unsafe {
ibv_reg_dmabuf_mr(
pd.inner.pd,
offset,
length,
iova,
fd,
access_flags.code() as i32,
)
};
if mr.is_null() {
Err(IbvError::from_errno_with_msg(
io::Error::last_os_error()
.raw_os_error()
.expect("ibv_reg_dmabuf_mr should set errno on error"),
"Failed to register memory region",
))
} else {
log::debug!("IbvMemoryRegion registered");
Ok(MemoryRegion { pd: pd.clone(), mr })
}
}
pub fn register_local_dmabuf_mr(
pd: &ProtectionDomain,
fd: i32,
offset: u64,
length: usize,
iova: u64,
) -> IbvResult<MemoryRegion> {
unsafe {
Self::register_dmabuf_mr_with_access(
pd,
fd,
offset,
length,
iova,
AccessFlags::new().with_local_write(),
)
}
}
pub unsafe fn register_shared_dmabuf_mr(
pd: &ProtectionDomain,
fd: i32,
offset: u64,
length: usize,
iova: u64,
) -> IbvResult<MemoryRegion> {
unsafe {
Self::register_dmabuf_mr_with_access(
pd,
fd,
offset,
length,
iova,
AccessFlags::new()
.with_local_write()
.with_remote_read()
.with_remote_write(),
)
}
}
}
impl MemoryRegion {
pub fn rkey(&self) -> u32 {
unsafe { *self.mr }.rkey
}
pub fn address(&self) -> usize {
unsafe { (*self.mr).addr as usize }
}
pub fn length(&self) -> usize {
unsafe { (*self.mr).length }
}
pub fn lkey(&self) -> u32 {
unsafe { *self.mr }.lkey
}
pub fn remote(&self) -> RemoteMemoryRegion {
RemoteMemoryRegion::new(self.address() as u64, self.length(), self.rkey())
}
}
impl MemoryRegion {
pub fn gather_element<'a>(&'a self, data: &'a [u8]) -> GatherElement<'a> {
GatherElement::new(self, data)
}
pub fn gather_element_checked<'a>(
&'a self,
data: &'a [u8],
) -> Result<GatherElement<'a>, ScatterGatherElementError> {
GatherElement::new_checked(self, data)
}
pub fn gather_element_unchecked<'a>(&'a self, data: &'a [u8]) -> GatherElement<'a> {
GatherElement::new_unchecked(self, data)
}
pub fn scatter_element<'a>(&'a self, data: &'a mut [u8]) -> ScatterElement<'a> {
ScatterElement::new(self, data)
}
pub fn scatter_element_checked<'a>(
&'a self,
data: &'a mut [u8],
) -> Result<ScatterElement<'a>, ScatterGatherElementError> {
ScatterElement::new_checked(self, data)
}
pub fn scatter_element_unchecked<'a>(&'a self, data: &'a mut [u8]) -> ScatterElement<'a> {
ScatterElement::new_unchecked(self, data)
}
pub fn encloses(&self, address: *const u8, length: usize) -> bool {
let mr_start = self.address();
let data_start = address as usize;
if data_start < mr_start {
return false;
}
let offset = data_start - mr_start;
let remaining = self.length().saturating_sub(offset);
length <= remaining
}
pub fn encloses_slice(&self, slice: &[u8]) -> bool {
self.encloses(slice.as_ptr(), slice.len())
}
}