use crate::markers;
use bytes::BytesMut;
use failure::Fallible;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, Mutex};
use zmq;
pub type ZMQSocketArc = Arc<Mutex<zmq::Socket>>;
pub trait ZMQCodec {
fn zmq_encode(&self) -> Fallible<Vec<zmq::Message>>;
fn zmq_decode(from: Vec<zmq::Message>) -> Fallible<Box<Self>>;
}
#[derive(Debug, Clone)]
pub struct RawMessage {
pub raw_parts: Vec<BytesMut>,
}
impl markers::ZMQMessageMarker for RawMessage {}
impl ZMQCodec for RawMessage {
fn zmq_encode(&self) -> Fallible<Vec<zmq::Message>> {
let mut ret: Vec<zmq::Message> = Vec::with_capacity(self.raw_parts.len());
for part in self.raw_parts.iter() {
ret.push(zmq::Message::from(&part[..]));
}
Ok(ret)
}
fn zmq_decode(from: Vec<zmq::Message>) -> Fallible<Box<Self>> {
let mut raw_parts: Vec<BytesMut> = Vec::with_capacity(from.len());
for msg in from.iter() {
raw_parts.push(BytesMut::from(msg as &[u8]));
}
let msgbox = Box::new(RawMessage { raw_parts });
Ok(msgbox)
}
}
#[macro_export]
macro_rules! naive_tryfrom (
( $totyp: ident, [ $( $fromtyp: ident ),* ]) => {
$(
impl TryFrom<$fromtyp> for $totyp {
type Error = failure::Error;
fn try_from(rmsg: $fromtyp) -> Result<Self, Self::Error> {
let msgparts = rmsg.zmq_encode()?;
let msg = *$totyp::zmq_decode(msgparts)?;
Ok(msg)
}
}
impl TryFrom<&$fromtyp> for $totyp {
type Error = failure::Error;
fn try_from(rmsg: &$fromtyp) -> Result<Self, Self::Error> {
let msgparts = rmsg.zmq_encode()?;
let msg = *$totyp::zmq_decode(msgparts)?;
Ok(msg)
}
}
)*
}
);
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub enum ZMQSocketType {
PUB,
SUB,
REQ,
REP,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct ZMQSocketDescription {
pub socketuris: Vec<String>,
pub sockettype: ZMQSocketType,
}
pub type SocketHandlerArc = Arc<Mutex<dyn SocketHandler>>;
pub trait SocketHandler {
fn _get_cached_socket(&self, desc: &ZMQSocketDescription) -> Option<ZMQSocketArc>;
fn _set_cached_socket(
&mut self,
desc: &ZMQSocketDescription,
sock: zmq::Socket,
) -> Fallible<()>;
fn get_open_sockets(&self) -> Vec<ZMQSocketArc>;
fn close_all_sockets(&self) -> Fallible<()> {
Ok(())
}
fn get_socket(&mut self, desc: &ZMQSocketDescription) -> Fallible<ZMQSocketArc> {
match self._get_cached_socket(desc) {
Some(socket) => {
log::debug!("Returning socket from cache");
return Ok(socket);
}
None => {}
}
let zctx = zmq::Context::new();
let rawsocket = match desc.sockettype {
ZMQSocketType::PUB => zctx.socket(zmq::PUB)?,
ZMQSocketType::SUB => zctx.socket(zmq::SUB)?,
ZMQSocketType::REQ => zctx.socket(zmq::REQ)?,
ZMQSocketType::REP => zctx.socket(zmq::REP)?,
};
match desc.sockettype {
ZMQSocketType::PUB | ZMQSocketType::REP => {
for uri in desc.socketuris.iter() {
log::debug!("Binding to {}", uri);
rawsocket.bind(uri.as_str())?;
}
}
ZMQSocketType::SUB | ZMQSocketType::REQ => {
for uri in desc.socketuris.iter() {
log::debug!("connecting to {}", uri);
rawsocket.connect(uri.as_str())?;
}
}
}
self._set_cached_socket(desc, rawsocket)?;
Ok(self._get_cached_socket(desc).unwrap())
}
}
#[derive(Default)]
pub struct BaseSocketHandler {
sockets_by_desc: HashMap<ZMQSocketDescription, ZMQSocketArc>,
}
impl fmt::Debug for BaseSocketHandler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BaseSocketHandler")
.field("sockets_by_desc", &"<hidden>".to_string())
.finish()
}
}
lazy_static! {
static ref BASESOCKETHANDLER_SINGLETON: BaseSocketHandlerArc =
Arc::new(Mutex::new(BaseSocketHandler::new()));
}
type BaseSocketHandlerArc = Arc<Mutex<BaseSocketHandler>>;
impl BaseSocketHandler {
pub fn instance() -> BaseSocketHandlerArc {
BASESOCKETHANDLER_SINGLETON.clone()
}
pub fn new() -> BaseSocketHandler {
BaseSocketHandler {
..Default::default()
}
}
}
impl SocketHandler for BaseSocketHandler {
fn _get_cached_socket(&self, desc: &ZMQSocketDescription) -> Option<ZMQSocketArc> {
if !self.sockets_by_desc.contains_key(desc) {
return None;
}
Some(self.sockets_by_desc[desc].clone())
}
fn _set_cached_socket(
&mut self,
desc: &ZMQSocketDescription,
sock: zmq::Socket,
) -> Fallible<()> {
if self.sockets_by_desc.contains_key(desc) {
return Err(failure::err_msg("Described socket is already in cache"));
}
let sockwrapper: ZMQSocketArc = Arc::new(Mutex::new(sock));
match self.sockets_by_desc.insert(desc.clone(), sockwrapper) {
Some(_) => {
panic!("Updated existing socket key");
}
None => {}
}
Ok(())
}
fn get_open_sockets(&self) -> Vec<ZMQSocketArc> {
let mut ret: Vec<ZMQSocketArc> = Vec::with_capacity(self.sockets_by_desc.len());
for key in self.sockets_by_desc.keys() {
ret.push(self._get_cached_socket(key).unwrap());
}
ret
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env::temp_dir;
#[test]
fn test_sockethandler_singleton() {
let sh1 = BaseSocketHandler::instance();
let sh2 = BaseSocketHandler::instance();
log::debug!("sh1 is {:?}, sh2 is {:?}", sh1, sh2);
let mut tmppath1 = temp_dir();
tmppath1.push("d709d495-f587-4f9f-9566-f3c66721d48f_pub.sock");
let sockpath1 = "ipc://".to_string() + &tmppath1.to_string_lossy();
let desc1 = ZMQSocketDescription {
socketuris: vec![sockpath1],
sockettype: ZMQSocketType::PUB,
};
let _sock1 = sh1.lock().unwrap().get_socket(&desc1).unwrap();
let _sock2 = sh2.lock().unwrap().get_socket(&desc1).unwrap();
let mut tmppath2 = temp_dir();
tmppath2.push("7a4f4f4f-0016-420e-ae88-38e0c581ea29_pub.sock");
let sockpath2 = "ipc://".to_string() + &tmppath2.to_string_lossy();
let desc2 = ZMQSocketDescription {
socketuris: vec![sockpath2],
sockettype: ZMQSocketType::PUB,
};
let _sock3 = sh2.lock().unwrap().get_socket(&desc2).unwrap();
let svec1 = sh1.lock().unwrap().get_open_sockets();
let svec2 = sh1.lock().unwrap().get_open_sockets();
assert_eq!(svec1.len(), 2);
assert_eq!(svec2.len(), svec1.len());
}
#[test]
fn test_rawmessage_encode() {
let mut raw_parts: Vec<BytesMut> = Vec::with_capacity(3);
raw_parts.push(BytesMut::from(String::from("hellotopic").as_bytes()));
raw_parts.push(BytesMut::from(String::from("datapart1").as_bytes()));
raw_parts.push(BytesMut::from(String::from("datapart2").as_bytes()));
let msg = RawMessage { raw_parts };
log::debug!("msg is {:?}", msg);
let msgparts = msg.zmq_encode().unwrap();
assert_eq!(msgparts[0].as_str().unwrap(), String::from("hellotopic"));
assert_eq!(msgparts[1].as_str().unwrap(), String::from("datapart1"));
assert_eq!(msgparts[2].as_str().unwrap(), String::from("datapart2"));
}
#[test]
fn test_rawmessage_decode() {
let mut msgparts: Vec<zmq::Message> = Vec::with_capacity(3);
msgparts.push(zmq::Message::from(String::from("hellotopic").as_bytes()));
msgparts.push(zmq::Message::from(String::from("datapart1").as_bytes()));
msgparts.push(zmq::Message::from(String::from("datapart2").as_bytes()));
let msg = *RawMessage::zmq_decode(msgparts).unwrap();
assert_eq!(msg.raw_parts[0], String::from("hellotopic").as_bytes());
assert_eq!(msg.raw_parts[1], String::from("datapart1").as_bytes());
assert_eq!(msg.raw_parts[2], String::from("datapart2").as_bytes());
}
}