1use std::net::SocketAddr;
2
3use thiserror::Error;
4use tokio::sync::mpsc;
5use tracing::warn;
6
7use crate::{config::PoolPeerConfig, spawn::SpawnAction};
8
9use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnEvent, SpawnerId};
10
11struct PoolPeer {
12 id: PeerId,
13 addr: SocketAddr,
14}
15
16pub struct PoolSpawner {
17 config: PoolPeerConfig,
18 network_wait_period: std::time::Duration,
19 id: SpawnerId,
20 current_peers: Vec<PoolPeer>,
21 known_ips: Vec<SocketAddr>,
22}
23
24#[derive(Error, Debug)]
25pub enum PoolSpawnError {}
26
27impl PoolSpawner {
28 pub fn new(config: PoolPeerConfig, network_wait_period: std::time::Duration) -> PoolSpawner {
29 PoolSpawner {
30 config,
31 network_wait_period,
32 id: Default::default(),
33 current_peers: Default::default(),
34 known_ips: Default::default(),
35 }
36 }
37
38 pub async fn fill_pool(
39 &mut self,
40 action_tx: &mpsc::Sender<SpawnEvent>,
41 ) -> Result<(), PoolSpawnError> {
42 let mut wait_period = self.network_wait_period;
43
44 if self.current_peers.len() >= self.config.max_peers {
46 return Ok(());
47 }
48
49 loop {
50 if self.known_ips.len() < self.config.max_peers - self.current_peers.len() {
51 match self.config.addr.lookup_host().await {
52 Ok(addresses) => {
53 self.known_ips.append(&mut addresses.collect());
55 self.known_ips
57 .retain(|ip| !self.current_peers.iter().any(|p| p.addr == *ip))
58 }
59 Err(e) => {
60 warn!(error = ?e, "error while resolving peer address, retrying");
61 tokio::time::sleep(wait_period).await;
62 continue;
63 }
64 }
65 }
66
67 while self.current_peers.len() < self.config.max_peers {
69 if let Some(addr) = self.known_ips.pop() {
70 let id = PeerId::new();
71 self.current_peers.push(PoolPeer { id, addr });
72 let action = SpawnAction::create(id, addr, self.config.addr.clone(), None);
73 tracing::debug!(?action, "intending to spawn new pool peer at");
74
75 action_tx
76 .send(SpawnEvent::new(self.id, action))
77 .await
78 .expect("Channel was no longer connected");
79 } else {
80 break;
81 }
82 }
83
84 let wait_period_max = if cfg!(test) {
85 std::time::Duration::default()
86 } else {
87 std::time::Duration::from_secs(60)
88 };
89
90 wait_period = Ord::min(2 * wait_period, wait_period_max);
91 let peers_needed = self.config.max_peers - self.current_peers.len();
92 if peers_needed > 0 {
93 warn!(peers_needed, "could not fully fill pool");
94 tokio::time::sleep(wait_period).await;
95 } else {
96 return Ok(());
97 }
98 }
99 }
100}
101
102#[async_trait::async_trait]
103impl BasicSpawner for PoolSpawner {
104 type Error = PoolSpawnError;
105
106 async fn handle_init(
107 &mut self,
108 action_tx: &mpsc::Sender<SpawnEvent>,
109 ) -> Result<(), PoolSpawnError> {
110 self.fill_pool(action_tx).await?;
111 Ok(())
112 }
113
114 async fn handle_peer_removed(
115 &mut self,
116 removed_peer: PeerRemovedEvent,
117 action_tx: &mpsc::Sender<SpawnEvent>,
118 ) -> Result<(), PoolSpawnError> {
119 self.current_peers.retain(|p| p.id != removed_peer.id);
120 self.fill_pool(action_tx).await?;
121 Ok(())
122 }
123
124 fn get_id(&self) -> SpawnerId {
125 self.id
126 }
127
128 fn get_addr_description(&self) -> String {
129 format!("{} ({})", self.config.addr, self.config.max_peers)
130 }
131
132 fn get_description(&self) -> &str {
133 "pool"
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use std::time::Duration;
140
141 use tokio::sync::mpsc::{self, error::TryRecvError};
142
143 use crate::{
144 config::{NormalizedAddress, PoolPeerConfig},
145 spawn::{
146 pool::PoolSpawner, tests::get_create_params, PeerRemovalReason, Spawner, SystemEvent,
147 },
148 system::{MESSAGE_BUFFER_SIZE, NETWORK_WAIT_PERIOD},
149 };
150
151 #[tokio::test]
152 async fn creates_multiple_peers() {
153 let address_strings = ["127.0.0.1:123", "127.0.0.2:123", "127.0.0.3:123"];
154 let addresses = address_strings.map(|addr| addr.parse().unwrap());
155
156 let pool = PoolSpawner::new(
157 PoolPeerConfig {
158 addr: NormalizedAddress::with_hardcoded_dns("example.com", 123, addresses.to_vec()),
159 max_peers: 2,
160 },
161 NETWORK_WAIT_PERIOD,
162 );
163 let spawner_id = pool.get_id();
164 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
165 let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
166 tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
167 tokio::time::sleep(Duration::from_millis(10)).await;
168 let res = action_rx.try_recv().unwrap();
169 assert_eq!(spawner_id, res.id);
170 let params = get_create_params(res);
171 let addr1 = params.addr;
172
173 tokio::time::sleep(Duration::from_millis(10)).await;
174 let res = action_rx.try_recv().unwrap();
175 assert_eq!(spawner_id, res.id);
176 let params = get_create_params(res);
177 let addr2 = params.addr;
178
179 assert_ne!(addr1, addr2);
180 assert!(addresses.contains(&addr1));
181 assert!(addresses.contains(&addr2));
182
183 tokio::time::sleep(Duration::from_millis(10)).await;
184 let res = action_rx.try_recv().unwrap_err();
185 assert_eq!(res, TryRecvError::Empty);
186 }
187
188 #[tokio::test]
189 async fn refills_peers_upto_limit() {
190 let address_strings = ["127.0.0.1:123", "127.0.0.2:123", "127.0.0.3:123"];
191 let addresses = address_strings.map(|addr| addr.parse().unwrap());
192
193 let pool = PoolSpawner::new(
194 PoolPeerConfig {
195 addr: NormalizedAddress::with_hardcoded_dns("example.com", 123, addresses.to_vec()),
196 max_peers: 2,
197 },
198 NETWORK_WAIT_PERIOD,
199 );
200 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
201 let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
202 tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
203 tokio::time::sleep(Duration::from_millis(10)).await;
204 let res = action_rx.try_recv().unwrap();
205 let params = get_create_params(res);
206 let addr1 = params.addr;
207 tokio::time::sleep(Duration::from_millis(10)).await;
208 let res = action_rx.try_recv().unwrap();
209 let params = get_create_params(res);
210 let addr2 = params.addr;
211 tokio::time::sleep(Duration::from_millis(10)).await;
212
213 notify_tx
214 .send(SystemEvent::peer_removed(
215 params.id,
216 PeerRemovalReason::NetworkIssue,
217 ))
218 .await
219 .unwrap();
220 tokio::time::sleep(Duration::from_millis(10)).await;
221
222 let res = action_rx.try_recv().unwrap();
223 let params = get_create_params(res);
224 let addr3 = params.addr;
225
226 assert_ne!(addr1, addr2);
228 assert_ne!(addr2, addr3);
229 assert_ne!(addr3, addr1);
230
231 assert!(addresses.contains(&addr3));
232 }
233
234 #[tokio::test]
235 async fn works_if_address_does_not_resolve() {
236 let pool = PoolSpawner::new(
237 PoolPeerConfig {
238 addr: NormalizedAddress::with_hardcoded_dns("does.not.resolve", 123, vec![]),
239 max_peers: 2,
240 },
241 NETWORK_WAIT_PERIOD,
242 );
243 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
244 let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
245 tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
246 tokio::time::sleep(Duration::from_millis(1000)).await;
247 let res = action_rx.try_recv().unwrap_err();
248 assert_eq!(res, TryRecvError::Empty);
249 }
250}