use bimap::{BiMap, Overwritten};
use rand::{rngs::StdRng, seq::IteratorRandom, SeedableRng};
use serde::{Deserialize, Serialize};
use serde_with::{hex::Hex, serde_as};
use std::{collections::HashMap, fmt::Formatter};
use crate::{flo::realm::RealmID, nodeconfig::NodeInfo, timer::Timer};
use flarch::{
broker::{Broker, SubsystemHandler, TranslateFrom, TranslateInto},
nodeids::{NodeID, U256},
platform_async_trait,
web_rtc::{
messages::PeerInfo,
websocket::{BrokerWSServer, WSServerIn, WSServerOut},
},
};
pub type BrokerSignal = Broker<SignalIn, SignalOut>;
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct FledgerConfig {
pub system_realm: Option<RealmID>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct SignalConfig {
pub ttl_minutes: u16,
pub system_realm: Option<RealmID>,
pub max_list_len: Option<usize>,
}
#[derive(Clone, PartialEq)]
pub enum SignalIn {
Timer,
WSServer(WSServerOut),
Stop,
}
#[derive(Clone, Debug, PartialEq)]
pub enum SignalOut {
NodeStats(Vec<NodeStat>),
NewNode(NodeID),
WSServer(WSServerIn),
Stopped,
}
pub struct SignalServer {
challenges: BiMap<U256, usize>,
connection_ids: BiMap<NodeID, usize>,
info: HashMap<U256, NodeInfo>,
ttl: HashMap<usize, u16>,
config: SignalConfig,
}
pub const SIGNAL_VERSION: u64 = 3;
impl SignalServer {
pub async fn start(
ws_server: BrokerWSServer,
config: SignalConfig,
) -> anyhow::Result<BrokerSignal> {
let mut broker = Self::new(config).await?;
broker.link_bi(ws_server).await?;
Timer::start()
.await?
.tick_minute(broker.clone(), SignalIn::Timer)
.await?;
Ok(broker)
}
pub async fn new(config: SignalConfig) -> anyhow::Result<BrokerSignal> {
let mut broker = Broker::new();
broker
.add_handler(Box::new(SignalServer {
challenges: BiMap::new(),
connection_ids: BiMap::new(),
info: HashMap::new(),
ttl: HashMap::new(),
config,
}))
.await?;
Ok(broker)
}
fn msg_in(&mut self, msg_in: SignalIn) -> Vec<SignalOut> {
match msg_in {
SignalIn::Timer => {
self.msg_in_timer();
vec![]
}
SignalIn::WSServer(msg_wss) => self.msg_wss(msg_wss),
SignalIn::Stop => vec![SignalOut::WSServer(WSServerIn::Stop)],
}
}
fn msg_wss(&mut self, msg: WSServerOut) -> Vec<SignalOut> {
match msg {
WSServerOut::Message(index, msg_s) => {
self.ttl
.entry(index.clone())
.and_modify(|ttl| *ttl = self.config.ttl_minutes);
if let Ok(msg_ws) = serde_json::from_str::<WSSignalMessageFromNode>(&msg_s) {
return self.msg_ws_process(index, msg_ws);
}
}
WSServerOut::NewConnection(index) => return self.msg_ws_connect(index),
WSServerOut::Disconnection(id) => self.remove_node(id),
WSServerOut::Stopped => return vec![SignalOut::Stopped],
}
vec![]
}
fn msg_in_timer(&mut self) {
let mut to_remove = Vec::new();
for (index, ttl) in self.ttl.iter_mut() {
*ttl -= 1;
if *ttl == 0 {
log::info!("Removing idle node {index}");
to_remove.push(*index);
}
}
for id in to_remove {
self.remove_node(id);
}
}
fn msg_ws_process(&mut self, index: usize, msg: WSSignalMessageFromNode) -> Vec<SignalOut> {
match msg {
WSSignalMessageFromNode::Announce(ann) => self.ws_announce(index, ann),
WSSignalMessageFromNode::ListIDsRequest => self.ws_list_ids(index),
WSSignalMessageFromNode::PeerSetup(pi) => self.ws_peer_setup(index, pi),
WSSignalMessageFromNode::NodeStats(ns) => self.ws_node_stats(ns),
}
}
fn msg_ws_connect(&mut self, index: usize) -> Vec<SignalOut> {
log::trace!("Sending challenge to new connection");
let challenge = U256::rnd();
self.challenges.insert(challenge, index);
self.ttl.insert(index, self.config.ttl_minutes);
let challenge_msg =
serde_json::to_string(&WSSignalMessageToNode::Challenge(SIGNAL_VERSION, challenge))
.unwrap();
vec![SignalOut::WSServer(WSServerIn::Message(
index,
challenge_msg,
))]
}
fn ws_announce(&mut self, index: usize, msg: MessageAnnounce) -> Vec<SignalOut> {
let challenge = match self.challenges.get_by_right(&index) {
Some(id) => id.clone(),
None => {
log::warn!("Got an announcement message without challenge.");
return vec![];
}
};
if !msg.node_info.verify(&challenge.to_bytes(), &msg.signature) {
log::warn!("Got node with wrong signature");
return vec![];
}
let id = msg.node_info.get_id();
let mut msgs = vec![];
if let Overwritten::Left(_, old) = self.connection_ids.insert(id, index) {
log::warn!("The same ID is already connected to this signalling server - sending kill signal to previous connection");
msgs.append(&mut self.send_msg_node(
old,
WSSignalMessageToNode::Error("New Connection with same ID".into()),
))
}
log::info!("Registration of node-id {}: {}", id, msg.node_info.name);
self.info.insert(id, msg.node_info);
self.challenges.remove_by_left(&challenge);
msgs.append(&mut self.send_msg_node(
index,
WSSignalMessageToNode::SystemConfig(FledgerConfig {
system_realm: self.config.system_realm.clone(),
}),
));
let list = self
.info
.iter()
.map(|(_, info)| info.clone())
.collect::<Vec<_>>();
for id in self.connection_ids.iter() {
if id.1 != &index {
msgs.append(
&mut self
.send_msg_node(*id.1, WSSignalMessageToNode::ListIDsReply(list.clone())),
)
}
}
msgs.push(SignalOut::NewNode(id));
msgs
}
fn ws_list_ids(&mut self, id: usize) -> Vec<SignalOut> {
let mut rng = StdRng::seed_from_u64(id as u64);
let max_size = self.config.max_list_len.unwrap_or(self.info.len());
let list = self
.info
.values()
.cloned()
.choose_multiple(&mut rng, max_size);
self.send_msg_node(id, WSSignalMessageToNode::ListIDsReply(list))
}
fn ws_peer_setup(&mut self, index: usize, pi: PeerInfo) -> Vec<SignalOut> {
let id = match self.connection_ids.get_by_right(&index) {
Some(id) => id,
None => {
log::warn!("Got a peer-setup message without challenge.");
return vec![];
}
};
log::trace!("Node {} sent peer setup: {:?}", id, pi);
if let Some(dst) = pi.get_remote(id) {
if let Some(dst_index) = self.connection_ids.get_by_left(&dst) {
return self.send_msg_node(*dst_index, WSSignalMessageToNode::PeerSetup(pi));
}
}
vec![]
}
fn ws_node_stats(&mut self, ns: Vec<NodeStat>) -> Vec<SignalOut> {
vec![SignalOut::NodeStats(ns)]
}
fn send_msg_node(&self, index: usize, msg: WSSignalMessageToNode) -> Vec<SignalOut> {
vec![SignalOut::WSServer(WSServerIn::Message(
index,
serde_json::to_string(&msg).unwrap(),
))]
}
fn remove_node(&mut self, index: usize) {
log::info!("Removing node {index} from {:?}", self.info);
self.challenges.remove_by_right(&index);
if let Some((id, _)) = self.connection_ids.remove_by_right(&index) {
self.info.remove(&id);
}
log::info!("Info is now: {:?}", self.info);
self.ttl.remove(&index);
}
}
#[platform_async_trait()]
impl SubsystemHandler<SignalIn, SignalOut> for SignalServer {
async fn messages(&mut self, msgs: Vec<SignalIn>) -> Vec<SignalOut> {
msgs.into_iter().flat_map(|msg| self.msg_in(msg)).collect()
}
}
impl TranslateFrom<WSServerOut> for SignalIn {
fn translate(msg: WSServerOut) -> Option<Self> {
Some(SignalIn::WSServer(msg))
}
}
impl TranslateInto<WSServerIn> for SignalOut {
fn translate(self) -> Option<WSServerIn> {
if let SignalOut::WSServer(msg_wss) = self {
Some(msg_wss)
} else {
None
}
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub enum WSSignalMessageToNode {
Challenge(u64, U256),
ListIDsReply(Vec<NodeInfo>),
PeerSetup(PeerInfo),
SystemConfig(FledgerConfig),
Error(String),
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub enum WSSignalMessageFromNode {
Announce(MessageAnnounce),
ListIDsRequest,
PeerSetup(PeerInfo),
NodeStats(Vec<NodeStat>),
}
impl std::fmt::Display for WSSignalMessageToNode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
WSSignalMessageToNode::Challenge(_, _) => write!(f, "Challenge"),
WSSignalMessageToNode::ListIDsReply(_) => write!(f, "ListIDsReply"),
WSSignalMessageToNode::PeerSetup(_) => write!(f, "PeerSetup"),
WSSignalMessageToNode::SystemConfig(_) => write!(f, "SystemConfig"),
WSSignalMessageToNode::Error(_) => write!(f, "Error"),
}
}
}
impl std::fmt::Display for WSSignalMessageFromNode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
WSSignalMessageFromNode::Announce(_) => write!(f, "Announce"),
WSSignalMessageFromNode::ListIDsRequest => write!(f, "ListIDsRequest"),
WSSignalMessageFromNode::PeerSetup(_) => write!(f, "PeerSetup"),
WSSignalMessageFromNode::NodeStats(_) => write!(f, "NodeStats"),
}
}
}
#[serde_as]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct MessageAnnounce {
pub version: u64,
pub challenge: U256,
pub node_info: NodeInfo,
#[serde_as(as = "Hex")]
pub signature: Vec<u8>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct NodeStat {
pub id: NodeID,
pub version: String,
pub ping_ms: u32,
pub ping_rx: u32,
}
impl std::fmt::Debug for SignalIn {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
SignalIn::WSServer(_) => write!(f, "WebSocket"),
SignalIn::Timer => write!(f, "Timer"),
SignalIn::Stop => write!(f, "Stop"),
}
}
}