#![cfg(feature = "tokio")]
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::Duration;
use nodedb_bridge::async_bridge::BridgeChannel;
const MESSAGE_COUNT: u64 = 10_000;
fn tpc_event_loop(
mut data: nodedb_bridge::async_bridge::DataHandle<u64, u64>,
done: Arc<AtomicBool>,
) {
let epoll_fd = unsafe { libc::epoll_create1(0) };
assert!(epoll_fd >= 0, "epoll_create1 failed");
let req_fd = data.request_wake_fd();
let mut event = libc::epoll_event {
events: (libc::EPOLLIN | libc::EPOLLET) as u32,
u64: req_fd as u64,
};
let ret = unsafe { libc::epoll_ctl(epoll_fd, libc::EPOLL_CTL_ADD, req_fd, &mut event) };
assert_eq!(ret, 0, "epoll_ctl failed");
let mut processed = 0u64;
let mut events = [libc::epoll_event { events: 0, u64: 0 }; 8];
while processed < MESSAGE_COUNT {
let nfds = unsafe { libc::epoll_wait(epoll_fd, events.as_mut_ptr(), 8, 10) };
if nfds < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
panic!("epoll_wait failed: {err}");
}
if nfds > 0 {
let _ = data.req_wake.consumer_wake.try_read();
}
let mut batch = Vec::new();
data.drain_requests(&mut batch, 512);
for req in batch {
loop {
match data.try_send_response(req * 2) {
Ok(()) => break,
Err(nodedb_bridge::BridgeError::Full { .. }) => {
thread::yield_now();
}
Err(e) => panic!("TPC send error: {e}"),
}
}
processed += 1;
}
}
done.store(true, Ordering::Release);
unsafe { libc::close(epoll_fd) };
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tokio_to_tpc_via_eventfd() {
let bridge: BridgeChannel<u64, u64> = BridgeChannel::new(1024, 1024).unwrap();
let control = bridge.control;
let data = bridge.data;
let done = Arc::new(AtomicBool::new(false));
let done_clone = Arc::clone(&done);
let tpc_handle = thread::spawn(move || {
tpc_event_loop(data, done_clone);
});
let mut async_ctrl = nodedb_bridge::tokio_fd::AsyncControlHandle::new(control).unwrap();
let mut sent = 0u64;
let mut responses = Vec::with_capacity(MESSAGE_COUNT as usize);
while responses.len() < MESSAGE_COUNT as usize {
while sent < MESSAGE_COUNT {
match async_ctrl.inner.try_send_request(sent + 1) {
Ok(()) => sent += 1,
Err(nodedb_bridge::BridgeError::Full { .. }) => break,
Err(e) => panic!("send error: {e}"),
}
}
loop {
match async_ctrl.try_recv_response() {
Ok(rsp) => responses.push(rsp),
Err(nodedb_bridge::BridgeError::Empty) => break,
Err(nodedb_bridge::BridgeError::Disconnected { .. }) => break,
Err(e) => panic!("recv error: {e}"),
}
}
if responses.len() < MESSAGE_COUNT as usize {
tokio::task::yield_now().await;
}
}
tpc_handle.join().unwrap();
responses.sort();
assert_eq!(responses.len(), MESSAGE_COUNT as usize);
for (i, rsp) in responses.iter().enumerate() {
let expected = (i as u64 + 1) * 2;
assert_eq!(
*rsp, expected,
"response mismatch at index {i}: got {rsp}, expected {expected}"
);
}
assert!(done.load(Ordering::Acquire));
}
#[tokio::test]
async fn tpc_wakes_tokio_on_response() {
let bridge: BridgeChannel<u64, u64> = BridgeChannel::new(16, 16).unwrap();
let mut data = bridge.data;
let control = bridge.control;
let mut async_ctrl = nodedb_bridge::tokio_fd::AsyncControlHandle::new(control).unwrap();
let recv_task = tokio::spawn(async move {
let rsp = async_ctrl.recv_response().await.unwrap();
(async_ctrl, rsp)
});
tokio::time::sleep(Duration::from_millis(20)).await;
data.try_send_response(999).unwrap();
let (_ctrl, rsp) = tokio::time::timeout(Duration::from_secs(5), recv_task)
.await
.expect("recv should complete within 5s")
.unwrap();
assert_eq!(rsp, 999);
}