atm0s_small_p2p/service/
alias_service.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    time::Duration,
4};
5
6use anyhow::anyhow;
7use derive_more::derive::{Display, From};
8use lru::LruCache;
9use serde::{Deserialize, Serialize};
10use tokio::{
11    select,
12    sync::{
13        mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
14        oneshot,
15    },
16    time::Interval,
17};
18
19use crate::{
20    stream::P2pQuicStream,
21    utils::{now_ms, ErrorExt, ErrorExt2},
22    PeerId,
23};
24
25use super::{P2pService, P2pServiceEvent, P2pServiceRequester};
26
27const LRU_CACHE_SIZE: usize = 1_000_000;
28const HINT_TIMEOUT_MS: u64 = 500;
29const SCAN_TIMEOUT_MS: u64 = 1000;
30
31#[derive(Debug, From, Display, Serialize, Deserialize, Hash, PartialEq, Eq, Clone, Copy)]
32pub struct AliasId(u64);
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum AliasFoundLocation {
36    Local,
37    Hint(PeerId),
38    Scan(PeerId),
39}
40
41pub enum AliasStreamLocation {
42    Local,
43    Hint(P2pQuicStream),
44    Scan(P2pQuicStream),
45}
46
47#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
48enum AliasMessage {
49    NotifySet(AliasId),
50    NotifyDel(AliasId),
51    Check(AliasId),
52    Scan(AliasId),
53    Found(AliasId),
54    NotFound(AliasId),
55    // when a node
56    Shutdown,
57}
58
59enum AliasControl {
60    Register(AliasId),
61    Unregister(AliasId),
62    Find(AliasId, oneshot::Sender<Option<AliasFoundLocation>>),
63    Shutdown,
64}
65
66#[derive(Debug)]
67pub struct AliasGuard {
68    alias: AliasId,
69    tx: UnboundedSender<AliasControl>,
70}
71
72impl Drop for AliasGuard {
73    fn drop(&mut self) {
74        log::info!("[AliasGuard] drop {} => auto unregister", self.alias);
75        self.tx.send(AliasControl::Unregister(self.alias)).expect("alias service main channal should work");
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct AliasServiceRequester {
81    tx: UnboundedSender<AliasControl>,
82}
83
84impl AliasServiceRequester {
85    pub fn register<A: Into<AliasId>>(&self, alias: A) -> AliasGuard {
86        let alias: AliasId = alias.into();
87        log::info!("[AliasServiceRequester] register alias {alias}");
88        self.tx.send(AliasControl::Register(alias)).expect("alias service main channal should work");
89
90        AliasGuard { alias, tx: self.tx.clone() }
91    }
92
93    pub async fn find<A: Into<AliasId>>(&self, alias: A) -> Option<AliasFoundLocation> {
94        let alias: AliasId = alias.into();
95        log::info!("[AliasServiceRequester] find alias {alias}");
96        let (tx, rx) = oneshot::channel();
97        self.tx.send(AliasControl::Find(alias, tx)).expect("alias service main channal should work");
98        let res = rx.await.ok()?;
99        log::info!("[AliasServiceRequester] find alias {alias} => result {res:?}");
100        res
101    }
102
103    pub async fn open_stream<A: Into<AliasId>>(&self, alias: A, over_service: P2pServiceRequester, meta: Vec<u8>) -> anyhow::Result<AliasStreamLocation> {
104        match self.find(alias).await {
105            Some(AliasFoundLocation::Local) => Ok(AliasStreamLocation::Local),
106            Some(AliasFoundLocation::Hint(dest)) => over_service.open_stream(dest, meta).await.map(AliasStreamLocation::Hint),
107            Some(AliasFoundLocation::Scan(dest)) => over_service.open_stream(dest, meta).await.map(AliasStreamLocation::Scan),
108            None => Err(anyhow!("alias not found")),
109        }
110    }
111
112    pub fn shutdown(&self) {
113        log::info!("[AliasServiceRequester] shutdown");
114        self.tx.send(AliasControl::Shutdown).expect("alias service main channal should work");
115    }
116}
117
118enum FindRequestState {
119    CheckHint(u64, HashSet<PeerId>),
120    Scan(u64),
121}
122
123struct FindRequest {
124    state: FindRequestState,
125    waits: Vec<oneshot::Sender<Option<AliasFoundLocation>>>,
126}
127
128#[derive(Debug, PartialEq, Eq)]
129enum InternalOutput {
130    Broadcast(AliasMessage),
131    Unicast(PeerId, AliasMessage),
132}
133
134struct AliasServiceInternal {
135    local: HashMap<AliasId, u8>,
136    cache: LruCache<AliasId, HashSet<PeerId>>,
137    find_reqs: HashMap<AliasId, FindRequest>,
138    outs: VecDeque<InternalOutput>,
139}
140
141pub struct AliasService {
142    service: P2pService,
143    tx: UnboundedSender<AliasControl>,
144    rx: UnboundedReceiver<AliasControl>,
145    internal: AliasServiceInternal,
146    interval: Interval,
147}
148
149impl AliasService {
150    pub fn new(service: P2pService) -> Self {
151        let (tx, rx) = unbounded_channel();
152        Self {
153            service,
154            tx,
155            rx,
156            internal: AliasServiceInternal {
157                cache: LruCache::new(LRU_CACHE_SIZE.try_into().expect("")),
158                find_reqs: HashMap::new(),
159                outs: VecDeque::new(),
160                local: HashMap::new(),
161            },
162            interval: tokio::time::interval(Duration::from_secs(1)),
163        }
164    }
165
166    pub fn requester(&self) -> AliasServiceRequester {
167        AliasServiceRequester { tx: self.tx.clone() }
168    }
169
170    pub async fn run_loop(&mut self) -> anyhow::Result<()> {
171        loop {
172            select! {
173                _ = self.interval.tick() => {
174                    self.on_tick().await;
175                },
176                event = self.service.recv() => match event.expect("service channel should work") {
177                    P2pServiceEvent::Unicast(from, data) => {
178                        if let Ok(msg) = bincode::deserialize::<AliasMessage>(&data) {
179                            self.on_msg(from, msg).await;
180                        }
181                    }
182                    P2pServiceEvent::Broadcast(from, data) => {
183                        if let Ok(msg) = bincode::deserialize::<AliasMessage>(&data) {
184                            self.on_msg(from, msg).await;
185                        }
186                    }
187                    P2pServiceEvent::Stream(..) => {},
188                },
189                control = self.rx.recv() => {
190                    self.on_control(control.expect("service channel should work")).await;
191                }
192            }
193        }
194    }
195
196    async fn on_tick(&mut self) {
197        self.internal.on_tick(now_ms());
198        self.pop_internal().await;
199    }
200
201    async fn on_msg(&mut self, from: PeerId, msg: AliasMessage) {
202        log::debug!("[AliasService] on msg from {from}, {msg:?}");
203        self.internal.on_msg(now_ms(), from, msg);
204        self.pop_internal().await;
205    }
206
207    async fn on_control(&mut self, control: AliasControl) {
208        self.internal.on_control(now_ms(), control);
209        self.pop_internal().await;
210    }
211
212    async fn pop_internal(&mut self) {
213        while let Some(out) = self.internal.pop_output() {
214            match out {
215                InternalOutput::Broadcast(msg) => {
216                    self.service.send_broadcast(bincode::serialize(&msg).expect("should serialie")).await;
217                }
218                InternalOutput::Unicast(dest, msg) => {
219                    self.service
220                        .send_unicast(dest, bincode::serialize(&msg).expect("should serialie"))
221                        .await
222                        .print_on_err("[AliasService] send unicast");
223                }
224            }
225        }
226    }
227}
228
229impl AliasServiceInternal {
230    fn on_tick(&mut self, now: u64) {
231        let mut timeout_reqs = vec![];
232        for (alias_id, req) in self.find_reqs.iter_mut() {
233            match req.state {
234                FindRequestState::CheckHint(requested_at, ref mut _hash_set) => {
235                    if requested_at + HINT_TIMEOUT_MS <= now {
236                        log::info!("[AliasServiceInternal] check hint timeout {alias_id} => switch to scan");
237                        self.outs.push_back(InternalOutput::Broadcast(AliasMessage::Scan(*alias_id)));
238                        req.state = FindRequestState::Scan(now);
239                    }
240                }
241                FindRequestState::Scan(requested_at) => {
242                    if requested_at + SCAN_TIMEOUT_MS <= now {
243                        log::info!("[AliasServiceInternal] find scan timeout {alias_id}");
244                        timeout_reqs.push(*alias_id);
245                        while let Some(tx) = req.waits.pop() {
246                            tx.send(None).print_on_err2("");
247                        }
248                    }
249                }
250            }
251        }
252
253        for alias_id in timeout_reqs {
254            self.find_reqs.remove(&alias_id);
255        }
256    }
257
258    fn on_msg(&mut self, now: u64, from: PeerId, msg: AliasMessage) {
259        log::info!("[AliasServiceInternal] on msg from {from}, {msg:?}");
260        match msg {
261            AliasMessage::NotifySet(alias_id) => {
262                let slot = self.cache.get_or_insert_mut(alias_id, HashSet::new);
263                slot.insert(from);
264            }
265            AliasMessage::NotifyDel(alias_id) => {
266                if let Some(slot) = self.cache.get_mut(&alias_id) {
267                    slot.remove(&from);
268                    if slot.is_empty() {
269                        self.cache.pop(&alias_id);
270                    }
271                }
272            }
273            AliasMessage::Check(alias_id) => {
274                if self.local.contains_key(&alias_id) {
275                    self.outs.push_back(InternalOutput::Unicast(from, AliasMessage::Found(alias_id)));
276                } else {
277                    self.outs.push_back(InternalOutput::Unicast(from, AliasMessage::NotFound(alias_id)));
278                }
279            }
280            AliasMessage::Scan(alias_id) => {
281                if self.local.contains_key(&alias_id) {
282                    self.outs.push_back(InternalOutput::Unicast(from, AliasMessage::Found(alias_id)));
283                }
284            }
285            AliasMessage::Found(alias_id) => {
286                let slot = self.cache.get_or_insert_mut(alias_id, HashSet::new);
287                slot.insert(from);
288
289                if let Some(req) = self.find_reqs.remove(&alias_id) {
290                    let found = if matches!(req.state, FindRequestState::Scan(_)) {
291                        AliasFoundLocation::Scan(from)
292                    } else {
293                        AliasFoundLocation::Hint(from)
294                    };
295                    for tx in req.waits {
296                        tx.send(Some(found)).print_on_err2("[AliasServiceInternal] send query response");
297                    }
298                }
299            }
300            AliasMessage::NotFound(alias_id) => {
301                if let Some(slot) = self.cache.get_mut(&alias_id) {
302                    slot.remove(&from);
303                    if slot.is_empty() {
304                        self.cache.pop(&alias_id);
305                    }
306                }
307
308                if let Some(req) = self.find_reqs.get_mut(&alias_id) {
309                    match req.state {
310                        FindRequestState::CheckHint(_, ref mut hint_peers) => {
311                            hint_peers.remove(&from);
312                            if hint_peers.is_empty() {
313                                //not found => should switch to scan
314                                req.state = FindRequestState::Scan(now);
315                                self.outs.push_back(InternalOutput::Broadcast(AliasMessage::Scan(alias_id)));
316                            }
317                        }
318                        FindRequestState::Scan(_) => {}
319                    }
320                }
321            }
322            AliasMessage::Shutdown => {
323                let mut removed_alias_ids = vec![];
324                for (k, _v) in &mut self.cache {
325                    removed_alias_ids.push(*k);
326                }
327                for alias_id in removed_alias_ids {
328                    self.cache.pop(&alias_id);
329                }
330            }
331        }
332    }
333
334    fn on_control(&mut self, now: u64, control: AliasControl) {
335        match control {
336            AliasControl::Register(alias_id) => {
337                let ref_count = self.local.entry(alias_id).or_default();
338                *ref_count += 1;
339                self.outs.push_back(InternalOutput::Broadcast(AliasMessage::NotifySet(alias_id)));
340            }
341            AliasControl::Unregister(alias_id) => {
342                if let Some(ref_count) = self.local.get_mut(&alias_id) {
343                    *ref_count -= 1;
344                    if *ref_count == 0 {
345                        self.local.remove(&alias_id);
346                        self.outs.push_back(InternalOutput::Broadcast(AliasMessage::NotifyDel(alias_id)));
347                    }
348                }
349            }
350            AliasControl::Find(alias_id, sender) => {
351                if let Some(req) = self.find_reqs.get_mut(&alias_id) {
352                    req.waits.push(sender);
353                    return;
354                }
355
356                if self.local.contains_key(&alias_id) {
357                    sender.send(Some(AliasFoundLocation::Local)).print_on_err2("[AliasServiceInternal] send query response");
358                } else if let Some(slot) = self.cache.get(&alias_id) {
359                    for peer in slot {
360                        self.outs.push_back(InternalOutput::Unicast(*peer, AliasMessage::Check(alias_id)));
361                    }
362                    self.find_reqs.insert(
363                        alias_id,
364                        FindRequest {
365                            state: FindRequestState::CheckHint(now, slot.clone()),
366                            waits: vec![sender],
367                        },
368                    );
369                } else {
370                    self.outs.push_back(InternalOutput::Broadcast(AliasMessage::Scan(alias_id)));
371                    self.find_reqs.insert(
372                        alias_id,
373                        FindRequest {
374                            state: FindRequestState::Scan(now),
375                            waits: vec![sender],
376                        },
377                    );
378                }
379            }
380            AliasControl::Shutdown => {
381                self.outs.push_back(InternalOutput::Broadcast(AliasMessage::Shutdown));
382            }
383        }
384    }
385
386    fn pop_output(&mut self) -> Option<InternalOutput> {
387        self.outs.pop_front()
388    }
389}
390
391#[cfg(test)]
392mod test {
393    use super::*;
394
395    struct TestContext {
396        internal: AliasServiceInternal,
397        now: u64,
398    }
399
400    impl TestContext {
401        fn new() -> Self {
402            Self {
403                internal: AliasServiceInternal {
404                    local: HashMap::new(),
405                    cache: LruCache::new(LRU_CACHE_SIZE.try_into().expect("should create NoneZeroUsize")),
406                    find_reqs: HashMap::new(),
407                    outs: VecDeque::new(),
408                },
409                now: 1000,
410            }
411        }
412
413        fn advance_time(&mut self, ms: u64) {
414            self.now += ms;
415        }
416
417        fn collect_outputs(&mut self) -> Vec<InternalOutput> {
418            let mut outputs = Vec::new();
419            while let Some(output) = self.internal.pop_output() {
420                outputs.push(output);
421            }
422            outputs
423        }
424    }
425
426    #[test]
427    fn test_register_alias() {
428        let mut ctx = TestContext::new();
429        let alias_id = AliasId(1);
430
431        // Test registering an alias
432        ctx.internal.on_control(ctx.now, AliasControl::Register(alias_id));
433
434        // Verify local set contains the alias
435        assert!(ctx.internal.local.contains_key(&alias_id));
436
437        // Verify broadcast message
438        let outputs = ctx.collect_outputs();
439        assert_eq!(outputs.len(), 1);
440        match &outputs[0] {
441            InternalOutput::Broadcast(AliasMessage::NotifySet(id)) => assert_eq!(*id, alias_id),
442            _ => panic!("Expected broadcast NotifySet message"),
443        }
444    }
445
446    #[test]
447    fn test_unregister_alias() {
448        let mut ctx = TestContext::new();
449        let alias_id = AliasId(1);
450
451        // Register first
452        ctx.internal.on_control(ctx.now, AliasControl::Register(alias_id));
453        ctx.collect_outputs(); // Clear outputs
454
455        // Test unregistering
456        ctx.internal.on_control(ctx.now, AliasControl::Unregister(alias_id));
457
458        // Verify local set doesn't contain the alias
459        assert!(!ctx.internal.local.contains_key(&alias_id));
460
461        // Verify broadcast message
462        let outputs = ctx.collect_outputs();
463        assert_eq!(outputs.len(), 1);
464        match &outputs[0] {
465            InternalOutput::Broadcast(AliasMessage::NotifyDel(id)) => assert_eq!(*id, alias_id),
466            _ => panic!("Expected broadcast NotifyDel message"),
467        }
468    }
469
470    #[test]
471    fn test_find_local_alias() {
472        let mut ctx = TestContext::new();
473        let alias_id = AliasId(1);
474
475        // Register alias locally
476        ctx.internal.on_control(ctx.now, AliasControl::Register(alias_id));
477        ctx.collect_outputs(); // Clear outputs
478
479        // Create a oneshot channel for the find response
480        let (tx, mut rx) = oneshot::channel();
481
482        // Test finding the local alias
483        ctx.internal.on_control(ctx.now, AliasControl::Find(alias_id, tx));
484
485        // Verify response
486        let response = rx.try_recv().expect("Should have a response");
487        assert_eq!(response, Some(AliasFoundLocation::Local));
488
489        // Verify no outputs (shouldn't need to broadcast for local find)
490        let outputs = ctx.collect_outputs();
491        assert!(outputs.is_empty());
492    }
493
494    #[test]
495    fn test_find_cached_alias_found() {
496        let mut ctx = TestContext::new();
497        let alias_id = AliasId(1);
498        let peer_addr = PeerId(1);
499
500        // Add alias to cache
501        ctx.internal.on_msg(ctx.now, peer_addr, AliasMessage::NotifySet(alias_id));
502
503        // Create a oneshot channel for the find response
504        let (tx, mut rx) = oneshot::channel();
505
506        // Test finding the cached alias
507        ctx.internal.on_control(ctx.now, AliasControl::Find(alias_id, tx));
508
509        // Verify unicast message to check with cached peer
510        let outputs = ctx.collect_outputs();
511        assert_eq!(outputs, vec![InternalOutput::Unicast(peer_addr, AliasMessage::Check(alias_id))]);
512
513        // Simulate peer response
514        ctx.internal.on_msg(ctx.now, peer_addr, AliasMessage::Found(alias_id));
515
516        // Verify find response
517        let response = rx.try_recv().expect("Should have a response");
518        assert_eq!(response, Some(AliasFoundLocation::Hint(peer_addr)));
519    }
520
521    #[test]
522    fn test_find_cached_alias_not_found() {
523        let mut ctx = TestContext::new();
524        let alias_id = AliasId(1);
525        let peer_addr = PeerId(1);
526
527        // Add alias to cache
528        ctx.internal.on_msg(ctx.now, peer_addr, AliasMessage::NotifySet(alias_id));
529
530        // Create a oneshot channel for the find response
531        let (tx, _rx) = oneshot::channel();
532
533        // Test finding the cached alias
534        ctx.internal.on_control(ctx.now, AliasControl::Find(alias_id, tx));
535
536        // Verify unicast message to check with cached peer
537        let outputs = ctx.collect_outputs();
538        assert_eq!(outputs, vec![InternalOutput::Unicast(peer_addr, AliasMessage::Check(alias_id))]);
539
540        // Simulate peer response
541        ctx.internal.on_msg(ctx.now, peer_addr, AliasMessage::NotFound(alias_id));
542
543        // Verify broadcast scan message
544        let outputs = ctx.collect_outputs();
545        assert_eq!(outputs, vec![InternalOutput::Broadcast(AliasMessage::Scan(alias_id))]);
546    }
547
548    #[test]
549    fn test_find_cached_alias_timeout_switch_to_scan() {
550        let mut ctx = TestContext::new();
551        let alias_id = AliasId(1);
552        let peer_addr = PeerId(1);
553
554        // Add alias to cache
555        ctx.internal.on_msg(ctx.now, peer_addr, AliasMessage::NotifySet(alias_id));
556
557        // Create a oneshot channel for the find response
558        let (tx, _rx) = oneshot::channel();
559
560        // Test finding the cached alias
561        ctx.internal.on_control(ctx.now, AliasControl::Find(alias_id, tx));
562
563        // Verify unicast message to check with cached peer
564        let outputs = ctx.collect_outputs();
565        assert_eq!(outputs, vec![InternalOutput::Unicast(peer_addr, AliasMessage::Check(alias_id))]);
566
567        // Simulate timeout
568        ctx.advance_time(HINT_TIMEOUT_MS + 1);
569        ctx.internal.on_tick(ctx.now);
570
571        let outputs = ctx.collect_outputs();
572        assert_eq!(outputs, vec![InternalOutput::Broadcast(AliasMessage::Scan(alias_id))]);
573    }
574
575    #[test]
576    fn test_find_timeout() {
577        let mut ctx = TestContext::new();
578        let alias_id = AliasId(1);
579
580        // Create a oneshot channel for the find response
581        let (tx, mut rx) = oneshot::channel();
582
583        // Test finding a non-existent alias
584        ctx.internal.on_control(ctx.now, AliasControl::Find(alias_id, tx));
585
586        // Verify broadcast scan message
587        let outputs = ctx.collect_outputs();
588        assert_eq!(outputs, vec![InternalOutput::Broadcast(AliasMessage::Scan(alias_id))]);
589
590        // Advance time past timeout
591        ctx.advance_time(SCAN_TIMEOUT_MS + 1);
592        ctx.internal.on_tick(ctx.now);
593
594        // Verify timeout response
595        let response = rx.try_recv().expect("Should have a response");
596        assert_eq!(response, None);
597    }
598
599    #[test]
600    fn test_shutdown() {
601        let mut ctx = TestContext::new();
602        let alias_id = AliasId(1);
603        let peer_addr = PeerId(1);
604
605        // Add some data to cache
606        let mut peers = HashSet::new();
607        peers.insert(peer_addr);
608        ctx.internal.cache.put(alias_id, peers);
609
610        // Test shutdown
611        ctx.internal.on_control(ctx.now, AliasControl::Shutdown);
612
613        // Verify broadcast shutdown message
614        let outputs = ctx.collect_outputs();
615        assert_eq!(outputs, vec![InternalOutput::Broadcast(AliasMessage::Shutdown)]);
616
617        // Simulate receiving shutdown message
618        ctx.internal.on_msg(ctx.now, peer_addr, AliasMessage::Shutdown);
619
620        // Verify cache is cleared
621        assert!(ctx.internal.cache.is_empty());
622    }
623}