use super::PeerDiscoverySettings;
use crate::api::peer_discovery_internals::get_peers_all;
use crate::error::PeerDiscoveryError;
use crate::{api::node::get_info, NodeConf, PeerInfo};
use async_trait::async_trait;
use bounded_integer::BoundedU16;
use bounded_vec::NonEmptyVec;
use ergo_chain_types::PeerAddr;
use std::fmt::Debug;
use std::{collections::HashSet, time::Duration};
use url::Url;
pub(crate) async fn peer_discovery_inner(
seeds: NonEmptyVec<Url>,
max_parallel_tasks: BoundedU16<1, { u16::MAX }>,
timeout: Duration,
) -> Result<Vec<Url>, PeerDiscoveryError> {
let settings = PeerDiscoverySettings {
max_parallel_tasks,
task_2_buffer_length: max_parallel_tasks.get() as usize,
global_timeout: timeout,
timeout_of_individual_node_request: Duration::from_secs(4),
};
#[cfg(not(target_arch = "wasm32"))]
let (tx_msg, rx_msg) = tokio::sync::mpsc::channel::<Msg>(settings.task_2_buffer_length);
#[cfg(not(target_arch = "wasm32"))]
let (tx_url, rx_url) = tokio::sync::mpsc::channel::<Url>(settings.task_2_buffer_length);
#[cfg(not(target_arch = "wasm32"))]
let url_stream = tokio_stream::wrappers::ReceiverStream::new(rx_url);
#[cfg(not(target_arch = "wasm32"))]
let msg_stream = tokio_stream::wrappers::ReceiverStream::new(rx_msg);
#[cfg(target_arch = "wasm32")]
let (tx_msg, rx_msg) = futures::channel::mpsc::channel::<Msg>(settings.task_2_buffer_length);
#[cfg(target_arch = "wasm32")]
let (tx_url, rx_url) = futures::channel::mpsc::channel::<Url>(settings.task_2_buffer_length);
#[cfg(target_arch = "wasm32")]
let url_stream = rx_url;
#[cfg(target_arch = "wasm32")]
let msg_stream = rx_msg;
peer_discovery_impl(seeds, tx_msg, msg_stream, tx_url, url_stream, settings).await
}
async fn peer_discovery_impl<
SendMsg: 'static + ChannelInfallibleSender<Msg> + Clone + Send + Sync,
SendUrl: 'static + ChannelInfallibleSender<Url> + ChannelTrySender<Url> + Clone + Send + Sync,
>(
seeds: NonEmptyVec<Url>,
tx_msg: SendMsg,
msg_stream: impl futures::Stream<Item = Msg> + Send + 'static,
mut tx_url: SendUrl,
url_stream: impl futures::Stream<Item = Url> + Send + 'static,
settings: PeerDiscoverySettings,
) -> Result<Vec<Url>, PeerDiscoveryError> {
use futures::future::FutureExt;
use futures::StreamExt;
let mut seeds_set: HashSet<Url> = HashSet::new();
for mut seed_url in seeds {
#[allow(clippy::unwrap_used)]
seed_url.set_port(None).unwrap();
seeds_set.insert(seed_url);
}
spawn_http_request_task(
tx_msg,
url_stream,
settings.max_parallel_tasks,
settings.timeout_of_individual_node_request,
);
for url in &seeds_set {
tx_url.infallible_send(url.clone()).await;
}
let mut count = seeds_set.len();
let mut visited_active_peers = HashSet::new();
let mut visited_peers = HashSet::new();
let mut peer_stack: Vec<PeerInfo> = vec![];
#[cfg(target_arch = "wasm32")]
let rx_timeout_signal = {
let (tx, rx) = futures::channel::oneshot::channel::<()>();
wasm_bindgen_futures::spawn_local(async move {
let _ = crate::wasm_timer::Delay::new(settings.global_timeout).await;
let _ = tx.send(());
});
rx.into_stream()
};
#[cfg(not(target_arch = "wasm32"))]
let rx_timeout_signal = {
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
tokio::time::sleep(settings.global_timeout).await;
let _ = tx.send(());
});
rx.into_stream()
};
enum C {
RxMsg(Msg),
RxTimeoutSignal,
}
type CombinedStream = std::pin::Pin<Box<dyn futures::stream::Stream<Item = C> + Send>>;
let streams: Vec<CombinedStream> = vec![
msg_stream.map(C::RxMsg).boxed(),
rx_timeout_signal.map(|_| C::RxTimeoutSignal).boxed(),
];
let mut combined_stream = futures::stream::select_all(streams);
let mut add_peers = true;
'loop_: while let Some(n) = combined_stream.next().await {
match n {
C::RxMsg(p) => {
while let Some(peer) = peer_stack.pop() {
let mut url = peer.addr.as_http_url();
#[allow(clippy::unwrap_used)]
url.set_port(None).unwrap();
if !visited_peers.contains(&url) {
match tx_url.try_send(url.clone()) {
Ok(_) => {
visited_peers.insert(url);
count += 1;
}
Err(TrySendError::Full) => {
peer_stack.push(peer);
break;
}
Err(TrySendError::Closed) => {
return Err(PeerDiscoveryError::MpscSender);
}
}
}
}
match p {
Msg::AddActiveNode(mut url) => {
#[allow(clippy::unwrap_used)]
url.set_port(None).unwrap();
visited_active_peers.insert(url.clone());
visited_peers.insert(url);
count -= 1;
if count == 0 {
break 'loop_;
}
}
Msg::AddInactiveNode(mut url) => {
#[allow(clippy::unwrap_used)]
url.set_port(None).unwrap();
visited_peers.insert(url);
count -= 1;
if count == 0 {
break 'loop_;
}
}
Msg::CheckPeers(mut peers) => {
use rand::seq::SliceRandom;
use rand::thread_rng;
peers.shuffle(&mut thread_rng());
if add_peers {
peer_stack.extend(peers);
}
}
}
}
C::RxTimeoutSignal => {
add_peers = false;
peer_stack.clear();
}
}
}
drop(tx_url);
let coll: Vec<_> = visited_active_peers
.difference(&seeds_set)
.cloned()
.collect();
Ok(coll)
}
fn spawn_http_request_task<
SendMsg: ChannelInfallibleSender<Msg> + Clone + Send + Sync + 'static,
>(
tx_peer: SendMsg,
url_stream: impl futures::Stream<Item = Url> + Send + 'static,
max_parallel_requests: BoundedU16<1, { u16::MAX }>,
request_timeout_duration: Duration,
) {
use futures::StreamExt;
#[cfg(not(target_arch = "wasm32"))]
let spawn_fn = tokio::spawn;
#[cfg(target_arch = "wasm32")]
let spawn_fn = wasm_bindgen_futures::spawn_local;
let mapped_stream = url_stream
.map(move |mut url| {
let mut tx_peer = tx_peer.clone();
async move {
let _handle = spawn_fn(async move {
#[allow(clippy::unwrap_used)]
url.set_port(Some(9053)).unwrap();
#[allow(clippy::unwrap_used)]
let node_conf = NodeConf {
addr: PeerAddr(url.socket_addrs(|| Some(9053)).unwrap()[0]),
api_key: None,
timeout: Some(request_timeout_duration),
};
match get_info(node_conf).await {
Ok(_) => {
match get_peers_all(node_conf).await {
Ok(peers) => {
tx_peer.infallible_send(Msg::CheckPeers(peers)).await;
tx_peer
.infallible_send(Msg::AddActiveNode(url.clone()))
.await;
}
Err(_) => {
#[allow(clippy::unwrap_used)]
tx_peer.infallible_send(Msg::AddInactiveNode(url)).await;
}
}
}
Err(_) => {
#[allow(clippy::unwrap_used)]
tx_peer.infallible_send(Msg::AddInactiveNode(url)).await;
}
}
});
}
})
.buffer_unordered(max_parallel_requests.get() as usize);
#[cfg(not(target_arch = "wasm32"))]
let spawn_fn_new = tokio::spawn;
#[cfg(target_arch = "wasm32")]
let spawn_fn_new = wasm_bindgen_futures::spawn_local;
spawn_fn_new(mapped_stream.for_each(|_| async move {}));
}
#[derive(Debug)]
pub(crate) enum Msg {
AddActiveNode(Url),
AddInactiveNode(Url),
CheckPeers(Vec<PeerInfo>),
}
#[async_trait]
trait ChannelInfallibleSender<T> {
async fn infallible_send(&mut self, value: T);
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
impl<T: Debug + Send> ChannelInfallibleSender<T> for tokio::sync::mpsc::Sender<T> {
async fn infallible_send(&mut self, value: T) {
let _ = self.send(value).await;
}
}
#[cfg(target_arch = "wasm32")]
#[async_trait]
impl<T: Debug + Send> ChannelInfallibleSender<T> for futures::channel::mpsc::Sender<T> {
async fn infallible_send(&mut self, value: T) {
use futures::sink::SinkExt;
let _ = self.send(value).await;
}
}
trait ChannelTrySender<T> {
fn try_send(&mut self, value: T) -> Result<(), TrySendError>;
}
enum TrySendError {
Full,
Closed,
}
#[cfg(not(target_arch = "wasm32"))]
impl<T> ChannelTrySender<T> for tokio::sync::mpsc::Sender<T> {
fn try_send(&mut self, value: T) -> Result<(), TrySendError> {
use tokio::sync::mpsc::error::TrySendError as TokioTrySendError;
match tokio::sync::mpsc::Sender::try_send(self, value) {
Ok(()) => Ok(()),
Err(TokioTrySendError::Full(_)) => Err(TrySendError::Full),
Err(TokioTrySendError::Closed(_)) => Err(TrySendError::Closed),
}
}
}
#[cfg(target_arch = "wasm32")]
impl<T> ChannelTrySender<T> for futures::channel::mpsc::Sender<T> {
fn try_send(&mut self, value: T) -> Result<(), TrySendError> {
match futures::channel::mpsc::Sender::try_send(self, value) {
Ok(_) => Ok(()),
Err(e) => {
if e.is_full() {
Err(TrySendError::Full)
} else {
Err(TrySendError::Closed)
}
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_get_peers_all() {
let runtime_inner = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let node_conf = NodeConf {
addr: PeerAddr::from_str("213.239.193.208:9053").unwrap(),
api_key: None,
timeout: Some(Duration::from_secs(5)),
};
let res = runtime_inner.block_on(async { get_peers_all(node_conf).await.unwrap() });
assert!(!res.is_empty())
}
}