use crate::message::{Message, Put, Get};
use crate::actor::{Actor, ActorContext, Addr};
use crate::{Config};
use crate::utils::{BoundedHashSet, BoundedHashMap};
use crate::adapters::{SledStorage, MemoryStorage, WsServer, OutgoingWebsocketManager, Multicast};
use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use log::{debug, error, info};
use rand::{seq::IteratorRandom, thread_rng};
static SEEN_MSGS_MAX_SIZE: usize = 10000;
struct SeenGetMessage {
from: Addr,
last_reply_checksum: Option<i32>,
}
pub struct Router {
config: Config,
storage_adapters: HashSet<Addr>,
network_adapters: HashSet<Addr>,
server_peers: HashSet<Addr>, seen_messages: BoundedHashSet,
seen_get_messages: BoundedHashMap<String, SeenGetMessage>,
subscribers_by_topic: HashMap<String, HashSet<Addr>>,
msg_counter: AtomicUsize,
}
#[async_trait]
impl Actor for Router {
async fn pre_start(&mut self, ctx: &ActorContext) {
let config = &self.config;
if config.multicast {
let addr = ctx.start_actor(Box::new(Multicast::new()));
self.server_peers.insert(addr);
}
if config.websocket_server {
let addr = ctx.start_actor(Box::new(WsServer::new(config.clone())));
self.network_adapters.insert(addr);
}
if config.sled_storage {
let addr = ctx.start_actor(Box::new(SledStorage::new(config.clone())));
self.storage_adapters.insert(addr);
}
if config.memory_storage {
let addr = ctx.start_actor(Box::new(MemoryStorage::new(config.clone())));
self.storage_adapters.insert(addr);
}
if config.outgoing_websocket_peers.len() > 0 {
let actor = OutgoingWebsocketManager::new(config.clone());
let addr = ctx.start_actor(Box::new(actor));
self.server_peers.insert(addr.clone());
self.network_adapters.insert(addr);
}
if self.config.stats {
self.update_stats();
}
}
async fn stopping(&mut self, _ctx: &ActorContext) {
info!("Router stopping");
}
async fn handle(&mut self, msg: Message, ctx: &ActorContext) {
debug!("incoming message");
match msg {
Message::Put(put) => self.handle_put(put),
Message::Get(get) => self.handle_get(get),
_ => {}
};
}
}
impl Router {
pub fn new(config: Config) -> Self {
Self {
config,
storage_adapters: HashSet::new(),
network_adapters: HashSet::new(),
server_peers: HashSet::new(),
seen_messages: BoundedHashSet::new(SEEN_MSGS_MAX_SIZE),
seen_get_messages: BoundedHashMap::new(SEEN_MSGS_MAX_SIZE),
subscribers_by_topic: HashMap::new(),
msg_counter: AtomicUsize::new(0),
}
}
fn update_stats(&self) {
}
fn handle_get(&mut self, get: Get) {
if !get.id.chars().all(char::is_alphanumeric) {
error!("id {}", get.id);
}
if self.is_message_seen(&get.id) {
return;
}
let seen_get_message = SeenGetMessage { from: get.from.clone(), last_reply_checksum: None };
self.seen_get_messages.insert(get.id.clone(), seen_get_message);
let topic = get.node_id.split("/").next().unwrap_or("");
debug!("{} subscribed to {}", get.from, topic);
self.subscribers_by_topic.entry(topic.to_string())
.or_insert_with(HashSet::new).insert(get.from.clone());
for addr in self.storage_adapters.iter() {
let _ = addr.sender.send(Message::Get(get.clone()));
}
for addr in self.server_peers.iter() {
addr.sender.send(Message::Get(get.clone()));
}
let mut errored = HashSet::new();
if let Some(topic_subscribers) = self.subscribers_by_topic.get(topic) {
let mut rng = thread_rng();
let sample = topic_subscribers.iter().choose_multiple(&mut rng, 4);
debug!("sending get to a random sample of subscribers of size {}", sample.len());
for addr in sample {
if get.from == *addr {
continue;
}
if let Err(_) = addr.sender.send(Message::Get(get.clone())) {
}
}
}
if errored.len() > 0 {
if let Some(topic_subscribers) = self.subscribers_by_topic.get_mut(topic) {
for addr in errored {
topic_subscribers.remove(&addr);
}
}
}
}
fn handle_put(&mut self, put: Put) {
if self.is_message_seen(&put.id) {
return;
}
match &put.in_response_to {
Some(in_response_to) => {
if let Some(seen_get_message) = self.seen_get_messages.get_mut(in_response_to) {
if put.checksum != None && put.checksum == seen_get_message.last_reply_checksum {
debug!("same reply already sent");
return;
} seen_get_message.last_reply_checksum = put.checksum.clone();
let _ = seen_get_message.from.sender.send(Message::Put(put));
}
},
_ => {
for addr in self.storage_adapters.iter() {
if put.from == *addr {
continue;
}
let _ = addr.sender.send(Message::Put(put.clone()));
}
for addr in self.server_peers.iter() {
addr.sender.send(Message::Put(put.clone()));
}
let mut already_sent_to = HashSet::new();
for node_id in put.clone().updated_nodes.keys() {
let topic = node_id.split("/").next().unwrap_or("");
if let Some(topic_subscribers) = self.subscribers_by_topic.get_mut(topic) {
topic_subscribers.retain(|addr| { if put.from == *addr {
return true;
}
if already_sent_to.contains(addr) {
return true;
}
already_sent_to.insert(addr.clone());
addr.sender.send(Message::Put(put.clone())).is_ok()
})
}
}
}
};
}
fn is_message_seen(&mut self, id: &String) -> bool {
self.msg_counter.fetch_add(1, Ordering::Relaxed);
if self.seen_messages.contains(id) {
debug!("already seen message {}", id);
return true;
}
self.seen_messages.insert(id.clone());
return false;
}
}