use async_rdma::{LocalMr, LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaListener};
use std::{alloc::Layout, sync::Arc, time::Duration};
use tokio::net::ToSocketAddrs;
#[derive(Clone, Debug)]
enum Request {
Echo { msg: String },
Sync,
}
#[derive(Clone, Debug)]
enum Response {
Echo { msg: String },
Sync,
}
struct Server {}
impl Server {
#[tokio::main]
async fn start<A: ToSocketAddrs>(addr: A) {
let rdmalistener = RdmaListener::bind(addr)
.await
.map_err(|err| println!("{}", err))
.unwrap();
let rdma = Arc::new(rdmalistener.accept(1, 1, 128).await.unwrap());
let sr_handler = tokio::spawn(Self::sr_task(rdma.clone()));
let wr_handler = tokio::spawn(Self::wr_task(rdma));
sr_handler.await.unwrap();
wr_handler.await.unwrap();
}
async fn sync_with_client(rdma: &Rdma) {
let mut lmr_sync = rdma
.alloc_local_mr(Layout::new::<Request>())
.map_err(|err| println!("{}", &err))
.unwrap();
unsafe { *(lmr_sync.as_mut_ptr() as *mut Request) = Request::Sync };
rdma.send(&lmr_sync)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
}
async fn wr_task(rdma: Arc<Rdma>) {
loop {
let lmr_req = rdma
.receive_local_mr()
.await
.map_err(|err| println!("{}", &err))
.unwrap();
Self::sync_with_client(&rdma).await;
let resp = unsafe {
let req = &*(lmr_req.as_ptr() as *const Request);
Self::process_request(req)
};
let mut lmr_resp = rdma
.alloc_local_mr(Layout::new::<Response>())
.map_err(|err| println!("{}", &err))
.unwrap();
unsafe { *(lmr_resp.as_mut_ptr() as *mut Response) = resp };
Self::sync_with_client(&rdma).await;
rdma.send_local_mr(lmr_resp)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
}
}
async fn sr_task(rdma: Arc<Rdma>) {
loop {
let resp = rdma
.receive()
.await
.map(|lmr_req| unsafe {
let req = &*(lmr_req.as_ptr() as *const Request);
Self::process_request(req)
})
.map_err(|err| println!("{}", &err))
.unwrap();
let mut lmr_resp = rdma
.alloc_local_mr(Layout::new::<Response>())
.map_err(|err| println!("{}", &err))
.unwrap();
unsafe { *(lmr_resp.as_mut_ptr() as *mut Response) = resp };
rdma.send(&lmr_resp)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
}
}
fn echo(msg: String) -> Response {
Response::Echo { msg }
}
fn process_request(req: &Request) -> Response {
match req {
Request::Echo { msg } => Self::echo(msg.to_string()),
Request::Sync => Response::Sync,
}
}
}
fn transmute_lmr_to_string(lmr: &LocalMr) -> String {
unsafe {
let resp = &*(lmr.as_ptr() as *const Response);
match resp {
Response::Echo { msg } => msg.to_string(),
_ => panic!("invalid input : {:?}", resp),
}
}
}
struct Client {
rdma_stub: Rdma,
}
impl Client {
async fn new<A: ToSocketAddrs>(addr: A) -> Self {
let rdma_stub = Rdma::connect(addr, 1, 1, 128)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
Client { rdma_stub }
}
async fn echo_req_sr(&self, msg: String) -> String {
let mut lmr_req = self
.rdma_stub
.alloc_local_mr(Layout::new::<Request>())
.map_err(|err| println!("{}", &err))
.unwrap();
unsafe { *(lmr_req.as_mut_ptr() as *mut Request) = Request::Echo { msg } };
self.rdma_stub
.send(&lmr_req)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
self.rdma_stub
.receive()
.await
.map(|lmr_resp| transmute_lmr_to_string(&lmr_resp))
.map_err(|err| println!("{}", &err))
.unwrap()
}
async fn echo_req_wr(&self, msg: String) -> String {
let mut lmr_req = self
.rdma_stub
.alloc_local_mr(Layout::new::<Request>())
.map_err(|err| println!("{}", &err))
.unwrap();
unsafe { *(lmr_req.as_mut_ptr() as *mut Request) = Request::Echo { msg } };
let mut rmr_req = self
.rdma_stub
.request_remote_mr(Layout::new::<Request>())
.await
.map_err(|err| println!("{}", &err))
.unwrap();
self.rdma_stub
.write(&lmr_req, &mut rmr_req)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
self.rdma_stub
.send_remote_mr(rmr_req)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
self.sync_with_server().await;
let rmr_resp = self
.rdma_stub
.receive_remote_mr()
.await
.map_err(|err| println!("{}", &err))
.unwrap();
let mut lmr_resp = self
.rdma_stub
.alloc_local_mr(Layout::new::<Response>())
.map_err(|err| println!("{}", &err))
.unwrap();
self.sync_with_server().await;
self.rdma_stub
.read(&mut lmr_resp, &rmr_resp)
.await
.map_err(|err| println!("{}", &err))
.unwrap();
transmute_lmr_to_string(&lmr_resp)
}
async fn sync_with_server(&self) {
self.rdma_stub
.receive()
.await
.map(|lmr_resp| unsafe {
let resp = &*(lmr_resp.as_ptr() as *const Response);
if let Response::Sync = resp {
} else {
panic!("invalid response");
}
})
.map_err(|err| println!("{}", &err))
.unwrap()
}
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
std::thread::spawn(|| Server::start("127.0.0.1:5555"));
println!("rpc server started");
tokio::time::sleep(Duration::new(1, 0)).await;
let msg_hello = String::from("hello");
let msg_world = String::from("world");
let client = Client::new("127.0.0.1:5555").await;
println!("request: {}", msg_hello);
let res = client.echo_req_sr(msg_hello).await;
println!("response: {}", res);
println!("request: {}", msg_world);
let res = client.echo_req_wr(msg_world).await;
println!("response: {}", res);
}