use async_ucx::ucp;
use bytes::Bytes;
use std::{
cell::RefCell,
collections::HashMap,
io::{self, IoSlice, IoSliceMut},
mem::MaybeUninit,
net::{SocketAddr, ToSocketAddrs},
rc::Rc,
sync::Arc,
};
use tokio::sync::{mpsc, oneshot};
use tracing::*;
lazy_static::lazy_static! {
static ref CONTEXT: Arc<ucp::Context> = ucp::Context::new().expect("failed to initialize UCX context");
}
#[derive(Clone)]
pub struct Endpoint {
addr: SocketAddr,
sender: mpsc::Sender<SendMsg>,
recver: mpsc::Sender<RecvMsg>,
}
struct SendMsg {
dst: SocketAddr,
tag: u64,
bufs: &'static [IoSlice<'static>],
done: oneshot::Sender<()>,
}
struct RecvMsg {
tag: u64,
bufs: &'static mut [IoSliceMut<'static>],
done: oneshot::Sender<(usize, SocketAddr)>,
}
fn worker_thread(
addr: SocketAddr,
addr_tx: oneshot::Sender<SocketAddr>,
mut send_rx: mpsc::Receiver<SendMsg>,
mut recv_rx: mpsc::Receiver<RecvMsg>,
) {
let context = CONTEXT.clone();
let worker = context.create_worker().unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let local = tokio::task::LocalSet::new();
let eps: Rc<RefCell<HashMap<SocketAddr, Rc<ucp::Endpoint>>>> = Default::default();
let mut listener = worker.create_listener(addr).unwrap();
let addr = listener.socket_addr().unwrap();
addr_tx.send(addr).unwrap();
local.spawn_local({
let worker = worker.clone();
let eps = eps.clone();
async move {
loop {
let conn = listener.next().await;
let peer = conn.remote_addr().unwrap();
let ep = Rc::new(worker.accept(conn).await.unwrap());
eps.borrow_mut().insert(peer, ep);
}
}
});
local.spawn_local({
let worker = worker.clone();
async move {
while let Some(msg) = send_rx.recv().await {
let mut eps = eps.borrow_mut();
let ep = if let Some(ep) = eps.get(&msg.dst) {
ep.clone()
} else {
let ep = Rc::new(worker.connect_socket(msg.dst).await.unwrap());
eps.insert(msg.dst, ep.clone());
ep
};
tokio::task::spawn_local(async move {
let addr = socket_addr_to_bytes(addr);
let mut bufs = vec![IoSlice::new(&addr)];
bufs.extend_from_slice(msg.bufs);
ep.tag_send_vectored(msg.tag, &bufs).await.unwrap();
msg.done.send(()).unwrap();
});
}
}
});
local.spawn_local({
let worker = worker.clone();
async move {
while let Some(RecvMsg { tag, bufs, done }) = recv_rx.recv().await {
let worker = worker.clone();
tokio::task::spawn_local(async move {
let mut addr_buf = [0u8; 6];
let mut bufs1 = vec![IoSliceMut::new(&mut addr_buf)];
bufs1.reserve(bufs.len());
unsafe {
std::ptr::copy_nonoverlapping(
bufs.as_ptr(),
bufs1.as_mut_ptr().add(1),
bufs.len(),
);
bufs1.set_len(bufs.len() + 1);
}
let len = worker.tag_recv_vectored(tag, &mut bufs1).await.unwrap();
let addr = socket_addr_from_bytes(addr_buf);
done.send((len - 6, addr)).unwrap();
});
}
}
});
local.block_on(&rt, worker.event_poll()).unwrap();
}
impl Endpoint {
pub async fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
let addr = addr.to_socket_addrs()?.next().unwrap();
let (addr_tx, addr_rx) = oneshot::channel();
let (sender, send_rx) = mpsc::channel::<SendMsg>(8);
let (recver, recv_rx) = mpsc::channel::<RecvMsg>(8);
std::thread::spawn(move || worker_thread(addr, addr_tx, send_rx, recv_rx));
let addr = addr_rx.await.unwrap();
trace!("new ep: {addr}");
Ok(Endpoint {
addr,
sender,
recver,
})
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.addr)
}
pub async fn send_to(&self, dst: impl ToSocketAddrs, tag: u64, data: &[u8]) -> io::Result<()> {
self.send_to_vectored(dst, tag, &[IoSlice::new(data)]).await
}
pub async fn send_to_vectored(
&self,
dst: impl ToSocketAddrs,
tag: u64,
bufs: &[IoSlice<'_>],
) -> io::Result<()> {
let dst = dst.to_socket_addrs()?.next().unwrap();
trace!("send: {} -> {}, tag={}", self.addr, dst, tag);
let bufs = unsafe { std::mem::transmute(bufs) };
let (done, done_recver) = oneshot::channel();
let msg = SendMsg {
dst,
tag,
bufs,
done,
};
self.sender.send(msg).await.ok().unwrap();
done_recver.await.unwrap();
Ok(())
}
pub async fn recv_from(&self, tag: u64, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_vectored(tag, &mut [IoSliceMut::new(buf)])
.await
}
pub async fn recv_from_vectored(
&self,
tag: u64,
bufs: &mut [IoSliceMut<'_>],
) -> io::Result<(usize, SocketAddr)> {
let bufs = unsafe { std::mem::transmute(bufs) };
let (done, done_recver) = oneshot::channel();
let msg = RecvMsg { tag, bufs, done };
self.recver.send(msg).await.ok().unwrap();
let (len, from) = done_recver.await.unwrap();
trace!("recv: {} <- {}, tag={}", self.addr, from, tag);
Ok((len, from))
}
pub(crate) async fn recv_from_raw(&self, tag: u64) -> io::Result<(Bytes, SocketAddr)> {
let buf = vec![MaybeUninit::<u8>::uninit(); 0x1000];
let mut buf: Vec<u8> = unsafe { std::mem::transmute(buf) };
let (len, from) = self.recv_from(tag, &mut buf).await?;
Ok((Bytes::from(buf).split_to(len), from))
}
}
fn socket_addr_to_bytes(addr: SocketAddr) -> [u8; 6] {
match addr {
SocketAddr::V4(addr) => {
let [a, b, c, d] = addr.ip().octets();
let [e, f] = addr.port().to_be_bytes();
[a, b, c, d, e, f]
}
SocketAddr::V6(_) => todo!(),
}
}
fn socket_addr_from_bytes(bytes: [u8; 6]) -> SocketAddr {
let [a, b, c, d, e, f] = bytes;
SocketAddr::from(([a, b, c, d], u16::from_be_bytes([e, f])))
}