1use std::net::SocketAddr;
2
3use thiserror::Error;
4use tokio::sync::mpsc;
5use tracing::warn;
6
7use crate::config::StandardPeerConfig;
8
9use super::{
10 BasicSpawner, PeerId, PeerRemovalReason, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId,
11};
12
13pub struct StandardSpawner {
14 id: SpawnerId,
15 config: StandardPeerConfig,
16 network_wait_period: std::time::Duration,
17 resolved: Option<SocketAddr>,
18}
19
20#[derive(Error, Debug)]
21pub enum StandardSpawnError {
22 #[error("Channel send error: {0}")]
23 SendError(#[from] mpsc::error::SendError<SpawnEvent>),
24}
25
26impl StandardSpawner {
27 pub fn new(
28 config: StandardPeerConfig,
29 network_wait_period: std::time::Duration,
30 ) -> StandardSpawner {
31 StandardSpawner {
32 id: Default::default(),
33 config,
34 network_wait_period,
35 resolved: None,
36 }
37 }
38
39 async fn do_resolve(&mut self, force_resolve: bool) -> SocketAddr {
40 if let (false, Some(addr)) = (force_resolve, self.resolved) {
41 addr
42 } else {
43 let addr = loop {
44 match self.config.addr.lookup_host().await {
45 Ok(mut addresses) => match addresses.next() {
46 None => {
47 warn!("Could not resolve peer address, retrying");
48 tokio::time::sleep(self.network_wait_period).await
49 }
50 Some(first) => {
51 break first;
52 }
53 },
54 Err(e) => {
55 warn!(error = ?e, "error while resolving peer address, retrying");
56 tokio::time::sleep(self.network_wait_period).await
57 }
58 }
59 };
60 self.resolved = Some(addr);
61 addr
62 }
63 }
64
65 async fn spawn(
66 &mut self,
67 action_tx: &mpsc::Sender<SpawnEvent>,
68 ) -> Result<(), StandardSpawnError> {
69 let addr = self.do_resolve(false).await;
70 action_tx
71 .send(SpawnEvent::new(
72 self.id,
73 SpawnAction::create(PeerId::new(), addr, self.config.addr.clone(), None),
74 ))
75 .await?;
76 Ok(())
77 }
78}
79
80#[async_trait::async_trait]
81impl BasicSpawner for StandardSpawner {
82 type Error = StandardSpawnError;
83
84 async fn handle_init(
85 &mut self,
86 action_tx: &mpsc::Sender<SpawnEvent>,
87 ) -> Result<(), StandardSpawnError> {
88 self.spawn(action_tx).await
89 }
90
91 async fn handle_peer_removed(
92 &mut self,
93 removed_peer: PeerRemovedEvent,
94 action_tx: &mpsc::Sender<SpawnEvent>,
95 ) -> Result<(), StandardSpawnError> {
96 if removed_peer.reason == PeerRemovalReason::Unreachable {
97 self.resolved = None;
99 }
100 if removed_peer.reason != PeerRemovalReason::Demobilized {
101 self.spawn(action_tx).await
102 } else {
103 Ok(())
104 }
105 }
106
107 fn get_id(&self) -> SpawnerId {
108 self.id
109 }
110
111 fn get_addr_description(&self) -> String {
112 self.config.addr.to_string()
113 }
114
115 fn get_description(&self) -> &str {
116 "standard"
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use std::time::Duration;
123
124 use tokio::sync::mpsc::{self, error::TryRecvError};
125
126 use crate::{
127 config::{NormalizedAddress, StandardPeerConfig},
128 spawn::{
129 standard::StandardSpawner, tests::get_create_params, PeerRemovalReason, Spawner,
130 SystemEvent,
131 },
132 system::{MESSAGE_BUFFER_SIZE, NETWORK_WAIT_PERIOD},
133 };
134
135 #[tokio::test]
136 async fn creates_a_peer() {
137 let spawner = StandardSpawner::new(
138 StandardPeerConfig {
139 addr: NormalizedAddress::with_hardcoded_dns(
140 "example.com",
141 123,
142 vec!["127.0.0.1:123".parse().unwrap()],
143 ),
144 },
145 NETWORK_WAIT_PERIOD,
146 );
147 let spawner_id = spawner.get_id();
148 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
149 let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
150
151 tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
152 tokio::time::sleep(Duration::from_millis(10)).await;
153 let res = action_rx.try_recv().unwrap();
154 assert_eq!(res.id, spawner_id);
155 let params = get_create_params(res);
156 assert_eq!(params.addr.to_string(), "127.0.0.1:123");
157
158 tokio::time::sleep(Duration::from_millis(10)).await;
160 let res = action_rx.try_recv().unwrap_err();
161 assert_eq!(res, TryRecvError::Empty);
162 }
163
164 #[tokio::test]
165 async fn recreates_a_peer() {
166 let spawner = StandardSpawner::new(
167 StandardPeerConfig {
168 addr: NormalizedAddress::with_hardcoded_dns(
169 "example.com",
170 123,
171 vec!["127.0.0.1:123".parse().unwrap()],
172 ),
173 },
174 NETWORK_WAIT_PERIOD,
175 );
176 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
177 let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
178
179 tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
180 tokio::time::sleep(Duration::from_millis(10)).await;
181 let res = action_rx.try_recv().unwrap();
182 let params = get_create_params(res);
183
184 notify_tx
185 .send(SystemEvent::peer_removed(
186 params.id,
187 PeerRemovalReason::NetworkIssue,
188 ))
189 .await
190 .unwrap();
191 tokio::time::sleep(Duration::from_millis(10)).await;
192
193 let res = action_rx.try_recv().unwrap();
194 let params = get_create_params(res);
195 assert_eq!(params.addr.to_string(), "127.0.0.1:123");
196 }
197
198 #[tokio::test]
199 async fn reresolves_on_unreachable() {
200 let address_strings = ["127.0.0.1:123", "127.0.0.2:123", "127.0.0.3:123"];
201 let addresses = address_strings.map(|addr| addr.parse().unwrap());
202
203 let spawner = StandardSpawner::new(
204 StandardPeerConfig {
205 addr: NormalizedAddress::with_hardcoded_dns(
206 "europe.pool.ntp.org",
207 123,
208 addresses.to_vec(),
209 ),
210 },
211 NETWORK_WAIT_PERIOD,
212 );
213 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
214 let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
215
216 tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
217 let res = action_rx.recv().await.unwrap();
218 let params = get_create_params(res);
219 let initial_addr = params.addr;
220
221 let mut seen_addresses = vec![];
224 for _ in 0..5 {
225 notify_tx
226 .send(SystemEvent::peer_removed(
227 params.id,
228 PeerRemovalReason::Unreachable,
229 ))
230 .await
231 .unwrap();
232 let res = action_rx.recv().await.unwrap();
233 let params = get_create_params(res);
234 seen_addresses.push(params.addr);
235 }
236 let seen_addresses = seen_addresses;
237
238 for addr in seen_addresses.iter() {
239 assert!(
240 addresses.contains(addr),
241 "{:?} should have been drawn from {:?}",
242 addr,
243 addresses
244 );
245 }
246
247 assert!(
248 seen_addresses.iter().any(|seen| seen != &initial_addr),
249 "Re-resolved\n\n\t{:?}\n\n should contain at least one address that isn't the original\n\n\t{:?}",
250 seen_addresses,
251 initial_addr,
252 );
253 }
254
255 #[tokio::test]
256 async fn works_if_address_does_not_resolve() {
257 let spawner = StandardSpawner::new(
258 StandardPeerConfig {
259 addr: NormalizedAddress::with_hardcoded_dns("does.not.resolve", 123, vec![]),
260 },
261 NETWORK_WAIT_PERIOD,
262 );
263 let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
264 let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
265 tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
266
267 tokio::time::sleep(Duration::from_millis(1000)).await;
268 let res = action_rx.try_recv().unwrap_err();
269 assert_eq!(res, TryRecvError::Empty);
270 }
271}