use super::*;
use derivative::*;
#[cfg(feature = "am")]
use std::collections::HashMap;
use std::net::SocketAddr;
use std::os::unix::io::AsRawFd;
#[cfg(feature = "am")]
use std::sync::RwLock;
#[cfg(feature = "event")]
use tokio::io::unix::AsyncFd;
#[derive(Derivative)]
#[derivative(Debug)]
pub struct Worker {
pub(super) handle: ucp_worker_h,
context: Arc<Context>,
#[cfg(feature = "am")]
#[derivative(Debug = "ignore")]
pub(crate) am_streams: RwLock<HashMap<u16, Rc<AmStreamInner>>>,
}
impl Drop for Worker {
fn drop(&mut self) {
unsafe { ucp_worker_destroy(self.handle) }
}
}
impl Worker {
pub(super) fn new(context: &Arc<Context>) -> Result<Rc<Self>, Error> {
let mut params = MaybeUninit::<ucp_worker_params_t>::uninit();
unsafe {
(*params.as_mut_ptr()).field_mask =
ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_THREAD_MODE.0 as _;
(*params.as_mut_ptr()).thread_mode = ucs_thread_mode_t::UCS_THREAD_MODE_SINGLE;
};
let mut handle = MaybeUninit::uninit();
let status =
unsafe { ucp_worker_create(context.handle, params.as_ptr(), handle.as_mut_ptr()) };
Error::from_status(status)?;
Ok(Rc::new(Worker {
handle: unsafe { handle.assume_init() },
context: context.clone(),
#[cfg(feature = "am")]
am_streams: RwLock::new(HashMap::new()),
}))
}
pub async fn polling(self: Rc<Self>) {
while Rc::strong_count(&self) > 1 {
while self.progress() != 0 {}
futures_lite::future::yield_now().await;
}
}
#[cfg(feature = "event")]
pub async fn event_poll(self: Rc<Self>) -> Result<(), Error> {
let fd = self.event_fd()?;
let wait_fd = AsyncFd::new(fd).unwrap();
while Rc::strong_count(&self) > 1 {
while self.progress() != 0 {}
if self.arm().unwrap() {
let mut ready = wait_fd.readable().await.unwrap();
ready.clear_ready();
}
}
Ok(())
}
pub fn print_to_stderr(&self) {
unsafe { ucp_worker_print_info(self.handle, stderr) };
}
pub fn thread_mode(&self) -> ucs_thread_mode_t {
let mut attr = MaybeUninit::<ucp_worker_attr>::uninit();
unsafe { &mut *attr.as_mut_ptr() }.field_mask =
ucp_worker_attr_field::UCP_WORKER_ATTR_FIELD_THREAD_MODE.0 as u64;
let status = unsafe { ucp_worker_query(self.handle, attr.as_mut_ptr()) };
assert_eq!(status, ucs_status_t::UCS_OK);
let attr = unsafe { attr.assume_init() };
attr.thread_mode
}
pub fn address(&self) -> Result<WorkerAddress<'_>, Error> {
let mut handle = MaybeUninit::uninit();
let mut length = MaybeUninit::uninit();
let status = unsafe {
ucp_worker_get_address(self.handle, handle.as_mut_ptr(), length.as_mut_ptr())
};
Error::from_status(status)?;
Ok(WorkerAddress {
handle: unsafe { handle.assume_init() },
length: unsafe { length.assume_init() } as usize,
worker: self,
})
}
pub fn create_listener(self: &Rc<Self>, addr: SocketAddr) -> Result<Listener, Error> {
Listener::new(self, addr)
}
pub fn connect_addr(self: &Rc<Self>, addr: &WorkerAddress) -> Result<Endpoint, Error> {
Endpoint::connect_addr(self, addr.handle)
}
pub async fn connect_socket(self: &Rc<Self>, addr: SocketAddr) -> Result<Endpoint, Error> {
Endpoint::connect_socket(self, addr).await
}
pub async fn accept(self: &Rc<Self>, connection: ConnectionRequest) -> Result<Endpoint, Error> {
Endpoint::accept(self, connection).await
}
pub fn wait(&self) -> Result<(), Error> {
let status = unsafe { ucp_worker_wait(self.handle) };
Error::from_status(status)
}
pub fn arm(&self) -> Result<bool, Error> {
let status = unsafe { ucp_worker_arm(self.handle) };
match status {
ucs_status_t::UCS_OK => Ok(true),
ucs_status_t::UCS_ERR_BUSY => Ok(false),
status => Err(Error::from_error(status)),
}
}
pub fn progress(&self) -> u32 {
unsafe { ucp_worker_progress(self.handle) }
}
pub fn event_fd(&self) -> Result<i32, Error> {
let mut fd = MaybeUninit::uninit();
let status = unsafe { ucp_worker_get_efd(self.handle, fd.as_mut_ptr()) };
Error::from_status(status)?;
unsafe { Ok(fd.assume_init()) }
}
pub fn flush(&self) {
let status = unsafe { ucp_worker_flush(self.handle) };
assert_eq!(status, ucs_status_t::UCS_OK);
}
}
impl AsRawFd for Worker {
fn as_raw_fd(&self) -> i32 {
self.event_fd().unwrap()
}
}
#[derive(Debug)]
pub struct WorkerAddress<'a> {
handle: *mut ucp_address_t,
length: usize,
worker: &'a Worker,
}
impl<'a> AsRef<[u8]> for WorkerAddress<'a> {
fn as_ref(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.handle as *const u8, self.length) }
}
}
impl<'a> Drop for WorkerAddress<'a> {
fn drop(&mut self) {
unsafe { ucp_worker_release_address(self.worker.handle, self.handle) }
}
}