use crate::{
core::{GetRawSocket, Period, RawSocket},
error::{msg_from_errno, Error, ErrorKind},
old::OldSocket,
socket::*,
};
use libzmq_sys as sys;
use sys::errno;
use bitflags::bitflags;
use std::os::{
raw::{c_long, c_short, c_void},
unix::io::{AsRawFd, RawFd},
};
bitflags! {
pub struct Trigger: c_short {
const EMPTY = 0i16;
const READABLE = sys::ZMQ_POLLIN as c_short;
const WRITABLE = sys::ZMQ_POLLOUT as c_short;
}
}
pub const EMPTY: Trigger = Trigger::EMPTY;
pub const READABLE: Trigger = Trigger::READABLE;
pub const WRITABLE: Trigger = Trigger::WRITABLE;
bitflags! {
struct Cause: c_short {
const READABLE = sys::ZMQ_POLLIN as c_short;
const WRITABLE = sys::ZMQ_POLLOUT as c_short;
const ERROR = sys::ZMQ_POLLERR as c_short;
const PRIORITY = sys::ZMQ_POLLPRI as c_short;
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PollId(pub usize);
impl From<usize> for PollId {
fn from(val: usize) -> PollId {
PollId(val)
}
}
impl From<PollId> for usize {
fn from(val: PollId) -> usize {
val.0
}
}
pub enum Pollable<'a> {
Socket(&'a RawSocket),
Fd(RawFd),
}
impl<'a, T> From<&'a T> for Pollable<'a>
where
T: AsRawFd,
{
fn from(entity: &'a T) -> Self {
Pollable::Fd(entity.as_raw_fd())
}
}
impl<'a> From<&'a Client> for Pollable<'a> {
fn from(client: &'a Client) -> Self {
Pollable::Socket(client.raw_socket())
}
}
impl<'a> From<&'a Server> for Pollable<'a> {
fn from(server: &'a Server) -> Self {
Pollable::Socket(server.raw_socket())
}
}
impl<'a> From<&'a Radio> for Pollable<'a> {
fn from(radio: &'a Radio) -> Self {
Pollable::Socket(radio.raw_socket())
}
}
impl<'a> From<&'a Dish> for Pollable<'a> {
fn from(dish: &'a Dish) -> Self {
Pollable::Socket(dish.raw_socket())
}
}
impl<'a> From<&'a Gather> for Pollable<'a> {
fn from(gather: &'a Gather) -> Self {
Pollable::Socket(gather.raw_socket())
}
}
impl<'a> From<&'a Scatter> for Pollable<'a> {
fn from(scatter: &'a Scatter) -> Self {
Pollable::Socket(scatter.raw_socket())
}
}
#[doc(hidden)]
impl<'a> From<&'a OldSocket> for Pollable<'a> {
fn from(old: &'a OldSocket) -> Self {
Pollable::Socket(old.raw_socket())
}
}
#[derive(Clone, Debug)]
pub struct Iter<'a> {
inner: &'a Events,
pos: usize,
}
impl<'a> Iterator for Iter<'a> {
type Item = Event;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let raw = self.inner.inner.get(self.pos);
self.pos += 1;
if let Some(event) = raw {
if event.events == 0 {
self.next()
} else {
let user_data = event.user_data as *mut usize as usize;
Some(Event {
id: PollId(user_data),
cause: Cause::from_bits(event.events).unwrap(),
})
}
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.inner.inner.len();
(len, Some(len))
}
}
#[derive(Debug)]
pub struct IntoIter {
inner: Events,
pos: usize,
}
impl Iterator for IntoIter {
type Item = Event;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let raw = self.inner.inner.get(self.pos);
self.pos += 1;
raw.map(|raw| {
let user_data = raw.user_data as *mut usize as usize;
Event {
id: PollId(user_data),
cause: Cause::from_bits(raw.events).unwrap(),
}
})
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.inner.inner.len();
(len, Some(len))
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Event {
cause: Cause,
id: PollId,
}
impl Event {
pub fn is_readable(&self) -> bool {
self.cause.contains(Cause::READABLE)
}
pub fn is_writable(&self) -> bool {
self.cause.contains(Cause::WRITABLE)
}
pub fn is_error(&self) -> bool {
self.cause.contains(Cause::ERROR)
}
pub fn is_priority(&self) -> bool {
self.cause.contains(Cause::PRIORITY)
}
pub fn id(&self) -> PollId {
self.id
}
}
#[derive(Default, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Events {
inner: Vec<sys::zmq_poller_event_t>,
}
impl Events {
pub fn new() -> Self {
Self::default()
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Vec::with_capacity(capacity),
}
}
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn iter(&self) -> Iter {
Iter {
inner: &self,
pos: 0,
}
}
pub fn clear(&mut self) {
self.inner.clear();
}
}
impl<'a> IntoIterator for &'a Events {
type Item = Event;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl IntoIterator for Events {
type Item = Event;
type IntoIter = IntoIter;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
inner: self,
pos: 0,
}
}
}
#[derive(Eq, PartialEq, Debug)]
pub struct Poller {
poller: *mut c_void,
count: usize,
}
impl Poller {
pub fn new() -> Self {
Self::default()
}
pub fn add<'a, P>(
&mut self,
pollable: P,
id: PollId,
trigger: Trigger,
) -> Result<(), Error>
where
P: Into<Pollable<'a>>,
{
match pollable.into() {
Pollable::Socket(raw_socket) => {
self.add_raw_socket(raw_socket, id, trigger)
}
Pollable::Fd(fd) => self.add_fd(fd, id, trigger),
}
}
fn add_fd(
&mut self,
fd: RawFd,
id: PollId,
trigger: Trigger,
) -> Result<(), Error> {
let user_data: usize = id.into();
let user_data = user_data as *mut usize as *mut c_void;
let rc = unsafe {
sys::zmq_poller_add_fd(self.poller, fd, user_data, trigger.bits())
};
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::EINVAL => {
Error::new(ErrorKind::InvalidInput("cannot add fd twice"))
}
errno::EBADF => Error::new(ErrorKind::InvalidInput(
"specified fd was the retired fd",
)),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
self.count += 1;
Ok(())
}
}
fn add_raw_socket(
&mut self,
raw_socket: &RawSocket,
id: PollId,
trigger: Trigger,
) -> Result<(), Error> {
let socket_mut_ptr = raw_socket.as_mut_ptr();
let user_data: usize = id.into();
let user_data = user_data as *mut usize as *mut c_void;
let rc = unsafe {
sys::zmq_poller_add(
self.poller,
socket_mut_ptr,
user_data,
trigger.bits(),
)
};
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::EINVAL => Error::new(ErrorKind::InvalidInput(
"cannot add socket twice",
)),
errno::ENOTSOCK => panic!("invalid socket"),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
self.count += 1;
Ok(())
}
}
pub fn remove<'a, P>(&mut self, pollable: P) -> Result<(), Error>
where
P: Into<Pollable<'a>>,
{
match pollable.into() {
Pollable::Socket(raw_socket) => self.remove_raw_socket(raw_socket),
Pollable::Fd(fd) => self.remove_fd(fd),
}
}
fn remove_fd(&mut self, fd: RawFd) -> Result<(), Error> {
let rc = unsafe { sys::zmq_poller_remove_fd(self.poller, fd) };
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::EINVAL => Error::new(ErrorKind::InvalidInput(
"cannot remove absent fd",
)),
errno::EBADF => Error::new(ErrorKind::InvalidInput(
"specified fd was the retired fd",
)),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
self.count -= 1;
Ok(())
}
}
fn remove_raw_socket(
&mut self,
raw_socket: &RawSocket,
) -> Result<(), Error> {
let socket_mut_ptr = raw_socket.as_mut_ptr();
let rc = unsafe { sys::zmq_poller_remove(self.poller, socket_mut_ptr) };
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::ENOTSOCK => panic!("invalid socket"),
errno::EINVAL => Error::new(ErrorKind::InvalidInput(
"cannot remove absent socket",
)),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
self.count -= 1;
Ok(())
}
}
pub fn modify<'a, P>(
&mut self,
pollable: P,
trigger: Trigger,
) -> Result<(), Error>
where
P: Into<Pollable<'a>>,
{
match pollable.into() {
Pollable::Socket(raw_socket) => {
self.modify_raw_socket(raw_socket, trigger)
}
Pollable::Fd(fd) => self.modify_fd(fd, trigger),
}
}
fn modify_fd(&mut self, fd: RawFd, trigger: Trigger) -> Result<(), Error> {
let rc = unsafe {
sys::zmq_poller_modify_fd(self.poller, fd, trigger.bits())
};
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::EINVAL => Error::new(ErrorKind::InvalidInput(
"cannot modify absent fd",
)),
errno::EBADF => Error::new(ErrorKind::InvalidInput(
"specified fd is the retired fd",
)),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
Ok(())
}
}
fn modify_raw_socket(
&mut self,
raw_socket: &RawSocket,
trigger: Trigger,
) -> Result<(), Error> {
let socket_mut_ptr = raw_socket.as_mut_ptr();
let rc = unsafe {
sys::zmq_poller_modify(self.poller, socket_mut_ptr, trigger.bits())
};
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::EINVAL => Error::new(ErrorKind::InvalidInput(
"cannot modify absent socket",
)),
errno::ENOTSOCK => panic!("invalid socket"),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
Ok(())
}
}
fn wait(&mut self, events: &mut Events, timeout: i64) -> Result<(), Error> {
events.clear();
for _i in 0..self.count {
events.inner.push(sys::zmq_poller_event_t::default());
}
let rc = unsafe {
sys::zmq_poller_wait_all(
self.poller,
events.inner.as_mut_ptr(),
events.inner.len() as i32,
timeout as c_long,
)
};
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
let err = match errno {
errno::EINVAL => panic!("invalid poller"),
errno::ETERM => Error::new(ErrorKind::InvalidCtx),
errno::EINTR => Error::new(ErrorKind::Interrupted),
errno::EAGAIN => Error::new(ErrorKind::WouldBlock),
_ => panic!(msg_from_errno(errno)),
};
Err(err)
} else {
Ok(())
}
}
pub fn try_poll(&mut self, events: &mut Events) -> Result<(), Error> {
self.wait(events, 0)
}
pub fn poll(
&mut self,
events: &mut Events,
timeout: Period,
) -> Result<(), Error> {
match timeout {
Period::Finite(duration) => {
let ms = duration.as_millis();
if ms > i64::max_value() as u128 {
return Err(Error::new(ErrorKind::InvalidInput(
"ms in timeout must be less than i64::MAX",
)));
}
self.wait(events, ms as i64)
}
Period::Infinite => self.wait(events, -1),
}
}
}
impl Default for Poller {
fn default() -> Self {
let poller = unsafe { sys::zmq_poller_new() };
if poller.is_null() {
panic!(msg_from_errno(unsafe { sys::zmq_errno() }));
}
Self { poller, count: 0 }
}
}
impl Drop for Poller {
fn drop(&mut self) {
let rc = unsafe { sys::zmq_poller_destroy(&mut self.poller) };
if rc != 0 {
let errno = unsafe { sys::zmq_errno() };
match errno {
errno::EFAULT => panic!("invalid poller"),
_ => panic!(msg_from_errno(errno)),
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_trigger() {
assert_eq!(READABLE.bits(), sys::ZMQ_POLLIN as c_short);
assert_eq!(WRITABLE.bits(), sys::ZMQ_POLLOUT as c_short);
}
}