async-ucx 0.1.1

Asynchronous Rust bindings to UCX.
Documentation
use async_ucx::ucp::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::io::Result;
use std::mem::MaybeUninit;
use std::net::SocketAddr;
use std::sync::atomic::*;
use std::sync::Arc;
use tokio::sync::mpsc;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<()> {
    env_logger::init();
    let local = tokio::task::LocalSet::new();
    if let Some(server_addr) = std::env::args().nth(1) {
        local.run_until(client(server_addr)).await;
    } else {
        local.run_until(server()).await;
    }
    Ok(())
}

async fn client(server_addr: String) -> ! {
    println!("client: connect to {:?}", server_addr);

    let context = Context::new().unwrap();
    let worker = context.create_worker().unwrap();
    #[cfg(not(feature = "event"))]
    tokio::task::spawn_local(worker.clone().polling());
    #[cfg(feature = "event")]
    tokio::task::spawn_local(worker.clone().event_poll());

    let endpoint = worker
        .connect_socket(server_addr.parse().unwrap())
        .await
        .unwrap();
    endpoint.print_to_stderr();

    let mut tag = [MaybeUninit::uninit(); 8];
    endpoint.worker().tag_recv(100, &mut tag).await.unwrap();
    let tag: u64 = unsafe { std::mem::transmute(tag) };
    println!("client: got tag {:#x}", tag);

    let long_msg: Vec<u8> = (0..8).map(|x| x as u8).collect();
    loop {
        endpoint.tag_send(tag, &long_msg).await.unwrap();
        for _ in 0..500 {
            std::hint::spin_loop();
        }
        // endpoint
        //     .worker()
        //     .tag_recv(tag, &mut [MaybeUninit::uninit()])
        //     .await;
    }
}

async fn server() -> ! {
    println!("server");
    let context = Context::new().unwrap();
    let mut worker_threads = vec![];
    let mut counters = vec![];
    for _ in 0..4 {
        let wt = WorkerThread::new(&context, |ep, addr, counter| {
            println!("accept: {:?}", addr);
            tokio::task::spawn_local(async move {
                let mut hasher = DefaultHasher::new();
                addr.hash(&mut hasher);
                let tag = hasher.finish();
                ep.tag_send(100, &tag.to_ne_bytes()).await.unwrap();

                let mut buf = vec![MaybeUninit::uninit(); 50000];
                loop {
                    ep.worker().tag_recv(tag, &mut buf).await.unwrap();
                    // ep.tag_send(tag, &[0]).await;
                    unsafe { *(&*counter as *const AtomicUsize as *mut usize) += 1 };
                }
            });
        });
        counters.push(wt.counter.clone());
        worker_threads.push(wt);
    }

    let worker = context.create_worker().unwrap();
    #[cfg(not(feature = "event"))]
    tokio::task::spawn_local(worker.clone().polling());
    #[cfg(feature = "event")]
    tokio::task::spawn_local(worker.clone().event_poll());

    let mut listener = worker
        .create_listener("0.0.0.0:0".parse().unwrap())
        .unwrap();
    tokio::task::spawn_local(async move {
        loop {
            tokio::time::sleep(std::time::Duration::from_secs(1)).await;
            let count: usize = counters.iter().map(|c| c.swap(0, Ordering::SeqCst)).sum();
            println!("{} IOPS", count);
        }
    });
    println!("listening on {}", listener.socket_addr().unwrap());

    for i in 0.. {
        let conn = listener.next().await;
        let n = worker_threads.len();
        worker_threads[i % n].accept(conn);
    }
    unreachable!()
}

struct WorkerThread {
    sender: mpsc::UnboundedSender<ConnectionRequest>,
    counter: Arc<AtomicUsize>,
}

impl WorkerThread {
    fn new(context: &Arc<Context>, handle_ep: fn(Endpoint, SocketAddr, Arc<AtomicUsize>)) -> Self {
        let context = context.clone();
        let (sender, mut recver) = mpsc::unbounded_channel::<ConnectionRequest>();
        let counter = Arc::new(AtomicUsize::new(0));
        let counter1 = counter.clone();
        std::thread::spawn(move || {
            let worker = context.create_worker().unwrap();
            let rt = tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .unwrap();
            let local = tokio::task::LocalSet::new();
            #[cfg(not(event))]
            local.spawn_local(worker.clone().polling());
            #[cfg(feature = "event")]
            local.spawn_local(worker.clone().event_poll());
            local.block_on(&rt, async move {
                while let Some(conn) = recver.recv().await {
                    let addr = conn.remote_addr().unwrap();
                    let ep = worker.accept(conn).await.unwrap();
                    handle_ep(ep, addr, counter1.clone());
                }
            });
        });
        WorkerThread { sender, counter }
    }

    fn accept(&mut self, conn: ConnectionRequest) {
        self.sender.send(conn).unwrap();
    }
}