koibumi_node/
node_manager.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::TryInto,
4    fmt,
5    iter::FromIterator,
6    sync::atomic::Ordering,
7    time::Duration as StdDuration,
8};
9
10use async_std::{stream::interval, sync::Arc};
11use futures::{channel::mpsc::Receiver, select, sink::SinkExt, stream::StreamExt, FutureExt};
12use log::{debug, error};
13use rand::seq::SliceRandom;
14use rand_distr::{Binomial, Distribution};
15
16use koibumi_core::{
17    message::{self, NetAddr, Pack, Services, StreamNumber, UserAgent},
18    net::SocketAddrExt,
19    time::Time,
20};
21
22use crate::{
23    connection::Direction,
24    connection_loop::{Context, Event as BrokerEvent},
25    db,
26    manager::Event as BmEvent,
27    net::SocketAddrNode,
28};
29
30#[derive(Clone, PartialEq, Eq, Hash, Debug)]
31pub struct Entry {
32    stream: StreamNumber,
33    addr: SocketAddrNode,
34    last_seen: Time,
35}
36
37impl Entry {
38    pub fn new(stream: StreamNumber, addr: SocketAddrNode, last_seen: Time) -> Self {
39        Self {
40            stream,
41            addr,
42            last_seen,
43        }
44    }
45}
46
47#[derive(Debug)]
48pub enum Event {
49    Add(Vec<Entry>),
50    ConnectionSucceeded(SocketAddrNode, UserAgent),
51    ConnectionFailed(SocketAddrNode),
52    Disconnected(SocketAddrNode),
53    Send(SocketAddrNode, bool),
54}
55
56/// A rating of the connectivity of a node.
57#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
58pub struct Rating(i8);
59
60impl Rating {
61    const MAX: i8 = 10;
62    const MIN: i8 = -10;
63
64    /// Constructs a rating from a value.
65    pub fn new(value: i8) -> Self {
66        Self(value)
67    }
68
69    /// Returns the value as `i8`.
70    pub fn as_i8(&self) -> i8 {
71        self.0
72    }
73
74    /// Increments the rating.
75    /// The maximum is `10`.
76    pub fn increment(&mut self) {
77        self.0 = i8::min(self.0 + 1, Self::MAX);
78    }
79
80    /// Decrements the rating.
81    /// The minimum is `-10`.
82    pub fn decrement(&mut self) {
83        self.0 = i8::max(self.0 - 1, -Self::MIN);
84    }
85}
86
87impl fmt::Display for Rating {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        self.0.fmt(f)
90    }
91}
92
93impl From<i8> for Rating {
94    fn from(value: i8) -> Self {
95        Self(value)
96    }
97}
98
99struct Info {
100    #[allow(dead_code)]
101    stream: StreamNumber,
102    last_seen: Time,
103    rating: Rating,
104}
105
106#[derive(sqlx::FromRow)]
107struct Record {
108    stream: i32,
109    address: String,
110    last_seen: i64,
111    rating: i32,
112}
113
114struct Nodes {
115    ctx: Arc<Context>,
116    pool: db::SqlitePool,
117    map: HashMap<SocketAddrNode, Info>,
118    used_addrs: HashSet<SocketAddrNode>,
119}
120
121impl Nodes {
122    async fn new(ctx: Arc<Context>, pool: db::SqlitePool) -> Self {
123        if let Err(err) = sqlx::query(
124            "CREATE TABLE IF NOT EXISTS nodes (
125                stream INTEGER NOT NULL,
126                address TEXT NOT NULL,
127                last_seen INTEGER NOT NULL,
128                rating INTEGER NOT NULL,
129                PRIMARY KEY(stream, address)
130            )",
131        )
132        .execute(pool.write())
133        .await
134        {
135            error!("{}", err);
136        }
137
138        let mut map = HashMap::new();
139        if let Ok(list) = sqlx::query_as::<sqlx::Sqlite, Record>(
140            "SELECT stream, address, last_seen, rating FROM nodes",
141        )
142        .fetch_all(pool.read())
143        .await
144        {
145            for record in list {
146                if record.stream < 0 {
147                    continue;
148                }
149                let stream: StreamNumber = (record.stream as u32).into();
150
151                let addr = record.address.parse::<SocketAddrExt>();
152                if addr.is_err() {
153                    continue;
154                }
155                let addr = addr.unwrap();
156                let addr: SocketAddrNode = addr.into();
157
158                if record.last_seen < 0 {
159                    continue;
160                }
161                let last_seen: Time = (record.last_seen as u64).into();
162                if record.rating < -128 || record.rating > 127 {
163                    continue;
164                }
165                let rating: Rating = (record.rating as i8).into();
166
167                map.insert(
168                    addr,
169                    Info {
170                        stream,
171                        last_seen,
172                        rating,
173                    },
174                );
175            }
176        }
177
178        let used_addrs = HashSet::new();
179        let mut nodes = Self {
180            ctx,
181            pool,
182            map,
183            used_addrs,
184        };
185
186        nodes.retain().await;
187
188        nodes
189    }
190
191    fn len(&self) -> usize {
192        self.map.len()
193    }
194
195    async fn retain(&mut self) -> Option<()> {
196        if self.map.len() > self.ctx.config().max_nodes() {
197            let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
198            let mut list: Vec<SocketAddrNode> =
199                keys.difference(&self.used_addrs).cloned().collect();
200            list.sort_unstable_by(|a, b| {
201                let a_info = &self.map[a];
202                let b_info = &self.map[b];
203                if a_info.rating == b_info.rating {
204                    a_info.last_seen.cmp(&b_info.last_seen)
205                } else {
206                    a_info.rating.cmp(&b_info.rating)
207                }
208            });
209            let mut trunc_amount = self.ctx.config().max_nodes() / 10;
210            if trunc_amount == 0 {
211                trunc_amount = usize::min(1, list.len());
212            }
213            list.truncate(trunc_amount);
214            for addr in &list {
215                self.map.remove(addr);
216                if let SocketAddrNode::AddrExt(addr) = addr {
217                    if let Err(err) = sqlx::query("DELETE FROM nodes WHERE address=?1")
218                        .bind(addr.to_string())
219                        .execute(self.pool.write())
220                        .await
221                    {
222                        error!("{}", err);
223                    }
224                }
225            }
226            return Some(());
227        }
228        None
229    }
230
231    async fn insert(&mut self, entry: Entry, own_node: bool) -> Option<()> {
232        match self.map.get_mut(&entry.addr) {
233            Some(info) => {
234                if entry.last_seen > info.last_seen {
235                    info.last_seen = entry.last_seen;
236                }
237                if entry.stream.as_u32() <= i32::MAX as u32
238                    && entry.last_seen.as_secs() <= i64::MAX as u64
239                {
240                    if let SocketAddrNode::AddrExt(addr) = entry.addr {
241                        if let Err(err) = sqlx::query(
242                            "UPDATE nodes SET last_seen=?1 WHERE stream=?2 and address=?3",
243                        )
244                        .bind(entry.last_seen.as_secs() as i64)
245                        .bind(entry.stream.as_u32() as i32)
246                        .bind(addr.to_string())
247                        .execute(self.pool.write())
248                        .await
249                        {
250                            error!("{}", err);
251                        }
252                    }
253                }
254                None
255            }
256            None => {
257                if self.ctx.config().is_connectable_to(&entry.addr)
258                    && self.ctx.config().stream_numbers().contains(entry.stream)
259                {
260                    debug!("addr: {}", entry.addr);
261                    let rating: Rating = if own_node {
262                        Rating::MAX.into()
263                    } else {
264                        0.into()
265                    };
266                    self.map.insert(
267                        entry.addr.clone(),
268                        Info {
269                            stream: entry.stream,
270                            last_seen: entry.last_seen,
271                            rating: rating.clone(),
272                        },
273                    );
274                    if entry.stream.as_u32() <= i32::MAX as u32
275                        && entry.last_seen.as_secs() <= i64::MAX as u64
276                    {
277                        if let SocketAddrNode::AddrExt(addr) = entry.addr {
278                            if let Err(err) = sqlx::query(
279                                "INSERT INTO nodes (
280                                        stream, address, last_seen, rating
281                                    ) VALUES (?1, ?2, ?3, ?4)",
282                            )
283                            .bind(entry.stream.as_u32() as i32)
284                            .bind(addr.to_string())
285                            .bind(entry.last_seen.as_secs() as i64)
286                            .bind(rating.as_i8() as i32)
287                            .execute(self.pool.write())
288                            .await
289                            {
290                                error!("{}", err);
291                            }
292                        }
293                    }
294                    return Some(());
295                }
296                None
297            }
298        }
299    }
300
301    async fn increment(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
302        let now = Time::now();
303        if let Some(info) = self.map.get_mut(addr) {
304            info.last_seen = now;
305            info.rating.increment();
306
307            if info.stream.as_u32() <= i32::MAX as u32
308                && info.last_seen.as_secs() <= i64::MAX as u64
309            {
310                if let SocketAddrNode::AddrExt(addr) = addr {
311                    if let Err(err) =
312                        sqlx::query("UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3")
313                            .bind(info.rating.as_i8() as i64)
314                            .bind(info.stream.as_u32() as i32)
315                            .bind(addr.to_string())
316                            .execute(self.pool.write())
317                            .await
318                    {
319                        error!("{}", err);
320                    }
321                }
322            }
323            return Some(info.rating.clone());
324        }
325        None
326    }
327
328    async fn decrement(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
329        if let Some(info) = self.map.get_mut(addr) {
330            info.rating.decrement();
331
332            if info.stream.as_u32() <= i32::MAX as u32
333                && info.last_seen.as_secs() <= i64::MAX as u64
334            {
335                if let SocketAddrNode::AddrExt(addr) = addr {
336                    if let Err(err) =
337                        sqlx::query("UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3")
338                            .bind(info.rating.as_i8() as i64)
339                            .bind(info.stream.as_u32() as i32)
340                            .bind(addr.to_string())
341                            .execute(self.pool.write())
342                            .await
343                    {
344                        error!("{}", err);
345                    }
346                }
347            }
348            return Some(info.rating.clone());
349        }
350        None
351    }
352
353    fn reclaim(&mut self, addr: &SocketAddrNode) {
354        self.used_addrs.remove(addr);
355    }
356
357    fn sample_list(&self) -> Vec<NetAddr> {
358        let mut list: Vec<SocketAddrNode> = self.map.keys().cloned().collect();
359        list.sort_unstable_by(|a, b| {
360            let a_info = &self.map[a];
361            let b_info = &self.map[b];
362            if a_info.rating == b_info.rating {
363                b_info.last_seen.cmp(&a_info.last_seen)
364            } else {
365                b_info.rating.cmp(&a_info.rating)
366            }
367        });
368        list.truncate(1000);
369        list.shuffle(&mut rand::thread_rng());
370        let mut addr_list = Vec::with_capacity(list.len());
371        for addr in list {
372            let info = &self.map[&addr];
373            let addr = addr.try_into();
374            if let Err(err) = addr {
375                error!("{}", err);
376                continue;
377            }
378            let addr: SocketAddrExt = addr.unwrap();
379            if let Ok(addr) = addr.try_into() {
380                addr_list.push(NetAddr::new(
381                    info.last_seen,
382                    info.stream,
383                    Services::NETWORK,
384                    addr,
385                ));
386            }
387        }
388        addr_list
389    }
390
391    fn sample(&mut self, own_nodes: &[SocketAddrExt]) -> Option<SocketAddrNode> {
392        let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
393        let list: HashSet<SocketAddrNode> = keys.difference(&self.used_addrs).cloned().collect();
394        let own_nodes: HashSet<SocketAddrNode> =
395            HashSet::from_iter(own_nodes.iter().cloned().map(|a| a.into()));
396        let mut list: Vec<SocketAddrNode> = list.difference(&own_nodes).cloned().collect();
397        list.sort_unstable_by(|a, b| {
398            let a_info = &self.map[a];
399            let b_info = &self.map[b];
400            if a_info.rating == b_info.rating {
401                b_info.last_seen.cmp(&a_info.last_seen)
402            } else {
403                b_info.rating.cmp(&a_info.rating)
404            }
405        });
406        if !list.is_empty() {
407            let bin = Binomial::new(list.len() as u64 * 2 - 1, 0.5).unwrap();
408            let v = bin.sample(&mut rand::thread_rng());
409            let i = if (v as usize) < list.len() {
410                list.len() - 1 - v as usize
411            } else {
412                v as usize - list.len()
413            };
414            let sa = &list[i];
415            self.used_addrs.insert(sa.clone());
416            return Some(sa.clone());
417        }
418        None
419    }
420}
421
422pub async fn manage(ctx: Arc<Context>, mut receiver: Receiver<Event>) {
423    let mut broker_sender = ctx.broker_sender().clone();
424    let mut bm_event_sender = ctx.bm_event_sender().clone();
425
426    let mut nodes = Nodes::new(Arc::clone(&ctx), ctx.pool().clone()).await;
427
428    if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
429        error!("{}", err);
430    }
431
432    let mut interval = interval(StdDuration::from_secs(4));
433
434    loop {
435        if ctx.aborted().load(Ordering::SeqCst) {
436            break;
437        }
438        select! {
439            v = receiver.next().fuse() => match v {
440                Some(event) => match event {
441                    Event::Add(entries) => {
442                        for entry in entries {
443                            if nodes.retain().await.is_some() {
444                                if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
445                                    error!("{}", err);
446                                }
447                            }
448
449                            let own_node = if let SocketAddrNode::AddrExt(addr) = &entry.addr {
450                                ctx.config().own_nodes().contains(addr)
451                            } else {
452                                false
453                            };
454                            if nodes.insert(entry, own_node).await.is_some() {
455                                if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
456                                    error!("{}", err);
457                                }
458                            }
459                        }
460                    }
461                    Event::ConnectionSucceeded(addr, user_agent) => {
462                        if let Some(rating) = nodes.increment(&addr).await {
463                            if let Err(err) = bm_event_sender
464                                .send(BmEvent::Established {
465                                    addr: addr.clone(),
466                                    user_agent,
467                                    rating,
468                                })
469                                .await
470                            {
471                                error!("{}", err)
472                            }
473                        }
474                    }
475                    Event::ConnectionFailed(addr) => {
476                        let own_node = if let SocketAddrNode::AddrExt(addr) = &addr {
477                            ctx.config().own_nodes().contains(addr)
478                        } else {
479                            false
480                        };
481                        if !own_node {
482                            nodes.decrement(&addr).await;
483                        }
484                    }
485                    Event::Disconnected(addr) => {
486                        nodes.reclaim(&addr);
487                    }
488                    Event::Send(addr, close) => {
489                        let addr_list = nodes.sample_list();
490                        if !addr_list.is_empty() {
491                            let message = message::Addr::new(addr_list).unwrap();
492                            let packet = message.pack(ctx.config().core()).unwrap();
493                            if let Err(err) = broker_sender
494                                .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
495                                .await
496                            {
497                                error!("{}", err);
498                            }
499                        }
500                        if close {
501                            let error = message::Error::new(2.into(),
502                                "Server full, please try again later.".as_bytes().to_vec().into());
503                            let packet = error.pack(ctx.config().core()).unwrap();
504                            if let Err(err) = broker_sender
505                                .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
506                                .await
507                            {
508                                error!("{}", err);
509                            }
510                            if let Err(err) = broker_sender
511                                .send(BrokerEvent::Close { addr })
512                                .await
513                            {
514                                error!("{}", err);
515                            }
516                        }
517                    }
518                },
519                None => break,
520            },
521            v = interval.next().fuse() => match v {
522                Some(_) => {
523                    let initiated = ctx.initiated(Direction::Outgoing).load(Ordering::SeqCst);
524                    let connected = ctx.connected(Direction::Outgoing).load(Ordering::SeqCst);
525                    let established = ctx.established(Direction::Outgoing).load(Ordering::SeqCst);
526                    if established >= ctx.config().max_outgoing_established() {
527                        if initiated > connected {
528                            if let Err(err) = broker_sender
529                                .send(BrokerEvent::AbortPendings)
530                                .await
531                            {
532                                error!("{}", err);
533                            }
534                        }
535                    } else if initiated < ctx.config().max_outgoing_initiated()
536                            && initiated + ctx.config().own_nodes().len() < nodes.len() {
537                        if let Some(addr) = nodes.sample(ctx.config().own_nodes()) {
538                            if let Err(err) = broker_sender
539                                .send(BrokerEvent::Outgoing { addr })
540                                .await
541                            {
542                                error!("{}", err);
543                            }
544                        }
545                    }
546                },
547                None => break,
548            },
549        };
550    }
551    if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(0)).await {
552        error!("{}", err);
553    }
554}