ntp_daemon/spawn/
standard.rs

1use std::net::SocketAddr;
2
3use thiserror::Error;
4use tokio::sync::mpsc;
5use tracing::warn;
6
7use crate::config::StandardPeerConfig;
8
9use super::{
10    BasicSpawner, PeerId, PeerRemovalReason, PeerRemovedEvent, SpawnAction, SpawnEvent, SpawnerId,
11};
12
13pub struct StandardSpawner {
14    id: SpawnerId,
15    config: StandardPeerConfig,
16    network_wait_period: std::time::Duration,
17    resolved: Option<SocketAddr>,
18}
19
20#[derive(Error, Debug)]
21pub enum StandardSpawnError {
22    #[error("Channel send error: {0}")]
23    SendError(#[from] mpsc::error::SendError<SpawnEvent>),
24}
25
26impl StandardSpawner {
27    pub fn new(
28        config: StandardPeerConfig,
29        network_wait_period: std::time::Duration,
30    ) -> StandardSpawner {
31        StandardSpawner {
32            id: Default::default(),
33            config,
34            network_wait_period,
35            resolved: None,
36        }
37    }
38
39    async fn do_resolve(&mut self, force_resolve: bool) -> SocketAddr {
40        if let (false, Some(addr)) = (force_resolve, self.resolved) {
41            addr
42        } else {
43            let addr = loop {
44                match self.config.addr.lookup_host().await {
45                    Ok(mut addresses) => match addresses.next() {
46                        None => {
47                            warn!("Could not resolve peer address, retrying");
48                            tokio::time::sleep(self.network_wait_period).await
49                        }
50                        Some(first) => {
51                            break first;
52                        }
53                    },
54                    Err(e) => {
55                        warn!(error = ?e, "error while resolving peer address, retrying");
56                        tokio::time::sleep(self.network_wait_period).await
57                    }
58                }
59            };
60            self.resolved = Some(addr);
61            addr
62        }
63    }
64
65    async fn spawn(
66        &mut self,
67        action_tx: &mpsc::Sender<SpawnEvent>,
68    ) -> Result<(), StandardSpawnError> {
69        let addr = self.do_resolve(false).await;
70        action_tx
71            .send(SpawnEvent::new(
72                self.id,
73                SpawnAction::create(PeerId::new(), addr, self.config.addr.clone(), None),
74            ))
75            .await?;
76        Ok(())
77    }
78}
79
80#[async_trait::async_trait]
81impl BasicSpawner for StandardSpawner {
82    type Error = StandardSpawnError;
83
84    async fn handle_init(
85        &mut self,
86        action_tx: &mpsc::Sender<SpawnEvent>,
87    ) -> Result<(), StandardSpawnError> {
88        self.spawn(action_tx).await
89    }
90
91    async fn handle_peer_removed(
92        &mut self,
93        removed_peer: PeerRemovedEvent,
94        action_tx: &mpsc::Sender<SpawnEvent>,
95    ) -> Result<(), StandardSpawnError> {
96        if removed_peer.reason == PeerRemovalReason::Unreachable {
97            // force new resolution
98            self.resolved = None;
99        }
100        if removed_peer.reason != PeerRemovalReason::Demobilized {
101            self.spawn(action_tx).await
102        } else {
103            Ok(())
104        }
105    }
106
107    fn get_id(&self) -> SpawnerId {
108        self.id
109    }
110
111    fn get_addr_description(&self) -> String {
112        self.config.addr.to_string()
113    }
114
115    fn get_description(&self) -> &str {
116        "standard"
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::time::Duration;
123
124    use tokio::sync::mpsc::{self, error::TryRecvError};
125
126    use crate::{
127        config::{NormalizedAddress, StandardPeerConfig},
128        spawn::{
129            standard::StandardSpawner, tests::get_create_params, PeerRemovalReason, Spawner,
130            SystemEvent,
131        },
132        system::{MESSAGE_BUFFER_SIZE, NETWORK_WAIT_PERIOD},
133    };
134
135    #[tokio::test]
136    async fn creates_a_peer() {
137        let spawner = StandardSpawner::new(
138            StandardPeerConfig {
139                addr: NormalizedAddress::with_hardcoded_dns(
140                    "example.com",
141                    123,
142                    vec!["127.0.0.1:123".parse().unwrap()],
143                ),
144            },
145            NETWORK_WAIT_PERIOD,
146        );
147        let spawner_id = spawner.get_id();
148        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
149        let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
150
151        tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
152        tokio::time::sleep(Duration::from_millis(10)).await;
153        let res = action_rx.try_recv().unwrap();
154        assert_eq!(res.id, spawner_id);
155        let params = get_create_params(res);
156        assert_eq!(params.addr.to_string(), "127.0.0.1:123");
157
158        // and now we should no longer receive anything
159        tokio::time::sleep(Duration::from_millis(10)).await;
160        let res = action_rx.try_recv().unwrap_err();
161        assert_eq!(res, TryRecvError::Empty);
162    }
163
164    #[tokio::test]
165    async fn recreates_a_peer() {
166        let spawner = StandardSpawner::new(
167            StandardPeerConfig {
168                addr: NormalizedAddress::with_hardcoded_dns(
169                    "example.com",
170                    123,
171                    vec!["127.0.0.1:123".parse().unwrap()],
172                ),
173            },
174            NETWORK_WAIT_PERIOD,
175        );
176        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
177        let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
178
179        tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
180        tokio::time::sleep(Duration::from_millis(10)).await;
181        let res = action_rx.try_recv().unwrap();
182        let params = get_create_params(res);
183
184        notify_tx
185            .send(SystemEvent::peer_removed(
186                params.id,
187                PeerRemovalReason::NetworkIssue,
188            ))
189            .await
190            .unwrap();
191        tokio::time::sleep(Duration::from_millis(10)).await;
192
193        let res = action_rx.try_recv().unwrap();
194        let params = get_create_params(res);
195        assert_eq!(params.addr.to_string(), "127.0.0.1:123");
196    }
197
198    #[tokio::test]
199    async fn reresolves_on_unreachable() {
200        let address_strings = ["127.0.0.1:123", "127.0.0.2:123", "127.0.0.3:123"];
201        let addresses = address_strings.map(|addr| addr.parse().unwrap());
202
203        let spawner = StandardSpawner::new(
204            StandardPeerConfig {
205                addr: NormalizedAddress::with_hardcoded_dns(
206                    "europe.pool.ntp.org",
207                    123,
208                    addresses.to_vec(),
209                ),
210            },
211            NETWORK_WAIT_PERIOD,
212        );
213        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
214        let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
215
216        tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
217        let res = action_rx.recv().await.unwrap();
218        let params = get_create_params(res);
219        let initial_addr = params.addr;
220
221        // We repeat multiple times and check at least one is different to be less
222        // sensitive to dns resolver giving the same pool ip.
223        let mut seen_addresses = vec![];
224        for _ in 0..5 {
225            notify_tx
226                .send(SystemEvent::peer_removed(
227                    params.id,
228                    PeerRemovalReason::Unreachable,
229                ))
230                .await
231                .unwrap();
232            let res = action_rx.recv().await.unwrap();
233            let params = get_create_params(res);
234            seen_addresses.push(params.addr);
235        }
236        let seen_addresses = seen_addresses;
237
238        for addr in seen_addresses.iter() {
239            assert!(
240                addresses.contains(addr),
241                "{:?} should have been drawn from {:?}",
242                addr,
243                addresses
244            );
245        }
246
247        assert!(
248            seen_addresses.iter().any(|seen| seen != &initial_addr),
249            "Re-resolved\n\n\t{:?}\n\n should contain at least one address that isn't the original\n\n\t{:?}",
250            seen_addresses,
251            initial_addr,
252        );
253    }
254
255    #[tokio::test]
256    async fn works_if_address_does_not_resolve() {
257        let spawner = StandardSpawner::new(
258            StandardPeerConfig {
259                addr: NormalizedAddress::with_hardcoded_dns("does.not.resolve", 123, vec![]),
260            },
261            NETWORK_WAIT_PERIOD,
262        );
263        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
264        let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
265        tokio::spawn(async move { spawner.run(action_tx, notify_rx).await });
266
267        tokio::time::sleep(Duration::from_millis(1000)).await;
268        let res = action_rx.try_recv().unwrap_err();
269        assert_eq!(res, TryRecvError::Empty);
270    }
271}