use crate::child_ref::ChildRef;
use crate::envelope::SignedMessage;
use anyhow::Result as AnyResult;
use lever::prelude::*;
use std::fmt::{self, Debug};
use std::hash::{Hash, Hasher};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use tracing::{debug, trace, warn};
pub type DispatcherMap = LOTable<ChildRef, String>;
#[derive(Debug, Clone)]
pub enum NotificationType {
Register,
Remove,
}
#[derive(Debug, Clone)]
pub enum BroadcastTarget {
All,
Group(String),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum DispatcherType {
Anonymous,
Named(String),
}
pub type DefaultDispatcherHandler = RoundRobinHandler;
#[derive(Default, Debug)]
pub struct RoundRobinHandler {
index: AtomicUsize,
}
impl DispatcherHandler for RoundRobinHandler {
fn notify(
&self,
_from_child: &ChildRef,
_entries: &DispatcherMap,
_notification_type: NotificationType,
) {
}
fn broadcast_message(&self, entries: &DispatcherMap, message: &Arc<SignedMessage>) {
let entries = entries
.iter()
.filter(|entry| entry.0.is_public())
.collect::<Vec<_>>();
if entries.is_empty() {
debug!("no public children to broadcast message to");
return;
}
let current_index = self.index.load(Ordering::SeqCst) % entries.len();
if let Some(entry) = entries.get(current_index) {
warn!(
"sending message to child {}/{} - {}",
current_index + 1,
entries.len(),
entry.0.path()
);
entry.0.tell_anonymously(message.clone()).unwrap();
self.index.store(current_index + 1, Ordering::SeqCst);
};
}
}
pub trait DispatcherHandler {
fn notify(
&self,
from_child: &ChildRef,
entries: &DispatcherMap,
notification_type: NotificationType,
);
fn broadcast_message(&self, entries: &DispatcherMap, message: &Arc<SignedMessage>);
}
pub struct Dispatcher {
dispatcher_type: DispatcherType,
handler: Box<dyn DispatcherHandler + Send + Sync + 'static>,
actors: DispatcherMap,
}
impl Dispatcher {
pub fn dispatcher_type(&self) -> DispatcherType {
self.dispatcher_type.clone()
}
pub fn handler(&self) -> &(dyn DispatcherHandler + Send + Sync + 'static) {
&*self.handler
}
pub fn with_dispatcher_type(mut self, dispatcher_type: DispatcherType) -> Self {
trace!("Setting dispatcher the {:?} type.", dispatcher_type);
self.dispatcher_type = dispatcher_type;
self
}
pub fn with_type(dispatcher_type: DispatcherType) -> Self {
trace!(
"Instanciating a dispatcher with type {:?}.",
dispatcher_type
);
Self {
dispatcher_type,
handler: Box::new(DefaultDispatcherHandler::default()),
actors: Default::default(),
}
}
pub fn with_handler(
mut self,
handler: Box<dyn DispatcherHandler + Send + Sync + 'static>,
) -> Self {
trace!(
"Setting handler for the {:?} dispatcher.",
self.dispatcher_type
);
self.handler = handler;
self
}
pub(crate) fn register(&self, key: &ChildRef, module_name: String) -> AnyResult<()> {
self.actors.insert(key.to_owned(), module_name)?;
self.handler
.notify(key, &self.actors, NotificationType::Register);
Ok(())
}
pub(crate) fn remove(&self, key: &ChildRef) {
if self.actors.remove(key).is_ok() {
self.handler
.notify(key, &self.actors, NotificationType::Remove);
}
}
pub fn notify(&self, from_child: &ChildRef, notification_type: NotificationType) {
self.handler
.notify(from_child, &self.actors, notification_type)
}
pub fn broadcast_message(&self, message: &Arc<SignedMessage>) {
self.handler.broadcast_message(&self.actors, &message);
}
}
impl Debug for Dispatcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Dispatcher(type: {:?}, actors: {:?})",
self.dispatcher_type,
self.actors.len()
)
}
}
impl DispatcherType {
pub(crate) fn name(&self) -> String {
match self {
DispatcherType::Anonymous => String::from("__Anonymous__"),
DispatcherType::Named(value) => value.to_owned(),
}
}
}
impl Default for Dispatcher {
fn default() -> Self {
Dispatcher {
dispatcher_type: DispatcherType::default(),
handler: Box::new(DefaultDispatcherHandler::default()),
actors: LOTable::new(),
}
}
}
impl Default for DispatcherType {
fn default() -> Self {
DispatcherType::Anonymous
}
}
#[allow(clippy::derive_hash_xor_eq)]
impl Hash for DispatcherType {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name().hash(state);
}
}
impl Into<DispatcherType> for String {
fn into(self) -> DispatcherType {
match self == DispatcherType::Anonymous.name() {
true => DispatcherType::Anonymous,
false => DispatcherType::Named(self),
}
}
}
#[derive(Debug)]
pub(crate) struct GlobalDispatcher {
pub dispatchers: LOTable<DispatcherType, Arc<Box<Dispatcher>>>,
}
impl GlobalDispatcher {
pub(crate) fn new() -> Self {
GlobalDispatcher {
dispatchers: LOTable::new(),
}
}
pub(crate) fn register(
&self,
dispatchers: &[DispatcherType],
child_ref: &ChildRef,
module_name: String,
) -> AnyResult<()> {
dispatchers
.iter()
.filter(|key| self.dispatchers.contains_key(*key))
.map(|key| {
if let Some(dispatcher) = self.dispatchers.get(key) {
dispatcher.register(child_ref, module_name.clone())
} else {
Ok(())
}
})
.collect::<AnyResult<Vec<_>>>()?;
Ok(())
}
pub(crate) fn remove(&self, dispatchers: &[DispatcherType], child_ref: &ChildRef) {
dispatchers
.iter()
.filter(|key| self.dispatchers.contains_key(*key))
.for_each(|key| {
if let Some(dispatcher) = self.dispatchers.get(key) {
dispatcher.remove(child_ref)
}
})
}
pub(crate) fn notify(
&self,
from_actor: &ChildRef,
dispatchers: &[DispatcherType],
notification_type: NotificationType,
) {
self.dispatchers
.iter()
.filter(|pair| dispatchers.contains(&pair.0))
.for_each(|pair| {
let dispatcher = pair.1;
dispatcher.notify(from_actor, notification_type.clone())
})
}
pub(crate) fn broadcast_message(&self, target: BroadcastTarget, message: &Arc<SignedMessage>) {
let mut acked_dispatchers: Vec<DispatcherType> = Vec::new();
match target {
BroadcastTarget::All => self
.dispatchers
.iter()
.map(|pair| pair.0.name().into())
.for_each(|group_name| acked_dispatchers.push(group_name)),
BroadcastTarget::Group(name) => {
let target_dispatcher = name.into();
acked_dispatchers.push(target_dispatcher);
}
}
for dispatcher_type in acked_dispatchers {
match self.dispatchers.get(&dispatcher_type) {
Some(dispatcher) => {
dispatcher.broadcast_message(&message.clone());
}
None => {
let name = dispatcher_type.name();
warn!(
"The message can't be delivered to the group with the '{}' name.",
name
);
}
}
}
}
pub(crate) fn register_dispatcher(&self, dispatcher: &Arc<Box<Dispatcher>>) -> AnyResult<()> {
let dispatcher_type = dispatcher.dispatcher_type();
let is_registered = self.dispatchers.contains_key(&dispatcher_type);
if is_registered && dispatcher_type != DispatcherType::Anonymous {
warn!(
"The dispatcher with the '{:?}' name already registered in the cluster.",
dispatcher_type
);
return Ok(());
}
let instance = dispatcher.clone();
self.dispatchers.insert(dispatcher_type, instance)?;
Ok(())
}
pub(crate) fn remove_dispatcher(&self, dispatcher: &Arc<Box<Dispatcher>>) -> AnyResult<()> {
self.dispatchers.remove(&dispatcher.dispatcher_type())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::child_ref::ChildRef;
use crate::context::BastionId;
use crate::dispatcher::*;
use crate::envelope::{RefAddr, SignedMessage};
use crate::message::Msg;
use crate::path::BastionPath;
use futures::channel::mpsc;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct CustomHandler {
called: Arc<Mutex<bool>>,
}
#[allow(clippy::mutex_atomic)]
impl CustomHandler {
pub fn new(value: bool) -> Self {
CustomHandler {
called: Arc::new(Mutex::new(value)),
}
}
pub fn was_called(&self) -> bool {
*self.called.clone().lock().unwrap()
}
}
impl DispatcherHandler for CustomHandler {
fn notify(
&self,
_from_child: &ChildRef,
_entries: &DispatcherMap,
_notification_type: NotificationType,
) {
let handler_field_ref = self.called.clone();
let mut data = handler_field_ref.lock().unwrap();
*data = true;
}
fn broadcast_message(&self, _entries: &DispatcherMap, _message: &Arc<SignedMessage>) {
let handler_field_ref = self.called.clone();
let mut data = handler_field_ref.lock().unwrap();
*data = true;
}
}
#[test]
fn test_get_dispatcher_type_as_anonymous() {
let instance = Dispatcher::default();
assert_eq!(instance.dispatcher_type(), DispatcherType::Anonymous);
}
#[test]
fn test_get_dispatcher_type_as_named() {
let name = "test_group".to_string();
let dispatcher_type = DispatcherType::Named(name);
let instance = Dispatcher::with_type(dispatcher_type.clone());
assert_eq!(instance.dispatcher_type(), dispatcher_type);
}
#[test]
fn test_local_dispatcher_append_child_ref() {
let instance = Dispatcher::default();
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
assert_eq!(instance.actors.contains_key(&child_ref), false);
instance
.register(&child_ref, "my::test::module".to_string())
.unwrap();
assert_eq!(instance.actors.contains_key(&child_ref), true);
}
#[test]
fn test_dispatcher_remove_child_ref() {
let instance = Dispatcher::default();
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
instance
.register(&child_ref, "my::test::module".to_string())
.unwrap();
assert_eq!(instance.actors.contains_key(&child_ref), true);
instance.remove(&child_ref);
assert_eq!(instance.actors.contains_key(&child_ref), false);
}
#[test]
fn test_local_dispatcher_notify() {
let handler = Box::new(CustomHandler::new(false));
let instance = Dispatcher::default().with_handler(handler.clone());
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
instance.notify(&child_ref, NotificationType::Register);
let handler_was_called = handler.was_called();
assert_eq!(handler_was_called, true);
}
#[test]
fn test_local_dispatcher_broadcast_message() {
let handler = Box::new(CustomHandler::new(false));
let instance = Dispatcher::default().with_handler(handler.clone());
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
const DATA: &str = "A message containing data (ask).";
let message = Arc::new(SignedMessage::new(
Msg::broadcast(DATA),
RefAddr::new(path, sender),
));
instance.broadcast_message(&message);
let handler_was_called = handler.was_called();
assert_eq!(handler_was_called, true);
}
#[test]
fn test_global_dispatcher_add_local_dispatcher() {
let dispatcher_type = DispatcherType::Named("test".to_string());
let local_dispatcher = Arc::new(Box::new(Dispatcher::with_type(dispatcher_type.clone())));
let global_dispatcher = GlobalDispatcher::new();
assert_eq!(
global_dispatcher.dispatchers.contains_key(&dispatcher_type),
false
);
global_dispatcher
.register_dispatcher(&local_dispatcher)
.unwrap();
assert_eq!(
global_dispatcher.dispatchers.contains_key(&dispatcher_type),
true
);
}
#[test]
fn test_global_dispatcher_remove_local_dispatcher() {
let dispatcher_type = DispatcherType::Named("test".to_string());
let local_dispatcher = Arc::new(Box::new(Dispatcher::with_type(dispatcher_type.clone())));
let global_dispatcher = GlobalDispatcher::new();
global_dispatcher
.register_dispatcher(&local_dispatcher)
.unwrap();
assert_eq!(
global_dispatcher.dispatchers.contains_key(&dispatcher_type),
true
);
global_dispatcher
.remove_dispatcher(&local_dispatcher)
.unwrap();
assert_eq!(
global_dispatcher.dispatchers.contains_key(&dispatcher_type),
false
);
}
#[test]
fn test_global_dispatcher_register_actor() {
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
let dispatcher_type = DispatcherType::Named("test".to_string());
let local_dispatcher = Arc::new(Box::new(Dispatcher::with_type(dispatcher_type.clone())));
let actor_groups = vec![dispatcher_type];
let module_name = "my::test::module".to_string();
let global_dispatcher = GlobalDispatcher::new();
global_dispatcher
.register_dispatcher(&local_dispatcher)
.unwrap();
assert_eq!(local_dispatcher.actors.contains_key(&child_ref), false);
global_dispatcher
.register(&actor_groups, &child_ref, module_name)
.unwrap();
assert_eq!(local_dispatcher.actors.contains_key(&child_ref), true);
}
#[test]
fn test_global_dispatcher_remove_actor() {
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
let dispatcher_type = DispatcherType::Named("test".to_string());
let local_dispatcher = Arc::new(Box::new(Dispatcher::with_type(dispatcher_type.clone())));
let actor_groups = vec![dispatcher_type];
let module_name = "my::test::module".to_string();
let global_dispatcher = GlobalDispatcher::new();
global_dispatcher
.register_dispatcher(&local_dispatcher)
.unwrap();
global_dispatcher
.register(&actor_groups, &child_ref, module_name)
.unwrap();
assert_eq!(local_dispatcher.actors.contains_key(&child_ref), true);
global_dispatcher.remove(&actor_groups, &child_ref);
assert_eq!(local_dispatcher.actors.contains_key(&child_ref), false);
}
#[test]
fn test_global_dispatcher_notify() {
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
let dispatcher_type = DispatcherType::Named("test".to_string());
let handler = Box::new(CustomHandler::new(false));
let local_dispatcher = Arc::new(Box::new(
Dispatcher::with_type(dispatcher_type.clone()).with_handler(handler.clone()),
));
let actor_groups = vec![dispatcher_type];
let module_name = "my::test::module".to_string();
let global_dispatcher = GlobalDispatcher::new();
global_dispatcher
.register_dispatcher(&local_dispatcher)
.unwrap();
global_dispatcher
.register(&actor_groups, &child_ref, module_name)
.unwrap();
global_dispatcher.notify(&child_ref, &actor_groups, NotificationType::Register);
let handler_was_called = handler.was_called();
assert_eq!(handler_was_called, true);
}
#[test]
fn test_global_dispatcher_broadcast_message() {
let bastion_id = BastionId::new();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
let name = "test_name".to_string();
let child_ref = ChildRef::new(bastion_id, sender, name, path);
let dispatcher_type = DispatcherType::Named("test".to_string());
let handler = Box::new(CustomHandler::new(false));
let local_dispatcher = Arc::new(Box::new(
Dispatcher::with_type(dispatcher_type.clone()).with_handler(handler.clone()),
));
let actor_groups = vec![dispatcher_type];
let module_name = "my::test::module".to_string();
let global_dispatcher = GlobalDispatcher::new();
global_dispatcher
.register_dispatcher(&local_dispatcher)
.unwrap();
global_dispatcher
.register(&actor_groups, &child_ref, module_name)
.unwrap();
let (sender, _) = mpsc::unbounded();
let path = Arc::new(BastionPath::root());
const DATA: &str = "A message containing data (ask).";
let message = Arc::new(SignedMessage::new(
Msg::broadcast(DATA),
RefAddr::new(path, sender),
));
global_dispatcher.broadcast_message(BroadcastTarget::Group("".to_string()), &message);
let handler_was_called = handler.was_called();
assert_eq!(handler_was_called, true);
}
}