use std::collections::HashSet;
use std::hash::Hash as StdHash;
use std::marker::PhantomData;
use rand::Rng;
use rand::seq::IteratorRandom;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use crate::address_book::{AddressBookStore, NodeInfo};
use crate::{DiscoveryResult, DiscoveryStrategy};
#[derive(Debug)]
pub struct RandomWalkerConfig {
pub reset_walk_probability: f64,
}
impl Default for RandomWalkerConfig {
fn default() -> Self {
Self {
reset_walk_probability: 0.02, }
}
}
pub struct RandomWalker<R, S, ID, N> {
my_id: ID,
store: S,
rng: Mutex<R>,
config: RandomWalkerConfig,
state: RwLock<RandomWalkerState<ID>>,
_marker: PhantomData<N>,
}
struct RandomWalkerState<ID> {
unvisited: HashSet<ID>,
visited: HashSet<ID>,
}
impl<ID> Default for RandomWalkerState<ID> {
fn default() -> Self {
Self {
unvisited: HashSet::new(),
visited: HashSet::new(),
}
}
}
impl<R, S, ID, N> RandomWalker<R, S, ID, N>
where
R: Rng,
S: AddressBookStore<ID, N>,
ID: Clone + Eq + StdHash,
N: NodeInfo<ID>,
{
pub fn new(my_id: ID, store: S, rng: R) -> Self {
Self::from_config(my_id, store, rng, RandomWalkerConfig::default())
}
pub fn from_config(my_id: ID, store: S, rng: R, config: RandomWalkerConfig) -> Self {
Self {
my_id,
store,
rng: Mutex::new(rng),
state: RwLock::new(RandomWalkerState::default()),
config,
_marker: PhantomData,
}
}
async fn reset(&self) -> Result<(), RandomWalkError<S, ID, N>> {
let all_nodes = self
.store
.all_node_infos()
.await
.map_err(RandomWalkError::Store)?
.into_iter()
.filter_map(|info| {
let id = info.id();
if id != self.my_id { Some(id) } else { None }
});
{
let mut state = self.state.write().await;
state.unvisited.extend(all_nodes);
state.visited = HashSet::from([self.my_id.clone()]);
}
Ok(())
}
async fn random_bootstrap_node(&self) -> Result<Option<ID>, RandomWalkError<S, ID, N>> {
let bootstrap_node = loop {
let node_id = self
.store
.random_bootstrap_node()
.await
.map_err(RandomWalkError::Store)?
.map(|info| info.id());
let Some(node_id) = node_id else {
break None;
};
if node_id != self.my_id {
break Some(node_id);
}
let bootstrap_nodes_len = self
.store
.all_bootstrap_nodes_len()
.await
.map_err(RandomWalkError::Store)?;
if bootstrap_nodes_len == 1 {
return Ok(None);
}
};
if bootstrap_node.is_none() {
self.random_unvisited_node().await
} else {
Ok(bootstrap_node)
}
}
async fn random_unvisited_node(&self) -> Result<Option<ID>, RandomWalkError<S, ID, N>> {
let state = self.state.read().await;
let mut rng = self.rng.lock().await;
let sampled = state.unvisited.iter().choose(&mut rng);
Ok(sampled.cloned())
}
async fn merge_previous(&self, previous: &DiscoveryResult<ID, N>) {
let node_ids = previous.transport_infos.keys();
let mut state = self.state.write().await;
for id in node_ids {
if !state.visited.contains(id) {
state.unvisited.insert(id.clone());
}
}
}
async fn mark_visited(&self, id: &ID) {
let mut state = self.state.write().await;
state.visited.insert(id.clone());
state.unvisited.remove(id);
}
}
impl<R, S, ID, N> DiscoveryStrategy<ID, N> for RandomWalker<R, S, ID, N>
where
R: Rng,
S: AddressBookStore<ID, N>,
ID: Clone + Eq + StdHash,
N: NodeInfo<ID>,
{
type Error = RandomWalkError<S, ID, N>;
async fn next_node(
&self,
previous: Option<&DiscoveryResult<ID, N>>,
) -> Result<Option<ID>, Self::Error> {
if let Some(previous) = previous {
self.merge_previous(previous).await;
}
let reset = {
if previous.is_none() {
true
} else if self.state.read().await.unvisited.is_empty() {
true
} else {
self.rng
.lock()
.await
.random_bool(self.config.reset_walk_probability)
}
};
let node_id = if reset {
self.reset().await?;
self.random_bootstrap_node().await?
} else {
self.random_unvisited_node().await?
};
if let Some(ref node_id) = node_id {
self.mark_visited(node_id).await;
}
Ok(node_id)
}
}
#[derive(Debug, Error)]
pub enum RandomWalkError<S, ID, N>
where
S: AddressBookStore<ID, N>,
{
#[error("{0}")]
Store(S::Error),
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use crate::address_book::AddressBookStore;
use crate::test_utils::{TestId, TestInfo, TestStore};
use crate::traits::{DiscoveryResult, DiscoveryStrategy};
use super::{RandomWalker, RandomWalkerConfig};
#[tokio::test]
async fn explore_full_graph() {
let graph = HashMap::from([
(0, vec![]),
(1, vec![2, 3]),
(2, vec![4]),
(3, vec![5]),
(4, vec![]),
(5, vec![6, 7, 8]),
(6, vec![]),
(7, vec![8]),
(8, vec![]),
]);
let rng = ChaCha20Rng::from_seed([1; 32]);
let store = TestStore::new(rng.clone());
store.insert_node_info(TestInfo::new(0)).await.unwrap();
store
.insert_node_info(TestInfo::new_bootstrap(1))
.await
.unwrap();
let strategy = RandomWalker::new(0, store, rng);
let mut visited: HashSet<TestId> = HashSet::new();
let mut previous: Option<DiscoveryResult<TestId, TestInfo>> = None;
for _ in 0..graph.len() - 1 {
let id = strategy
.next_node(previous.as_ref())
.await
.unwrap()
.expect("should return a Some value");
visited.insert(id);
previous = Some(DiscoveryResult::from_neighbors(id, graph.get(&id).unwrap()));
}
assert_eq!(visited.len(), graph.len() - 1);
}
#[tokio::test]
async fn mark_nodes_as_visited() {
const NUM_NODES: usize = 32;
let rng = ChaCha20Rng::from_seed([1; 32]);
let store = TestStore::new(rng.clone());
for id in 0..NUM_NODES {
store.insert_node_info(TestInfo::new(id)).await.unwrap();
}
let strategy = RandomWalker::from_config(
0,
store,
rng,
RandomWalkerConfig {
reset_walk_probability: 0.0,
},
);
let mut visited: HashSet<TestId> = HashSet::new();
let mut previous: Option<DiscoveryResult<TestId, TestInfo>> = None;
for _ in 0..NUM_NODES - 1 {
let id = strategy
.next_node(previous.as_ref())
.await
.unwrap()
.expect("should return a Some value");
if !visited.insert(id) {
panic!("should never return duplicates");
}
previous = Some(DiscoveryResult::new(id));
}
let id = strategy
.next_node(previous.as_ref())
.await
.unwrap()
.expect("should return a Some value");
assert!(visited.contains(&id));
}
#[tokio::test]
async fn never_yield_own_node_info() {
let rng = ChaCha20Rng::from_seed([1; 32]);
let store = TestStore::new(rng.clone());
store
.insert_node_info(TestInfo::new_bootstrap(0))
.await
.unwrap();
let strategy = RandomWalker::new(0, store, rng);
assert!(strategy.next_node(None).await.unwrap().is_none());
}
}