1use std::{
2 collections::{HashMap, HashSet},
3 convert::TryInto,
4 fmt,
5 iter::FromIterator,
6 sync::atomic::Ordering,
7 time::Duration as StdDuration,
8};
9
10use async_std::{stream::interval, sync::Arc};
11use futures::{channel::mpsc::Receiver, select, sink::SinkExt, stream::StreamExt, FutureExt};
12use log::{debug, error};
13use rand::seq::SliceRandom;
14use rand_distr::{Binomial, Distribution};
15
16use koibumi_core::{
17 message::{self, NetAddr, Pack, Services, StreamNumber, UserAgent},
18 net::SocketAddrExt,
19 time::Time,
20};
21
22use crate::{
23 connection::Direction,
24 connection_loop::{Context, Event as BrokerEvent},
25 db,
26 manager::Event as BmEvent,
27 net::SocketAddrNode,
28};
29
30#[derive(Clone, PartialEq, Eq, Hash, Debug)]
31pub struct Entry {
32 stream: StreamNumber,
33 addr: SocketAddrNode,
34 last_seen: Time,
35}
36
37impl Entry {
38 pub fn new(stream: StreamNumber, addr: SocketAddrNode, last_seen: Time) -> Self {
39 Self {
40 stream,
41 addr,
42 last_seen,
43 }
44 }
45}
46
47#[derive(Debug)]
48pub enum Event {
49 Add(Vec<Entry>),
50 ConnectionSucceeded(SocketAddrNode, UserAgent),
51 ConnectionFailed(SocketAddrNode),
52 Disconnected(SocketAddrNode),
53 Send(SocketAddrNode, bool),
54}
55
56#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
58pub struct Rating(i8);
59
60impl Rating {
61 const MAX: i8 = 10;
62 const MIN: i8 = -10;
63
64 pub fn new(value: i8) -> Self {
66 Self(value)
67 }
68
69 pub fn as_i8(&self) -> i8 {
71 self.0
72 }
73
74 pub fn increment(&mut self) {
77 self.0 = i8::min(self.0 + 1, Self::MAX);
78 }
79
80 pub fn decrement(&mut self) {
83 self.0 = i8::max(self.0 - 1, -Self::MIN);
84 }
85}
86
87impl fmt::Display for Rating {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 self.0.fmt(f)
90 }
91}
92
93impl From<i8> for Rating {
94 fn from(value: i8) -> Self {
95 Self(value)
96 }
97}
98
99struct Info {
100 #[allow(dead_code)]
101 stream: StreamNumber,
102 last_seen: Time,
103 rating: Rating,
104}
105
106#[derive(sqlx::FromRow)]
107struct Record {
108 stream: i32,
109 address: String,
110 last_seen: i64,
111 rating: i32,
112}
113
114struct Nodes {
115 ctx: Arc<Context>,
116 pool: db::SqlitePool,
117 map: HashMap<SocketAddrNode, Info>,
118 used_addrs: HashSet<SocketAddrNode>,
119}
120
121impl Nodes {
122 async fn new(ctx: Arc<Context>, pool: db::SqlitePool) -> Self {
123 if let Err(err) = sqlx::query(
124 "CREATE TABLE IF NOT EXISTS nodes (
125 stream INTEGER NOT NULL,
126 address TEXT NOT NULL,
127 last_seen INTEGER NOT NULL,
128 rating INTEGER NOT NULL,
129 PRIMARY KEY(stream, address)
130 )",
131 )
132 .execute(pool.write())
133 .await
134 {
135 error!("{}", err);
136 }
137
138 let mut map = HashMap::new();
139 if let Ok(list) = sqlx::query_as::<sqlx::Sqlite, Record>(
140 "SELECT stream, address, last_seen, rating FROM nodes",
141 )
142 .fetch_all(pool.read())
143 .await
144 {
145 for record in list {
146 if record.stream < 0 {
147 continue;
148 }
149 let stream: StreamNumber = (record.stream as u32).into();
150
151 let addr = record.address.parse::<SocketAddrExt>();
152 if addr.is_err() {
153 continue;
154 }
155 let addr = addr.unwrap();
156 let addr: SocketAddrNode = addr.into();
157
158 if record.last_seen < 0 {
159 continue;
160 }
161 let last_seen: Time = (record.last_seen as u64).into();
162 if record.rating < -128 || record.rating > 127 {
163 continue;
164 }
165 let rating: Rating = (record.rating as i8).into();
166
167 map.insert(
168 addr,
169 Info {
170 stream,
171 last_seen,
172 rating,
173 },
174 );
175 }
176 }
177
178 let used_addrs = HashSet::new();
179 let mut nodes = Self {
180 ctx,
181 pool,
182 map,
183 used_addrs,
184 };
185
186 nodes.retain().await;
187
188 nodes
189 }
190
191 fn len(&self) -> usize {
192 self.map.len()
193 }
194
195 async fn retain(&mut self) -> Option<()> {
196 if self.map.len() > self.ctx.config().max_nodes() {
197 let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
198 let mut list: Vec<SocketAddrNode> =
199 keys.difference(&self.used_addrs).cloned().collect();
200 list.sort_unstable_by(|a, b| {
201 let a_info = &self.map[a];
202 let b_info = &self.map[b];
203 if a_info.rating == b_info.rating {
204 a_info.last_seen.cmp(&b_info.last_seen)
205 } else {
206 a_info.rating.cmp(&b_info.rating)
207 }
208 });
209 let mut trunc_amount = self.ctx.config().max_nodes() / 10;
210 if trunc_amount == 0 {
211 trunc_amount = usize::min(1, list.len());
212 }
213 list.truncate(trunc_amount);
214 for addr in &list {
215 self.map.remove(addr);
216 if let SocketAddrNode::AddrExt(addr) = addr {
217 if let Err(err) = sqlx::query("DELETE FROM nodes WHERE address=?1")
218 .bind(addr.to_string())
219 .execute(self.pool.write())
220 .await
221 {
222 error!("{}", err);
223 }
224 }
225 }
226 return Some(());
227 }
228 None
229 }
230
231 async fn insert(&mut self, entry: Entry, own_node: bool) -> Option<()> {
232 match self.map.get_mut(&entry.addr) {
233 Some(info) => {
234 if entry.last_seen > info.last_seen {
235 info.last_seen = entry.last_seen;
236 }
237 if entry.stream.as_u32() <= i32::MAX as u32
238 && entry.last_seen.as_secs() <= i64::MAX as u64
239 {
240 if let SocketAddrNode::AddrExt(addr) = entry.addr {
241 if let Err(err) = sqlx::query(
242 "UPDATE nodes SET last_seen=?1 WHERE stream=?2 and address=?3",
243 )
244 .bind(entry.last_seen.as_secs() as i64)
245 .bind(entry.stream.as_u32() as i32)
246 .bind(addr.to_string())
247 .execute(self.pool.write())
248 .await
249 {
250 error!("{}", err);
251 }
252 }
253 }
254 None
255 }
256 None => {
257 if self.ctx.config().is_connectable_to(&entry.addr)
258 && self.ctx.config().stream_numbers().contains(entry.stream)
259 {
260 debug!("addr: {}", entry.addr);
261 let rating: Rating = if own_node {
262 Rating::MAX.into()
263 } else {
264 0.into()
265 };
266 self.map.insert(
267 entry.addr.clone(),
268 Info {
269 stream: entry.stream,
270 last_seen: entry.last_seen,
271 rating: rating.clone(),
272 },
273 );
274 if entry.stream.as_u32() <= i32::MAX as u32
275 && entry.last_seen.as_secs() <= i64::MAX as u64
276 {
277 if let SocketAddrNode::AddrExt(addr) = entry.addr {
278 if let Err(err) = sqlx::query(
279 "INSERT INTO nodes (
280 stream, address, last_seen, rating
281 ) VALUES (?1, ?2, ?3, ?4)",
282 )
283 .bind(entry.stream.as_u32() as i32)
284 .bind(addr.to_string())
285 .bind(entry.last_seen.as_secs() as i64)
286 .bind(rating.as_i8() as i32)
287 .execute(self.pool.write())
288 .await
289 {
290 error!("{}", err);
291 }
292 }
293 }
294 return Some(());
295 }
296 None
297 }
298 }
299 }
300
301 async fn increment(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
302 let now = Time::now();
303 if let Some(info) = self.map.get_mut(addr) {
304 info.last_seen = now;
305 info.rating.increment();
306
307 if info.stream.as_u32() <= i32::MAX as u32
308 && info.last_seen.as_secs() <= i64::MAX as u64
309 {
310 if let SocketAddrNode::AddrExt(addr) = addr {
311 if let Err(err) =
312 sqlx::query("UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3")
313 .bind(info.rating.as_i8() as i64)
314 .bind(info.stream.as_u32() as i32)
315 .bind(addr.to_string())
316 .execute(self.pool.write())
317 .await
318 {
319 error!("{}", err);
320 }
321 }
322 }
323 return Some(info.rating.clone());
324 }
325 None
326 }
327
328 async fn decrement(&mut self, addr: &SocketAddrNode) -> Option<Rating> {
329 if let Some(info) = self.map.get_mut(addr) {
330 info.rating.decrement();
331
332 if info.stream.as_u32() <= i32::MAX as u32
333 && info.last_seen.as_secs() <= i64::MAX as u64
334 {
335 if let SocketAddrNode::AddrExt(addr) = addr {
336 if let Err(err) =
337 sqlx::query("UPDATE nodes SET rating=?1 WHERE stream=?2 and address=?3")
338 .bind(info.rating.as_i8() as i64)
339 .bind(info.stream.as_u32() as i32)
340 .bind(addr.to_string())
341 .execute(self.pool.write())
342 .await
343 {
344 error!("{}", err);
345 }
346 }
347 }
348 return Some(info.rating.clone());
349 }
350 None
351 }
352
353 fn reclaim(&mut self, addr: &SocketAddrNode) {
354 self.used_addrs.remove(addr);
355 }
356
357 fn sample_list(&self) -> Vec<NetAddr> {
358 let mut list: Vec<SocketAddrNode> = self.map.keys().cloned().collect();
359 list.sort_unstable_by(|a, b| {
360 let a_info = &self.map[a];
361 let b_info = &self.map[b];
362 if a_info.rating == b_info.rating {
363 b_info.last_seen.cmp(&a_info.last_seen)
364 } else {
365 b_info.rating.cmp(&a_info.rating)
366 }
367 });
368 list.truncate(1000);
369 list.shuffle(&mut rand::thread_rng());
370 let mut addr_list = Vec::with_capacity(list.len());
371 for addr in list {
372 let info = &self.map[&addr];
373 let addr = addr.try_into();
374 if let Err(err) = addr {
375 error!("{}", err);
376 continue;
377 }
378 let addr: SocketAddrExt = addr.unwrap();
379 if let Ok(addr) = addr.try_into() {
380 addr_list.push(NetAddr::new(
381 info.last_seen,
382 info.stream,
383 Services::NETWORK,
384 addr,
385 ));
386 }
387 }
388 addr_list
389 }
390
391 fn sample(&mut self, own_nodes: &[SocketAddrExt]) -> Option<SocketAddrNode> {
392 let keys: HashSet<SocketAddrNode> = self.map.keys().cloned().collect();
393 let list: HashSet<SocketAddrNode> = keys.difference(&self.used_addrs).cloned().collect();
394 let own_nodes: HashSet<SocketAddrNode> =
395 HashSet::from_iter(own_nodes.iter().cloned().map(|a| a.into()));
396 let mut list: Vec<SocketAddrNode> = list.difference(&own_nodes).cloned().collect();
397 list.sort_unstable_by(|a, b| {
398 let a_info = &self.map[a];
399 let b_info = &self.map[b];
400 if a_info.rating == b_info.rating {
401 b_info.last_seen.cmp(&a_info.last_seen)
402 } else {
403 b_info.rating.cmp(&a_info.rating)
404 }
405 });
406 if !list.is_empty() {
407 let bin = Binomial::new(list.len() as u64 * 2 - 1, 0.5).unwrap();
408 let v = bin.sample(&mut rand::thread_rng());
409 let i = if (v as usize) < list.len() {
410 list.len() - 1 - v as usize
411 } else {
412 v as usize - list.len()
413 };
414 let sa = &list[i];
415 self.used_addrs.insert(sa.clone());
416 return Some(sa.clone());
417 }
418 None
419 }
420}
421
422pub async fn manage(ctx: Arc<Context>, mut receiver: Receiver<Event>) {
423 let mut broker_sender = ctx.broker_sender().clone();
424 let mut bm_event_sender = ctx.bm_event_sender().clone();
425
426 let mut nodes = Nodes::new(Arc::clone(&ctx), ctx.pool().clone()).await;
427
428 if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
429 error!("{}", err);
430 }
431
432 let mut interval = interval(StdDuration::from_secs(4));
433
434 loop {
435 if ctx.aborted().load(Ordering::SeqCst) {
436 break;
437 }
438 select! {
439 v = receiver.next().fuse() => match v {
440 Some(event) => match event {
441 Event::Add(entries) => {
442 for entry in entries {
443 if nodes.retain().await.is_some() {
444 if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
445 error!("{}", err);
446 }
447 }
448
449 let own_node = if let SocketAddrNode::AddrExt(addr) = &entry.addr {
450 ctx.config().own_nodes().contains(addr)
451 } else {
452 false
453 };
454 if nodes.insert(entry, own_node).await.is_some() {
455 if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(nodes.len())).await {
456 error!("{}", err);
457 }
458 }
459 }
460 }
461 Event::ConnectionSucceeded(addr, user_agent) => {
462 if let Some(rating) = nodes.increment(&addr).await {
463 if let Err(err) = bm_event_sender
464 .send(BmEvent::Established {
465 addr: addr.clone(),
466 user_agent,
467 rating,
468 })
469 .await
470 {
471 error!("{}", err)
472 }
473 }
474 }
475 Event::ConnectionFailed(addr) => {
476 let own_node = if let SocketAddrNode::AddrExt(addr) = &addr {
477 ctx.config().own_nodes().contains(addr)
478 } else {
479 false
480 };
481 if !own_node {
482 nodes.decrement(&addr).await;
483 }
484 }
485 Event::Disconnected(addr) => {
486 nodes.reclaim(&addr);
487 }
488 Event::Send(addr, close) => {
489 let addr_list = nodes.sample_list();
490 if !addr_list.is_empty() {
491 let message = message::Addr::new(addr_list).unwrap();
492 let packet = message.pack(ctx.config().core()).unwrap();
493 if let Err(err) = broker_sender
494 .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
495 .await
496 {
497 error!("{}", err);
498 }
499 }
500 if close {
501 let error = message::Error::new(2.into(),
502 "Server full, please try again later.".as_bytes().to_vec().into());
503 let packet = error.pack(ctx.config().core()).unwrap();
504 if let Err(err) = broker_sender
505 .send(BrokerEvent::Write { addr: addr.clone(), list: vec![packet] })
506 .await
507 {
508 error!("{}", err);
509 }
510 if let Err(err) = broker_sender
511 .send(BrokerEvent::Close { addr })
512 .await
513 {
514 error!("{}", err);
515 }
516 }
517 }
518 },
519 None => break,
520 },
521 v = interval.next().fuse() => match v {
522 Some(_) => {
523 let initiated = ctx.initiated(Direction::Outgoing).load(Ordering::SeqCst);
524 let connected = ctx.connected(Direction::Outgoing).load(Ordering::SeqCst);
525 let established = ctx.established(Direction::Outgoing).load(Ordering::SeqCst);
526 if established >= ctx.config().max_outgoing_established() {
527 if initiated > connected {
528 if let Err(err) = broker_sender
529 .send(BrokerEvent::AbortPendings)
530 .await
531 {
532 error!("{}", err);
533 }
534 }
535 } else if initiated < ctx.config().max_outgoing_initiated()
536 && initiated + ctx.config().own_nodes().len() < nodes.len() {
537 if let Some(addr) = nodes.sample(ctx.config().own_nodes()) {
538 if let Err(err) = broker_sender
539 .send(BrokerEvent::Outgoing { addr })
540 .await
541 {
542 error!("{}", err);
543 }
544 }
545 }
546 },
547 None => break,
548 },
549 };
550 }
551 if let Err(err) = bm_event_sender.send(BmEvent::AddrCount(0)).await {
552 error!("{}", err);
553 }
554}