use crate::{error::msg_from_errno, GroupSlice};
use libzmq_sys as sys;
use sys::errno;
use libc::size_t;
use log::error;
use serde::{Deserialize, Serialize};
use std::{
ffi::CStr,
fmt,
os::raw::c_void,
ptr, slice,
str::{self, Utf8Error},
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct RoutingId(pub u32);
pub struct Msg {
msg: sys::zmq_msg_t,
}
impl From<RoutingId> for u32 {
fn from(id: RoutingId) -> u32 {
id.0
}
}
impl From<u32> for RoutingId {
fn from(u: u32) -> Self {
Self(u)
}
}
impl Msg {
pub fn new() -> Self {
Self::default()
}
pub fn with_size(size: usize) -> Self {
unsafe {
Self::deferred_alloc(|msg| {
sys::zmq_msg_init_size(msg, size as size_t)
})
}
}
pub fn len(&self) -> usize {
unsafe { sys::zmq_msg_size(self.as_ptr()) }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn to_str(&self) -> Result<&str, Utf8Error> {
str::from_utf8(self.as_bytes())
}
pub fn as_bytes(&self) -> &[u8] {
unsafe {
let ptr = &self.msg as *const _ as *mut _;
let data = sys::zmq_msg_data(ptr);
slice::from_raw_parts(data as *mut u8, self.len())
}
}
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
unsafe {
let data = sys::zmq_msg_data(self.as_mut_ptr());
slice::from_raw_parts_mut(data as *mut u8, self.len())
}
}
pub fn routing_id(&self) -> Option<RoutingId> {
let rc = unsafe {
let ptr = self.as_ptr() as *mut _;
sys::zmq_msg_routing_id(ptr)
};
if rc == 0 {
None
} else {
Some(RoutingId(rc))
}
}
pub fn set_routing_id(&mut self, routing_id: RoutingId) {
let rc = unsafe {
sys::zmq_msg_set_routing_id(self.as_mut_ptr(), routing_id.0)
};
if rc != 0 {
let errno = unsafe { sys::zmq_errno() };
panic!(msg_from_errno(errno));
}
}
pub fn group(&self) -> Option<&GroupSlice> {
let mut_msg_ptr = self.as_ptr() as *mut _;
let char_ptr = unsafe { sys::zmq_msg_group(mut_msg_ptr) };
if char_ptr.is_null() {
None
} else {
let c_str = unsafe { CStr::from_ptr(char_ptr) };
Some(GroupSlice::from_c_str_unchecked(c_str))
}
}
pub fn set_group<G>(&mut self, group: G)
where
G: AsRef<GroupSlice>,
{
let group = group.as_ref();
let rc = unsafe {
sys::zmq_msg_set_group(self.as_mut_ptr(), group.as_c_str().as_ptr())
};
if rc == -1 {
let errno = unsafe { sys::zmq_errno() };
panic!(msg_from_errno(errno));
}
}
unsafe fn deferred_alloc<F>(f: F) -> Msg
where
F: FnOnce(&mut sys::zmq_msg_t) -> i32,
{
let mut msg = sys::zmq_msg_t::default();
let rc = f(&mut msg);
if rc == -1 {
panic!(msg_from_errno(sys::zmq_errno()));
}
Msg { msg }
}
pub(crate) fn as_mut_ptr(&mut self) -> *mut sys::zmq_msg_t {
&mut self.msg
}
pub(crate) fn as_ptr(&self) -> *const sys::zmq_msg_t {
&self.msg
}
pub(crate) fn has_more(&self) -> bool {
let rc = unsafe { sys::zmq_msg_more(self.as_ptr()) };
rc != 0
}
}
impl PartialEq for Msg {
fn eq(&self, other: &Self) -> bool {
ptr::eq(self.as_ptr(), other.as_ptr())
}
}
impl Eq for Msg {}
impl fmt::Debug for Msg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.as_bytes())
}
}
impl Default for Msg {
fn default() -> Self {
unsafe { Self::deferred_alloc(|msg| sys::zmq_msg_init(msg)) }
}
}
impl Clone for Msg {
fn clone(&self) -> Self {
let mut msg = Msg::new();
let rc = unsafe {
let ptr = self.as_ptr() as *mut _;
sys::zmq_msg_copy(msg.as_mut_ptr(), ptr)
};
if rc != 0 {
let errno = unsafe { sys::zmq_errno() };
match errno {
errno::EFAULT => panic!("invalid message"),
_ => panic!(msg_from_errno(errno)),
}
}
msg
}
}
impl Drop for Msg {
fn drop(&mut self) {
let rc = unsafe { sys::zmq_msg_close(self.as_mut_ptr()) };
if rc != 0 {
let errno = unsafe { sys::zmq_errno() };
error!("error while dropping message: {}", msg_from_errno(errno));
}
}
}
impl From<Box<[u8]>> for Msg {
fn from(data: Box<[u8]>) -> Msg {
unsafe extern "C" fn drop_zmq_msg_t(
data: *mut c_void,
_hint: *mut c_void,
) {
Box::from_raw(data as *mut u8);
}
if data.is_empty() {
return Msg::new();
}
let size = data.len() as size_t;
let data = Box::into_raw(data);
unsafe {
Self::deferred_alloc(|msg| {
sys::zmq_msg_init_data(
msg,
data as *mut c_void,
size,
Some(drop_zmq_msg_t),
ptr::null_mut(), )
})
}
}
}
impl<'a> From<&[u8]> for Msg {
fn from(slice: &[u8]) -> Self {
unsafe {
let mut msg = Msg::with_size(slice.len());
ptr::copy_nonoverlapping(
slice.as_ptr(),
msg.as_bytes_mut().as_mut_ptr(),
slice.len(),
);
msg
}
}
}
macro_rules! array_impls {
($($N:expr)+) => {
$(
impl From<[u8; $N]> for Msg {
fn from(array: [u8; $N]) -> Self {
let boxed: Box<[u8]> = Box::new(array);
Msg::from(boxed)
}
}
)+
}
}
array_impls! {
0 1 2 3 4 5 6 7 8 9
10 11 12 13 14 15 16 17 18 19
20 21 22 23 24 25 26 27 28 29
30 31 32
}
impl From<Vec<u8>> for Msg {
fn from(bytes: Vec<u8>) -> Self {
Msg::from(bytes.into_boxed_slice())
}
}
impl<'a> From<&'a str> for Msg {
fn from(text: &str) -> Self {
Msg::from(text.as_bytes())
}
}
impl From<String> for Msg {
fn from(text: String) -> Self {
Msg::from(text.into_bytes())
}
}
impl<'a, T> From<&'a T> for Msg
where
T: Into<Msg> + Clone,
{
fn from(v: &'a T) -> Self {
v.clone().into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem;
#[test]
fn test_cast_routing_id_slice() {
assert_eq!(mem::size_of::<u32>(), mem::size_of::<RoutingId>());
let routing_stack: &[u32] = &[1, 2, 3, 4];
let cast_stack = unsafe {
slice::from_raw_parts(
routing_stack.as_ptr() as *const RoutingId,
routing_stack.len(),
)
};
for (&i, &j) in routing_stack.iter().zip(cast_stack.iter()) {
assert_eq!(i, j.0);
}
}
}