#![warn(missing_docs)]
use futures::future::BoxFuture;
use futures::Future;
use log::debug;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot::{channel as os_channel, Sender as OsSender};
use tokio::time::sleep;
use tokio::time::Duration;
pub trait HandlerError: Send + Sync + Debug + 'static + Clone {}
impl<T> HandlerError for T where T: Send + Sync + Debug + 'static + Clone {}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[doc(hidden)]
pub enum ErrorKind<E: HandlerError> {
Timeout,
ProtocolBreak,
Unspecified,
SerializationError,
NotFound,
DirectoryService,
HandlerError(E),
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Error<E: HandlerError> {
pub kind: ErrorKind<E>,
pub msg: String,
}
impl<T: HandlerError> From<T> for Error<T> {
fn from(e: T) -> Self {
Self::handler_error(e)
}
}
impl<E: HandlerError> Error<E> {
fn handler_error(e: E) -> Self {
Self {
kind: ErrorKind::HandlerError(e),
msg: "Handler returned an error".to_string(),
}
}
fn timeout<S: ToString>(msg: S) -> Self {
Self {
kind: ErrorKind::Timeout,
msg: msg.to_string(),
}
}
fn custom<S: ToString>(msg: S) -> Self {
Self {
kind: ErrorKind::Unspecified,
msg: msg.to_string(),
}
}
fn serialization_error<S: ToString>(msg: S) -> Self {
Self {
kind: ErrorKind::SerializationError,
msg: msg.to_string(),
}
}
fn directory_service_error<S: ToString>(msg: S) -> Self {
Self {
kind: ErrorKind::DirectoryService,
msg: msg.to_string(),
}
}
fn network_error<S: ToString>(msg: S) -> Self {
Self {
kind: ErrorKind::DirectoryService,
msg: msg.to_string(),
}
}
}
type Result<T, E> = std::result::Result<T, Error<E>>;
pub trait NetworkContent:
Serialize + DeserializeOwned + Send + Sync + Eq + PartialEq + 'static + Debug
{
}
impl<T> NetworkContent for T where
T: Serialize + DeserializeOwned + Send + Sync + Eq + PartialEq + 'static + Debug
{
}
const MAX_SOCKET_BUF_SIZE: usize = 1500;
const DEFAULT_MESSAGE_TIMEOUT_MILLIS: u64 = 5000;
const CHANNEL_SIZE: usize = 100;
pub trait DirectoryService<N: Send + Sync, E: HandlerError>: Send + Sync {
fn translate(&self, name: &N) -> Result<SocketAddr, E>;
}
pub struct SimpleDirectoryService<
S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync,
> {
_pd: std::marker::PhantomData<S>,
}
impl<S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync>
SimpleDirectoryService<S>
{
pub fn new() -> Self {
Self {
_pd: std::marker::PhantomData::<S>,
}
}
}
impl<S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync, E: HandlerError>
DirectoryService<S, E> for SimpleDirectoryService<S>
{
fn translate(&self, name: &S) -> Result<SocketAddr, E> {
let mut sockets = name.to_socket_addrs().map_err(|_| {
Error::directory_service_error("Could not get socket address from directory service")
})?;
let socket = sockets.next().ok_or_else(|| {
Error::directory_service_error("Could not get socket address from directory service")
})?;
Ok(socket)
}
}
type NetPack<T, E> = (
NetworkMessage<T>,
Option<Action<T, E>>,
Option<Duration>,
OsSender<Result<(), E>>,
);
type ConnectionPackage<T, E> = (
NetworkMessage<T>,
Option<Duration>,
bool,
OsSender<Result<Option<NetworkMessage<T>>, E>>,
);
#[derive(Debug, Clone)]
pub struct Networker<T: NetworkContent + 'static, E: HandlerError> {
tx: Sender<NetPack<T, E>>,
command_tx: Sender<(bool, OsSender<Result<(), E>>)>,
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub struct NetworkMessage<T: NetworkContent> {
pub to: Arc<String>,
pub from: Arc<String>,
pub id: Arc<String>,
pub reply: Option<Arc<String>>,
#[serde(bound(deserialize = "T: DeserializeOwned"))]
pub content: T,
}
fn new_id() -> String {
let rand_string: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
rand_string
}
impl<T: NetworkContent> NetworkMessage<T> {
pub fn new(to: Arc<String>, from: Arc<String>, content: T) -> Self {
Self {
to,
from,
id: Arc::new(new_id()),
reply: None,
content,
}
}
pub fn reply(&self, content: T) -> Self {
Self {
to: self.from.clone(),
from: self.to.clone(),
id: Arc::new(new_id()),
reply: Some(self.id.clone()),
content,
}
}
}
async fn create_socket_task<T: NetworkContent, M, F, E: HandlerError>(
mut socket: TcpStream,
handle_message: M,
) -> Result<Sender<NetPack<T, E>>, E>
where
F: Future<Output = HandlerResult<(), E>> + Send,
M: FnMut(NetworkMessage<T>, Connection<T, E>) -> F + Send + Sync + Clone + 'static,
{
let (tx, mut rx): (Sender<NetPack<T, E>>, Receiver<NetPack<T, E>>) = channel(CHANNEL_SIZE);
tokio::spawn(async move {
debug!("Starting up a new socket task");
let (tx, mut reaction_rx): (
Sender<ConnectionPackage<T, E>>,
Receiver<ConnectionPackage<T, E>>,
) = channel(CHANNEL_SIZE);
let mut awaiting_reply: HashMap<Arc<String>, OsSender<NetworkMessage<T>>> = HashMap::new();
loop {
tokio::select! {
Some((msg, timeout, want_reply, os_tx)) = reaction_rx.recv() => {
debug!("Socket task - Received a request to send a message on reaction thread");
let msg_s = match serde_json::to_string(&msg) {
Ok(r) => r,
Err(_) => {
let e = Error::serialization_error(format!("Could not serialize: {:?}", msg));
if let Err(e) = os_tx.send(Err(e)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
continue;
},
};
match socket.write_all(msg_s.as_bytes()).await {
Ok(()) => (),
Err(_) => {
let e = Error::network_error(format!("Could not write on socket"));
if let Err(e) = os_tx.send(Err(e)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
continue;
},
}
if want_reply {
let timeout_time = match timeout {
Some(t) => t,
None => Duration::from_millis(DEFAULT_MESSAGE_TIMEOUT_MILLIS),
};
let (hm_tx, hm_rx) = os_channel();
awaiting_reply.insert(msg.id, hm_tx);
tokio::spawn(async move {
let timeout = tokio::time::sleep(timeout_time);
tokio::pin!(timeout);
tokio::select! {
Ok(msg) = hm_rx => {
if let Err(e) = os_tx.send(Ok(Some(msg))) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
},
_ = timeout => {
if let Err(e) = os_tx.send(Err(Error::timeout("Did not recieve a response in time!"))) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
}
}
});
} else {
if let Err(e) = os_tx.send(Ok(None)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
}
},
Result::<NetworkMessage<T>, E>::Ok(msg) = read_message(&mut socket) => {
match &msg.reply {
Some(id) => {
match awaiting_reply.remove(id) {
Some(tx) => {
debug!("Sending message to waiter {:?}", msg);
if let Err(e) = tx.send(msg) {
debug!("Discarded the message due to: {:?}", e);
}
}
None => {
debug!("Could not find channel to pass message on");
}
};
},
None => {
debug!("Did not find any waiters for {:?}", msg);
let con = Connection {
sender: tx.clone(),
};
let mut handle_message = handle_message.clone();
tokio::spawn(async move {
if let Err(e) = handle_message(msg, con).await {
debug!("Handle message returned an error {:?}", e);
}
});
},
};
},
Some((msg, react, timeout, return_tx)) = rx.recv() => {
debug!("On socket task - Received a message to send {:?}", msg);
let (os_tx, os_rx) = os_channel();
let want_reply = react.is_some();
tx.send((msg, timeout, want_reply, os_tx)).await.expect("Networker internal error due to reaction channel being closed");
let tx = tx.clone();
tokio::spawn(async move {
let r = os_rx.await.expect("Networker internal error, awaiting os_rx channel but transmitter closed");
match r {
Ok(r) => {
if want_reply {
let react = react.expect("Networker unreachable state");
let msg = r.expect("Network unreachable state - received no message while expecting reply");
let con = Connection {
sender: tx.clone(),
};
let result = react.0(msg, con).await;
return_tx.send(result).expect("Networker internal error - networker did not listen to return channel");
} else {
return_tx.send(Ok(())).expect("Networker internal error - networker did not listen to return channel");
}
},
Err(e) => {
return_tx.send(Err(e)).expect("Networker internal error - networker did not listen to return channel");
}
};
});
}
}
}
});
Ok(tx)
}
async fn read_message<T: NetworkContent, E: HandlerError>(
socket: &mut TcpStream,
) -> Result<NetworkMessage<T>, E> {
let mut buf = [0; MAX_SOCKET_BUF_SIZE];
match socket.read(&mut buf).await {
Ok(n) => match String::from_utf8(buf[..n].to_vec()) {
Ok(s) => match serde_json::from_str(&s) {
Ok(s) => return Ok(s),
Err(_) => {
return Err(Error::serialization_error(
"Could not deserialize recieved message",
));
}
},
Err(e) => {
return Err(Error::serialization_error(format!(
"Could not convert to utf-8 - {:?}",
e
)));
}
},
Err(e) => {
return Err(Error::network_error(format!(
"Could not read from socket - {:?}",
e
)));
}
}
}
async fn process_socket<T: NetworkContent, H, M, FH, FM, E: HandlerError>(
mut socket: TcpStream,
mut handle_handshake: H,
handle_message: M,
) -> Result<(String, Sender<NetPack<T, E>>), E>
where
FH: Future<Output = HandlerResult<String, E>> + Send,
FM: Future<Output = HandlerResult<(), E>> + Send,
H: FnMut(NetworkMessage<T>, Connection<T, E>) -> FH + Send + Sync + 'static,
M: FnMut(NetworkMessage<T>, Connection<T, E>) -> FM + Send + Sync + Clone + 'static,
{
let msg = read_message(&mut socket).await?;
let (handshake_tx, mut handshake_rx) = channel(CHANNEL_SIZE);
let con = Connection {
sender: handshake_tx.clone(),
};
let listen_for_messages = async {
loop {
if let Some((msg, timeout, want_reply, os_tx)) = handshake_rx.recv().await {
let msg = match serde_json::to_string(&msg) {
Ok(s) => s,
Err(_) => {
let e =
Error::serialization_error(format!("Could not serialize {:?}", msg));
if let Err(e) = os_tx.send(Err(e)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
continue;
}
};
match socket.write_all(msg.as_bytes()).await {
Ok(()) => (),
Err(_) => {
let e = Error::network_error("Could not send over network socket");
if let Err(e) = os_tx.send(Err(e)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
continue;
}
};
if !want_reply {
if let Err(e) = os_tx.send(Ok(None)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
continue;
} else {
let timeout = sleep(
timeout.unwrap_or(Duration::from_millis(DEFAULT_MESSAGE_TIMEOUT_MILLIS)),
);
tokio::pin!(timeout);
tokio::select! {
s = read_message(&mut socket) => {
match s {
Ok(s) => {
if let Err(e) = os_tx.send(Ok(Some(s))) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
},
Err(e) => {
if let Err(e) = os_tx.send(Err(e)) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
}
}
},
_ = timeout => {
if let Err(e) = os_tx.send(Err(Error::timeout("Did not recieve a response in time"))) {
debug!("Oneshot return channel did not stay open: {:?}", e);
}
}
}
}
}
}
};
let r = tokio::select! {
Result::<T, E>::Err(e) = listen_for_messages => return Err(e),
r = handle_handshake(msg, con) => {
match r {
Ok(r) => r,
Err(e) => {
return Err(e);
}
}
}
};
let tx = create_socket_task(socket, handle_message).await?;
Ok((r, tx))
}
type HandlerResult<T, E> = Result<T, E>;
impl<T: NetworkContent, E: HandlerError> Networker<T, E> {
pub async fn new<H, M, FH, FM>(
address: SocketAddr,
directory_service: impl DirectoryService<String, E> + 'static,
handle_handshakes: H,
handle_messages: M,
) -> Result<Networker<T, E>, E>
where
FM: Future<Output = HandlerResult<(), E>> + Send,
FH: Future<Output = HandlerResult<String, E>> + Send,
M: FnMut(NetworkMessage<T>, Connection<T, E>) -> FM + Send + Sync + Clone + 'static,
H: FnMut(NetworkMessage<T>, Connection<T, E>) -> FH + Send + Sync + Clone + 'static,
{
let (net_tx, mut thread_rx): (Sender<NetPack<T, E>>, Receiver<NetPack<T, E>>) =
channel(CHANNEL_SIZE);
let (command_tx, mut command_rx): (
Sender<(bool, OsSender<Result<(), E>>)>,
Receiver<(bool, OsSender<Result<(), E>>)>,
) = channel(CHANNEL_SIZE);
let mut name_channel_hm = HashMap::new();
let mut listener: Option<TcpListener> = None;
tokio::spawn(async move {
loop {
tokio::select! {
Some((should_be_server, os_tx)) = command_rx.recv() => {
if should_be_server {
listener = match TcpListener::bind(address).await {
Ok(r) => Some(r),
Err(e) => {
match os_tx.send(Err(Error::network_error(format!(
"Could not listen to address: {} due to: {:?}",
address, e
)))) {
Ok(()) => (),
Err(_) => {
debug!("Internal networker error - could not return result from turning on/off server");
},
};
continue;
}
};
} else {
listener = None;
}
match os_tx.send(Ok(())) {
Ok(()) => {
},
Err(_) => {
debug!("Internal networker error - could not send on return channel from request to start listening");
},
};
}
Ok((socket, _)) = async {
if let Some(listener) = &listener {
listener.accept().await
} else {
let forever = futures::future::pending();
let () = forever.await;
unreachable!("Networker unreachable state - tried to listen to a non-existent server");
}
} => {
debug!("Received a TCP connection");
let (name, tx) = match process_socket(socket, handle_handshakes.clone(), handle_messages.clone()).await {
Ok(r) => r,
Err(e) => {
debug!("Could not establish contact {:?}", e);
continue;
},
};
debug!("Handshake finished peer name is: {}", name);
name_channel_hm.insert(name, tx);
},
Some((message, react, timeout, os_tx)) = thread_rx.recv() => {
match name_channel_hm.get(&*message.to) {
Some(tx) => {
if let Err(e) = tx.send((message, react, timeout, os_tx)).await {
debug!("{}", e);
}
}
None => {
match directory_service.translate(&message.to) {
Ok(address) => {
let socket = match TcpStream::connect(address).await
.map_err(|_| Error::network_error(format!("Could not connect to address {:?}", address))) {
Ok(s) => s,
Err(e) => {
if let Err(_) = os_tx.send(Err(e)) {
debug!("Could not return error send on one-shot channel");
}
continue;
}
};
let tx = match create_socket_task(socket, handle_messages.clone()).await {
Ok(tx) => tx,
Err(e) => {
if let Err(_) = os_tx.send(Err(e)) {
debug!("Could not return error send on one-shot channel");
}
continue;
}
};
let name = message.to.clone();
if let Err(e) = tx.send((message, react, timeout, os_tx)).await {
debug!("{}", e);
}
name_channel_hm.insert((*name).clone(), tx);
},
Err(e) => {
if let Err(_) = os_tx.send(Err(e)) {
debug!("Could not return error send on one-shot channel");
}
continue;
}
}
}
}
},
}
}
});
Ok(Networker {
tx: net_tx,
command_tx,
})
}
pub async fn send_message(
&self,
message: NetworkMessage<T>,
timeout: Option<Duration>,
react: Option<Action<T, E>>,
) -> Result<(), E> {
let (os_tx, os_rx) = os_channel();
if let Err(_) = self.tx.send((message, react, timeout, os_tx)).await {
debug!("Could not send to networker");
}
match os_rx.await.expect("Oneshot transmitter dropped in socket") {
Ok(()) => Ok(()),
Err(e) => Err(e),
}
}
pub async fn listen(&self, should_listen: bool) -> Result<(), E> {
let (tx, rx) = os_channel();
match self.command_tx.send((should_listen, tx)).await {
Ok(()) => (),
Err(_) => {
return Err(Error::custom(format!(
"Internal error - Could change listening status due to channel being down"
)))
}
};
match rx.await {
Ok(r) => r,
Err(_) => return Err(Error::custom(format!("Internal error - Could not get response from listening call due to return channel closing prematurely"))),
}
}
}
pub struct Action<
T: Send + Sync + Serialize + DeserializeOwned + Eq + PartialEq + Debug + 'static,
E: HandlerError,
>(
Box<
dyn FnOnce(NetworkMessage<T>, Connection<T, E>) -> BoxFuture<'static, HandlerResult<(), E>>
+ Send
+ Sync,
>,
);
impl<T: NetworkContent, E: HandlerError> Action<T, E> {
pub fn new<
F: FnOnce(NetworkMessage<T>, Connection<T, E>) -> BoxFuture<'static, HandlerResult<(), E>>
+ 'static
+ Send
+ Sync,
>(
f: F,
) -> Self {
Self(Box::new(f))
}
}
pub struct Connection<T: NetworkContent, E: HandlerError> {
sender: Sender<ConnectionPackage<T, E>>,
}
impl<T: NetworkContent, E: HandlerError> Connection<T, E> {
pub async fn send_message(&mut self, msg: NetworkMessage<T>) -> Result<(), E> {
let (tx, rx) = os_channel();
match self.sender.send((msg, None, false, tx)).await {
Ok(()) => (),
Err(_) => {
panic!("Internal error - could not send on internal channel",);
}
};
let r = match rx.await {
Ok(r) => r,
Err(_) => {
panic!("Internal error - internal return channel was closed before receiving a message");
}
};
r.map(|r| {
if r.is_some() {
panic!("Unreachable state - expected None but was provided a network message")
}
})
}
pub async fn send_message_await_reply(
&mut self,
msg: NetworkMessage<T>,
timeout: Option<Duration>,
) -> Result<NetworkMessage<T>, E> {
let (tx, rx) = os_channel();
match self.sender.send((msg, timeout, true, tx)).await {
Ok(()) => (),
Err(_) => {
return Err(Error::custom(
"Internal error - could not send on internal channel",
));
}
};
let r = match rx.await {
Ok(r) => r,
Err(_) => {
return Err(Error::custom("Internal error - internal return channel was closed before receiving a message"));
}
};
r.map(|r| r.expect("Expecting a network message as response but None was provided"))
}
}