use crate::hashmap_extension::HashMapExtension;
use crate::queue_pair::MAX_RECV_WR;
use crate::rmr_manager::RemoteMrManager;
use crate::RemoteMrReadAccess;
use crate::{
id,
memory_region::{
local::{LocalMr, LocalMrReadAccess, LocalMrSlice, LocalMrWriteAccess},
remote::RemoteMr,
MrAccess, MrToken,
},
mr_allocator::MrAllocator,
queue_pair::QueuePair,
};
use clippy_utilities::Cast;
use rdma_sys::ibv_access_flags;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::SystemTime;
use std::{
alloc::Layout,
fmt::Debug,
io::{self, Cursor},
mem,
sync::Arc,
time::Duration,
};
use tokio::{
sync::{
mpsc::{channel, Receiver, Sender},
Mutex,
},
task::JoinHandle,
};
use tracing::{debug, error, trace};
static RESPONSE_TIMEOUT: Duration = Duration::from_secs(5);
pub(crate) static MAX_MSG_LEN: usize = 512;
#[derive(Debug)]
pub(crate) struct Agent {
inner: Arc<AgentInner>,
local_mr_recv: Mutex<Receiver<LocalMr>>,
remote_mr_recv: Mutex<Receiver<RemoteMr>>,
data_recv: Mutex<Receiver<(LocalMr, usize, Option<u32>)>>,
imm_recv: Mutex<Receiver<u32>>,
handles: Handles,
#[allow(dead_code)]
agent_thread: Arc<AgentThread>,
}
impl Drop for Agent {
fn drop(&mut self) {
for handle in &self.handles {
handle.abort();
}
}
}
impl Agent {
pub(crate) fn new(
qp: Arc<QueuePair>,
allocator: Arc<MrAllocator>,
max_sr_data_len: usize,
max_rmr_access: ibv_access_flags,
) -> io::Result<Self> {
let response_waits = Arc::new(parking_lot::Mutex::new(HashMap::new()));
let rmr_manager = RemoteMrManager::new(Arc::clone(&qp.pd), max_rmr_access);
let (local_mr_send, local_mr_recv) = channel(1024);
let (remote_mr_send, remote_mr_recv) = channel(1024);
let (data_send, data_recv) = channel(1024);
let (imm_send, imm_recv) = channel(1024);
let local_mr_recv = Mutex::new(local_mr_recv);
let remote_mr_recv = Mutex::new(remote_mr_recv);
let data_recv = Mutex::new(data_recv);
let imm_recv = Mutex::new(imm_recv);
let inner = Arc::new(AgentInner {
qp,
response_waits,
rmr_manager,
allocator,
max_sr_data_len,
});
let (agent_thread, handles) = AgentThread::run(
Arc::<AgentInner>::clone(&inner),
local_mr_send,
remote_mr_send,
data_send,
imm_send,
max_sr_data_len,
)?;
Ok(Self {
inner,
local_mr_recv,
remote_mr_recv,
data_recv,
imm_recv,
handles,
agent_thread,
})
}
pub(crate) async fn request_remote_mr_with_timeout(
&self,
layout: Layout,
timeout: Duration,
) -> io::Result<RemoteMr> {
self.inner
.request_remote_mr_with_timeout(layout, timeout)
.await
}
pub(crate) async fn send_local_mr_with_timeout(
&self,
mr: LocalMr,
timeout: Duration,
) -> io::Result<()> {
let token = unsafe { mr.token_with_timeout_unchecked(timeout) }.map_or_else(
|| {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"wrong timeout value, duration is too long",
))
},
Ok,
)?;
self.inner.rmr_manager.record_mr(token, mr, timeout).await?;
let kind = RequestKind::SendMR(SendMRRequest {
kind: SendMRKind::Local(token),
});
let _ = self.inner.send_request(kind).await?;
Ok(())
}
pub(crate) async fn send_remote_mr(&self, mr: RemoteMr) -> io::Result<()> {
let request = RequestKind::SendMR(SendMRRequest {
kind: SendMRKind::Remote(mr.token()),
});
let resp_kind = self.inner.send_request(request).await?;
#[allow(clippy::mem_forget)]
mem::forget(mr);
match resp_kind {
ResponseKind::SendMR(smr) => match smr.kind {
SendMRResponseKind::Success => Ok(()),
SendMRResponseKind::Timeout => {
Err(io::Error::new(io::ErrorKind::Other, "this rmr is timeout"))
}
SendMRResponseKind::RemoteAgentErr => Err(io::Error::new(
io::ErrorKind::Other,
"remote agent is in an error state",
)),
},
ResponseKind::AllocMR(_) | ResponseKind::ReleaseMR(_) | ResponseKind::SendData(_) => {
Err(io::Error::new(
io::ErrorKind::Other,
"received wrong response, expect SendMR response",
))
}
}
}
pub(crate) async fn receive_local_mr(&self) -> io::Result<LocalMr> {
self.local_mr_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "mr channel closed"))
}
pub(crate) async fn receive_remote_mr(&self) -> io::Result<RemoteMr> {
self.remote_mr_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "mr channel closed"))
}
pub(crate) async fn send_data(&self, lm: &LocalMr, imm: Option<u32>) -> io::Result<()> {
let mut start = 0;
let lm_len = lm.length();
while start < lm_len {
let end = (start.saturating_add(self.max_msg_len())).min(lm_len);
let kind = RequestKind::SendData(SendDataRequest {
len: end.wrapping_sub(start),
});
let response = self
.inner
.send_request_append_data(kind, &[&unsafe { lm.get_unchecked(start..end) }], imm)
.await?;
if let ResponseKind::SendData(send_data_resp) = response {
if send_data_resp.status > 0 {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"send data failed, response status is {}",
send_data_resp.status
),
));
}
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"send data failed, due to unexpected response type {:?}",
response
),
));
}
start = end;
}
Ok(())
}
pub(crate) async fn receive_data(&self) -> io::Result<(LocalMr, Option<u32>)> {
let (lmr, len, imm) = self
.data_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "data channel closed"))?;
let lmr = lmr.take(0..len).ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "this is a bug, received wrong len")
})?;
Ok((lmr, imm))
}
pub(crate) async fn receive_imm(&self) -> io::Result<u32> {
self.imm_recv
.lock()
.await
.recv()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "imm data channel closed"))
}
pub(crate) fn max_msg_len(&self) -> usize {
self.inner.max_sr_data_len
}
pub(crate) fn max_rmr_access(&self) -> ibv_access_flags {
self.inner.rmr_manager.max_rmr_access
}
}
#[derive(Debug)]
struct AgentThread {
inner: Arc<AgentInner>,
local_mr_send: Sender<LocalMr>,
remote_mr_send: Sender<RemoteMr>,
data_send: Sender<(LocalMr, usize, Option<u32>)>,
imm_send: Sender<u32>,
max_sr_data_len: usize,
}
type Handles = Vec<JoinHandle<io::Result<()>>>;
impl AgentThread {
fn run(
inner: Arc<AgentInner>,
local_mr_send: Sender<LocalMr>,
remote_mr_send: Sender<RemoteMr>,
data_send: Sender<(LocalMr, usize, Option<u32>)>,
imm_send: Sender<u32>,
max_sr_data_len: usize,
) -> io::Result<(Arc<Self>, Handles)> {
let agent = Arc::new(Self {
inner,
local_mr_send,
remote_mr_send,
data_send,
imm_send,
max_sr_data_len,
});
if max_sr_data_len == 0 {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"max message length is {:?}, it should be greater than 0",
max_sr_data_len
),
))
} else {
let mut handles = Vec::new();
for _ in 0..MAX_RECV_WR {
handles.push(tokio::spawn(Arc::<AgentThread>::clone(&agent).main()));
}
Ok((Arc::<AgentThread>::clone(&agent), handles))
}
}
async fn main(self: Arc<Self>) -> io::Result<()> {
let mut header_buf = self
.inner
.allocator
.alloc_zeroed_default(unsafe {
&Layout::from_size_align_unchecked(*REQUEST_HEADER_MAX_LEN, 1)
})?;
let mut data_buf = self
.inner
.allocator
.alloc_zeroed_default(unsafe {
&Layout::from_size_align_unchecked(self.max_sr_data_len, 1)
})?;
loop {
unsafe {
std::ptr::write_bytes(header_buf.as_mut_ptr_unchecked(), 0_u8, header_buf.length());
}
let (sz, imm) = self
.inner
.qp
.receive_sge(&[&mut header_buf, &mut data_buf])
.await?;
if imm.is_some() && unsafe { header_buf.as_slice_unchecked() } == CLEAN_STATE {
debug!("write with immediate data : {:?}", imm);
#[allow(clippy::unwrap_used)]
let _task = tokio::spawn(Arc::<Self>::clone(&self).handle_write_imm(imm.unwrap()));
continue;
}
let message =
bincode::deserialize(unsafe {header_buf.as_slice_unchecked()}.get(..).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!(
"{:?} is out of range, the length is {:?}",
0..sz,
header_buf.length()
),
)
})?)
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize {:?}", e),
)
})?;
debug!("message: {:?}", message);
match message {
Message::Request(request) => match request.kind {
RequestKind::SendData(_) => {
let _task = tokio::spawn(
Arc::<Self>::clone(&self).handle_send_req(request, data_buf, imm),
);
data_buf = self.inner.allocator.alloc_zeroed_default(unsafe {
&Layout::from_size_align_unchecked(self.max_sr_data_len, 1)
})?;
}
RequestKind::AllocMR(_)
| RequestKind::ReleaseMR(_)
| RequestKind::SendMR(_) => {
let _task = tokio::spawn(Arc::<Self>::clone(&self).handle_request(request));
}
},
Message::Response(response) => {
let _task = tokio::spawn(Arc::<Self>::clone(&self).handle_response(response));
}
};
}
}
async fn handle_write_imm(self: Arc<Self>, imm: u32) -> io::Result<()> {
self.imm_send.send(imm).await.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Data receiver has stopped, {:?}", e),
)
})
}
async fn handle_request(self: Arc<Self>, request: Request) -> io::Result<()> {
debug!("handle request");
let response = match request.kind {
RequestKind::AllocMR(param) => {
let mr = self.inner.allocator.alloc_zeroed(
&Layout::from_size_align(param.size, param.align)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?,
self.inner.rmr_manager.max_rmr_access,
&self.inner.rmr_manager.pd,
)?;
let token = unsafe { mr.token_with_timeout_unchecked(param.timeout) }.map_or_else(
|| {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"wrong timeout value, duration is too long",
))
},
Ok,
)?;
#[allow(clippy::unreachable)]
self.inner
.rmr_manager
.record_mr(token, mr, param.timeout)
.await
.map_or_else(
|err| unreachable!("{:?}", err),
|_| debug!("record mr requested by remote end {:?}", token),
);
let response = AllocMRResponse { token };
ResponseKind::AllocMR(response)
}
RequestKind::ReleaseMR(param) => {
self.inner.rmr_manager.release_mr(¶m.token).map_or_else(
|err| {
debug!(
"{:?} already released by `rmr_manager`. {:?}",
param.token, err
);
},
|lmr| debug!("release {:?}", lmr),
);
ResponseKind::ReleaseMR(ReleaseMRResponse { status: 0 })
}
RequestKind::SendMR(param) => {
let kind = match param.kind {
SendMRKind::Local(token) => self
.remote_mr_send
.send(RemoteMr::new_from_token(
token,
Arc::<AgentInner>::clone(&self.inner),
))
.await
.map_or_else(
|err| {
error!("Agent remote_mr channel error {:?}", err);
SendMRResponseKind::RemoteAgentErr
},
|_| SendMRResponseKind::Success,
),
SendMRKind::Remote(token) => match self.inner.rmr_manager.release_mr(&token) {
Ok(mr) => self.local_mr_send.send(mr).await.map_or_else(
|err| {
error!("Agent local_mr channel error {:?}", err);
SendMRResponseKind::RemoteAgentErr
},
|_| SendMRResponseKind::Success,
),
Err(err) => {
debug!("{:?}", err);
SendMRResponseKind::Timeout
}
},
};
ResponseKind::SendMR(SendMRResponse { kind })
}
RequestKind::SendData(_) => {
return Err(io::Error::new(
io::ErrorKind::Other,
"Should not reach here, SendData is handled separately",
));
}
};
let response = Response {
request_id: request.request_id,
kind: response,
};
self.inner.send_response(response).await?;
trace!("handle request done");
Ok(())
}
async fn handle_response(self: Arc<Self>, response: Response) -> io::Result<()> {
trace!("handle response");
let sender = self
.inner
.response_waits
.lock()
.remove(&response.request_id)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!(
"request id {:?} is missing in waiting list",
&response.request_id
),
)
})?;
match sender.try_send(Ok(response.kind)) {
Ok(_) => Ok(()),
Err(_) => Err(io::Error::new(
io::ErrorKind::Other,
"The waiting task has dropped",
)),
}
}
async fn handle_send_req(
self: Arc<Self>,
request: Request,
buf: LocalMr,
imm: Option<u32>,
) -> io::Result<()> {
if let RequestKind::SendData(param) = request.kind {
self.data_send
.send((buf, param.len, imm))
.await
.map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("Data receiver has stopped, {:?}", e),
)
})?;
let response = Response {
request_id: request.request_id,
kind: ResponseKind::SendData(SendDataResponse { status: 0 }),
};
self.inner.send_response(response).await?;
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"This function only handles send request",
));
}
Ok(())
}
}
#[derive(Debug)]
pub(crate) struct AgentInner {
qp: Arc<QueuePair>,
response_waits: ResponseWaitsMap,
rmr_manager: RemoteMrManager,
allocator: Arc<MrAllocator>,
max_sr_data_len: usize,
}
impl AgentInner {
pub(crate) async fn request_remote_mr_with_timeout(
self: &Arc<Self>,
layout: Layout,
timeout: Duration,
) -> io::Result<RemoteMr> {
let request = AllocMRRequest {
size: layout.size(),
align: layout.align(),
timeout,
};
let kind = RequestKind::AllocMR(request);
let response = self.send_request(kind).await?;
if let ResponseKind::AllocMR(alloc_mr_response) = response {
Ok(RemoteMr::new_from_token(
alloc_mr_response.token,
Arc::<Self>::clone(self),
))
} else {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"Should not be here, we're expecting AllocMR response",
))
}
}
pub(crate) async fn release_mr(&self, token: MrToken) -> io::Result<()> {
let kind = RequestKind::ReleaseMR(ReleaseMRRequest { token });
let _response = self.send_request(kind).await?;
Ok(())
}
async fn send_request(&self, kind: RequestKind) -> io::Result<ResponseKind> {
self.send_request_append_data(kind, &[], None).await
}
async fn send_request_append_data(
&self,
kind: RequestKind,
data: &[&LocalMrSlice<'_>],
imm: Option<u32>,
) -> io::Result<ResponseKind> {
let data_len: usize = data.iter().map(|l| l.length()).sum();
assert!(data_len <= self.max_sr_data_len);
let (tx, mut rx) = channel(2);
let req_id = self
.response_waits
.lock()
.insert_until_success(tx, AgentRequestId::new);
let req = Request {
request_id: req_id,
kind,
};
let mut header_buf = self
.allocator
.alloc_zeroed_default(unsafe {
&Layout::from_size_align_unchecked(*REQUEST_HEADER_MAX_LEN, 1)
})?;
let cursor = Cursor::new(unsafe { header_buf.as_mut_slice_unchecked() });
let message = Message::Request(req);
bincode::serialize_into(cursor, &message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let header_buf = &unsafe { header_buf.get_unchecked(0..header_buf.length()) };
let mut lms: Vec<&LocalMrSlice> = vec![header_buf];
lms.extend(data);
self.qp.send_sge(&lms, imm).await?;
match tokio::time::timeout(RESPONSE_TIMEOUT, rx.recv()).await {
Ok(resp) => {
resp.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "agent is dropped"))?
}
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"Timeout for waiting for a response.",
)),
}
}
async fn send_response(&self, response: Response) -> io::Result<()> {
let mut header = self
.allocator
.alloc_zeroed_default(unsafe {
&Layout::from_size_align_unchecked(*RESPONSE_HEADER_MAX_LEN, 1)
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let cursor = Cursor::new(unsafe { header.as_mut_slice_unchecked() });
let message = Message::Response(response);
let msz = bincode::serialized_size(&message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
.cast();
bincode::serialize_into(cursor, &message)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let buf = header.get_mut(0..msz).ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"this is a bug, get a wrong serialized size",
)
})?;
self.qp.send(&buf).await?;
Ok(())
}
}
lazy_static! {
static ref SEND_DATA_OFFSET: usize = {
let request = Request {
request_id: AgentRequestId::new(),
kind: RequestKind::SendData(SendDataRequest { len: 0 }),
};
let message = Message::Request(request);
#[allow(clippy::unwrap_used)]
bincode::serialized_size(&message).unwrap().cast()
};
}
const CLEAN_STATE: [u8; 56] = [0_u8; 56];
lazy_static! {
static ref REQUEST_HEADER_MAX_LEN: usize = {
let request_kind = vec![
RequestKind::AllocMR(AllocMRRequest {
size: 0,
align: 0,
timeout: Duration::from_secs(1),
}),
RequestKind::ReleaseMR(ReleaseMRRequest {
token: MrToken {
addr: 0,
len: 0,
rkey: 0,
ddl: SystemTime::now(),
access:0,
},
}),
RequestKind::SendMR(SendMRRequest {
kind: SendMRKind::Local(MrToken {
addr: 0,
len: 0,
rkey: 0,
ddl: SystemTime::now(),
access:0,
}),
}),
RequestKind::SendMR(SendMRRequest {
kind: SendMRKind::Remote(MrToken {
addr: 0,
len: 0,
rkey: 0,
ddl: SystemTime::now(),
access:0,
}),
}),
RequestKind::SendData(SendDataRequest { len: 0 }),
];
#[allow(clippy::unwrap_used)]
let max = request_kind
.into_iter()
.map(|kind| {
#[allow(clippy::unwrap_used)]
bincode::serialized_size(&Message::Request(Request {
request_id: AgentRequestId::new(),
kind,
}))
.unwrap()
.cast()
})
.max()
.unwrap();
assert_eq!(max, CLEAN_STATE.len(), "make sure the length of CLEAN_STATE equals to max");
max
};
static ref RESPONSE_HEADER_MAX_LEN: usize = {
let response_kind = vec![
ResponseKind::AllocMR(AllocMRResponse {
token: MrToken {
addr: 0,
len: 0,
rkey: 0,
ddl: SystemTime::now(),
access:0,
},
}),
ResponseKind::ReleaseMR(ReleaseMRResponse { status: 0 }),
ResponseKind::SendMR(SendMRResponse { kind: SendMRResponseKind::Success }),
ResponseKind::SendData(SendDataResponse { status: 0 }),
];
#[allow(clippy::unwrap_used)]
response_kind
.into_iter()
.map(|kind| {
#[allow(clippy::unwrap_used)]
bincode::serialized_size(&Message::Response(Response {
request_id: AgentRequestId::new(),
kind,
}))
.unwrap()
.cast()
})
.max()
.unwrap()
};
}
type ResponseWaitsMap =
Arc<parking_lot::Mutex<HashMap<AgentRequestId, Sender<io::Result<ResponseKind>>>>>;
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct AgentRequestId(u64);
impl AgentRequestId {
fn new() -> Self {
Self(id::random_u64())
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
struct AllocMRRequest {
size: usize,
align: usize,
timeout: Duration,
}
#[derive(Debug, Serialize, Deserialize)]
struct AllocMRResponse {
token: MrToken,
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
struct ReleaseMRRequest {
token: MrToken,
}
#[derive(Debug, Serialize, Deserialize)]
struct ReleaseMRResponse {
status: usize,
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
enum SendMRKind {
Local(MrToken),
Remote(MrToken),
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
struct SendMRRequest {
kind: SendMRKind,
}
#[derive(Debug, Serialize, Deserialize)]
struct SendMRResponse {
kind: SendMRResponseKind,
}
#[derive(Debug, Serialize, Deserialize)]
enum SendMRResponseKind {
Success,
Timeout,
RemoteAgentErr,
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
struct SendDataRequest {
len: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct SendDataResponse {
status: usize,
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
enum RequestKind {
AllocMR(AllocMRRequest),
ReleaseMR(ReleaseMRRequest),
SendMR(SendMRRequest),
SendData(SendDataRequest),
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
struct Request {
request_id: AgentRequestId,
kind: RequestKind,
}
#[derive(Serialize, Deserialize, Debug)]
#[allow(variant_size_differences)]
enum ResponseKind {
AllocMR(AllocMRResponse),
ReleaseMR(ReleaseMRResponse),
SendMR(SendMRResponse),
SendData(SendDataResponse),
}
#[derive(Serialize, Deserialize, Debug)]
struct Response {
request_id: AgentRequestId,
kind: ResponseKind,
}
#[derive(Serialize, Deserialize, Debug)]
enum Message {
Request(Request),
Response(Response),
}