mod heartbeat;
mod raw;
mod recv;
mod send;
pub(crate) mod sockopt;
pub(crate) use raw::*;
pub use heartbeat::*;
pub use recv::*;
pub use send::*;
mod private {
use super::*;
use crate::socket::*;
pub trait Sealed {}
impl Sealed for SocketConfig {}
impl Sealed for SendConfig {}
impl Sealed for RecvConfig {}
impl Sealed for HeartbeatingConfig {}
impl Sealed for Client {}
impl Sealed for ClientConfig {}
impl Sealed for ClientBuilder {}
impl Sealed for Server {}
impl Sealed for ServerConfig {}
impl Sealed for ServerBuilder {}
impl Sealed for Radio {}
impl Sealed for RadioConfig {}
impl Sealed for RadioBuilder {}
impl Sealed for Dish {}
impl Sealed for DishConfig {}
impl Sealed for DishBuilder {}
impl Sealed for Scatter {}
impl Sealed for ScatterConfig {}
impl Sealed for ScatterBuilder {}
impl Sealed for Gather {}
impl Sealed for GatherConfig {}
impl Sealed for GatherBuilder {}
impl Sealed for SocketType {}
use crate::old::OldSocket;
impl Sealed for OldSocket {}
}
use crate::{addr::Endpoint, auth::*, Error, ErrorKind};
use humantime_serde::Serde;
use serde::{Deserialize, Serialize};
use std::{sync::MutexGuard, time::Duration};
const DEFAULT_HWM: i32 = 1000;
const DEFAULT_BATCH_SIZE: i32 = 8192;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(from = "Option<i32>")]
#[serde(into = "Option<i32>")]
pub(crate) struct HighWaterMark(i32);
impl Default for HighWaterMark {
fn default() -> Self {
HighWaterMark(DEFAULT_HWM)
}
}
impl From<i32> for HighWaterMark {
fn from(i: i32) -> Self {
Self(i)
}
}
impl From<HighWaterMark> for i32 {
fn from(hwm: HighWaterMark) -> i32 {
hwm.0
}
}
impl From<Option<i32>> for HighWaterMark {
fn from(v: Option<i32>) -> Self {
match v {
Some(v) => Self(v),
None => Self::default(),
}
}
}
impl From<HighWaterMark> for Option<i32> {
fn from(hwm: HighWaterMark) -> Self {
Some(hwm.0)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(from = "Option<i32>")]
#[serde(into = "Option<i32>")]
pub(crate) struct BatchSize(i32);
impl Default for BatchSize {
fn default() -> Self {
BatchSize(DEFAULT_BATCH_SIZE)
}
}
impl From<i32> for BatchSize {
fn from(v: i32) -> Self {
BatchSize(v)
}
}
impl From<BatchSize> for i32 {
fn from(size: BatchSize) -> i32 {
size.0
}
}
impl From<Option<i32>> for BatchSize {
fn from(v: Option<i32>) -> Self {
match v {
Some(v) => Self(v),
None => Self::default(),
}
}
}
impl From<BatchSize> for Option<i32> {
fn from(size: BatchSize) -> Self {
Some(size.0)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(from = "Serde<Option<Duration>>")]
#[serde(into = "Serde<Option<Duration>>")]
pub enum Period {
Infinite,
Finite(Duration),
}
pub use Period::*;
impl Default for Period {
fn default() -> Self {
Infinite
}
}
#[doc(hidden)]
impl From<Period> for Option<Duration> {
fn from(period: Period) -> Self {
match period {
Finite(duration) => Some(duration),
Infinite => None,
}
}
}
#[doc(hidden)]
impl From<Option<Duration>> for Period {
fn from(option: Option<Duration>) -> Self {
match option {
None => Infinite,
Some(duration) => Finite(duration),
}
}
}
#[doc(hidden)]
impl From<Serde<Option<Duration>>> for Period {
fn from(serde: Serde<Option<Duration>>) -> Self {
match serde.into_inner() {
None => Infinite,
Some(duration) => Finite(duration),
}
}
}
#[doc(hidden)]
impl From<Period> for Serde<Option<Duration>> {
fn from(period: Period) -> Self {
let inner = match period {
Finite(duration) => Some(duration),
Infinite => None,
};
Serde::from(inner)
}
}
pub trait Socket: GetRawSocket {
fn connect<E>(&self, endpoint: E) -> Result<(), Error>
where
E: Into<Endpoint>,
{
self.raw_socket().connect(&endpoint.into())
}
fn bind<E>(&self, endpoint: E) -> Result<(), Error>
where
E: Into<Endpoint>,
{
self.raw_socket().bind(&endpoint.into())
}
fn disconnect<E>(&self, endpoint: E) -> Result<(), Error>
where
E: Into<Endpoint>,
{
self.raw_socket().disconnect(&endpoint.into())
}
fn unbind<I, E>(&self, endpoint: E) -> Result<(), Error>
where
E: Into<Endpoint>,
{
self.raw_socket().unbind(&endpoint.into())
}
fn last_endpoint(&self) -> Result<Endpoint, Error> {
match self.raw_socket().last_endpoint()? {
Some(endpoint) => Ok(endpoint),
None => Err(Error::new(ErrorKind::NotFound(
"no endpoint previously bound or connected to",
))),
}
}
fn mechanism(&self) -> Mechanism {
self.raw_socket().mechanism().lock().unwrap().to_owned()
}
fn set_mechanism<M>(&self, mechanism: M) -> Result<(), Error>
where
M: Into<Mechanism>,
{
let raw_socket = self.raw_socket();
let mechanism = mechanism.into();
let mutex = raw_socket.mechanism().lock().unwrap();
set_mechanism(raw_socket, mechanism, mutex)
}
}
fn set_mechanism(
raw_socket: &RawSocket,
mut mechanism: Mechanism,
mut mutex: MutexGuard<Mechanism>,
) -> Result<(), Error> {
if *mutex == mechanism {
return Ok(());
}
match &*mutex {
Mechanism::Null => (),
Mechanism::PlainClient(_) => {
raw_socket.set_username(None)?;
raw_socket.set_password(None)?;
}
Mechanism::PlainServer => {
raw_socket.set_plain_server(false)?;
}
Mechanism::CurveClient(_) => {
raw_socket.set_curve_server_key(None)?;
raw_socket.set_curve_public_key(None)?;
raw_socket.set_curve_secret_key(None)?;
}
Mechanism::CurveServer(_) => {
raw_socket.set_curve_secret_key(None)?;
raw_socket.set_curve_server(false)?;
}
}
let mut missing_client_cert = false;
if let Mechanism::CurveClient(creds) = &mechanism {
if creds.client.is_none() {
missing_client_cert = true;
}
}
if missing_client_cert {
let cert = CurveCert::new_unique();
let server_key = if let Mechanism::CurveClient(creds) = mechanism {
creds.server
} else {
unreachable!()
};
let creds = CurveClientCreds {
client: Some(cert),
server: server_key,
};
mechanism = Mechanism::CurveClient(creds);
}
match &mechanism {
Mechanism::Null => (),
Mechanism::PlainClient(creds) => {
raw_socket.set_username(Some(&creds.username))?;
raw_socket.set_password(Some(&creds.password))?;
}
Mechanism::PlainServer => {
raw_socket.set_plain_server(true)?;
}
Mechanism::CurveClient(creds) => {
let server_key: BinCurveKey = (&creds.server).into();
raw_socket.set_curve_server_key(Some(&server_key))?;
let cert = creds.client.as_ref().unwrap();
let public_key: BinCurveKey = cert.public().into();
raw_socket.set_curve_public_key(Some(&public_key))?;
let secret_key: BinCurveKey = cert.secret().into();
raw_socket.set_curve_secret_key(Some(&secret_key))?;
}
Mechanism::CurveServer(creds) => {
let secret_key: BinCurveKey = (&creds.secret).into();
raw_socket.set_curve_secret_key(Some(&secret_key))?;
raw_socket.set_curve_server(true)?;
}
}
*mutex = mechanism;
Ok(())
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
#[doc(hidden)]
pub struct SocketConfig {
pub(crate) connect: Option<Vec<Endpoint>>,
pub(crate) bind: Option<Vec<Endpoint>>,
pub(crate) mechanism: Option<Mechanism>,
}
impl SocketConfig {
pub(crate) fn apply<S: Socket>(&self, socket: &S) -> Result<(), Error> {
if let Some(ref mechanism) = self.mechanism {
socket.set_mechanism(mechanism)?;
}
if let Some(ref endpoints) = self.connect {
for endpoint in endpoints {
socket.connect(endpoint)?;
}
}
if let Some(ref endpoints) = self.bind {
for endpoint in endpoints {
socket.bind(endpoint)?;
}
}
Ok(())
}
}
#[doc(hidden)]
pub trait GetSocketConfig: private::Sealed {
fn socket_config(&self) -> &SocketConfig;
fn socket_config_mut(&mut self) -> &mut SocketConfig;
}
impl GetSocketConfig for SocketConfig {
fn socket_config(&self) -> &SocketConfig {
self
}
fn socket_config_mut(&mut self) -> &mut SocketConfig {
self
}
}
pub trait ConfigureSocket: GetSocketConfig {
fn connect(&self) -> Option<&[Endpoint]> {
self.socket_config().connect.as_deref()
}
fn set_connect<I, E>(&mut self, maybe: Option<I>)
where
I: IntoIterator<Item = E>,
E: Into<Endpoint>,
{
let maybe: Option<Vec<Endpoint>> =
maybe.map(|e| e.into_iter().map(E::into).collect());
self.socket_config_mut().connect = maybe;
}
fn bind(&self) -> Option<&[Endpoint]> {
self.socket_config().bind.as_deref()
}
fn set_bind<I, E>(&mut self, maybe: Option<I>)
where
I: IntoIterator<Item = E>,
E: Into<Endpoint>,
{
let maybe: Option<Vec<Endpoint>> =
maybe.map(|e| e.into_iter().map(E::into).collect());
self.socket_config_mut().bind = maybe;
}
fn mechanism(&self) -> Option<&Mechanism> {
self.socket_config().mechanism.as_ref()
}
fn set_mechanism(&mut self, maybe: Option<Mechanism>) {
self.socket_config_mut().mechanism = maybe;
}
}
impl ConfigureSocket for SocketConfig {}
pub trait BuildSocket: GetSocketConfig + Sized {
fn connect<I, E>(&mut self, endpoints: I) -> &mut Self
where
I: IntoIterator<Item = E>,
E: Into<Endpoint>,
{
self.socket_config_mut().set_connect(Some(endpoints));
self
}
fn bind<I, E>(&mut self, endpoints: I) -> &mut Self
where
I: IntoIterator<Item = E>,
E: Into<Endpoint>,
{
self.socket_config_mut().set_bind(Some(endpoints));
self
}
fn mechanism<M>(&mut self, mechanism: M) -> &mut Self
where
M: Into<Mechanism>,
{
self.socket_config_mut()
.set_mechanism(Some(mechanism.into()));
self
}
}
#[cfg(test)]
mod test {
#[test]
fn test_disconnect_connection() {
use crate::{prelude::*, *};
use std::{convert::TryInto, thread, time::Duration};
let addr: TcpAddr = "127.0.0.1:*".try_into().unwrap();
let server =
ServerBuilder::new().bind(addr).recv_hwm(1).build().unwrap();
let bound = server.last_endpoint().unwrap();
let client = ClientBuilder::new().connect(&bound).build().unwrap();
for _ in 0..3 {
client.send("").unwrap();
}
let mut msg = server.recv_msg().unwrap();
let id = msg.routing_id().unwrap();
server.route("", id).unwrap();
client.disconnect(bound).unwrap();
thread::sleep(Duration::from_millis(200));
client.try_recv(&mut msg).unwrap_err();
server.recv(&mut msg).unwrap();
server.try_recv(&mut msg).unwrap_err();
}
#[test]
fn test_disconnect_bind() {
use crate::{prelude::*, *};
use std::{convert::TryInto, thread, time::Duration};
let addr: TcpAddr = "127.0.0.1:*".try_into().unwrap();
let server = ServerBuilder::new().bind(addr).build().unwrap();
let bound = server.last_endpoint().unwrap();
let client = ClientBuilder::new().connect(&bound).build().unwrap();
for _ in 0..3 {
client.send("").unwrap();
}
let mut msg = server.recv_msg().unwrap();
let id = msg.routing_id().unwrap();
server.route("", id).unwrap();
server.disconnect(bound).unwrap();
thread::sleep(Duration::from_millis(50));
client.recv(&mut msg).unwrap();
for _ in 0..2 {
server.recv(&mut msg).unwrap();
}
client.send("").unwrap();
server.try_recv(&mut msg).unwrap_err();
let err = server.route("", id).unwrap_err();
match err.kind() {
ErrorKind::HostUnreachable => (),
_ => panic!(),
}
}
}