use std::any::Any;
use std::borrow::Borrow;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use crate::channel::{Receiver, RecvTimeoutError, SendError, Sender};
use crate::network::sender::SendRequest;
const TIMEOUT_SEC: u64 = 2;
#[derive(Clone, Debug)]
pub struct MessageContext<MT: Hash + Eq + Debug + Clone> {
source_peer_id: String,
message_type: MT,
message_bytes: Vec<u8>,
}
impl<MT: Hash + Eq + Debug + Clone> MessageContext<MT> {
pub fn source_peer_id(&self) -> &str {
&self.source_peer_id
}
pub fn message_type(&self) -> &MT {
&self.message_type
}
pub fn message_bytes(&self) -> &[u8] {
&self.message_bytes
}
}
pub trait Handler<MT, T>: Send
where
MT: Hash + Eq + Debug + Clone,
T: FromMessageBytes,
{
fn handle(
&self,
message: T,
message_context: &MessageContext<MT>,
network_sender: &dyn Sender<SendRequest>,
) -> Result<(), DispatchError>;
}
impl<MT, T, F> Handler<MT, T> for F
where
MT: Hash + Eq + Debug + Clone,
T: FromMessageBytes,
F: Fn(T, &MessageContext<MT>, &dyn Sender<SendRequest>) -> Result<(), DispatchError> + Send,
{
fn handle(
&self,
message: T,
message_context: &MessageContext<MT>,
network_sender: &dyn Sender<SendRequest>,
) -> Result<(), DispatchError> {
(*self)(message, message_context, network_sender)
}
}
pub trait FromMessageBytes: Any + Sized {
fn from_message_bytes(message_bytes: &[u8]) -> Result<Self, DispatchError>;
}
#[derive(Debug, Clone)]
pub struct RawBytes {
bytes: Vec<u8>,
}
impl RawBytes {
pub fn into_inner(self) -> Vec<u8> {
self.bytes
}
pub fn bytes(&self) -> &[u8] {
&self.bytes
}
}
impl From<&[u8]> for RawBytes {
fn from(source: &[u8]) -> Self {
RawBytes {
bytes: source.to_vec(),
}
}
}
impl AsRef<[u8]> for RawBytes {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}
impl FromMessageBytes for RawBytes {
fn from_message_bytes(message_bytes: &[u8]) -> Result<Self, DispatchError> {
Ok(RawBytes::from(message_bytes))
}
}
#[derive(Debug, PartialEq)]
pub enum DispatchError {
DeserializationError(String),
SerializationError(String),
UnknownMessageType(String),
NetworkSendError(SendError),
HandleError(String),
}
impl std::error::Error for DispatchError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
DispatchError::NetworkSendError(err) => Some(err),
_ => None,
}
}
}
impl std::fmt::Display for DispatchError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
DispatchError::DeserializationError(msg) => {
write!(f, "unable to deserialize message: {}", msg)
}
DispatchError::SerializationError(msg) => {
write!(f, "unable to serialize message: {}", msg)
}
DispatchError::UnknownMessageType(msg) => write!(f, "unknown message type: {}", msg),
DispatchError::NetworkSendError(e) => write!(f, "unable to send message: {}", e),
DispatchError::HandleError(msg) => write!(f, "unable to handle message: {}", msg),
}
}
}
impl From<SendError> for DispatchError {
fn from(e: SendError) -> Self {
DispatchError::NetworkSendError(e)
}
}
pub struct Dispatcher<MT: Any + Hash + Eq + Debug + Clone> {
handlers: HashMap<MT, HandlerWrapper<MT>>,
network_sender: Box<dyn Sender<SendRequest>>,
}
impl<MT: Any + Hash + Eq + Debug + Clone> Dispatcher<MT> {
pub fn new(network_sender: Box<dyn Sender<SendRequest>>) -> Self {
Dispatcher {
handlers: HashMap::new(),
network_sender,
}
}
pub fn set_handler<T>(&mut self, message_type: MT, handler: Box<dyn Handler<MT, T>>)
where
T: FromMessageBytes,
{
self.handlers.insert(
message_type,
HandlerWrapper {
inner: Box::new(move |message_bytes, message_context, network_sender| {
let message = FromMessageBytes::from_message_bytes(message_bytes)?;
handler.handle(message, message_context, network_sender)
}),
},
);
}
pub fn dispatch(
&self,
source_peer_id: &str,
message_type: &MT,
message_bytes: Vec<u8>,
) -> Result<(), DispatchError> {
let message_context = MessageContext {
message_type: message_type.clone(),
message_bytes,
source_peer_id: source_peer_id.into(),
};
self.handlers
.get(message_type)
.ok_or_else(|| {
DispatchError::UnknownMessageType(format!("No handler for type {:?}", message_type))
})
.and_then(|handler| {
handler.handle(
&message_context.message_bytes,
&message_context,
self.network_sender.borrow(),
)
})
}
}
type InnerHandler<MT> = Box<
dyn Fn(&[u8], &MessageContext<MT>, &dyn Sender<SendRequest>) -> Result<(), DispatchError>
+ Send,
>;
struct HandlerWrapper<MT: Hash + Eq + Debug + Clone> {
inner: InnerHandler<MT>,
}
impl<MT: Hash + Eq + Debug + Clone> HandlerWrapper<MT> {
fn handle(
&self,
message_bytes: &[u8],
message_context: &MessageContext<MT>,
network_sender: &dyn Sender<SendRequest>,
) -> Result<(), DispatchError> {
(*self.inner)(message_bytes, message_context, network_sender)
}
}
#[derive(Clone)]
pub struct DispatchMessage<MT: Any + Hash + Eq + Debug + Clone> {
message_type: MT,
message_bytes: Vec<u8>,
source_peer_id: String,
}
impl<MT: Any + Hash + Eq + Debug + Clone> DispatchMessage<MT> {
pub fn new(message_type: MT, message_bytes: Vec<u8>, source_peer_id: String) -> Self {
DispatchMessage {
message_type,
message_bytes,
source_peer_id,
}
}
pub fn message_type(&self) -> &MT {
&self.message_type
}
pub fn message_bytes(&self) -> &[u8] {
&self.message_bytes
}
pub fn source_peer_id(&self) -> &str {
&self.source_peer_id
}
}
#[derive(Debug)]
pub struct DispatchLoopError(String);
pub struct DispatchLoop<MT: Any + Hash + Eq + Debug + Clone> {
receiver: Box<dyn Receiver<DispatchMessage<MT>>>,
dispatcher: Dispatcher<MT>,
running: Arc<AtomicBool>,
}
impl<MT: Any + Hash + Eq + Debug + Clone> DispatchLoop<MT> {
pub fn new(
receiver: Box<dyn Receiver<DispatchMessage<MT>>>,
dispatcher: Dispatcher<MT>,
running: Arc<AtomicBool>,
) -> Self {
DispatchLoop {
receiver,
dispatcher,
running,
}
}
pub fn run(&self) -> Result<(), DispatchLoopError> {
let timeout = Duration::from_secs(TIMEOUT_SEC);
while self.running.load(Ordering::SeqCst) {
let dispatch_msg = match self.receiver.recv_timeout(timeout) {
Ok(dispatch_msg) => dispatch_msg,
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => {
error!("Received Disconnected Error from receiver");
return Err(DispatchLoopError(String::from(
"Received Disconnected Error from receiver",
)));
}
};
match self.dispatcher.dispatch(
&dispatch_msg.source_peer_id,
&dispatch_msg.message_type,
dispatch_msg.message_bytes,
) {
Ok(_) => (),
Err(err) => warn!("Unable to dispatch message: {:?}", err),
}
}
while let Ok(dispatch_msg) = self.receiver.try_recv() {
match self.dispatcher.dispatch(
&dispatch_msg.source_peer_id,
&dispatch_msg.message_type,
dispatch_msg.message_bytes,
) {
Ok(_) => (),
Err(err) => warn!("Unable to dispatch message: {:?}", err),
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use protobuf::Message;
use crate::channel::mock::MockSender;
use crate::network::sender::SendRequest;
use crate::protos::network::{NetworkEcho, NetworkMessageType};
#[test]
fn dispatch_to_closure() {
let flag = Arc::new(AtomicBool::new(false));
let mut dispatcher = Dispatcher::new(Box::new(MockSender::default()));
let handler_flag = flag.clone();
dispatcher.set_handler(
NetworkMessageType::NETWORK_ECHO,
Box::new(
move |_: NetworkEcho,
_: &MessageContext<NetworkMessageType>,
_: &dyn Sender<SendRequest>| {
handler_flag.store(true, Ordering::SeqCst);
Ok(())
},
),
);
assert_eq!(
Err(DispatchError::UnknownMessageType(format!(
"No handler for type {:?}",
NetworkMessageType::CIRCUIT
))),
dispatcher.dispatch("TestPeer", &NetworkMessageType::CIRCUIT, Vec::new())
);
assert_eq!(false, flag.load(Ordering::SeqCst));
assert_eq!(
Ok(()),
dispatcher.dispatch("TestPeer", &NetworkMessageType::NETWORK_ECHO, Vec::new())
);
assert_eq!(true, flag.load(Ordering::SeqCst));
}
#[test]
fn dispatch_to_handler() {
let mut dispatcher = Dispatcher::new(Box::new(MockSender::default()));
let handler = NetworkEchoHandler::default();
let echos = handler.echos.clone();
dispatcher.set_handler(NetworkMessageType::NETWORK_ECHO, Box::new(handler));
let mut outgoing_message = NetworkEcho::new();
outgoing_message.set_payload(b"test_dispatcher".to_vec());
let outgoing_message_bytes = outgoing_message.write_to_bytes().unwrap();
assert_eq!(
Ok(()),
dispatcher.dispatch(
"TestPeer",
&NetworkMessageType::NETWORK_ECHO,
outgoing_message_bytes
)
);
assert_eq!(
vec!["test_dispatcher".to_string()],
echos.lock().unwrap().clone()
);
}
#[test]
fn dispatch_to_fn() {
let network_sender: MockSender<SendRequest> = MockSender::default();
let mut dispatcher = Dispatcher::new(Box::new(network_sender.clone()));
dispatcher.set_handler(NetworkMessageType::NETWORK_ECHO, Box::new(handle_echo));
assert_eq!(
Ok(()),
dispatcher.dispatch("TestPeer", &NetworkMessageType::NETWORK_ECHO, Vec::new())
);
assert_eq!(
&vec![SendRequest::new("TestPeer".into(), vec![])],
&network_sender.sent()
);
}
#[test]
fn move_dispatcher_to_thread() {
let mut dispatcher = Dispatcher::new(Box::new(MockSender::default()));
let handler = NetworkEchoHandler::default();
let echos = handler.echos.clone();
dispatcher.set_handler(NetworkMessageType::NETWORK_ECHO, Box::new(handler));
std::thread::spawn(move || {
let mut outgoing_message = NetworkEcho::new();
outgoing_message.set_payload(b"thread_echo".to_vec());
let outgoing_message_bytes = outgoing_message.write_to_bytes().unwrap();
assert_eq!(
Ok(()),
dispatcher.dispatch(
"TestPeer",
&NetworkMessageType::NETWORK_ECHO,
outgoing_message_bytes
)
);
})
.join()
.unwrap();
assert_eq!(
vec!["thread_echo".to_string()],
echos.lock().unwrap().clone()
);
}
#[derive(Default)]
struct NetworkEchoHandler {
echos: Arc<Mutex<Vec<String>>>,
}
impl Handler<NetworkMessageType, NetworkEcho> for NetworkEchoHandler {
fn handle(
&self,
message: NetworkEcho,
_message_context: &MessageContext<NetworkMessageType>,
_: &dyn Sender<SendRequest>,
) -> Result<(), DispatchError> {
let echo_string = String::from_utf8(message.get_payload().to_vec()).unwrap();
self.echos.lock().unwrap().push(echo_string);
Ok(())
}
}
fn handle_echo(
message: RawBytes,
message_context: &MessageContext<NetworkMessageType>,
network_sender: &dyn Sender<SendRequest>,
) -> Result<(), DispatchError> {
let expected_message: Vec<u8> = vec![];
assert_eq!(expected_message, message.bytes());
network_sender.send(SendRequest::new(
message_context.source_peer_id().to_string(),
vec![],
))?;
Ok(())
}
}