use crate::{agent::AgentInner, protection_domain::ProtectionDomain};
use clippy_utilities::{Cast, OverflowArithmetic};
use rdma_sys::{ibv_access_flags, ibv_dereg_mr, ibv_mr, ibv_reg_mr};
use serde::{Deserialize, Serialize};
use std::{
alloc::{alloc, Layout},
fmt::Debug,
io,
ops::Range,
ptr::NonNull,
slice,
sync::{Arc, Mutex, MutexGuard},
};
use tracing::error;
#[derive(Debug)]
pub struct MemoryRegion<T: RemoteKey> {
inner: Arc<InnerMr<T>>,
}
impl<T: RemoteKey> MemoryRegion<T> {
fn new_root(addr: usize, len: usize, t: T) -> Self {
Self {
inner: Arc::new(InnerMr::new_root(addr, len, t)),
}
}
pub fn as_ptr(&self) -> *const u8 {
self.inner.as_ptr()
}
pub fn length(&self) -> usize {
self.inner.length()
}
pub(crate) fn rkey(&self) -> u32 {
self.inner.rkey()
}
pub(crate) fn token(&self) -> MemoryRegionToken {
MemoryRegionToken {
addr: self.inner.addr,
len: self.inner.len,
rkey: self.rkey(),
}
}
pub fn slice(&self, range: Range<usize>) -> io::Result<Self> {
Ok(Self {
inner: Arc::new(self.inner.slice(range)?),
})
}
pub(crate) fn alloc(&self, layout: &Layout) -> io::Result<Self> {
Ok(Self {
inner: Arc::new(self.inner.alloc(layout)?),
})
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Copy, Debug)]
pub(crate) struct MemoryRegionToken {
pub(crate) addr: usize,
pub(crate) len: usize,
pub(crate) rkey: u32,
}
pub trait RemoteKey {
fn rkey(&self) -> u32;
}
#[derive(Debug)]
struct Node<T: RemoteKey> {
parent: Arc<InnerMr<T>>,
root: Arc<InnerMr<T>>,
}
#[derive(Debug)]
enum MemoryRegionKind<T: RemoteKey> {
Root(T),
Node(Node<T>),
}
impl<T: RemoteKey> MemoryRegionKind<T> {
fn rkey(&self) -> u32 {
match *self {
MemoryRegionKind::Root(ref root) => root.rkey(),
MemoryRegionKind::Node(ref node) => node.root.rkey(),
}
}
}
#[derive(Debug)]
pub(crate) struct InnerMr<T: RemoteKey> {
addr: usize,
len: usize,
kind: MemoryRegionKind<T>,
sub: AllocManager,
}
impl<T: RemoteKey> InnerMr<T> {
#[allow(clippy::as_conversions)]
fn as_ptr(&self) -> *const u8 {
self.addr as _
}
fn length(&self) -> usize {
self.len
}
fn rkey(&self) -> u32 {
self.kind.rkey()
}
fn new_root(addr: usize, len: usize, t: T) -> Self {
Self {
addr,
len,
kind: MemoryRegionKind::Root(t),
sub: AllocManager::new(addr, len),
}
}
fn new_node(self: &Arc<Self>, addr: usize, len: usize) -> Self {
let new_node = Node {
parent: Arc::<Self>::clone(self),
root: self.root(),
};
let kind = MemoryRegionKind::Node(new_node);
Self {
addr,
len,
kind,
sub: AllocManager::new(addr, len),
}
}
fn root(self: &Arc<Self>) -> Arc<Self> {
match self.kind {
MemoryRegionKind::Root(_) => Arc::<Self>::clone(self),
MemoryRegionKind::Node(ref node) => Arc::<Self>::clone(&node.root),
}
}
fn slice(self: &Arc<Self>, range: Range<usize>) -> io::Result<Self> {
self.sub.slice(&range)?;
Ok(self.new_node(self.addr.overflow_add(range.start), range.len()))
}
fn alloc(self: &Arc<Self>, layout: &Layout) -> io::Result<Self> {
let range = self.sub.alloc(layout)?;
Ok(self.new_node(self.addr.overflow_add(range.start), range.len()))
}
}
impl<T: RemoteKey> Drop for InnerMr<T> {
fn drop(&mut self) {
if let MemoryRegionKind::Node(ref node) = self.kind {
if let Err(e) = node.parent.sub.free(
self.addr.overflow_sub(node.parent.addr)
..self
.len
.overflow_add(self.addr)
.overflow_sub(node.parent.addr),
) {
error!("Faild to drop a memory region, {:?}", e);
}
}
}
}
pub struct Local {
inner_mr: NonNull<ibv_mr>,
_pd: Arc<ProtectionDomain>,
}
impl Debug for Local {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Local")
.field("inner_mr", &self.inner_mr)
.finish()
}
}
impl Drop for Local {
fn drop(&mut self) {
let errno = unsafe { ibv_dereg_mr(self.inner_mr.as_ptr()) };
assert_eq!(errno, 0_i32);
}
}
impl Local {
fn lkey(&self) -> u32 {
unsafe { self.inner_mr.as_ref() }.lkey
}
}
impl RemoteKey for Local {
fn rkey(&self) -> u32 {
unsafe { self.inner_mr.as_ref() }.rkey
}
}
unsafe impl Sync for Local {}
unsafe impl Send for Local {}
impl InnerMr<Local> {
fn lkey(&self) -> u32 {
match self.kind {
MemoryRegionKind::Root(ref root) => root.lkey(),
MemoryRegionKind::Node(ref node) => node.root.lkey(),
}
}
}
pub type LocalMemoryRegion = MemoryRegion<Local>;
impl LocalMemoryRegion {
#[allow(clippy::as_conversions)]
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.as_ptr() as _
}
pub fn as_slice(&self) -> &[u8] {
unsafe { slice::from_raw_parts(self.as_ptr(), self.length()) }
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.length()) }
}
pub(crate) fn lkey(&self) -> u32 {
self.inner.lkey()
}
#[allow(clippy::as_conversions)] pub(crate) fn new_from_pd(
pd: &Arc<ProtectionDomain>,
layout: Layout,
access: ibv_access_flags,
) -> io::Result<Self> {
let addr = unsafe { alloc(layout) };
let inner_mr = NonNull::new(unsafe {
ibv_reg_mr(pd.as_ptr(), addr.cast(), layout.size(), access.0.cast())
})
.ok_or_else(io::Error::last_os_error)?;
let len = layout.size();
let local = Local {
inner_mr,
_pd: Arc::<ProtectionDomain>::clone(pd),
};
Ok(MemoryRegion::new_root(addr as usize, len, local))
}
}
#[derive(Debug)]
pub struct Remote {
token: MemoryRegionToken,
agent: Arc<AgentInner>,
}
impl RemoteKey for Remote {
fn rkey(&self) -> u32 {
self.token.rkey
}
}
impl Drop for Remote {
fn drop(&mut self) {
let agent = Arc::<AgentInner>::clone(&self.agent);
let token = self.token;
let _task = tokio::spawn(async move { AgentInner::release_mr(&agent, token).await });
}
}
pub type RemoteMemoryRegion = MemoryRegion<Remote>;
impl RemoteMemoryRegion {
pub(crate) fn new_from_token(token: MemoryRegionToken, agent: Arc<AgentInner>) -> Self {
let addr = token.addr;
let len = token.len;
let remote = Remote { token, agent };
let inner = Arc::new(InnerMr::new_root(addr, len, remote));
Self { inner }
}
}
#[derive(Debug)]
struct AllocManager {
addr: usize,
length: usize,
occupy: Mutex<Vec<Range<usize>>>,
}
impl AllocManager {
fn new(addr: usize, length: usize) -> Self {
Self {
addr,
occupy: Mutex::new(Vec::new()),
length,
}
}
fn insert_slice(
mut locked_sub: MutexGuard<Vec<Range<usize>>>,
range: &Range<usize>,
) -> io::Result<()> {
let pos = locked_sub
.binary_search_by(|r| r.start.cmp(&range.start))
.err()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "The slice is there"))?;
locked_sub.insert(pos, range.clone());
Ok(())
}
fn slice(&self, range: &Range<usize>) -> io::Result<()> {
#[allow(clippy::suspicious_operation_groupings)]
if range.start >= range.end || range.end > self.length {
return Err(io::Error::new(io::ErrorKind::Other, "Invalid Range"));
}
let locked_sub = self.occupy.lock().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Cannot lock occupy in alloc manager, {:?}", e),
)
})?;
if !locked_sub
.iter()
.all(|sub_range| range.end <= sub_range.start || range.start >= sub_range.end)
{
return Err(io::Error::new(
io::ErrorKind::Other,
"Memory slice Has been used",
));
}
Self::insert_slice(locked_sub, range)?;
Ok(())
}
fn alloc(&self, layout: &Layout) -> io::Result<Range<usize>> {
let mut last = 0;
let locked_sub = self.occupy.lock().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Cannot lock occupy in alloc manager, {:?}", e),
)
})?;
let mut ans = Err(io::Error::new(io::ErrorKind::Other, "No Enough Memory"));
for range in locked_sub.iter() {
last = self.celling(last, layout);
if last.overflow_add(layout.size()) <= range.start {
ans = Ok(last..last.overflow_add(layout.size()));
break;
}
last = range.end;
}
if ans.is_err() {
last = self.celling(last, layout);
if last.overflow_add(layout.size()) <= self.length {
ans = Ok(last..last.overflow_add(layout.size()));
}
}
let ans = ans?;
Self::insert_slice(locked_sub, &ans)?;
Ok(ans)
}
fn celling(&self, offset: usize, layout: &Layout) -> usize {
let align = layout.align();
self.addr
.overflow_add(offset)
.overflow_add(align)
.overflow_sub(1)
.overflow_div(align)
.overflow_mul(align)
.overflow_sub(self.addr)
}
fn free(&self, range: Range<usize>) -> io::Result<()> {
let mut locked_sub = self.occupy.lock().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Cannot lock occupy in alloc manager, {:?}", e),
)
})?;
let pos = locked_sub
.binary_search_by(|r| r.start.cmp(&range.start))
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Free an invalid range, {:?}", e),
)
})?;
let _ = locked_sub.remove(pos);
Ok(())
}
}
#[cfg(test)]
mod test {
use super::AllocManager;
use std::io;
#[test]
#[allow(clippy::indexing_slicing)]
fn test_sub_mr_slice() -> io::Result<()> {
let sub_mr = AllocManager::new(0, 1024);
let res = sub_mr.slice(&(0..1));
assert!(res.is_ok());
let occupy_list = sub_mr.occupy.lock().map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Cannot lock occupy in alloc manager, {:?}", e),
)
})?;
assert_eq!(occupy_list.len(), 1);
assert_eq!(occupy_list[0], (0..1));
Ok(())
}
}