use std::{
hash::{Hash, Hasher},
num::NonZeroU32,
os::raw::c_void,
ptr::{self, NonNull},
sync::{
atomic::{AtomicPtr, AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use log::error;
use crate::{
ctx::Context,
error::{Error, Result, SendResult},
message::Message,
socket::Socket,
util::{abort_unwind, duration_to_nng, validate_ptr},
};
type InnerCallback = Box<dyn Fn() + Send + Sync + 'static>;
#[derive(Clone, Debug)]
pub struct Aio
{
inner: Arc<Inner>,
}
impl Aio
{
pub fn new<F>(callback: F) -> Result<Self>
where
F: Fn(Aio, AioResult) + Sync + Send + 'static,
{
let inner = Arc::new(Inner {
handle: AtomicPtr::new(ptr::null_mut()),
state: AtomicUsize::new(State::Inactive as usize),
callback: AtomicPtr::new(ptr::null_mut()),
});
let weak = Arc::downgrade(&inner);
let bounce = move || {
let cb_aio = match weak.upgrade() {
Some(i) => Aio { inner: i },
None => return,
};
let res = unsafe {
let state = cb_aio.inner.state.load(Ordering::Acquire).into();
let aiop = cb_aio.inner.handle.load(Ordering::Relaxed);
let rv = nng_sys::nng_aio_result(aiop) as u32;
let res = match (state, rv) {
(State::Sending, 0) => AioResult::Send(Ok(())),
(State::Sending, e) => {
let msgp = nng_sys::nng_aio_get_msg(aiop);
let msg = Message::from_ptr(NonNull::new(msgp).unwrap());
AioResult::Send(Err((msg, NonZeroU32::new(e).unwrap().into())))
},
(State::Receiving, 0) => {
let msgp = nng_sys::nng_aio_get_msg(aiop);
let msg = Message::from_ptr(NonNull::new(msgp).unwrap());
AioResult::Recv(Ok(msg))
},
(State::Receiving, e) => {
AioResult::Recv(Err(NonZeroU32::new(e).unwrap().into()))
},
(State::Sleeping, 0) => AioResult::Sleep(Ok(())),
(State::Sleeping, e) => {
AioResult::Sleep(Err(NonZeroU32::new(e).unwrap().into()))
},
(State::Inactive, _) => unreachable!(),
};
cb_aio.inner.state.store(State::Inactive as usize, Ordering::Release);
res
};
callback(cb_aio, res);
};
let boxed: Box<InnerCallback> = Box::new(Box::new(bounce));
let callback_ptr = Box::into_raw(boxed);
let mut aio: *mut nng_sys::nng_aio = ptr::null_mut();
let aiop: *mut *mut nng_sys::nng_aio = &mut aio as _;
let rv = unsafe { nng_sys::nng_aio_alloc(aiop, Some(Aio::trampoline), callback_ptr as _) };
if rv != 0 && !aio.is_null() {
error!("NNG returned a non-null pointer from a failed function");
return Err(Error::Unknown(0));
}
validate_ptr(rv, aio)?;
inner.handle.store(aio, Ordering::Release);
inner.callback.store(callback_ptr, Ordering::Relaxed);
Ok(Self { inner })
}
pub fn set_timeout(&self, dur: Option<Duration>) -> Result<()>
{
let sleeping = State::Sleeping as usize;
let inactive = State::Inactive as usize;
self.inner
.state
.compare_exchange(inactive, sleeping, Ordering::Acquire, Ordering::Acquire)
.map_err(|_| Error::IncorrectState)?;
let ms = duration_to_nng(dur);
let aiop = self.inner.handle.load(Ordering::Relaxed);
unsafe {
nng_sys::nng_aio_set_timeout(aiop, ms);
}
self.inner.state.store(inactive, Ordering::Release);
Ok(())
}
pub fn sleep(&self, dur: Duration) -> Result<()>
{
let sleeping = State::Sleeping as usize;
let inactive = State::Inactive as usize;
self.inner
.state
.compare_exchange(inactive, sleeping, Ordering::AcqRel, Ordering::Acquire)
.map_err(|_| Error::IncorrectState)?;
let ms = duration_to_nng(Some(dur));
let aiop = self.inner.handle.load(Ordering::Relaxed);
unsafe {
nng_sys::nng_sleep_aio(ms, aiop);
}
Ok(())
}
pub fn wait(&self)
{
unsafe {
nng_sys::nng_aio_wait(self.inner.handle.load(Ordering::Relaxed));
}
}
pub fn cancel(&self)
{
unsafe {
nng_sys::nng_aio_cancel(self.inner.handle.load(Ordering::Relaxed));
}
}
pub(crate) fn send_socket(&self, socket: &Socket, msg: Message) -> SendResult<()>
{
let inactive = State::Inactive as usize;
let sending = State::Sending as usize;
if self
.inner
.state
.compare_exchange(inactive, sending, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return Err((msg, Error::IncorrectState));
}
let aiop = self.inner.handle.load(Ordering::Relaxed);
unsafe {
nng_sys::nng_aio_set_msg(aiop, msg.into_ptr().as_ptr());
nng_sys::nng_send_aio(socket.handle(), aiop);
}
Ok(())
}
pub(crate) fn recv_socket(&self, socket: &Socket) -> Result<()>
{
let inactive = State::Inactive as usize;
let receiving = State::Receiving as usize;
self.inner
.state
.compare_exchange(inactive, receiving, Ordering::AcqRel, Ordering::Acquire)
.map_err(|_| Error::IncorrectState)?;
let aiop = self.inner.handle.load(Ordering::Relaxed);
unsafe {
nng_sys::nng_recv_aio(socket.handle(), aiop);
}
Ok(())
}
pub(crate) fn send_ctx(&self, ctx: &Context, msg: Message) -> SendResult<()>
{
let inactive = State::Inactive as usize;
let sending = State::Sending as usize;
if self
.inner
.state
.compare_exchange(inactive, sending, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return Err((msg, Error::IncorrectState));
}
let aiop = self.inner.handle.load(Ordering::Relaxed);
unsafe {
nng_sys::nng_aio_set_msg(aiop, msg.into_ptr().as_ptr());
nng_sys::nng_ctx_send(ctx.handle(), aiop);
}
Ok(())
}
pub(crate) fn recv_ctx(&self, ctx: &Context) -> Result<()>
{
let inactive = State::Inactive as usize;
let receiving = State::Receiving as usize;
self.inner
.state
.compare_exchange(inactive, receiving, Ordering::AcqRel, Ordering::Acquire)
.map_err(|_| Error::IncorrectState)?;
let aiop = self.inner.handle.load(Ordering::Relaxed);
unsafe {
nng_sys::nng_ctx_recv(ctx.handle(), aiop);
}
Ok(())
}
extern "C" fn trampoline(arg: *mut c_void)
{
abort_unwind(|| unsafe {
let callback_ptr = arg as *const InnerCallback;
assert!(
!callback_ptr.is_null(),
"Null argument given to trampoline function - please open an issue"
);
(*callback_ptr)();
});
}
}
#[cfg(feature = "ffi-module")]
impl Aio
{
pub unsafe fn nng_aio(&self) -> *mut nng_sys::nng_aio
{
self.inner.handle.load(Ordering::Relaxed)
}
pub fn state(&self, ordering: Ordering) -> State { self.inner.state.load(ordering).into() }
pub unsafe fn set_state(&self, state: State, ordering: Ordering)
{
self.inner.state.store(state as usize, ordering);
}
}
impl Hash for Aio
{
fn hash<H: Hasher>(&self, state: &mut H)
{
self.inner.handle.load(Ordering::Relaxed).hash(state);
}
}
impl PartialEq for Aio
{
fn eq(&self, other: &Aio) -> bool
{
self.inner.handle.load(Ordering::Relaxed) == other.inner.handle.load(Ordering::Relaxed)
}
}
impl Eq for Aio {}
#[derive(Debug)]
struct Inner
{
handle: AtomicPtr<nng_sys::nng_aio>,
state: AtomicUsize,
callback: AtomicPtr<InnerCallback>,
}
impl Drop for Inner
{
#[allow(clippy::let_underscore_drop)]
fn drop(&mut self)
{
let aiop = self.handle.load(Ordering::Acquire);
if !aiop.is_null() {
unsafe {
nng_sys::nng_aio_stop(aiop);
nng_sys::nng_aio_free(aiop);
let _ = Box::from_raw(self.callback.load(Ordering::Relaxed));
}
}
}
}
#[derive(Clone, Debug)]
#[must_use]
pub enum AioResult
{
Send(SendResult<()>),
Recv(Result<Message>),
Sleep(Result<()>),
}
impl From<AioResult> for Result<Option<Message>>
{
fn from(aio_res: AioResult) -> Result<Option<Message>>
{
use self::AioResult::*;
match aio_res {
Recv(Ok(m)) => Ok(Some(m)),
Send(Ok(_)) | Sleep(Ok(_)) => Ok(None),
Send(Err((_, e))) | Recv(Err(e)) | Sleep(Err(e)) => Err(e),
}
}
}
mod state
{
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(usize)]
pub enum State
{
Inactive,
Sending,
Receiving,
Sleeping,
}
#[cfg_attr(feature = "ffi-module", doc(hidden))]
impl From<usize> for State
{
fn from(atm: usize) -> State
{
match atm {
x if x == State::Inactive as usize => State::Inactive,
x if x == State::Sending as usize => State::Sending,
x if x == State::Receiving as usize => State::Receiving,
x if x == State::Sleeping as usize => State::Sleeping,
_ => unreachable!(),
}
}
}
}
#[cfg(not(feature = "ffi-module"))]
use self::state::State;
#[cfg(feature = "ffi-module")]
pub use self::state::State;