use crate::{
id,
memory_region::{LocalMemoryRegion, MemoryRegionToken, RemoteMemoryRegion},
mr_allocator::MRAllocator,
queue_pair::QueuePair,
};
use clippy_utilities::{Cast, OverflowArithmetic};
use lockfree_cuckoohash::{pin, LockFreeCuckooHash};
use serde::{Deserialize, Serialize};
use std::{
alloc::Layout,
any::Any,
collections::HashMap,
fmt::Debug,
io::{self, Cursor},
sync::Arc,
};
use tokio::{
sync::{
mpsc::{channel, Receiver, Sender},
Mutex,
},
task::JoinHandle,
};
use tracing::{debug, trace};
#[derive(Debug)]
pub(crate) struct Agent {
inner: Arc<AgentInner>,
local_mr_recv: Mutex<Receiver<Arc<LocalMemoryRegion>>>,
remote_mr_recv: Mutex<Receiver<RemoteMemoryRegion>>,
data_recv: Mutex<Receiver<LocalMemoryRegion>>,
handle: JoinHandle<io::Result<()>>,
max_message_len: usize,
}
impl Drop for Agent {
fn drop(&mut self) {
self.handle.abort();
}
}
impl Agent {
pub(crate) fn new(
qp: Arc<QueuePair>,
allocator: Arc<MRAllocator>,
max_message_len: usize,
) -> io::Result<Self> {
let response_waits = Arc::new(LockFreeCuckooHash::new());
let mr_own = Arc::new(Mutex::new(HashMap::new()));
let (local_mr_send, local_mr_recv) = channel(1024);
let (remote_mr_send, remote_mr_recv) = channel(1024);
let (data_send, data_recv) = channel(1024);
let local_mr_recv = Mutex::new(local_mr_recv);
let remote_mr_recv = Mutex::new(remote_mr_recv);
let data_recv_mutex = Mutex::new(data_recv);
let inner = Arc::new(AgentInner {
qp,
response_waits,
mr_own,
allocator,
max_message_len,
});
let handle = AgentThread::run(
Arc::<AgentInner>::clone(&inner),
local_mr_send,
remote_mr_send,
data_send,
max_message_len,
)?;
Ok(Self {
inner,
data_recv: data_recv_mutex,
handle,
max_message_len,
local_mr_recv,
remote_mr_recv,
})
}
pub(crate) async fn request_remote_mr(&self, layout: Layout) -> io::Result<RemoteMemoryRegion> {
self.inner.request_remote_mr(layout).await
}
pub(crate) async fn send_mr(&self, mr: Arc<dyn Any + Send + Sync>) -> io::Result<()> {
let request = if mr.is::<LocalMemoryRegion>() {
#[allow(clippy::unwrap_used)]
let mr = mr.downcast::<LocalMemoryRegion>().unwrap();
let token = mr.token();
let ans = SendMRKind::Local(token);
if self.inner.mr_own.lock().await.insert(token, mr).is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("the MR {:?} should be send multiple times", token),
));
}
ans
} else {
let mr = mr.downcast::<RemoteMemoryRegion>().map_err(|m| {
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"this Mr {:?} can not be downcasted to RemoteMemoryRegion",
m
),
)
})?;
SendMRKind::Remote(mr.token())
};
let request = Request {
request_id: AgentRequestId::new(),
kind: RequestKind::SendMR(SendMRRequest { kind: request }),
};
let _ = self.inner.send_request(request).await?;
Ok(())
}
pub(crate) async fn receive_local_mr(&self) -> io::Result<Arc<LocalMemoryRegion>> {
self.local_mr_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "mr channel closed"))
}
pub(crate) async fn receive_remote_mr(&self) -> io::Result<RemoteMemoryRegion> {
self.remote_mr_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "mr channel closed"))
}
pub(crate) async fn send_data(&self, lm: &LocalMemoryRegion) -> io::Result<()> {
let mut start = 0;
let lm_len = lm.length();
let max_content_len = self.max_message_len.overflow_sub(*SEND_DATA_OFFSET);
while start < lm_len {
let end = (start.overflow_add(max_content_len)).min(lm_len);
let request = Request {
request_id: AgentRequestId::new(),
kind: RequestKind::SendData(SendDataRequest {
len: end.overflow_sub(start),
}),
};
let response = self
.inner
.send_request_append_data(request, &[&lm.slice(start..end)?])
.await?;
if let ResponseKind::SendData(send_data_resp) = response {
if send_data_resp.status > 0 {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"send data failed, response status is {}",
send_data_resp.status
),
));
}
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"send data failed, due to unexpected response type {:?}",
response
),
));
}
start = end;
}
Ok(())
}
pub(crate) async fn receive_data(&self) -> io::Result<LocalMemoryRegion> {
self.data_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "data channel closed"))
}
}
struct AgentThread {
inner: Arc<AgentInner>,
local_mr_send: Sender<Arc<LocalMemoryRegion>>,
remote_mr_send: Sender<RemoteMemoryRegion>,
data_send: Sender<LocalMemoryRegion>,
max_message_len: usize,
}
impl AgentThread {
fn run(
inner: Arc<AgentInner>,
local_mr_send: Sender<Arc<LocalMemoryRegion>>,
remote_mr_send: Sender<RemoteMemoryRegion>,
data_send: Sender<LocalMemoryRegion>,
max_message_len: usize,
) -> io::Result<JoinHandle<io::Result<()>>> {
let agent = Arc::new(Self {
inner,
local_mr_send,
remote_mr_send,
data_send,
max_message_len,
});
if max_message_len < *SEND_DATA_OFFSET {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"max message length is {:?}, it should be at leaset {:?}",
max_message_len, *SEND_DATA_OFFSET
),
))
} else {
Ok(tokio::spawn(agent.main()))
}
}
async fn main(self: Arc<Self>) -> io::Result<()> {
let mut buf = self
.inner
.allocator
.alloc(unsafe { &Layout::from_size_align_unchecked(self.max_message_len, 1) })?;
loop {
debug!("receiving message");
let sz = self.inner.qp.receive(&buf).await?;
debug!("received message, size = {}", sz);
let message = bincode::deserialize(buf.as_slice().get(0..sz).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!(
"{:?} is out of range, the length is {:?}",
0..sz,
buf.length()
),
)
})?)
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize {:?}", e),
)
})?;
match message {
Message::Request(request) => match request.kind {
RequestKind::SendData(_) => {
let _task =
tokio::spawn(Arc::<Self>::clone(&self).handle_send_req(request, buf));
buf = self.inner.allocator.alloc(unsafe {
&Layout::from_size_align_unchecked(self.max_message_len, 1)
})?;
}
RequestKind::AllocMR(_)
| RequestKind::ReleaseMR(_)
| RequestKind::SendMR(_) => {
let _task = tokio::spawn(Arc::<Self>::clone(&self).handle_request(request));
}
},
Message::Response(response) => {
let _task = tokio::spawn(Arc::<Self>::clone(&self).handle_response(response));
}
};
}
}
async fn handle_request(self: Arc<Self>, request: Request) -> io::Result<()> {
debug!("handle request");
let response = match request.kind {
RequestKind::AllocMR(param) => {
let mr = Arc::new(
self.inner.allocator.alloc(
&Layout::from_size_align(param.size, param.align)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?,
)?,
);
let token = mr.token();
let response = AllocMRResponse { token };
let _old = self.inner.mr_own.lock().await.insert(token, mr);
ResponseKind::AllocMR(response)
}
RequestKind::ReleaseMR(param) => {
assert!(self
.inner
.mr_own
.lock()
.await
.remove(¶m.token)
.is_some());
ResponseKind::ReleaseMR(ReleaseMRResponse { status: 0 })
}
RequestKind::SendMR(param) => {
match param.kind {
SendMRKind::Local(token) => {
assert!(self
.remote_mr_send
.send(RemoteMemoryRegion::new_from_token(
token,
Arc::<AgentInner>::clone(&self.inner)
))
.await
.is_ok());
}
SendMRKind::Remote(token) => {
let mr = Arc::clone(
self.inner.mr_own.lock().await.get(&token).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("the token {:?} is not registered", token),
)
})?,
);
assert!(self.local_mr_send.send(mr).await.is_ok());
}
}
ResponseKind::SendMR(SendMRResponse {})
}
RequestKind::SendData(_) => {
return Err(io::Error::new(
io::ErrorKind::Other,
"Should not reach here, SendData is handled separately",
));
}
};
let response = Response {
request_id: request.request_id,
kind: response,
};
self.inner.send_response(response).await?;
trace!("handle request done");
Ok(())
}
async fn handle_response(self: Arc<Self>, response: Response) -> io::Result<()> {
trace!("handle response");
let guard = pin();
let sender = self
.inner
.response_waits
.remove_with_guard(&response.request_id, &guard)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!(
"request id {:?} is missing in waiting list",
&response.request_id
),
)
})?;
match sender.try_send(Ok(response.kind)) {
Ok(_) => Ok(()),
Err(_) => Err(io::Error::new(
io::ErrorKind::Other,
"The waiting task has dropped",
)),
}
}
async fn handle_send_req(
self: Arc<Self>,
request: Request,
buf: LocalMemoryRegion,
) -> io::Result<()> {
if let RequestKind::SendData(param) = request.kind {
let buf = buf.slice(*SEND_DATA_OFFSET..(SEND_DATA_OFFSET.overflow_add(param.len)))?;
self.data_send.send(buf).await.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Data receiver has stopped, {:?}", e),
)
})?;
let response = Response {
request_id: request.request_id,
kind: ResponseKind::SendData(SendDataResponse { status: 0 }),
};
self.inner.send_response(response).await?;
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"This function only handles send request",
));
}
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct AgentInner {
qp: Arc<QueuePair>,
response_waits: ResponseWaitsMap,
mr_own: Arc<Mutex<HashMap<MemoryRegionToken, Arc<LocalMemoryRegion>>>>,
allocator: Arc<MRAllocator>,
max_message_len: usize,
}
impl AgentInner {
pub(crate) async fn request_remote_mr(
self: &Arc<Self>,
layout: Layout,
) -> io::Result<RemoteMemoryRegion> {
let request = AllocMRRequest {
size: layout.size(),
align: layout.align(),
};
let request = Request {
request_id: AgentRequestId::new(),
kind: RequestKind::AllocMR(request),
};
let response = self.send_request(request).await?;
if let ResponseKind::AllocMR(alloc_mr_response) = response {
Ok(RemoteMemoryRegion::new_from_token(
alloc_mr_response.token,
Arc::<Self>::clone(self),
))
} else {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"Should not be here, we're expecting AllocMR response",
))
}
}
pub(crate) async fn release_mr(&self, token: MemoryRegionToken) -> io::Result<()> {
let request = Request {
request_id: AgentRequestId::new(),
kind: RequestKind::ReleaseMR(ReleaseMRRequest { token }),
};
let _response = self.send_request(request).await?;
Ok(())
}
async fn send_request(&self, request: Request) -> io::Result<ResponseKind> {
self.send_request_append_data(request, &[]).await
}
async fn send_request_append_data(
&self,
request: Request,
lm: &[&LocalMemoryRegion],
) -> io::Result<ResponseKind> {
let (tx, mut rx) = channel(2);
let mut req = request;
while !self
.response_waits
.insert_if_not_exists(req.request_id, tx.clone())
{
req = Request {
request_id: AgentRequestId::new(),
kind: request.kind,
};
}
let mut buf = self
.allocator
.alloc(unsafe { &Layout::from_size_align_unchecked(self.max_message_len, 1) })?;
let cursor = Cursor::new(buf.as_mut_slice());
let message = Message::Request(req);
let msz = bincode::serialized_size(&message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.cast();
bincode::serialize_into(cursor, &message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let buf = buf.slice(0..msz)?;
let mut lms = vec![&buf];
lms.extend(lm);
let lms_len: usize = lms.iter().map(|l| l.length()).sum();
assert!(lms_len <= self.max_message_len);
self.qp.send_sge(&lms).await?;
rx.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
}
async fn send_response(&self, response: Response) -> io::Result<()> {
let mut buf = self
.allocator
.alloc(unsafe { &Layout::from_size_align_unchecked(self.max_message_len, 1) })
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let cursor = Cursor::new(buf.as_mut_slice());
let message = Message::Response(response);
let msz = bincode::serialized_size(&message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.cast();
bincode::serialize_into(cursor, &message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let buf = buf.slice(0..msz)?;
self.qp.send(&buf).await?;
Ok(())
}
}
lazy_static! {
static ref SEND_DATA_OFFSET: usize = {
let request = Request {
request_id: AgentRequestId::new(),
kind: RequestKind::SendData(SendDataRequest { len: 0 }),
};
let message = Message::Request(request);
#[allow(clippy::unwrap_used)]
bincode::serialized_size(&message).unwrap().cast()
};
}
type ResponseWaitsMap = Arc<LockFreeCuckooHash<AgentRequestId, Sender<io::Result<ResponseKind>>>>;
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct AgentRequestId(u64);
impl AgentRequestId {
fn new() -> Self {
Self(id::random_u64())
}
}
#[derive(Serialize, Deserialize, Clone, Copy)]
struct AllocMRRequest {
size: usize,
align: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct AllocMRResponse {
token: MemoryRegionToken,
}
#[derive(Serialize, Deserialize, Clone, Copy)]
struct ReleaseMRRequest {
token: MemoryRegionToken,
}
#[derive(Debug, Serialize, Deserialize)]
struct ReleaseMRResponse {
status: usize,
}
#[derive(Serialize, Deserialize, Clone, Copy)]
enum SendMRKind {
Local(MemoryRegionToken),
Remote(MemoryRegionToken),
}
#[derive(Serialize, Deserialize, Clone, Copy)]
struct SendMRRequest {
kind: SendMRKind,
}
#[derive(Debug, Serialize, Deserialize)]
struct SendMRResponse {}
#[derive(Serialize, Deserialize, Clone, Copy)]
struct SendDataRequest {
len: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct SendDataResponse {
status: usize,
}
#[derive(Serialize, Deserialize, Clone, Copy)]
enum RequestKind {
AllocMR(AllocMRRequest),
ReleaseMR(ReleaseMRRequest),
SendMR(SendMRRequest),
SendData(SendDataRequest),
}
#[derive(Serialize, Deserialize, Clone, Copy)]
struct Request {
request_id: AgentRequestId,
kind: RequestKind,
}
#[derive(Serialize, Deserialize, Debug)]
enum ResponseKind {
AllocMR(AllocMRResponse),
ReleaseMR(ReleaseMRResponse),
SendMR(SendMRResponse),
SendData(SendDataResponse),
}
#[derive(Serialize, Deserialize)]
struct Response {
request_id: AgentRequestId,
kind: ResponseKind,
}
#[derive(Serialize, Deserialize)]
enum Message {
Request(Request),
Response(Response),
}