datastreamcorelib 0.5.1

Rust version of https://gitlab.com/advian-oss/python-datastreamcorelib
Documentation
/// ZMQ backend agnostic abstractions
use crate::markers;
use bytes::BytesMut;
use failure::Fallible;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, Mutex};
use zmq;

pub type ZMQSocketArc = Arc<Mutex<zmq::Socket>>;

/// Trait for things that can enCOded to be sent and received over ZMQ sockets for DECoding
pub trait ZMQCodec {
    /// Encode the abtracted message into ZMQ message parts, ready for sending
    fn zmq_encode(&self) -> Fallible<Vec<zmq::Message>>;
    /// Decode ZMQ message parts into a pointer to abstracted message
    fn zmq_decode(from: Vec<zmq::Message>) -> Fallible<Box<Self>>;
}

/// Raw message, basically used only for testing the base traits
/// Can also be used for type between wire-compatible abstractions when there's no
/// direct TryFrom trait between them.
#[derive(Debug, Clone)]
pub struct RawMessage {
    pub raw_parts: Vec<BytesMut>,
}

impl markers::ZMQMessageMarker for RawMessage {}

impl ZMQCodec for RawMessage {
    /// Encode the abtracted message into ZMQ message parts, ready for sending
    fn zmq_encode(&self) -> Fallible<Vec<zmq::Message>> {
        let mut ret: Vec<zmq::Message> = Vec::with_capacity(self.raw_parts.len());
        for part in self.raw_parts.iter() {
            ret.push(zmq::Message::from(&part[..]));
        }
        Ok(ret)
    }

    /// Decode ZMQ message parts into a pointer to abstracted message
    fn zmq_decode(from: Vec<zmq::Message>) -> Fallible<Box<Self>> {
        let mut raw_parts: Vec<BytesMut> = Vec::with_capacity(from.len());
        for msg in from.iter() {
            raw_parts.push(BytesMut::from(msg as &[u8]));
        }
        let msgbox = Box::new(RawMessage { raw_parts });
        Ok(msgbox)
    }
}

/// This macro implements naive type conversion for things that support ZMQCodec trait
#[macro_export]
macro_rules! naive_tryfrom (
    ( $totyp: ident, [ $( $fromtyp: ident ),* ]) => {
        $(
            impl TryFrom<$fromtyp> for $totyp {
                type Error = failure::Error;

                fn try_from(rmsg: $fromtyp) -> Result<Self, Self::Error> {
                    let msgparts = rmsg.zmq_encode()?;
                    let msg = *$totyp::zmq_decode(msgparts)?;
                    Ok(msg)
                }
            }
            impl TryFrom<&$fromtyp> for $totyp {
                type Error = failure::Error;

                fn try_from(rmsg: &$fromtyp) -> Result<Self, Self::Error> {
                    let msgparts = rmsg.zmq_encode()?;
                    let msg = *$totyp::zmq_decode(msgparts)?;
                    Ok(msg)
                }
            }
        )*
    }
);

/// Abstracted hashable socket type enum
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub enum ZMQSocketType {
    PUB,
    SUB,
    REQ,
    REP,
}

/// Abstract description for a socket (combination of type and URIs)
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct ZMQSocketDescription {
    pub socketuris: Vec<String>,
    pub sockettype: ZMQSocketType,
}

/// Shorthand type for SocketHandler instances
pub type SocketHandlerArc = Arc<Mutex<dyn SocketHandler>>;

/// Socket tracking, caching etc
pub trait SocketHandler {
    /// Get a socket by desc from cache if it exist, use only in the trait implementation
    fn _get_cached_socket(&self, desc: &ZMQSocketDescription) -> Option<ZMQSocketArc>;
    /// Save a socket to the cache, use only in the trait implementation
    fn _set_cached_socket(
        &mut self,
        desc: &ZMQSocketDescription,
        sock: zmq::Socket,
    ) -> Fallible<()>;
    /// Get all open sockets NOTE: because of the rust library all sockets we have are always
    /// going to be open
    fn get_open_sockets(&self) -> Vec<ZMQSocketArc>;
    /// Close all open sockets, NOTE that for the zmq rust library we don't get
    /// to do this kind of low-level management, the destructor will close things.
    fn close_all_sockets(&self) -> Fallible<()> {
        Ok(())
    }
    /// Get a socket by description (will be created and put to cache if not there already)
    /// bind/connect will be handled automatically by this method.
    fn get_socket(&mut self, desc: &ZMQSocketDescription) -> Fallible<ZMQSocketArc> {
        // Return early if we could get the socket from cache
        match self._get_cached_socket(desc) {
            Some(socket) => {
                log::debug!("Returning socket from cache");
                return Ok(socket);
            }
            None => {}
        }
        // See documentation, this is basically singleton
        let zctx = zmq::Context::new();
        let rawsocket = match desc.sockettype {
            ZMQSocketType::PUB => zctx.socket(zmq::PUB)?,
            ZMQSocketType::SUB => zctx.socket(zmq::SUB)?,
            ZMQSocketType::REQ => zctx.socket(zmq::REQ)?,
            ZMQSocketType::REP => zctx.socket(zmq::REP)?,
        };
        match desc.sockettype {
            ZMQSocketType::PUB | ZMQSocketType::REP => {
                for uri in desc.socketuris.iter() {
                    log::debug!("Binding to {}", uri);
                    rawsocket.bind(uri.as_str())?;
                }
            }
            ZMQSocketType::SUB | ZMQSocketType::REQ => {
                for uri in desc.socketuris.iter() {
                    log::debug!("connecting to {}", uri);
                    rawsocket.connect(uri.as_str())?;
                }
            }
        }
        // Put to cache, consumes the socket and makes the Arc/Mutex wrapper
        self._set_cached_socket(desc, rawsocket)?;
        // Return the Arc wrapped socket from cache
        Ok(self._get_cached_socket(desc).unwrap())
    }
}

#[derive(Default)]
pub struct BaseSocketHandler {
    sockets_by_desc: HashMap<ZMQSocketDescription, ZMQSocketArc>,
}

impl fmt::Debug for BaseSocketHandler {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("BaseSocketHandler")
            .field("sockets_by_desc", &"<hidden>".to_string())
            .finish()
    }
}

lazy_static! {
    static ref BASESOCKETHANDLER_SINGLETON: BaseSocketHandlerArc =
        Arc::new(Mutex::new(BaseSocketHandler::new()));
}

type BaseSocketHandlerArc = Arc<Mutex<BaseSocketHandler>>;

impl BaseSocketHandler {
    /// Return a pointer to a singleton in a mutex
    pub fn instance() -> BaseSocketHandlerArc {
        BASESOCKETHANDLER_SINGLETON.clone()
    }
    pub fn new() -> BaseSocketHandler {
        BaseSocketHandler {
            ..Default::default()
        }
    }
}

impl SocketHandler for BaseSocketHandler {
    fn _get_cached_socket(&self, desc: &ZMQSocketDescription) -> Option<ZMQSocketArc> {
        if !self.sockets_by_desc.contains_key(desc) {
            return None;
        }
        Some(self.sockets_by_desc[desc].clone())
    }
    fn _set_cached_socket(
        &mut self,
        desc: &ZMQSocketDescription,
        sock: zmq::Socket,
    ) -> Fallible<()> {
        if self.sockets_by_desc.contains_key(desc) {
            return Err(failure::err_msg("Described socket is already in cache"));
        }
        let sockwrapper: ZMQSocketArc = Arc::new(Mutex::new(sock));
        match self.sockets_by_desc.insert(desc.clone(), sockwrapper) {
            Some(_) => {
                panic!("Updated existing socket key");
            }
            None => {}
        }
        Ok(())
    }
    fn get_open_sockets(&self) -> Vec<ZMQSocketArc> {
        let mut ret: Vec<ZMQSocketArc> = Vec::with_capacity(self.sockets_by_desc.len());
        for key in self.sockets_by_desc.keys() {
            ret.push(self._get_cached_socket(key).unwrap());
        }
        ret
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::env::temp_dir;

    #[test]
    fn test_sockethandler_singleton() {
        let sh1 = BaseSocketHandler::instance();
        let sh2 = BaseSocketHandler::instance();
        log::debug!("sh1 is {:?}, sh2 is {:?}", sh1, sh2);
        let mut tmppath1 = temp_dir();
        tmppath1.push("d709d495-f587-4f9f-9566-f3c66721d48f_pub.sock");
        let sockpath1 = "ipc://".to_string() + &tmppath1.to_string_lossy();
        let desc1 = ZMQSocketDescription {
            socketuris: vec![sockpath1],
            sockettype: ZMQSocketType::PUB,
        };
        let _sock1 = sh1.lock().unwrap().get_socket(&desc1).unwrap();
        let _sock2 = sh2.lock().unwrap().get_socket(&desc1).unwrap();

        let mut tmppath2 = temp_dir();
        tmppath2.push("7a4f4f4f-0016-420e-ae88-38e0c581ea29_pub.sock");
        let sockpath2 = "ipc://".to_string() + &tmppath2.to_string_lossy();
        let desc2 = ZMQSocketDescription {
            socketuris: vec![sockpath2],
            sockettype: ZMQSocketType::PUB,
        };
        let _sock3 = sh2.lock().unwrap().get_socket(&desc2).unwrap();

        let svec1 = sh1.lock().unwrap().get_open_sockets();
        let svec2 = sh1.lock().unwrap().get_open_sockets();
        assert_eq!(svec1.len(), 2);
        assert_eq!(svec2.len(), svec1.len());
    }

    #[test]
    fn test_rawmessage_encode() {
        let mut raw_parts: Vec<BytesMut> = Vec::with_capacity(3);
        raw_parts.push(BytesMut::from(String::from("hellotopic").as_bytes()));
        raw_parts.push(BytesMut::from(String::from("datapart1").as_bytes()));
        raw_parts.push(BytesMut::from(String::from("datapart2").as_bytes()));
        let msg = RawMessage { raw_parts };
        log::debug!("msg is {:?}", msg);
        let msgparts = msg.zmq_encode().unwrap();
        assert_eq!(msgparts[0].as_str().unwrap(), String::from("hellotopic"));
        assert_eq!(msgparts[1].as_str().unwrap(), String::from("datapart1"));
        assert_eq!(msgparts[2].as_str().unwrap(), String::from("datapart2"));
    }

    #[test]
    fn test_rawmessage_decode() {
        let mut msgparts: Vec<zmq::Message> = Vec::with_capacity(3);
        msgparts.push(zmq::Message::from(String::from("hellotopic").as_bytes()));
        msgparts.push(zmq::Message::from(String::from("datapart1").as_bytes()));
        msgparts.push(zmq::Message::from(String::from("datapart2").as_bytes()));
        let msg = *RawMessage::zmq_decode(msgparts).unwrap();
        assert_eq!(msg.raw_parts[0], String::from("hellotopic").as_bytes());
        assert_eq!(msg.raw_parts[1], String::from("datapart1").as_bytes());
        assert_eq!(msg.raw_parts[2], String::from("datapart2").as_bytes());
    }
}