use std::net::SocketAddr;
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::warn;
use crate::{config::PoolPeerConfig, spawn::SpawnAction};
use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnEvent, SpawnerId};
struct PoolPeer {
id: PeerId,
addr: SocketAddr,
}
pub struct PoolSpawner {
config: PoolPeerConfig,
network_wait_period: std::time::Duration,
id: SpawnerId,
current_peers: Vec<PoolPeer>,
known_ips: Vec<SocketAddr>,
}
#[derive(Error, Debug)]
pub enum PoolSpawnError {}
impl PoolSpawner {
pub fn new(config: PoolPeerConfig, network_wait_period: std::time::Duration) -> PoolSpawner {
PoolSpawner {
config,
network_wait_period,
id: Default::default(),
current_peers: Default::default(),
known_ips: Default::default(),
}
}
pub async fn fill_pool(
&mut self,
action_tx: &mpsc::Sender<SpawnEvent>,
) -> Result<(), PoolSpawnError> {
let mut wait_period = self.network_wait_period;
if self.current_peers.len() >= self.config.max_peers {
return Ok(());
}
loop {
if self.known_ips.len() < self.config.max_peers - self.current_peers.len() {
match self.config.addr.lookup_host().await {
Ok(addresses) => {
self.known_ips.append(&mut addresses.collect());
self.known_ips
.retain(|ip| !self.current_peers.iter().any(|p| p.addr == *ip))
}
Err(e) => {
warn!(error = ?e, "error while resolving peer address, retrying");
tokio::time::sleep(wait_period).await;
continue;
}
}
}
while self.current_peers.len() < self.config.max_peers {
if let Some(addr) = self.known_ips.pop() {
let id = PeerId::new();
self.current_peers.push(PoolPeer { id, addr });
let action = SpawnAction::create(id, addr, self.config.addr.clone(), None);
tracing::debug!(?action, "intending to spawn new pool peer at");
action_tx
.send(SpawnEvent::new(self.id, action))
.await
.expect("Channel was no longer connected");
} else {
break;
}
}
let wait_period_max = if cfg!(test) {
std::time::Duration::default()
} else {
std::time::Duration::from_secs(60)
};
wait_period = Ord::min(2 * wait_period, wait_period_max);
let peers_needed = self.config.max_peers - self.current_peers.len();
if peers_needed > 0 {
warn!(peers_needed, "could not fully fill pool");
tokio::time::sleep(wait_period).await;
} else {
return Ok(());
}
}
}
}
#[async_trait::async_trait]
impl BasicSpawner for PoolSpawner {
type Error = PoolSpawnError;
async fn handle_init(
&mut self,
action_tx: &mpsc::Sender<SpawnEvent>,
) -> Result<(), PoolSpawnError> {
self.fill_pool(action_tx).await?;
Ok(())
}
async fn handle_peer_removed(
&mut self,
removed_peer: PeerRemovedEvent,
action_tx: &mpsc::Sender<SpawnEvent>,
) -> Result<(), PoolSpawnError> {
self.current_peers.retain(|p| p.id != removed_peer.id);
self.fill_pool(action_tx).await?;
Ok(())
}
fn get_id(&self) -> SpawnerId {
self.id
}
fn get_addr_description(&self) -> String {
format!("{} ({})", self.config.addr, self.config.max_peers)
}
fn get_description(&self) -> &str {
"pool"
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tokio::sync::mpsc::{self, error::TryRecvError};
use crate::{
config::{NormalizedAddress, PoolPeerConfig},
spawn::{
pool::PoolSpawner, tests::get_create_params, PeerRemovalReason, Spawner, SystemEvent,
},
system::{MESSAGE_BUFFER_SIZE, NETWORK_WAIT_PERIOD},
};
#[tokio::test]
async fn creates_multiple_peers() {
let pool = PoolSpawner::new(
PoolPeerConfig {
addr: NormalizedAddress::with_hardcoded_dns(
"example.com",
123,
vec![
"127.0.0.1:123".parse().unwrap(),
"127.0.0.2:123".parse().unwrap(),
"127.0.0.3:123".parse().unwrap(),
],
),
max_peers: 2,
},
NETWORK_WAIT_PERIOD,
);
let spawner_id = pool.get_id();
let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let res = action_rx.try_recv().unwrap();
assert_eq!(spawner_id, res.id);
let params = get_create_params(res);
assert_eq!(params.addr.to_string(), "127.0.0.3:123");
tokio::time::sleep(Duration::from_millis(10)).await;
let res = action_rx.try_recv().unwrap();
assert_eq!(spawner_id, res.id);
let params = get_create_params(res);
assert_eq!(params.addr.to_string(), "127.0.0.2:123");
tokio::time::sleep(Duration::from_millis(10)).await;
let res = action_rx.try_recv().unwrap_err();
assert_eq!(res, TryRecvError::Empty);
}
#[tokio::test]
async fn refills_peers_upto_limit() {
let pool = PoolSpawner::new(
PoolPeerConfig {
addr: NormalizedAddress::with_hardcoded_dns(
"example.com",
123,
vec![
"127.0.0.1:123".parse().unwrap(),
"127.0.0.2:123".parse().unwrap(),
"127.0.0.3:123".parse().unwrap(),
],
),
max_peers: 2,
},
NETWORK_WAIT_PERIOD,
);
let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let res = action_rx.try_recv().unwrap();
let params = get_create_params(res);
tokio::time::sleep(Duration::from_millis(10)).await;
action_rx.try_recv().unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
notify_tx
.send(SystemEvent::peer_removed(
params.id,
PeerRemovalReason::NetworkIssue,
))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let res = action_rx.try_recv().unwrap();
let params = get_create_params(res);
assert_eq!(params.addr.to_string(), "127.0.0.1:123");
}
#[tokio::test]
async fn works_if_address_does_not_resolve() {
let pool = PoolSpawner::new(
PoolPeerConfig {
addr: NormalizedAddress::with_hardcoded_dns("does.not.resolve", 123, vec![]),
max_peers: 2,
},
NETWORK_WAIT_PERIOD,
);
let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
tokio::time::sleep(Duration::from_millis(1000)).await;
let res = action_rx.try_recv().unwrap_err();
assert_eq!(res, TryRecvError::Empty);
}
}