use std::{
collections::{HashMap, HashSet},
convert::TryInto,
fmt,
iter::FromIterator,
sync::{atomic::Ordering, Arc},
time::Duration as StdDuration,
};
use crossbeam_channel::{select, Receiver};
use log::{debug, error};
use rand::seq::SliceRandom;
use rand_distr::{Binomial, Distribution};
use koibumi_core::{
message::{self, NetAddr, Pack, Services, StreamNumber, UserAgent},
net::SocketAddrExt,
time::Time,
};
use crate::{
connection::Direction,
connection_loop::{Context, Event as BrokerEvent, ShutdownCommand},
manager::Event as BmEvent,
net::SocketAddrNode,
};
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Entry {
stream: StreamNumber,
addr: SocketAddrNode,
last_seen: Time,
}
impl Entry {
pub fn new(stream: StreamNumber, addr: SocketAddrNode, last_seen: Time) -> Self {
Self {
stream,
addr,
last_seen,
}
}
}
#[derive(Debug)]
pub enum Event {
Add(Vec<Entry>),
ConnectionSucceeded(SocketAddrNode, UserAgent),
ConnectionFailed(SocketAddrNode),
Disconnected(SocketAddrNode),
Send(SocketAddrNode, bool),
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Rating(i8);
impl Rating {
const MAX: i8 = 10;
const MIN: i8 = -10;
pub fn new(value: i8) -> Self {
Self(value)
}
pub fn as_i8(&self) -> i8 {
self.0
}
pub fn increment(&mut self) {
self.0 = i8::min(self.0 + 1, Self::MAX);
}
pub fn decrement(&mut self) {
self.0 = i8::max(self.0 - 1, -Self::MIN);
}
}
impl fmt::Display for Rating {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<i8> for Rating {
fn from(value: i8) -> Self {
Self(value)
}
}
struct Info {
#[allow(dead_code)]
stream: StreamNumber,
last_seen: Time,
rating: Rating,
}
struct Record {
stream: i32,
address: String,
last_seen: i64,
rating: i32,
}
struct Nodes {
ctx: Arc<Context>,
conn: rusqlite::Connection,
map: HashMap<SocketAddrNode, Info>,
used_addrs: HashSet<SocketAddrNode>,
}
impl Nodes {
fn new(ctx: Arc<Context>, conn: rusqlite::Connection) -> Self {
if let Err(err) = conn.execute(
"CREATE TABLE IF NOT EXISTS nodes (
stream INTEGER NOT NULL,
address TEXT NOT NULL,
last_seen INTEGER NOT NULL,
rating INTEGER NOT NULL,
PRIMARY KEY(stream, address)
)",
rusqlite::params![],
) {
error!("{}", err);
}
let mut map = HashMap::new();
if let Ok(mut stmt) = conn.prepare("SELECT stream, address, last_seen, rating FROM nodes") {
if let Ok(list) = stmt.query_map(rusqlite::params![], |row| {
Ok(Record {
stream: row.get::<usize, i32>(0)?,
address: row.get::<usize, String>(1)?,
last_seen: row.get::<usize, i64>(2)?,
rating: row.get::<usize, i32>(3)?,
})
}) {
for record in list {
if let Err(err) = record {
error!("{}", err);
continue;
}
let record = record.unwrap();
if record.stream < 0 {
continue;
}
let stream: StreamNumber = (record.stream as u32).into();
let addr = record.address.parse::<SocketAddrExt>();
if addr.is_err() {
continue;
}
let addr = addr.unwrap();
let addr: SocketAddrNode = addr.into();
if record.last_seen < 0 {
continue;
}
let last_seen: Time = (record.last_seen as u64).into();
if record.rating < -128 || record.rating > 127 {
continue;
}
let rating: Rating = (record.rating as i8).into();
map.insert(
addr,
Info {
stream,
last_seen,
rating,
},
);
}
}
}
let used_addrs = HashSet::new();
let mut nodes = Self {
ctx,
conn,
map,
used_addrs,
};
nodes.retain();
nodes
}
fn len(&self) -> usize {
self.map.len()
}
fn retain(&mut self) -> Option<()> {
if self.map.len() > self.ctx.config().max_nodes() {
let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
let mut list: Vec<SocketAddrNode> =
keys.difference(&self.used_addrs).cloned().collect();
list.sort_unstable_by(|a, b| {
let a_info = &self.map[a];
let b_info = &self.map[b];
if a_info.rating == b_info.rating {
a_info.last_seen.cmp(&b_info.last_seen)
} else {
a_info.rating.cmp(&b_info.rating)
}
});
let mut trunc_amount = self.ctx.config().max_nodes() / 10;
if trunc_amount == 0 {
trunc_amount = usize::min(1, list.len());
}
list.truncate(trunc_amount);
for addr in &list {
self.map.remove(addr);
if let SocketAddrNode::AddrExt(addr) = addr {
if let Err(err) = self.conn.execute(
"DELETE FROM nodes WHERE address=?1",
rusqlite::params![addr.to_string()],
) {
error!("{}", err);
}
}
}
return Some(());
}
None
}
fn insert(&mut self, entry: Entry, own_node: bool) -> Option<()> {
match self.map.get_mut(&entry.addr) {
Some(info) => {
if entry.last_seen > info.last_seen {
info.last_seen = entry.last_seen;
}
if entry.stream.as_u32() <= i32::MAX as u32
&& entry.last_seen.as_secs() <= i64::MAX as u64
{
if let SocketAddrNode::AddrExt(addr) = entry.addr {
if let Err(err) = self.conn.execute(
"UPDATE nodes SET last_seen=?1 WHERE stream=?2 and address=?3",
rusqlite::params![
entry.last_seen.as_secs() as i64,
entry.stream.as_u32() as i32,
addr.to_string()
],
) {
error!("{}", err);
}
}
}
None
}
None => {
if self.ctx.config().is_connectable_to(&entry.addr)
&& self.ctx.config().stream_numbers().contains(entry.stream)
{
debug!("addr: {}", entry.addr);
let rating: Rating = if own_node {
Rating::MAX.into()
} else {
0.into()
};
self.map.insert(
entry.addr.clone(),
Info {
stream: entry.stream,
last_seen: entry.last_seen,
rating: rating.clone(),
},
);
if entry.stream.as_u32() <= i32::MAX as u32
&& entry.last_seen.as_secs() <= i64::MAX as u64
{
if let SocketAddrNode::AddrExt(addr) = entry.addr {
if let Err(err) = self.conn.execute(
"INSERT INTO nodes (
stream, address, last_seen, rating
) VALUES (?1, ?2, ?3, ?4)",
rusqlite::params![
entry.stream.as_u32() as i32,
addr.to_string(),
entry.last_seen.as_secs() as i64,
rating.as_i8() as i32
],
) {
error!("{}", err);
}
}
}
return Some(());
}
None
}
}
}
fn increment(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
let now = Time::now();
if let Some(info) = self.map.get_mut(addr) {
info.last_seen = now;
info.rating.increment();
if info.stream.as_u32() <= i32::MAX as u32
&& info.last_seen.as_secs() <= i64::MAX as u64
{
if let SocketAddrNode::AddrExt(addr) = addr {
if let Err(err) = self.conn.execute(
"UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3",
rusqlite::params![
info.rating.as_i8() as i64,
info.stream.as_u32() as i32,
addr.to_string()
],
) {
error!("{}", err);
}
}
}
return Some(info.rating.clone());
}
None
}
fn decrement(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
if let Some(info) = self.map.get_mut(addr) {
info.rating.decrement();
if info.stream.as_u32() <= i32::MAX as u32
&& info.last_seen.as_secs() <= i64::MAX as u64
{
if let SocketAddrNode::AddrExt(addr) = addr {
if let Err(err) = self.conn.execute(
"UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3",
rusqlite::params![
info.rating.as_i8() as i64,
info.stream.as_u32() as i32,
addr.to_string()
],
) {
error!("{}", err);
}
}
}
return Some(info.rating.clone());
}
None
}
fn reclaim(&mut self, addr: &SocketAddrNode) {
self.used_addrs.remove(addr);
}
fn sample_list(&self) -> Vec<NetAddr> {
let mut list: Vec<SocketAddrNode> = self.map.keys().cloned().collect();
list.sort_unstable_by(|a, b| {
let a_info = &self.map[a];
let b_info = &self.map[b];
if a_info.rating == b_info.rating {
b_info.last_seen.cmp(&a_info.last_seen)
} else {
b_info.rating.cmp(&a_info.rating)
}
});
list.truncate(1000);
list.shuffle(&mut rand::thread_rng());
let mut addr_list = Vec::with_capacity(list.len());
for addr in list {
let info = &self.map[&addr];
let addr = addr.try_into();
if let Err(err) = addr {
error!("{}", err);
continue;
}
let addr: SocketAddrExt = addr.unwrap();
if let Ok(addr) = addr.try_into() {
addr_list.push(NetAddr::new(
info.last_seen,
info.stream,
Services::NETWORK,
addr,
));
}
}
addr_list
}
fn sample(&mut self, own_nodes: &[SocketAddrExt]) -> Option<SocketAddrNode> {
let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
let list: HashSet<SocketAddrNode> = keys.difference(&self.used_addrs).cloned().collect();
let own_nodes: HashSet<SocketAddrNode> =
HashSet::from_iter(own_nodes.iter().cloned().map(|a| a.into()));
let mut list: Vec<SocketAddrNode> = list.difference(&own_nodes).cloned().collect();
list.sort_unstable_by(|a, b| {
let a_info = &self.map[a];
let b_info = &self.map[b];
if a_info.rating == b_info.rating {
b_info.last_seen.cmp(&a_info.last_seen)
} else {
b_info.rating.cmp(&a_info.rating)
}
});
if !list.is_empty() {
let bin = Binomial::new(list.len() as u64 * 2 - 1, 0.5).unwrap();
let v = bin.sample(&mut rand::thread_rng());
let i = if (v as usize) < list.len() {
list.len() - 1 - v as usize
} else {
v as usize - list.len()
};
let sa = &list[i];
self.used_addrs.insert(sa.clone());
return Some(sa.clone());
}
None
}
}
pub fn manage(
ctx: Arc<Context>,
receiver: Receiver<Event>,
shutdown_receiver: Receiver<ShutdownCommand>,
) {
let broker_sender = ctx.broker_sender().clone();
let bm_event_sender = ctx.bm_event_sender().clone();
let conn = rusqlite::Connection::open(ctx.db_path());
if let Err(err) = conn {
error!("{}", err);
return;
}
let conn = conn.unwrap();
let mut nodes = Nodes::new(Arc::clone(&ctx), conn);
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
error!("{}", err);
}
let interval = crossbeam_channel::tick(StdDuration::from_secs(4));
loop {
if ctx.aborted().load(Ordering::SeqCst) {
break;
}
select! {
recv(receiver) -> v => match v {
Ok(event) => match event {
Event::Add(entries) => {
for entry in entries {
if nodes.retain().is_some() {
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
error!("{}", err);
}
}
let own_node = if let SocketAddrNode::AddrExt(addr) = &entry.addr {
ctx.config().own_nodes().contains(addr)
} else {
false
};
if nodes.insert(entry, own_node).is_some() {
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
error!("{}", err);
}
}
}
}
Event::ConnectionSucceeded(addr, user_agent) => {
if let Some(rating) = nodes.increment(&addr) {
if let Err(err) = bm_event_sender
.send(BmEvent::Established {
addr: addr.clone(),
user_agent,
rating,
})
{
error!("{}", err)
}
}
}
Event::ConnectionFailed(addr) => {
let own_node = if let SocketAddrNode::AddrExt(addr) = &addr {
ctx.config().own_nodes().contains(addr)
} else {
false
};
if !own_node {
nodes.decrement(&addr);
}
}
Event::Disconnected(addr) => {
nodes.reclaim(&addr);
}
Event::Send(addr, close) => {
let addr_list = nodes.sample_list();
if !addr_list.is_empty() {
let message = message::Addr::new(addr_list).unwrap();
let packet = message.pack(ctx.config().core()).unwrap();
if let Err(err) = broker_sender
.send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
{
error!("{}", err);
}
}
if close {
let error = message::Error::new(2.into(),
"Server full, please try again later.".as_bytes().to_vec().into());
let packet = error.pack(ctx.config().core()).unwrap();
if let Err(err) = broker_sender
.send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
{
error!("{}", err);
}
if let Err(err) = broker_sender
.send(BrokerEvent::Close { addr })
{
error!("{}", err);
}
}
}
},
Err(_err) => break,
},
recv(shutdown_receiver) -> _v => break,
recv(interval) -> v => match v {
Ok(_) => {
let initiated = ctx.initiated(Direction::Outgoing).load(Ordering::SeqCst);
let connected = ctx.connected(Direction::Outgoing).load(Ordering::SeqCst);
let established = ctx.established(Direction::Outgoing).load(Ordering::SeqCst);
if established >= ctx.config().max_outgoing_established() {
if initiated > connected {
if let Err(err) = broker_sender
.send(BrokerEvent::AbortPendings)
{
error!("{}", err);
}
}
} else if initiated < ctx.config().max_outgoing_initiated()
&& initiated + ctx.config().own_nodes().len() < nodes.len() {
if let Some(addr) = nodes.sample(ctx.config().own_nodes()) {
if let Err(err) = broker_sender
.send(BrokerEvent::Outgoing { addr })
{
error!("{}", err);
}
}
}
},
Err(_err) => break,
},
};
}
if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(0)) {
error!("{}", err);
}
}