koibumi-node-sync 0.0.0

A Bitmessage node implementation as a library for Koibumi (sync version), an experimental Bitmessage client
Documentation
use std::{
    collections::{HashMap, HashSet},
    convert::TryInto,
    fmt,
    iter::FromIterator,
    sync::{atomic::Ordering, Arc},
    time::Duration as StdDuration,
};

use crossbeam_channel::{select, Receiver};
use log::{debug, error};
use rand::seq::SliceRandom;
use rand_distr::{Binomial, Distribution};

use koibumi_core::{
    message::{self, NetAddr, Pack, Services, StreamNumber, UserAgent},
    net::SocketAddrExt,
    time::Time,
};

use crate::{
    connection::Direction,
    connection_loop::{Context, Event as BrokerEvent, ShutdownCommand},
    manager::Event as BmEvent,
    net::SocketAddrNode,
};

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Entry {
    stream: StreamNumber,
    addr: SocketAddrNode,
    last_seen: Time,
}

impl Entry {
    pub fn new(stream: StreamNumber, addr: SocketAddrNode, last_seen: Time) -> Self {
        Self {
            stream,
            addr,
            last_seen,
        }
    }
}

#[derive(Debug)]
pub enum Event {
    Add(Vec<Entry>),
    ConnectionSucceeded(SocketAddrNode, UserAgent),
    ConnectionFailed(SocketAddrNode),
    Disconnected(SocketAddrNode),
    Send(SocketAddrNode, bool),
}

/// A rating of the connectivity of a node.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Rating(i8);

impl Rating {
    const MAX: i8 = 10;
    const MIN: i8 = -10;

    /// Constructs a rating from a value.
    pub fn new(value: i8) -> Self {
        Self(value)
    }

    /// Returns the value as `i8`.
    pub fn as_i8(&self) -> i8 {
        self.0
    }

    /// Increments the rating.
    /// The maximum is `10`.
    pub fn increment(&mut self) {
        self.0 = i8::min(self.0 + 1, Self::MAX);
    }

    /// Decrements the rating.
    /// The minimum is `-10`.
    pub fn decrement(&mut self) {
        self.0 = i8::max(self.0 - 1, -Self::MIN);
    }
}

impl fmt::Display for Rating {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(f)
    }
}

impl From<i8> for Rating {
    fn from(value: i8) -> Self {
        Self(value)
    }
}

struct Info {
    #[allow(dead_code)]
    stream: StreamNumber,
    last_seen: Time,
    rating: Rating,
}

struct Record {
    stream: i32,
    address: String,
    last_seen: i64,
    rating: i32,
}

struct Nodes {
    ctx: Arc<Context>,
    conn: rusqlite::Connection,
    map: HashMap<SocketAddrNode, Info>,
    used_addrs: HashSet<SocketAddrNode>,
}

impl Nodes {
    fn new(ctx: Arc<Context>, conn: rusqlite::Connection) -> Self {
        if let Err(err) = conn.execute(
            "CREATE TABLE IF NOT EXISTS nodes (
                stream INTEGER NOT NULL,
                address TEXT NOT NULL,
                last_seen INTEGER NOT NULL,
                rating INTEGER NOT NULL,
                PRIMARY KEY(stream, address)
            )",
            rusqlite::params![],
        ) {
            error!("{}", err);
        }

        let mut map = HashMap::new();
        if let Ok(mut stmt) = conn.prepare("SELECT stream, address, last_seen, rating FROM nodes") {
            if let Ok(list) = stmt.query_map(rusqlite::params![], |row| {
                Ok(Record {
                    stream: row.get::<usize, i32>(0)?,
                    address: row.get::<usize, String>(1)?,
                    last_seen: row.get::<usize, i64>(2)?,
                    rating: row.get::<usize, i32>(3)?,
                })
            }) {
                for record in list {
                    if let Err(err) = record {
                        error!("{}", err);
                        continue;
                    }
                    let record = record.unwrap();
                    if record.stream < 0 {
                        continue;
                    }
                    let stream: StreamNumber = (record.stream as u32).into();
                    let addr = record.address.parse::<SocketAddrExt>();
                    if addr.is_err() {
                        continue;
                    }
                    let addr = addr.unwrap();
                    let addr: SocketAddrNode = addr.into();
                    if record.last_seen < 0 {
                        continue;
                    }
                    let last_seen: Time = (record.last_seen as u64).into();
                    if record.rating < -128 || record.rating > 127 {
                        continue;
                    }
                    let rating: Rating = (record.rating as i8).into();
                    map.insert(
                        addr,
                        Info {
                            stream,
                            last_seen,
                            rating,
                        },
                    );
                }
            }
        }

        let used_addrs = HashSet::new();
        let mut nodes = Self {
            ctx,
            conn,
            map,
            used_addrs,
        };

        nodes.retain();

        nodes
    }

    fn len(&self) -> usize {
        self.map.len()
    }

    fn retain(&mut self) -> Option<()> {
        if self.map.len() > self.ctx.config().max_nodes() {
            let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
            let mut list: Vec<SocketAddrNode> =
                keys.difference(&self.used_addrs).cloned().collect();
            list.sort_unstable_by(|a, b| {
                let a_info = &self.map[a];
                let b_info = &self.map[b];
                if a_info.rating == b_info.rating {
                    a_info.last_seen.cmp(&b_info.last_seen)
                } else {
                    a_info.rating.cmp(&b_info.rating)
                }
            });
            let mut trunc_amount = self.ctx.config().max_nodes() / 10;
            if trunc_amount == 0 {
                trunc_amount = usize::min(1, list.len());
            }
            list.truncate(trunc_amount);
            for addr in &list {
                self.map.remove(addr);
                if let SocketAddrNode::AddrExt(addr) = addr {
                    if let Err(err) = self.conn.execute(
                        "DELETE FROM nodes WHERE address=?1",
                        rusqlite::params![addr.to_string()],
                    ) {
                        error!("{}", err);
                    }
                }
            }
            return Some(());
        }
        None
    }

    fn insert(&mut self, entry: Entry, own_node: bool) -> Option<()> {
        match self.map.get_mut(&entry.addr) {
            Some(info) => {
                if entry.last_seen > info.last_seen {
                    info.last_seen = entry.last_seen;
                }
                if entry.stream.as_u32() <= i32::MAX as u32
                    && entry.last_seen.as_secs() <= i64::MAX as u64
                {
                    if let SocketAddrNode::AddrExt(addr) = entry.addr {
                        if let Err(err) = self.conn.execute(
                            "UPDATE nodes SET last_seen=?1 WHERE stream=?2 and address=?3",
                            rusqlite::params![
                                entry.last_seen.as_secs() as i64,
                                entry.stream.as_u32() as i32,
                                addr.to_string()
                            ],
                        ) {
                            error!("{}", err);
                        }
                    }
                }
                None
            }
            None => {
                if self.ctx.config().is_connectable_to(&entry.addr)
                    && self.ctx.config().stream_numbers().contains(entry.stream)
                {
                    debug!("addr: {}", entry.addr);
                    let rating: Rating = if own_node {
                        Rating::MAX.into()
                    } else {
                        0.into()
                    };
                    self.map.insert(
                        entry.addr.clone(),
                        Info {
                            stream: entry.stream,
                            last_seen: entry.last_seen,
                            rating: rating.clone(),
                        },
                    );
                    if entry.stream.as_u32() <= i32::MAX as u32
                        && entry.last_seen.as_secs() <= i64::MAX as u64
                    {
                        if let SocketAddrNode::AddrExt(addr) = entry.addr {
                            if let Err(err) = self.conn.execute(
                                "INSERT INTO nodes (
                                        stream, address, last_seen, rating
                                    ) VALUES (?1, ?2, ?3, ?4)",
                                rusqlite::params![
                                    entry.stream.as_u32() as i32,
                                    addr.to_string(),
                                    entry.last_seen.as_secs() as i64,
                                    rating.as_i8() as i32
                                ],
                            ) {
                                error!("{}", err);
                            }
                        }
                    }
                    return Some(());
                }
                None
            }
        }
    }

    fn increment(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
        let now = Time::now();
        if let Some(info) = self.map.get_mut(addr) {
            info.last_seen = now;
            info.rating.increment();

            if info.stream.as_u32() <= i32::MAX as u32
                && info.last_seen.as_secs() <= i64::MAX as u64
            {
                if let SocketAddrNode::AddrExt(addr) = addr {
                    if let Err(err) = self.conn.execute(
                        "UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3",
                        rusqlite::params![
                            info.rating.as_i8() as i64,
                            info.stream.as_u32() as i32,
                            addr.to_string()
                        ],
                    ) {
                        error!("{}", err);
                    }
                }
            }
            return Some(info.rating.clone());
        }
        None
    }

    fn decrement(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
        if let Some(info) = self.map.get_mut(addr) {
            info.rating.decrement();

            if info.stream.as_u32() <= i32::MAX as u32
                && info.last_seen.as_secs() <= i64::MAX as u64
            {
                if let SocketAddrNode::AddrExt(addr) = addr {
                    if let Err(err) = self.conn.execute(
                        "UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3",
                        rusqlite::params![
                            info.rating.as_i8() as i64,
                            info.stream.as_u32() as i32,
                            addr.to_string()
                        ],
                    ) {
                        error!("{}", err);
                    }
                }
            }
            return Some(info.rating.clone());
        }
        None
    }

    fn reclaim(&mut self, addr: &SocketAddrNode) {
        self.used_addrs.remove(addr);
    }

    fn sample_list(&self) -> Vec<NetAddr> {
        let mut list: Vec<SocketAddrNode> = self.map.keys().cloned().collect();
        list.sort_unstable_by(|a, b| {
            let a_info = &self.map[a];
            let b_info = &self.map[b];
            if a_info.rating == b_info.rating {
                b_info.last_seen.cmp(&a_info.last_seen)
            } else {
                b_info.rating.cmp(&a_info.rating)
            }
        });
        list.truncate(1000);
        list.shuffle(&mut rand::thread_rng());
        let mut addr_list = Vec::with_capacity(list.len());
        for addr in list {
            let info = &self.map[&addr];
            let addr = addr.try_into();
            if let Err(err) = addr {
                error!("{}", err);
                continue;
            }
            let addr: SocketAddrExt = addr.unwrap();
            if let Ok(addr) = addr.try_into() {
                addr_list.push(NetAddr::new(
                    info.last_seen,
                    info.stream,
                    Services::NETWORK,
                    addr,
                ));
            }
        }
        addr_list
    }

    fn sample(&mut self, own_nodes: &[SocketAddrExt]) -> Option<SocketAddrNode> {
        let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
        let list: HashSet<SocketAddrNode> = keys.difference(&self.used_addrs).cloned().collect();
        let own_nodes: HashSet<SocketAddrNode> =
            HashSet::from_iter(own_nodes.iter().cloned().map(|a| a.into()));
        let mut list: Vec<SocketAddrNode> = list.difference(&own_nodes).cloned().collect();
        list.sort_unstable_by(|a, b| {
            let a_info = &self.map[a];
            let b_info = &self.map[b];
            if a_info.rating == b_info.rating {
                b_info.last_seen.cmp(&a_info.last_seen)
            } else {
                b_info.rating.cmp(&a_info.rating)
            }
        });
        if !list.is_empty() {
            let bin = Binomial::new(list.len() as u64 * 2 - 1, 0.5).unwrap();
            let v = bin.sample(&mut rand::thread_rng());
            let i = if (v as usize) < list.len() {
                list.len() - 1 - v as usize
            } else {
                v as usize - list.len()
            };
            let sa = &list[i];
            self.used_addrs.insert(sa.clone());
            return Some(sa.clone());
        }
        None
    }
}

pub fn manage(
    ctx: Arc<Context>,
    receiver: Receiver<Event>,
    shutdown_receiver: Receiver<ShutdownCommand>,
) {
    let broker_sender = ctx.broker_sender().clone();
    let bm_event_sender = ctx.bm_event_sender().clone();

    let conn = rusqlite::Connection::open(ctx.db_path());
    if let Err(err) = conn {
        error!("{}", err);
        return;
    }
    let conn = conn.unwrap();

    let mut nodes = Nodes::new(Arc::clone(&ctx), conn);

    if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
        error!("{}", err);
    }

    let interval = crossbeam_channel::tick(StdDuration::from_secs(4));

    loop {
        if ctx.aborted().load(Ordering::SeqCst) {
            break;
        }
        select! {
            recv(receiver) -> v => match v {
                Ok(event) => match event {
                    Event::Add(entries) => {
                        for entry in entries {
                            if nodes.retain().is_some() {
                                if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
                                    error!("{}", err);
                                }
                            }

                            let own_node = if let SocketAddrNode::AddrExt(addr) = &entry.addr {
                                ctx.config().own_nodes().contains(addr)
                            } else {
                                false
                            };
                            if nodes.insert(entry, own_node).is_some() {
                                if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
                                    error!("{}", err);
                                }
                            }
                        }
                    }
                    Event::ConnectionSucceeded(addr, user_agent) => {
                        if let Some(rating) = nodes.increment(&addr) {
                            if let Err(err) = bm_event_sender
                                .send(BmEvent::Established {
                                    addr: addr.clone(),
                                    user_agent,
                                    rating,
                                })
                            {
                                error!("{}", err)
                            }
                        }
                    }
                    Event::ConnectionFailed(addr) => {
                        let own_node = if let SocketAddrNode::AddrExt(addr) = &addr {
                            ctx.config().own_nodes().contains(addr)
                        } else {
                            false
                        };
                        if !own_node {
                            nodes.decrement(&addr);
                        }
                    }
                    Event::Disconnected(addr) => {
                        nodes.reclaim(&addr);
                    }
                    Event::Send(addr, close) => {
                        let addr_list = nodes.sample_list();
                        if !addr_list.is_empty() {
                            let message = message::Addr::new(addr_list).unwrap();
                            let packet = message.pack(ctx.config().core()).unwrap();
                            if let Err(err) = broker_sender
                                .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
                            {
                                error!("{}", err);
                            }
                        }
                        if close {
                            let error = message::Error::new(2.into(),
                                "Server full, please try again later.".as_bytes().to_vec().into());
                            let packet = error.pack(ctx.config().core()).unwrap();
                            if let Err(err) = broker_sender
                                .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
                            {
                                error!("{}", err);
                            }
                            if let Err(err) = broker_sender
                                .send(BrokerEvent::Close { addr })
                            {
                                error!("{}", err);
                            }
                        }
                    }
                },
                Err(_err) => break,
            },
            recv(shutdown_receiver) -> _v => break,
            recv(interval) -> v => match v {
                Ok(_) => {
                    let initiated = ctx.initiated(Direction::Outgoing).load(Ordering::SeqCst);
                    let connected = ctx.connected(Direction::Outgoing).load(Ordering::SeqCst);
                    let established = ctx.established(Direction::Outgoing).load(Ordering::SeqCst);
                    if established >= ctx.config().max_outgoing_established() {
                        if initiated > connected {
                            if let Err(err) = broker_sender
                                .send(BrokerEvent::AbortPendings)
                            {
                                error!("{}", err);
                            }
                        }
                    } else if initiated < ctx.config().max_outgoing_initiated()
                            && initiated + ctx.config().own_nodes().len() < nodes.len() {
                        if let Some(addr) = nodes.sample(ctx.config().own_nodes()) {
                            if let Err(err) = broker_sender
                                .send(BrokerEvent::Outgoing { addr })
                            {
                                error!("{}", err);
                            }
                        }
                    }
                },
                Err(_err) => break,
            },
        };
    }
    if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(0)) {
        error!("{}", err);
    }
}