use std::{
collections::{HashMap, HashSet},
convert::TryInto,
fmt,
iter::FromIterator,
sync::atomic::Ordering,
time::Duration as StdDuration,
};
use async_std::{stream::interval, sync::Arc};
use futures::{channel::mpsc::Receiver, select, sink::SinkExt, stream::StreamExt, FutureExt};
use log::{debug, error};
use rand::seq::SliceRandom;
use rand_distr::{Binomial, Distribution};
use koibumi_core::{
message::{self, NetAddr, Pack, Services, StreamNumber, UserAgent},
net::SocketAddrExt,
time::Time,
};
use crate::{
connection::Direction,
connection_loop::{Context, Event as BrokerEvent},
db,
manager::Event as BmEvent,
};
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Entry {
stream: StreamNumber,
addr: SocketAddrExt,
last_seen: Time,
}
impl Entry {
pub fn new(stream: StreamNumber, addr: SocketAddrExt, last_seen: Time) -> Self {
Self {
stream,
addr,
last_seen,
}
}
}
#[derive(Debug)]
pub enum Event {
Add(Vec<Entry>),
ConnectionSucceeded(SocketAddrExt, UserAgent),
ConnectionFailed(SocketAddrExt),
Disconnected(SocketAddrExt),
Send(SocketAddrExt, bool),
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Rating(i8);
impl Rating {
const MAX: i8 = 10;
const MIN: i8 = -10;
pub fn new(value: i8) -> Self {
Self(value)
}
pub fn as_i8(&self) -> i8 {
self.0
}
pub fn increment(&mut self) {
self.0 = i8::min(self.0 + 1, Self::MAX);
}
pub fn decrement(&mut self) {
self.0 = i8::max(self.0 - 1, -Self::MIN);
}
}
impl fmt::Display for Rating {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<i8> for Rating {
fn from(value: i8) -> Self {
Self(value)
}
}
struct Info {
#[allow(dead_code)]
stream: StreamNumber,
last_seen: Time,
rating: Rating,
}
#[derive(sqlx::FromRow)]
struct Record {
stream: i32,
address: String,
last_seen: i64,
rating: i32,
}
struct Nodes {
ctx: Arc<Context>,
pool: db::SqlitePool,
map: HashMap<SocketAddrExt, Info>,
used_addrs: HashSet<SocketAddrExt>,
}
impl Nodes {
async fn new(ctx: Arc<Context>, pool: db::SqlitePool) -> Self {
if let Err(err) = sqlx::query(
"CREATE TABLE IF NOT EXISTS nodes (
stream INTEGER NOT NULL,
address TEXT NOT NULL,
last_seen INTEGER NOT NULL,
rating INTEGER NOT NULL,
PRIMARY KEY(stream, address)
)",
)
.execute(pool.write())
.await
{
error!(target: "koibumi", "{}", err);
}
let mut map = HashMap::new();
if let Ok(list) = sqlx::query_as::<sqlx::Sqlite, Record>(
"SELECT stream, address, last_seen, rating FROM nodes",
)
.fetch_all(pool.read())
.await
{
for record in list {
if record.stream < 0 {
continue;
}
let stream: StreamNumber = (record.stream as u32).into();
let addr = record.address.parse::<SocketAddrExt>();
if addr.is_err() {
continue;
}
let addr = addr.unwrap();
if record.last_seen < 0 {
continue;
}
let last_seen: Time = (record.last_seen as u64).into();
if record.rating < -128 || record.rating > 127 {
continue;
}
let rating: Rating = (record.rating as i8).into();
map.insert(
addr,
Info {
stream,
last_seen,
rating,
},
);
}
}
let used_addrs = HashSet::new();
let mut nodes = Self {
ctx,
pool,
map,
used_addrs,
};
nodes.retain().await;
nodes
}
fn len(&self) -> usize {
self.map.len()
}
async fn retain(&mut self) -> Option<()> {
if self.map.len() > self.ctx.config().max_nodes() {
let keys: HashSet<SocketAddrExt> = self.map.keys().cloned().collect();
let mut list: Vec<SocketAddrExt> = keys.difference(&self.used_addrs).cloned().collect();
list.sort_unstable_by(|a, b| {
let a_info = &self.map[a];
let b_info = &self.map[b];
if a_info.rating == b_info.rating {
a_info.last_seen.cmp(&b_info.last_seen)
} else {
a_info.rating.cmp(&b_info.rating)
}
});
let mut trunc_amount = self.ctx.config().max_nodes() / 10;
if trunc_amount == 0 {
trunc_amount = usize::min(1, list.len());
}
list.truncate(trunc_amount);
for addr in &list {
self.map.remove(addr);
if let Err(err) = sqlx::query("DELETE FROM nodes WHERE address=?1")
.bind(addr.to_string())
.execute(self.pool.write())
.await
{
error!(target: "koibumi", "{}", err);
}
}
return Some(());
}
None
}
async fn insert(&mut self, entry: Entry, own_node: bool) -> Option<()> {
match self.map.get_mut(&entry.addr) {
Some(info) => {
if entry.last_seen > info.last_seen {
info.last_seen = entry.last_seen;
}
if entry.stream.as_u32() <= i32::MAX as u32
&& entry.last_seen.as_secs() <= i64::MAX as u64
{
if let Err(err) =
sqlx::query("UPDATE nodes SET last_seen=?1 WHERE stream=?2 and address=?3")
.bind(entry.last_seen.as_secs() as i64)
.bind(entry.stream.as_u32() as i32)
.bind(entry.addr.to_string())
.execute(self.pool.write())
.await
{
error!(target: "koibumi", "{}", err);
}
}
None
}
None => {
if self.ctx.config().is_connectable_to(&entry.addr)
&& self.ctx.config().stream_numbers().contains(entry.stream)
{
debug!(target: "koibumi", "addr: {}", entry.addr);
let rating: Rating = if own_node {
Rating::MAX.into()
} else {
0.into()
};
self.map.insert(
entry.addr.clone(),
Info {
stream: entry.stream,
last_seen: entry.last_seen,
rating: rating.clone(),
},
);
if entry.stream.as_u32() <= i32::MAX as u32
&& entry.last_seen.as_secs() <= i64::MAX as u64
{
if let Err(err) = sqlx::query(
"INSERT INTO nodes (
stream, address, last_seen, rating
) VALUES (?1, ?2, ?3, ?4)",
)
.bind(entry.stream.as_u32() as i32)
.bind(entry.addr.to_string())
.bind(entry.last_seen.as_secs() as i64)
.bind(rating.as_i8() as i32)
.execute(self.pool.write())
.await
{
error!(target: "koibumi", "{}", err);
}
}
return Some(());
}
None
}
}
}
async fn increment(&mut self, addr: &SocketAddrExt) -> Option<Rating> {
let now = Time::now();
if let Some(info) = self.map.get_mut(addr) {
info.last_seen = now;
info.rating.increment();
if info.stream.as_u32() <= i32::MAX as u32
&& info.last_seen.as_secs() <= i64::MAX as u64
{
if let Err(err) =
sqlx::query("UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3")
.bind(info.rating.as_i8() as i64)
.bind(info.stream.as_u32() as i32)
.bind(addr.to_string())
.execute(self.pool.write())
.await
{
error!(target: "koibumi", "{}", err);
}
}
return Some(info.rating.clone());
}
None
}
async fn decrement(&mut self, addr: &SocketAddrExt) -> Option<Rating> {
if let Some(info) = self.map.get_mut(addr) {
info.rating.decrement();
if info.stream.as_u32() <= i32::MAX as u32
&& info.last_seen.as_secs() <= i64::MAX as u64
{
if let Err(err) =
sqlx::query("UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3")
.bind(info.rating.as_i8() as i64)
.bind(info.stream.as_u32() as i32)
.bind(addr.to_string())
.execute(self.pool.write())
.await
{
error!(target: "koibumi", "{}", err);
}
}
return Some(info.rating.clone());
}
None
}
fn reclaim(&mut self, addr: &SocketAddrExt) {
self.used_addrs.remove(addr);
}
fn sample_list(&self) -> Vec<NetAddr> {
let mut list: Vec<SocketAddrExt> = self.map.keys().cloned().collect();
list.sort_unstable_by(|a, b| {
let a_info = &self.map[a];
let b_info = &self.map[b];
if a_info.rating == b_info.rating {
b_info.last_seen.cmp(&a_info.last_seen)
} else {
b_info.rating.cmp(&a_info.rating)
}
});
list.truncate(1000);
list.shuffle(&mut rand::thread_rng());
let mut addr_list = Vec::with_capacity(list.len());
for addr in list {
let info = &self.map[&addr];
if let Ok(addr) = addr.try_into() {
addr_list.push(NetAddr::new(
info.last_seen,
info.stream,
Services::NETWORK,
addr,
));
}
}
addr_list
}
fn sample(&mut self, own_nodes: &[SocketAddrExt]) -> Option<SocketAddrExt> {
let keys: HashSet<SocketAddrExt> = self.map.keys().cloned().collect();
let list: HashSet<SocketAddrExt> = keys.difference(&self.used_addrs).cloned().collect();
let own_nodes: HashSet<SocketAddrExt> = HashSet::from_iter(own_nodes.iter().cloned());
let mut list: Vec<SocketAddrExt> = list.difference(&own_nodes).cloned().collect();
list.sort_unstable_by(|a, b| {
let a_info = &self.map[a];
let b_info = &self.map[b];
if a_info.rating == b_info.rating {
b_info.last_seen.cmp(&a_info.last_seen)
} else {
b_info.rating.cmp(&a_info.rating)
}
});
if !list.is_empty() {
let bin = Binomial::new(list.len() as u64 * 2 - 1, 0.5).unwrap();
let v = bin.sample(&mut rand::thread_rng());
let i = if (v as usize) < list.len() {
list.len() - 1 - v as usize
} else {
v as usize - list.len()
};
let sa = &list[i];
self.used_addrs.insert(sa.clone());
return Some(sa.clone());
}
None
}
}
pub async fn manage(ctx: Arc<Context>, mut receiver: Receiver<Event>) {
let mut broker_sender = ctx.broker_sender().clone();
let mut bm_event_sender = ctx.bm_event_sender().clone();
let mut nodes = Nodes::new(Arc::clone(&ctx), ctx.pool().clone()).await;
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
error!(target: "koibumi", "{}", err);
}
let mut interval = interval(StdDuration::from_secs(4));
loop {
if ctx.aborted().load(Ordering::SeqCst) {
break;
}
select! {
v = receiver.next().fuse() => match v {
Some(event) => match event {
Event::Add(entries) => {
for entry in entries {
if nodes.retain().await.is_some() {
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
error!(target: "koibumi", "{}", err);
}
}
let own_node = ctx.config().own_nodes().contains(&entry.addr);
if nodes.insert(entry, own_node).await.is_some() {
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
error!(target: "koibumi", "{}", err);
}
}
}
}
Event::ConnectionSucceeded(addr, user_agent) => {
if let Some(rating) = nodes.increment(&addr).await {
if let Err(err) = bm_event_sender
.send(BmEvent::Established {
addr: addr.clone(),
user_agent,
rating,
})
.await
{
error!(target: "koibumi", "{}", err)
}
}
}
Event::ConnectionFailed(addr) => {
if !ctx.config().own_nodes().contains(&addr) {
nodes.decrement(&addr).await;
}
}
Event::Disconnected(addr) => {
nodes.reclaim(&addr);
}
Event::Send(addr, close) => {
let addr_list = nodes.sample_list();
if !addr_list.is_empty() {
let message = message::Addr::new(addr_list).unwrap();
let packet = message.pack(ctx.config().core()).unwrap();
if let Err(err) = broker_sender
.send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
.await
{
error!(target: "koibumi", "{}", err);
}
}
if close {
let error = message::Error::new(2.into(),
"Server full, please try again later.".as_bytes().to_vec().into());
let packet = error.pack(ctx.config().core()).unwrap();
if let Err(err) = broker_sender
.send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
.await
{
error!(target: "koibumi", "{}", err);
}
if let Err(err) = broker_sender
.send(BrokerEvent::Close { addr })
.await
{
error!(target: "koibumi", "{}", err);
}
}
}
},
None => break,
},
v = interval.next().fuse() => match v {
Some(_) => {
let initiated = ctx.initiated(Direction::Outgoing).load(Ordering::SeqCst);
let connected = ctx.connected(Direction::Outgoing).load(Ordering::SeqCst);
let established = ctx.established(Direction::Outgoing).load(Ordering::SeqCst);
if established >= ctx.config().max_outgoing_established() {
if initiated > connected {
if let Err(err) = broker_sender
.send(BrokerEvent::AbortPendings)
.await
{
error!(target: "koibumi", "{}", err);
}
}
} else if initiated < ctx.config().max_outgoing_initiated()
&& initiated + ctx.config().own_nodes().len() < nodes.len() {
if let Some(addr) = nodes.sample(ctx.config().own_nodes()) {
if let Err(err) = broker_sender
.send(BrokerEvent::Outgoing { addr })
.await
{
error!(target: "koibumi", "{}", err);
}
}
}
},
None => break,
},
};
}
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(0)).await {
error!(target: "koibumi", "{}", err);
}
}