koibumi-node 0.0.8

A Bitmessage node implementation as a library for Koibumi, an experimental Bitmessage client
Documentation
use std::{
    collections::{hash_map::Keys, HashMap},
    convert::{TryFrom, TryInto},
    io::Cursor,
    sync::atomic::Ordering,
    time::Duration as StdDuration,
};

use async_std::{stream::interval, sync::Arc, task};
use futures::{
    channel::mpsc::{self, Receiver, Sender},
    select,
    sink::SinkExt,
    stream::StreamExt,
    FutureExt,
};
use log::error;
use rand::seq::SliceRandom;

use koibumi_core::{
    io::{ReadFrom, SizedReadFrom, SizedReadFromExact, WriteTo},
    message::{self, InvHash, Pack, StreamNumber},
    net::SocketAddrExt,
    object::{self, Header, ObjectKind},
    time::{Duration, Time},
};

use crate::{
    connection_loop::{Context, Event as BrokerEvent},
    constant::{OBJECT_FUTURE_LIFETIME, OBJECT_PAST_LIFETIME, REQUEST_EXPIRES},
    db,
    manager::Event as BmEvent,
    net::SocketAddrNode,
    object_processor::{process as process_objects, Event as ObjectEvent},
    pow_manager::Event as PowEvent,
};

#[derive(Debug)]
pub enum Event {
    Inv {
        addr: SocketAddrNode,
        list: Vec<InvHash>,
    },
    Object {
        addr: SocketAddrNode,
        object: message::Object,
    },
    Drop(InvHash),
    SendBigInv {
        addr: SocketAddrNode,
    },
    SendObjects {
        addr: SocketAddrNode,
        list: Vec<InvHash>,
    },
    Insert(message::Object),
}

#[derive(Clone, Debug)]
struct OwnNodeInfo {
    #[allow(dead_code)]
    stream: StreamNumber,
    expires: Time,
}

#[derive(sqlx::FromRow)]
struct OwnNode {
    stream: i32,
    address: String,
    expires: i64,
}

struct Objects {
    ctx: Arc<Context>,
    pool: db::SqlitePool,
    map: HashMap<InvHash, Header>,
    object_sender: Sender<ObjectEvent>,
    own_nodes: HashMap<SocketAddrExt, OwnNodeInfo>,
}

impl Objects {
    async fn new(
        ctx: Arc<Context>,
        pool: db::SqlitePool,
        object_sender: Sender<ObjectEvent>,
    ) -> Self {
        if let Err(err) = sqlx::query(
            "CREATE TABLE IF NOT EXISTS objects (
                hash BLOB NOT NULL PRIMARY KEY,
                message BLOB NOT NULL,
                header BLOB NOT NULL
            )",
        )
        .execute(pool.write())
        .await
        {
            error!("{}", err);
        }

        let mut map = HashMap::new();
        let mut onionpeers = Vec::new();
        if let Ok(list) =
            sqlx::query_as::<sqlx::Sqlite, (Vec<u8>, Vec<u8>)>("SELECT hash, header FROM objects")
                .fetch_all(pool.read())
                .await
        {
            for elem in list {
                if elem.0.len() != 32 {
                    continue;
                }
                let bytes: [u8; 32] = elem.0[..].try_into().unwrap();
                let hash = InvHash::new(bytes);
                let mut bytes = Cursor::new(elem.1);
                let header = Header::read_from(&mut bytes);
                if header.is_err() {
                    continue;
                }
                let header = header.unwrap();
                map.insert(hash.clone(), header.clone());

                if let Ok(kind) = ObjectKind::try_from(header.object_type()) {
                    if kind == ObjectKind::Onionpeer {
                        onionpeers.push(hash);
                    }
                }
            }
        }

        if let Err(err) = sqlx::query(
            "CREATE TABLE IF NOT EXISTS own_nodes (
                stream INTEGER NOT NULL,
                address TEXT NOT NULL,
                expires INTEGER NOT NULL,
                PRIMARY KEY(stream, address)
            )",
        )
        .execute(pool.write())
        .await
        {
            error!("{}", err);
        }

        for hash in onionpeers {
            let bytes = sqlx::query_as::<sqlx::Sqlite, (Vec<u8>,)>(
                "SELECT message FROM objects WHERE hash=?1",
            )
            .bind(hash.as_ref())
            .fetch_all(pool.read())
            .await;
            if bytes.is_err() {
                continue;
            }
            let bytes = bytes.unwrap();
            if bytes.is_empty() {
                continue;
            }
            let bytes = &bytes[0].0;

            let len = bytes.len();
            let mut r = Cursor::new(bytes);
            let message = message::Object::sized_read_from(&mut r, len);
            if message.is_err() {
                continue;
            }
            let message = message.unwrap();
            let mut r = Cursor::new(message.object_payload());
            let onionpeer =
                object::Onionpeer::sized_read_from(&mut r, message.object_payload().len());
            if onionpeer.is_err() {
                continue;
            }
            let addr = SocketAddrExt::try_from(onionpeer.unwrap());
            if addr.is_err() {
                continue;
            }
            let addr = addr.unwrap();
            if !ctx.config().own_nodes().contains(&addr) {
                continue;
            }

            let expires = sqlx::query_as::<sqlx::Sqlite, (i64,)>(
                "SELECT expires FROM own_nodes WHERE stream=?1 AND address=?2",
            )
            .bind(message.header().stream_number().as_u32() as i64)
            .bind(addr.to_string())
            .fetch_all(pool.read())
            .await;
            let expires = if expires.is_err() {
                None
            } else {
                let expires = expires.unwrap();
                if expires.is_empty() {
                    None
                } else {
                    Some(expires[0].0)
                }
            };
            if let Some(expires) = expires {
                if expires >= 0
                    && message.header().expires_time().as_secs() > expires as u64
                    && message.header().expires_time().as_secs() <= i64::MAX as u64
                {
                    if let Err(err) = sqlx::query(
                        "UPDATE own_nodes SET expires=?1 WHERE stream=?2 and address=?3",
                    )
                    .bind(message.header().expires_time().as_secs() as i64)
                    .bind(message.header().stream_number().as_u32() as i64)
                    .bind(addr.to_string())
                    .execute(pool.write())
                    .await
                    {
                        error!("{}", err);
                    }
                }
            } else if message.header().expires_time().as_secs() <= i64::MAX as u64 {
                if let Err(err) = sqlx::query(
                    "INSERT INTO own_nodes (
                            stream, address, expires
                        ) VALUES (?1, ?2, ?3)",
                )
                .bind(message.header().stream_number().as_u32() as i64)
                .bind(addr.to_string())
                .bind(message.header().expires_time().as_secs() as i64)
                .execute(pool.write())
                .await
                {
                    error!("{}", err);
                }
            }
        }

        let mut own_nodes: HashMap<SocketAddrExt, OwnNodeInfo> = HashMap::new();
        if let Ok(list) = sqlx::query_as::<sqlx::Sqlite, OwnNode>(
            "SELECT stream, address, expires FROM own_nodes",
        )
        .fetch_all(pool.read())
        .await
        {
            for record in list {
                if record.stream < 0 {
                    continue;
                }
                let stream: StreamNumber = (record.stream as u32).into();

                let addr = record.address.parse::<SocketAddrExt>();
                if addr.is_err() {
                    continue;
                }
                let addr = addr.unwrap();

                if record.expires < 0 {
                    continue;
                }
                let expires: Time = (record.expires as u64).into();

                if let Some(info) = own_nodes.get(&addr) {
                    if info.expires < expires {
                        own_nodes.insert(addr, OwnNodeInfo { stream, expires });
                    }
                } else {
                    own_nodes.insert(addr, OwnNodeInfo { stream, expires });
                }
            }
        }

        let mut objects = Self {
            ctx,
            pool,
            map,
            object_sender,
            own_nodes,
        };

        objects.retain().await;

        objects
    }

    fn is_empty(&self) -> bool {
        self.map.is_empty()
    }

    fn len(&self) -> usize {
        self.map.len()
    }

    fn contains_key(&self, k: &InvHash) -> bool {
        self.map.contains_key(k)
    }

    fn keys(&self) -> Keys<InvHash, Header> {
        self.map.keys()
    }

    async fn get(&self, k: &InvHash) -> Option<message::Object> {
        if !self.map.contains_key(k) {
            return None;
        }
        let list =
            sqlx::query_as::<sqlx::Sqlite, (Vec<u8>,)>("SELECT message FROM objects WHERE hash=?1")
                .bind(k.as_ref())
                .fetch_all(self.pool.read())
                .await;
        if let Err(err) = list {
            error!("{}", err);
            return None;
        }
        let list = list.unwrap();
        if list.is_empty() {
            return None;
        }
        message::Object::sized_read_from_exact(&list[0].0).ok()
    }

    async fn insert(&mut self, v: message::Object) -> Option<()> {
        let hash = v.inv_hash();
        if self.map.insert(hash.clone(), v.header().clone()).is_none() {
            let mut bytes = Vec::new();
            v.write_to(&mut bytes).unwrap();
            let mut header_bytes = Vec::new();
            v.header().write_to(&mut header_bytes).unwrap();
            if let Err(err) = sqlx::query(
                "INSERT INTO objects (
                        hash, message, header
                    ) VALUES (?1, ?2, ?3)",
            )
            .bind(hash.as_ref())
            .bind(bytes)
            .bind(header_bytes)
            .execute(self.pool.write())
            .await
            {
                error!("{}", err);
            }

            if let Ok(kind) = ObjectKind::try_from(v.header().object_type()) {
                if kind == ObjectKind::Onionpeer {
                    let onionpeer = {
                        let mut r = Cursor::new(v.object_payload());
                        object::Onionpeer::sized_read_from(&mut r, v.object_payload().len())
                    };
                    if let Ok(onionpeer) = onionpeer {
                        if let Ok(addr) = SocketAddrExt::try_from(onionpeer) {
                            if self.ctx.config().own_nodes().contains(&addr) {
                                match self.own_nodes.get_mut(&addr) {
                                    Some(info) => {
                                        if v.header().expires_time() > info.expires {
                                            info.expires = v.header().expires_time();
                                            if info.expires.as_secs() <= i64::MAX as u64 {
                                                if let Err(err) =
                                                    sqlx::query("UPDATE own_nodes SET expires=?1 WHERE stream=?2 and address=?3")
                                                        .bind(info.expires.as_secs() as i64)
                                                        .bind(info.stream.as_u32() as i64)
                                                        .bind(addr.to_string())
                                                        .execute(self.pool.write())
                                                        .await
                                                {
                                                    error!("{}", err);
                                                }
                                            }
                                        }
                                    }
                                    None => {
                                        let info = OwnNodeInfo {
                                            stream: v.header().stream_number(),
                                            expires: v.header().expires_time(),
                                        };
                                        self.own_nodes.insert(addr.clone(), info.clone());
                                        if info.expires.as_secs() <= i64::MAX as u64 {
                                            if let Err(err) = sqlx::query(
                                                "INSERT INTO own_node (
                                                        stream, address, expires
                                                    ) VALUES (?1, ?2, ?3)",
                                            )
                                            .bind(info.stream.as_u32() as i64)
                                            .bind(addr.to_string())
                                            .bind(info.expires.as_secs() as i64)
                                            .execute(self.pool.write())
                                            .await
                                            {
                                                error!("{}", err);
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }

            if let Err(err) = self.object_sender.send(ObjectEvent::Process(v)).await {
                error!("{}", err);
            }

            return None;
        }
        Some(())
    }

    async fn retain(&mut self) {
        let now = Time::now();
        let mut to_del = Vec::new();
        for (hash, header) in &self.map {
            let expires = header.expires_time();
            match expires.checked_add(OBJECT_PAST_LIFETIME) {
                Some(target) => {
                    if now > target {
                        to_del.push(hash.clone());
                        continue;
                    }
                }
                None => {
                    to_del.push(hash.clone());
                    continue;
                }
            }
            match now.checked_add(OBJECT_FUTURE_LIFETIME) {
                Some(target) => {
                    if expires > target {
                        to_del.push(hash.clone());
                        continue;
                    }
                }
                None => {
                    to_del.push(hash.clone());
                    continue;
                }
            }
        }
        for hash in &to_del {
            self.map.remove(&hash);

            if let Err(err) = sqlx::query("DELETE FROM objects WHERE hash=?1")
                .bind(hash.as_ref())
                .execute(self.pool.write())
                .await
            {
                error!("{}", err);
            }
        }
    }

    async fn check_own_nodes_expiration(&mut self, pow_sender: &mut Sender<PowEvent>) {
        let now = Time::now();
        for addr in self.ctx.config().own_nodes() {
            if let SocketAddrExt::OnionV3(_) = addr {
                match self.own_nodes.get_mut(&addr) {
                    Some(info) => {
                        if let Some(target) = now.checked_add(Duration::new(600)) {
                            if target >= info.expires {
                                if let Some(expires) =
                                    now.checked_add(Duration::new(7 * 24 * 60 * 60))
                                {
                                    advertise(pow_sender, info.stream, addr.clone(), expires).await;
                                    info.expires = expires;

                                    if info.expires.as_secs() <= i64::MAX as u64 {
                                        if let Err(err) =
                                            sqlx::query("UPDATE own_nodes SET expires=?1 WHERE stream=?2 and address=?3")
                                                .bind(info.expires.as_secs() as i64)
                                                .bind(info.stream.as_u32() as i64)
                                                .bind(addr.to_string())
                                                .execute(self.pool.write())
                                                .await
                                        {
                                            error!("{}", err);
                                        }
                                    }
                                }
                            }
                        }
                    }
                    None => {
                        if let Some(expires) = now.checked_add(Duration::new(7 * 24 * 60 * 60)) {
                            let stream = 1.into();
                            advertise(pow_sender, stream, addr.clone(), expires).await;
                            let info = OwnNodeInfo { stream, expires };
                            self.own_nodes.insert(addr.clone(), info.clone());

                            if info.expires.as_secs() <= i64::MAX as u64 {
                                if let Err(err) = sqlx::query(
                                    "INSERT INTO own_nodes (
                                            stream, address, expires
                                        ) VALUES (?1, ?2, ?3)",
                                )
                                .bind(info.stream.as_u32() as i64)
                                .bind(addr.to_string())
                                .bind(info.expires.as_secs() as i64)
                                .execute(self.pool.write())
                                .await
                                {
                                    error!("{}", err);
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}

async fn advertise(
    pow_sender: &mut Sender<PowEvent>,
    stream: StreamNumber,
    addr: SocketAddrExt,
    expires: Time,
) {
    let version: u64 = match &addr {
        SocketAddrExt::Ipv4(_) => 3,
        SocketAddrExt::Ipv6(_) => 3,
        SocketAddrExt::OnionV2(_) => 2,
        SocketAddrExt::OnionV3(_) => 3,
    };
    let header = object::Header::new(
        expires,
        object::ObjectKind::Onionpeer.into(),
        object::ObjectVersion::new(version),
        stream,
    );

    let onionpeer: object::Onionpeer = addr.into();
    let mut payload = Vec::new();
    onionpeer.write_to(&mut payload).unwrap();

    if let Err(err) = pow_sender.send(PowEvent::Perform { header, payload }).await {
        error!("{}", err);
    }
}

pub async fn manage(ctx: Arc<Context>, mut receiver: Receiver<Event>) {
    let mut broker_sender = ctx.broker_sender().clone();
    let mut bm_event_sender = ctx.bm_event_sender().clone();
    let mut pow_sender = ctx.pow_sender().clone();

    let (object_sender, object_receiver) = mpsc::channel(ctx.config().channel_buffer());
    task::spawn(process_objects(Arc::clone(&ctx), object_receiver));

    let mut clean_missing_objects_interval = interval(StdDuration::from_secs(60));
    let mut clean_objects_interval = interval(StdDuration::from_secs(7380));
    let mut objects = Objects::new(Arc::clone(&ctx), ctx.pool().clone(), object_sender).await;
    let mut missing_objects: HashMap<InvHash, Time> = HashMap::new();
    let mut uploaded: usize = 0;

    let mut own_nodes_interval = interval(StdDuration::from_secs(60));

    if let Err(err) = bm_event_sender
        .send(BmEvent::Objects {
            missing: missing_objects.len(),
            loaded: objects.len(),
            uploaded,
        })
        .await
    {
        error!("{}", err);
    }

    loop {
        if ctx.aborted().load(Ordering::SeqCst) {
            break;
        }
        select! {
            tick = clean_missing_objects_interval.next().fuse() => match tick {
                Some(_) => {
                    let now = Time::now();
                    missing_objects.retain(|_hash, time| {
                        match time.checked_add(REQUEST_EXPIRES) {
                            Some(target) => now < target,
                            None => false,
                        }
                    });

                    if let Err(err) = bm_event_sender
                        .send(BmEvent::Objects {
                            missing: missing_objects.len(),
                            loaded: objects.len(),
                            uploaded,
                        })
                        .await
                    {
                        error!("{}", err);
                    }
                },
                None => break,
            },
            tick = clean_objects_interval.next().fuse() => match tick {
                Some(_) => {
                    objects.retain().await;

                    if let Err(err) = bm_event_sender
                        .send(BmEvent::Objects {
                            missing: missing_objects.len(),
                            loaded: objects.len(),
                            uploaded,
                        })
                        .await
                    {
                        error!("{}", err);
                    }
                },
                None => break,
            },
            event = receiver.next().fuse() => match event {
                Some(event) => match event {
                    Event::Inv { addr, list } => {
                        let now = Time::now();
                        let mut new_list = Vec::new();
                        for hash in list {
                            if !objects.contains_key(&hash) {
                                missing_objects.insert(hash.clone(), now);
                                new_list.push(hash);
                            }
                        }
                        if let Err(err) = broker_sender
                            .send(BrokerEvent::ObjectsNewToMe { addr, list: new_list })
                            .await
                        {
                            error!("{}", err);
                        }

                        if let Err(err) = bm_event_sender
                            .send(BmEvent::Objects {
                                missing: missing_objects.len(),
                                loaded: objects.len(),
                                uploaded,
                            })
                            .await
                        {
                            error!("{}", err);
                        }
                    }
                    Event::Object { addr, object } => {
                        let hash = object.inv_hash();
                        if missing_objects.contains_key(&hash) {
                            objects.insert(object).await;
                            missing_objects.remove(&hash);
                            if let Err(err) = broker_sender
                                .send(BrokerEvent::ObjectsNewToHer { addr, list: vec![hash] })
                                .await
                            {
                                error!("{}", err);
                            }
                            if let Err(err) = bm_event_sender
                                .send(BmEvent::Objects {
                                    missing: missing_objects.len(),
                                    loaded: objects.len(),
                                    uploaded,
                                })
                                .await
                            {
                                error!("{}", err);
                            }
                        }
                    }
                    Event::Drop(hash) => {
                        missing_objects.remove(&hash);
                        if let Err(err) = bm_event_sender
                            .send(BmEvent::Objects {
                                missing: missing_objects.len(),
                                loaded: objects.len(),
                                uploaded,
                            })
                            .await
                        {
                            error!("{}", err);
                        }
                    }
                    Event::SendBigInv { addr } => {
                        if !objects.is_empty() {
                            let mut vec: Vec<InvHash> = objects.keys().cloned().collect();
                            vec.shuffle(&mut rand::thread_rng());
                            for chunk in vec.chunks(message::Inv::MAX_COUNT_FIXED) {
                                let inv = message::Inv::new(chunk.to_vec()).unwrap();
                                let packet = inv.pack(ctx.config().core()).unwrap();
                                if let Err(err) = broker_sender
                                    .send(BrokerEvent::Write {
                                        addr: addr.clone(),
                                        list: vec![packet],
                                    })
                                    .await
                                {
                                    error!("{}", err);
                                }
                            }
                        }
                    }
                    Event::SendObjects { addr, list } => {
                        let mut packets = Vec::new();
                        for hash in list {
                            if let Some(object) = objects.get(&hash).await {
                                let packet = object.pack(ctx.config().core()).unwrap();
                                packets.push(packet);
                            }
                        }
                        let count = packets.len();
                        if let Err(err) = broker_sender
                            .send(BrokerEvent::Write {
                                addr: addr.clone(),
                                list: packets,
                            })
                            .await
                        {
                            error!("{}", err);
                        } else {
                            uploaded += count;
                            if let Err(err) = bm_event_sender
                                .send(BmEvent::Objects {
                                    missing: missing_objects.len(),
                                    loaded: objects.len(),
                                    uploaded,
                                })
                                .await
                            {
                                error!("{}", err);
                            }
                        }
                    }
                    Event::Insert(object) => {
                        let hash = object.inv_hash();
                        objects.insert(object).await;
                        if let Err(err) = broker_sender
                            .send(BrokerEvent::PublishObjects(vec![hash]))
                            .await
                        {
                            error!("{}", err);
                        }
                        if let Err(err) = bm_event_sender
                            .send(BmEvent::Objects {
                                missing: missing_objects.len(),
                                loaded: objects.len(),
                                uploaded,
                            })
                            .await
                        {
                            error!("{}", err);
                        }
                    }
                },
                None => break,
            },
            tick = own_nodes_interval.next().fuse() => match tick {
                Some(_) => {
                    objects.check_own_nodes_expiration(&mut pow_sender).await;
                },
                None => break,
            },
        };
    }
    if let Err(err) = bm_event_sender
        .send(BmEvent::Objects {
            missing: 0,
            loaded: 0,
            uploaded: 0,
        })
        .await
    {
        error!("{}", err);
    }
}