koibumi_node_sync/
node_manager.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::TryInto,
4    fmt,
5    iter::FromIterator,
6    sync::{atomic::Ordering, Arc},
7    time::Duration as StdDuration,
8};
9
10use crossbeam_channel::{select, Receiver};
11use log::{debug, error};
12use rand::seq::SliceRandom;
13use rand_distr::{Binomial, Distribution};
14
15use koibumi_core::{
16    message::{self, NetAddr, Pack, Services, StreamNumber, UserAgent},
17    net::SocketAddrExt,
18    time::Time,
19};
20
21use crate::{
22    connection::Direction,
23    connection_loop::{Context, Event as BrokerEvent, ShutdownCommand},
24    manager::Event as BmEvent,
25    net::SocketAddrNode,
26};
27
28#[derive(Clone, PartialEq, Eq, Hash, Debug)]
29pub struct Entry {
30    stream: StreamNumber,
31    addr: SocketAddrNode,
32    last_seen: Time,
33}
34
35impl Entry {
36    pub fn new(stream: StreamNumber, addr: SocketAddrNode, last_seen: Time) -> Self {
37        Self {
38            stream,
39            addr,
40            last_seen,
41        }
42    }
43}
44
45#[derive(Debug)]
46pub enum Event {
47    Add(Vec<Entry>),
48    ConnectionSucceeded(SocketAddrNode, UserAgent),
49    ConnectionFailed(SocketAddrNode),
50    Disconnected(SocketAddrNode),
51    Send(SocketAddrNode, bool),
52}
53
54/// A rating of the connectivity of a node.
55#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
56pub struct Rating(i8);
57
58impl Rating {
59    const MAX: i8 = 10;
60    const MIN: i8 = -10;
61
62    /// Constructs a rating from a value.
63    pub fn new(value: i8) -> Self {
64        Self(value)
65    }
66
67    /// Returns the value as `i8`.
68    pub fn as_i8(&self) -> i8 {
69        self.0
70    }
71
72    /// Increments the rating.
73    /// The maximum is `10`.
74    pub fn increment(&mut self) {
75        self.0 = i8::min(self.0 + 1, Self::MAX);
76    }
77
78    /// Decrements the rating.
79    /// The minimum is `-10`.
80    pub fn decrement(&mut self) {
81        self.0 = i8::max(self.0 - 1, -Self::MIN);
82    }
83}
84
85impl fmt::Display for Rating {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        self.0.fmt(f)
88    }
89}
90
91impl From<i8> for Rating {
92    fn from(value: i8) -> Self {
93        Self(value)
94    }
95}
96
97struct Info {
98    #[allow(dead_code)]
99    stream: StreamNumber,
100    last_seen: Time,
101    rating: Rating,
102}
103
104struct Record {
105    stream: i32,
106    address: String,
107    last_seen: i64,
108    rating: i32,
109}
110
111struct Nodes {
112    ctx: Arc<Context>,
113    conn: rusqlite::Connection,
114    map: HashMap<SocketAddrNode, Info>,
115    used_addrs: HashSet<SocketAddrNode>,
116}
117
118impl Nodes {
119    fn new(ctx: Arc<Context>, conn: rusqlite::Connection) -> Self {
120        if let Err(err) = conn.execute(
121            "CREATE TABLE IF NOT EXISTS nodes (
122                stream INTEGER NOT NULL,
123                address TEXT NOT NULL,
124                last_seen INTEGER NOT NULL,
125                rating INTEGER NOT NULL,
126                PRIMARY KEY(stream, address)
127            )",
128            rusqlite::params![],
129        ) {
130            error!("{}", err);
131        }
132
133        let mut map = HashMap::new();
134        if let Ok(mut stmt) = conn.prepare("SELECT stream, address, last_seen, rating FROM nodes") {
135            if let Ok(list) = stmt.query_map(rusqlite::params![], |row| {
136                Ok(Record {
137                    stream: row.get::<usize, i32>(0)?,
138                    address: row.get::<usize, String>(1)?,
139                    last_seen: row.get::<usize, i64>(2)?,
140                    rating: row.get::<usize, i32>(3)?,
141                })
142            }) {
143                for record in list {
144                    if let Err(err) = record {
145                        error!("{}", err);
146                        continue;
147                    }
148                    let record = record.unwrap();
149                    if record.stream < 0 {
150                        continue;
151                    }
152                    let stream: StreamNumber = (record.stream as u32).into();
153                    let addr = record.address.parse::<SocketAddrExt>();
154                    if addr.is_err() {
155                        continue;
156                    }
157                    let addr = addr.unwrap();
158                    let addr: SocketAddrNode = addr.into();
159                    if record.last_seen < 0 {
160                        continue;
161                    }
162                    let last_seen: Time = (record.last_seen as u64).into();
163                    if record.rating < -128 || record.rating > 127 {
164                        continue;
165                    }
166                    let rating: Rating = (record.rating as i8).into();
167                    map.insert(
168                        addr,
169                        Info {
170                            stream,
171                            last_seen,
172                            rating,
173                        },
174                    );
175                }
176            }
177        }
178
179        let used_addrs = HashSet::new();
180        let mut nodes = Self {
181            ctx,
182            conn,
183            map,
184            used_addrs,
185        };
186
187        nodes.retain();
188
189        nodes
190    }
191
192    fn len(&self) -> usize {
193        self.map.len()
194    }
195
196    fn retain(&mut self) -> Option<()> {
197        if self.map.len() > self.ctx.config().max_nodes() {
198            let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
199            let mut list: Vec<SocketAddrNode> =
200                keys.difference(&self.used_addrs).cloned().collect();
201            list.sort_unstable_by(|a, b| {
202                let a_info = &self.map[a];
203                let b_info = &self.map[b];
204                if a_info.rating == b_info.rating {
205                    a_info.last_seen.cmp(&b_info.last_seen)
206                } else {
207                    a_info.rating.cmp(&b_info.rating)
208                }
209            });
210            let mut trunc_amount = self.ctx.config().max_nodes() / 10;
211            if trunc_amount == 0 {
212                trunc_amount = usize::min(1, list.len());
213            }
214            list.truncate(trunc_amount);
215            for addr in &list {
216                self.map.remove(addr);
217                if let SocketAddrNode::AddrExt(addr) = addr {
218                    if let Err(err) = self.conn.execute(
219                        "DELETE FROM nodes WHERE address=?1",
220                        rusqlite::params![addr.to_string()],
221                    ) {
222                        error!("{}", err);
223                    }
224                }
225            }
226            return Some(());
227        }
228        None
229    }
230
231    fn insert(&mut self, entry: Entry, own_node: bool) -> Option<()> {
232        match self.map.get_mut(&entry.addr) {
233            Some(info) => {
234                if entry.last_seen > info.last_seen {
235                    info.last_seen = entry.last_seen;
236                }
237                if entry.stream.as_u32() <= i32::MAX as u32
238                    && entry.last_seen.as_secs() <= i64::MAX as u64
239                {
240                    if let SocketAddrNode::AddrExt(addr) = entry.addr {
241                        if let Err(err) = self.conn.execute(
242                            "UPDATE nodes SET last_seen=?1 WHERE stream=?2 and address=?3",
243                            rusqlite::params![
244                                entry.last_seen.as_secs() as i64,
245                                entry.stream.as_u32() as i32,
246                                addr.to_string()
247                            ],
248                        ) {
249                            error!("{}", err);
250                        }
251                    }
252                }
253                None
254            }
255            None => {
256                if self.ctx.config().is_connectable_to(&entry.addr)
257                    && self.ctx.config().stream_numbers().contains(entry.stream)
258                {
259                    debug!("addr: {}", entry.addr);
260                    let rating: Rating = if own_node {
261                        Rating::MAX.into()
262                    } else {
263                        0.into()
264                    };
265                    self.map.insert(
266                        entry.addr.clone(),
267                        Info {
268                            stream: entry.stream,
269                            last_seen: entry.last_seen,
270                            rating: rating.clone(),
271                        },
272                    );
273                    if entry.stream.as_u32() <= i32::MAX as u32
274                        && entry.last_seen.as_secs() <= i64::MAX as u64
275                    {
276                        if let SocketAddrNode::AddrExt(addr) = entry.addr {
277                            if let Err(err) = self.conn.execute(
278                                "INSERT INTO nodes (
279                                        stream, address, last_seen, rating
280                                    ) VALUES (?1, ?2, ?3, ?4)",
281                                rusqlite::params![
282                                    entry.stream.as_u32() as i32,
283                                    addr.to_string(),
284                                    entry.last_seen.as_secs() as i64,
285                                    rating.as_i8() as i32
286                                ],
287                            ) {
288                                error!("{}", err);
289                            }
290                        }
291                    }
292                    return Some(());
293                }
294                None
295            }
296        }
297    }
298
299    fn increment(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
300        let now = Time::now();
301        if let Some(info) = self.map.get_mut(addr) {
302            info.last_seen = now;
303            info.rating.increment();
304
305            if info.stream.as_u32() <= i32::MAX as u32
306                && info.last_seen.as_secs() <= i64::MAX as u64
307            {
308                if let SocketAddrNode::AddrExt(addr) = addr {
309                    if let Err(err) = self.conn.execute(
310                        "UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3",
311                        rusqlite::params![
312                            info.rating.as_i8() as i64,
313                            info.stream.as_u32() as i32,
314                            addr.to_string()
315                        ],
316                    ) {
317                        error!("{}", err);
318                    }
319                }
320            }
321            return Some(info.rating.clone());
322        }
323        None
324    }
325
326    fn decrement(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
327        if let Some(info) = self.map.get_mut(addr) {
328            info.rating.decrement();
329
330            if info.stream.as_u32() <= i32::MAX as u32
331                && info.last_seen.as_secs() <= i64::MAX as u64
332            {
333                if let SocketAddrNode::AddrExt(addr) = addr {
334                    if let Err(err) = self.conn.execute(
335                        "UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3",
336                        rusqlite::params![
337                            info.rating.as_i8() as i64,
338                            info.stream.as_u32() as i32,
339                            addr.to_string()
340                        ],
341                    ) {
342                        error!("{}", err);
343                    }
344                }
345            }
346            return Some(info.rating.clone());
347        }
348        None
349    }
350
351    fn reclaim(&mut self, addr: &SocketAddrNode) {
352        self.used_addrs.remove(addr);
353    }
354
355    fn sample_list(&self) -> Vec<NetAddr> {
356        let mut list: Vec<SocketAddrNode> = self.map.keys().cloned().collect();
357        list.sort_unstable_by(|a, b| {
358            let a_info = &self.map[a];
359            let b_info = &self.map[b];
360            if a_info.rating == b_info.rating {
361                b_info.last_seen.cmp(&a_info.last_seen)
362            } else {
363                b_info.rating.cmp(&a_info.rating)
364            }
365        });
366        list.truncate(1000);
367        list.shuffle(&mut rand::thread_rng());
368        let mut addr_list = Vec::with_capacity(list.len());
369        for addr in list {
370            let info = &self.map[&addr];
371            let addr = addr.try_into();
372            if let Err(err) = addr {
373                error!("{}", err);
374                continue;
375            }
376            let addr: SocketAddrExt = addr.unwrap();
377            if let Ok(addr) = addr.try_into() {
378                addr_list.push(NetAddr::new(
379                    info.last_seen,
380                    info.stream,
381                    Services::NETWORK,
382                    addr,
383                ));
384            }
385        }
386        addr_list
387    }
388
389    fn sample(&mut self, own_nodes: &[SocketAddrExt]) -> Option<SocketAddrNode> {
390        let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
391        let list: HashSet<SocketAddrNode> = keys.difference(&self.used_addrs).cloned().collect();
392        let own_nodes: HashSet<SocketAddrNode> =
393            HashSet::from_iter(own_nodes.iter().cloned().map(|a| a.into()));
394        let mut list: Vec<SocketAddrNode> = list.difference(&own_nodes).cloned().collect();
395        list.sort_unstable_by(|a, b| {
396            let a_info = &self.map[a];
397            let b_info = &self.map[b];
398            if a_info.rating == b_info.rating {
399                b_info.last_seen.cmp(&a_info.last_seen)
400            } else {
401                b_info.rating.cmp(&a_info.rating)
402            }
403        });
404        if !list.is_empty() {
405            let bin = Binomial::new(list.len() as u64 * 2 - 1, 0.5).unwrap();
406            let v = bin.sample(&mut rand::thread_rng());
407            let i = if (v as usize) < list.len() {
408                list.len() - 1 - v as usize
409            } else {
410                v as usize - list.len()
411            };
412            let sa = &list[i];
413            self.used_addrs.insert(sa.clone());
414            return Some(sa.clone());
415        }
416        None
417    }
418}
419
420pub fn manage(
421    ctx: Arc<Context>,
422    receiver: Receiver<Event>,
423    shutdown_receiver: Receiver<ShutdownCommand>,
424) {
425    let broker_sender = ctx.broker_sender().clone();
426    let bm_event_sender = ctx.bm_event_sender().clone();
427
428    let conn = rusqlite::Connection::open(ctx.db_path());
429    if let Err(err) = conn {
430        error!("{}", err);
431        return;
432    }
433    let conn = conn.unwrap();
434
435    let mut nodes = Nodes::new(Arc::clone(&ctx), conn);
436
437    if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
438        error!("{}", err);
439    }
440
441    let interval = crossbeam_channel::tick(StdDuration::from_secs(4));
442
443    loop {
444        if ctx.aborted().load(Ordering::SeqCst) {
445            break;
446        }
447        select! {
448            recv(receiver) -> v => match v {
449                Ok(event) => match event {
450                    Event::Add(entries) => {
451                        for entry in entries {
452                            if nodes.retain().is_some() {
453                                if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
454                                    error!("{}", err);
455                                }
456                            }
457
458                            let own_node = if let SocketAddrNode::AddrExt(addr) = &entry.addr {
459                                ctx.config().own_nodes().contains(addr)
460                            } else {
461                                false
462                            };
463                            if nodes.insert(entry, own_node).is_some() {
464                                if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())) {
465                                    error!("{}", err);
466                                }
467                            }
468                        }
469                    }
470                    Event::ConnectionSucceeded(addr, user_agent) => {
471                        if let Some(rating) = nodes.increment(&addr) {
472                            if let Err(err) = bm_event_sender
473                                .send(BmEvent::Established {
474                                    addr: addr.clone(),
475                                    user_agent,
476                                    rating,
477                                })
478                            {
479                                error!("{}", err)
480                            }
481                        }
482                    }
483                    Event::ConnectionFailed(addr) => {
484                        let own_node = if let SocketAddrNode::AddrExt(addr) = &addr {
485                            ctx.config().own_nodes().contains(addr)
486                        } else {
487                            false
488                        };
489                        if !own_node {
490                            nodes.decrement(&addr);
491                        }
492                    }
493                    Event::Disconnected(addr) => {
494                        nodes.reclaim(&addr);
495                    }
496                    Event::Send(addr, close) => {
497                        let addr_list = nodes.sample_list();
498                        if !addr_list.is_empty() {
499                            let message = message::Addr::new(addr_list).unwrap();
500                            let packet = message.pack(ctx.config().core()).unwrap();
501                            if let Err(err) = broker_sender
502                                .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
503                            {
504                                error!("{}", err);
505                            }
506                        }
507                        if close {
508                            let error = message::Error::new(2.into(),
509                                "Server full, please try again later.".as_bytes().to_vec().into());
510                            let packet = error.pack(ctx.config().core()).unwrap();
511                            if let Err(err) = broker_sender
512                                .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
513                            {
514                                error!("{}", err);
515                            }
516                            if let Err(err) = broker_sender
517                                .send(BrokerEvent::Close { addr })
518                            {
519                                error!("{}", err);
520                            }
521                        }
522                    }
523                },
524                Err(_err) => break,
525            },
526            recv(shutdown_receiver) -> _v => break,
527            recv(interval) -> v => match v {
528                Ok(_) => {
529                    let initiated = ctx.initiated(Direction::Outgoing).load(Ordering::SeqCst);
530                    let connected = ctx.connected(Direction::Outgoing).load(Ordering::SeqCst);
531                    let established = ctx.established(Direction::Outgoing).load(Ordering::SeqCst);
532                    if established >= ctx.config().max_outgoing_established() {
533                        if initiated > connected {
534                            if let Err(err) = broker_sender
535                                .send(BrokerEvent::AbortPendings)
536                            {
537                                error!("{}", err);
538                            }
539                        }
540                    } else if initiated < ctx.config().max_outgoing_initiated()
541                            && initiated + ctx.config().own_nodes().len() < nodes.len() {
542                        if let Some(addr) = nodes.sample(ctx.config().own_nodes()) {
543                            if let Err(err) = broker_sender
544                                .send(BrokerEvent::Outgoing { addr })
545                            {
546                                error!("{}", err);
547                            }
548                        }
549                    }
550                },
551                Err(_err) => break,
552            },
553        };
554    }
555    if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(0)) {
556        error!("{}", err);
557    }
558}