use std::fmt::{self, Debug, Display, Formatter};
use std::str::FromStr;
use amplify::{Bipolar, Wrapper};
use inet2_addr::ServiceAddr;
use super::{DuplexConnection, RecvFrame, RoutedFrame, SendFrame};
use crate::transport;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Display)]
#[repr(u8)]
pub enum ZmqConnectionType {
#[display("PushPull")]
PullPush,
#[display("ReqRep")]
ReqRep,
#[display("PubSub")]
PubSub,
#[display("Router")]
Router,
}
impl ZmqConnectionType {
pub fn socket_in_type(self) -> ZmqSocketType {
match self {
ZmqConnectionType::PullPush => ZmqSocketType::Pull,
ZmqConnectionType::ReqRep => ZmqSocketType::Rep,
ZmqConnectionType::PubSub => ZmqSocketType::Pub,
ZmqConnectionType::Router => ZmqSocketType::RouterBind,
}
}
pub fn socket_out_type(self) -> ZmqSocketType {
match self {
ZmqConnectionType::PullPush => ZmqSocketType::Push,
ZmqConnectionType::ReqRep => ZmqSocketType::Req,
ZmqConnectionType::PubSub => ZmqSocketType::Sub,
ZmqConnectionType::Router => ZmqSocketType::RouterConnect,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Display)]
#[repr(u8)]
#[non_exhaustive]
pub enum ZmqSocketType {
#[display("PULL")]
Pull = 0,
#[display("PUSH")]
Push = 1,
#[display("REQ")]
Req = 2,
#[display("REP")]
Rep = 3,
#[display("PUB")]
Pub = 4,
#[display("SUB")]
Sub = 5,
#[display("ROUTER(bind)")]
RouterBind = 6,
#[display("ROUTER(connect)")]
RouterConnect = 7,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Error)]
#[display(Debug)]
pub struct UnknownApiType;
impl ZmqSocketType {
pub fn socket_type(&self) -> zmq::SocketType {
match self {
ZmqSocketType::Pull => zmq::PULL,
ZmqSocketType::Push => zmq::PUSH,
ZmqSocketType::Req => zmq::REQ,
ZmqSocketType::Rep => zmq::REP,
ZmqSocketType::Pub => zmq::PUB,
ZmqSocketType::Sub => zmq::SUB,
ZmqSocketType::RouterBind => zmq::ROUTER,
ZmqSocketType::RouterConnect => zmq::ROUTER,
}
}
pub fn api_name(&self) -> String {
match self {
ZmqSocketType::Pull | ZmqSocketType::Push => s!("p2p"),
ZmqSocketType::Req | ZmqSocketType::Rep => s!("rpc"),
ZmqSocketType::Pub | ZmqSocketType::Sub => s!("sub"),
ZmqSocketType::RouterBind | ZmqSocketType::RouterConnect => {
s!("esb")
}
}
}
}
impl FromStr for ZmqSocketType {
type Err = UnknownApiType;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.to_lowercase();
vec![
ZmqSocketType::Push,
ZmqSocketType::Pull,
ZmqSocketType::Req,
ZmqSocketType::Rep,
ZmqSocketType::Pub,
ZmqSocketType::Sub,
ZmqSocketType::RouterBind,
ZmqSocketType::RouterConnect,
]
.into_iter()
.find(|api| api.to_string() == s)
.ok_or(UnknownApiType)
}
}
#[derive(Display)]
pub enum Carrier {
#[display(inner)]
Locator(ServiceAddr),
#[display("zmq_socket(..)")]
Socket(zmq::Socket),
}
#[derive(
Wrapper, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Error, From
)]
pub struct Error(i32);
impl From<zmq::Error> for Error {
#[inline]
fn from(err: zmq::Error) -> Self { Self(err.to_raw()) }
}
impl From<Error> for zmq::Error {
#[inline]
fn from(err: Error) -> Self { zmq::Error::from_raw(err.into_inner()) }
}
impl From<zmq::Error> for transport::Error {
#[inline]
fn from(err: zmq::Error) -> Self {
match err {
zmq::Error::EHOSTUNREACH => transport::Error::ServiceOffline,
err => transport::Error::Zmq(err.into()),
}
}
}
impl From<Error> for transport::Error {
#[inline]
fn from(err: Error) -> Self {
transport::Error::from(zmq::Error::from(err))
}
}
impl Debug for Error {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&zmq::Error::from(*self), f)
}
}
impl Display for Error {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&zmq::Error::from(*self), f)
}
}
pub struct WrappedSocket {
api_type: ZmqSocketType,
socket: zmq::Socket,
}
pub struct Connection {
api_type: ZmqSocketType,
remote_addr: Option<ServiceAddr>,
input: WrappedSocket,
output: Option<WrappedSocket>,
}
impl Connection {
pub fn connect(
api_type: ZmqSocketType,
remote: &ServiceAddr,
local: Option<&ServiceAddr>,
identity: Option<impl AsRef<[u8]>>,
context: &zmq::Context,
) -> Result<Self, transport::Error> {
let socket = context.socket(api_type.socket_type())?;
if let Some(identity) = identity {
socket.set_identity(identity.as_ref())?;
}
let endpoint = remote.zmq_connect_string();
match api_type {
ZmqSocketType::Pull
| ZmqSocketType::Rep
| ZmqSocketType::Pub
| ZmqSocketType::RouterBind => socket.bind(&endpoint)?,
ZmqSocketType::Push
| ZmqSocketType::Req
| ZmqSocketType::Sub
| ZmqSocketType::RouterConnect => socket.connect(&endpoint)?,
}
let output = match (api_type, local) {
(ZmqSocketType::Pull, Some(local)) => {
let socket = context.socket(zmq::SocketType::PUSH)?;
socket.connect(&local.zmq_connect_string())?;
Some(socket)
}
(ZmqSocketType::Push, Some(local)) => {
let socket = context.socket(zmq::SocketType::PULL)?;
socket.bind(&local.zmq_connect_string())?;
Some(socket)
}
(ZmqSocketType::Pull, None) | (ZmqSocketType::Push, None) => {
return Err(transport::Error::RequiresLocalSocket)
}
(_, _) => None,
}
.map(|s| WrappedSocket::with_socket(api_type, s));
Ok(Self {
api_type,
remote_addr: Some(remote.clone()),
input: WrappedSocket::with_socket(api_type, socket),
output,
})
}
pub fn with_socket(api_type: ZmqSocketType, socket: zmq::Socket) -> Self {
Self {
api_type,
remote_addr: None,
input: WrappedSocket::with_socket(api_type, socket),
output: None,
}
}
#[inline]
pub(crate) fn as_socket(&self) -> &zmq::Socket { self.input.as_socket() }
#[inline]
pub(crate) fn as_socket_mut(&mut self) -> &mut zmq::Socket {
self.input.as_socket_mut()
}
#[inline]
pub fn set_identity(
&mut self,
identity: &impl AsRef<[u8]>,
context: &zmq::Context,
) -> Result<(), Error> {
let addr = if let Some(addr) = &self.remote_addr {
addr
} else {
return Err(Error::from(zmq::Error::EINVAL));
};
let socket = self.input.as_socket_mut();
let endpoint = addr.zmq_connect_string();
socket.disconnect(&endpoint)?;
*socket = context.socket(self.api_type.socket_type())?;
socket
.set_identity(identity.as_ref())
.map_err(Error::from)?;
match self.api_type {
ZmqSocketType::Pull
| ZmqSocketType::Rep
| ZmqSocketType::Pub
| ZmqSocketType::RouterBind => socket.bind(&endpoint)?,
ZmqSocketType::Push
| ZmqSocketType::Req
| ZmqSocketType::Sub
| ZmqSocketType::RouterConnect => socket.connect(&endpoint)?,
}
Ok(())
}
}
impl WrappedSocket {
#[inline]
fn with_socket(api_type: ZmqSocketType, socket: zmq::Socket) -> Self {
Self { api_type, socket }
}
#[inline]
pub(crate) fn as_socket(&self) -> &zmq::Socket { &self.socket }
#[inline]
pub(crate) fn as_socket_mut(&mut self) -> &mut zmq::Socket {
&mut self.socket
}
}
impl DuplexConnection for Connection {
#[inline]
fn as_receiver(&mut self) -> &mut dyn RecvFrame { &mut self.input }
#[inline]
fn as_sender(&mut self) -> &mut dyn SendFrame {
self.output.as_mut().unwrap_or(&mut self.input)
}
fn split(self) -> (Box<dyn RecvFrame + Send>, Box<dyn SendFrame + Send>) {
if self.api_type == ZmqSocketType::Push
|| self.api_type == ZmqSocketType::Pull
{
(
Box::new(self.input),
Box::new(self.output.expect(
"Splittable types always have output part present",
)),
)
} else {
panic!(
"Split operation is impossible for ZMQ stream type {}",
self.api_type
);
}
}
}
impl Bipolar for Connection {
type Left = WrappedSocket;
type Right = WrappedSocket;
fn join(input: Self::Left, output: Self::Right) -> Self {
if input.api_type != output.api_type {
panic!("ZMQ streams of different type can't be joined");
}
if input.api_type != ZmqSocketType::Push
|| input.api_type == ZmqSocketType::Pull
{
panic!("ZMQ streams of {} type can't be joined", input.api_type);
}
Self {
api_type: input.api_type,
remote_addr: None,
input,
output: Some(output),
}
}
fn split(self) -> (Self::Left, Self::Right) {
if self.api_type == ZmqSocketType::Push
|| self.api_type == ZmqSocketType::Pull
{
(self.input, self.output.unwrap())
} else {
panic!(
"Split operation is impossible for ZMQ stream type {}",
self.api_type
);
}
}
}
impl RecvFrame for WrappedSocket {
#[inline]
fn recv_frame(&mut self) -> Result<Vec<u8>, transport::Error> {
Ok(self.socket.recv_bytes(0)?)
}
fn recv_raw(&mut self, _len: usize) -> Result<Vec<u8>, transport::Error> {
Ok(self.socket.recv_bytes(0)?)
}
fn recv_routed(&mut self) -> Result<RoutedFrame, transport::Error> {
let mut multipart = self.socket.recv_multipart(0)?.into_iter();
let hop = multipart.next().ok_or(transport::Error::FrameBroken(
"zero frame parts in ZMQ multipart routed frame",
))?;
let src = multipart.next().ok_or(transport::Error::FrameBroken(
"no source part ZMQ multipart routed frame",
))?;
let dst = multipart.next().ok_or(transport::Error::FrameBroken(
"no destination part ZMQ multipart routed frame",
))?;
let msg = multipart.next().ok_or(transport::Error::FrameBroken(
"no message part in ZMQ multipart routed frame",
))?;
if multipart.count() > 0 {
return Err(transport::Error::FrameBroken(
"excessive parts in ZMQ multipart routed frame",
));
}
Ok(RoutedFrame { hop, src, dst, msg })
}
}
impl SendFrame for WrappedSocket {
#[inline]
fn send_frame(&mut self, data: &[u8]) -> Result<usize, transport::Error> {
self.socket.send(data, 0)?;
Ok(data.len())
}
fn send_raw(&mut self, data: &[u8]) -> Result<usize, transport::Error> {
self.socket.send(data, 0)?;
Ok(data.len())
}
fn send_routed(
&mut self,
source: &[u8],
route: &[u8],
dest: &[u8],
data: &[u8],
) -> Result<usize, transport::Error> {
self.socket
.send_multipart(&[route, source, dest, data], 0)?;
Ok(data.len())
}
}