ntp_daemon/spawn/
pool.rs

1use std::net::SocketAddr;
2
3use thiserror::Error;
4use tokio::sync::mpsc;
5use tracing::warn;
6
7use crate::{config::PoolPeerConfig, spawn::SpawnAction};
8
9use super::{BasicSpawner, PeerId, PeerRemovedEvent, SpawnEvent, SpawnerId};
10
11struct PoolPeer {
12    id: PeerId,
13    addr: SocketAddr,
14}
15
16pub struct PoolSpawner {
17    config: PoolPeerConfig,
18    network_wait_period: std::time::Duration,
19    id: SpawnerId,
20    current_peers: Vec<PoolPeer>,
21    known_ips: Vec<SocketAddr>,
22}
23
24#[derive(Error, Debug)]
25pub enum PoolSpawnError {}
26
27impl PoolSpawner {
28    pub fn new(config: PoolPeerConfig, network_wait_period: std::time::Duration) -> PoolSpawner {
29        PoolSpawner {
30            config,
31            network_wait_period,
32            id: Default::default(),
33            current_peers: Default::default(),
34            known_ips: Default::default(),
35        }
36    }
37
38    pub async fn fill_pool(
39        &mut self,
40        action_tx: &mpsc::Sender<SpawnEvent>,
41    ) -> Result<(), PoolSpawnError> {
42        let mut wait_period = self.network_wait_period;
43
44        // early return if there is nothing to do
45        if self.current_peers.len() >= self.config.max_peers {
46            return Ok(());
47        }
48
49        loop {
50            if self.known_ips.len() < self.config.max_peers - self.current_peers.len() {
51                match self.config.addr.lookup_host().await {
52                    Ok(addresses) => {
53                        // add the addresses looked up to our list of known ips
54                        self.known_ips.append(&mut addresses.collect());
55                        // remove known ips that we are already connected to
56                        self.known_ips
57                            .retain(|ip| !self.current_peers.iter().any(|p| p.addr == *ip))
58                    }
59                    Err(e) => {
60                        warn!(error = ?e, "error while resolving peer address, retrying");
61                        tokio::time::sleep(wait_period).await;
62                        continue;
63                    }
64                }
65            }
66
67            // Try and add peers to our pool
68            while self.current_peers.len() < self.config.max_peers {
69                if let Some(addr) = self.known_ips.pop() {
70                    let id = PeerId::new();
71                    self.current_peers.push(PoolPeer { id, addr });
72                    let action = SpawnAction::create(id, addr, self.config.addr.clone(), None);
73                    tracing::debug!(?action, "intending to spawn new pool peer at");
74
75                    action_tx
76                        .send(SpawnEvent::new(self.id, action))
77                        .await
78                        .expect("Channel was no longer connected");
79                } else {
80                    break;
81                }
82            }
83
84            let wait_period_max = if cfg!(test) {
85                std::time::Duration::default()
86            } else {
87                std::time::Duration::from_secs(60)
88            };
89
90            wait_period = Ord::min(2 * wait_period, wait_period_max);
91            let peers_needed = self.config.max_peers - self.current_peers.len();
92            if peers_needed > 0 {
93                warn!(peers_needed, "could not fully fill pool");
94                tokio::time::sleep(wait_period).await;
95            } else {
96                return Ok(());
97            }
98        }
99    }
100}
101
102#[async_trait::async_trait]
103impl BasicSpawner for PoolSpawner {
104    type Error = PoolSpawnError;
105
106    async fn handle_init(
107        &mut self,
108        action_tx: &mpsc::Sender<SpawnEvent>,
109    ) -> Result<(), PoolSpawnError> {
110        self.fill_pool(action_tx).await?;
111        Ok(())
112    }
113
114    async fn handle_peer_removed(
115        &mut self,
116        removed_peer: PeerRemovedEvent,
117        action_tx: &mpsc::Sender<SpawnEvent>,
118    ) -> Result<(), PoolSpawnError> {
119        self.current_peers.retain(|p| p.id != removed_peer.id);
120        self.fill_pool(action_tx).await?;
121        Ok(())
122    }
123
124    fn get_id(&self) -> SpawnerId {
125        self.id
126    }
127
128    fn get_addr_description(&self) -> String {
129        format!("{} ({})", self.config.addr, self.config.max_peers)
130    }
131
132    fn get_description(&self) -> &str {
133        "pool"
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use std::time::Duration;
140
141    use tokio::sync::mpsc::{self, error::TryRecvError};
142
143    use crate::{
144        config::{NormalizedAddress, PoolPeerConfig},
145        spawn::{
146            pool::PoolSpawner, tests::get_create_params, PeerRemovalReason, Spawner, SystemEvent,
147        },
148        system::{MESSAGE_BUFFER_SIZE, NETWORK_WAIT_PERIOD},
149    };
150
151    #[tokio::test]
152    async fn creates_multiple_peers() {
153        let address_strings = ["127.0.0.1:123", "127.0.0.2:123", "127.0.0.3:123"];
154        let addresses = address_strings.map(|addr| addr.parse().unwrap());
155
156        let pool = PoolSpawner::new(
157            PoolPeerConfig {
158                addr: NormalizedAddress::with_hardcoded_dns("example.com", 123, addresses.to_vec()),
159                max_peers: 2,
160            },
161            NETWORK_WAIT_PERIOD,
162        );
163        let spawner_id = pool.get_id();
164        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
165        let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
166        tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
167        tokio::time::sleep(Duration::from_millis(10)).await;
168        let res = action_rx.try_recv().unwrap();
169        assert_eq!(spawner_id, res.id);
170        let params = get_create_params(res);
171        let addr1 = params.addr;
172
173        tokio::time::sleep(Duration::from_millis(10)).await;
174        let res = action_rx.try_recv().unwrap();
175        assert_eq!(spawner_id, res.id);
176        let params = get_create_params(res);
177        let addr2 = params.addr;
178
179        assert_ne!(addr1, addr2);
180        assert!(addresses.contains(&addr1));
181        assert!(addresses.contains(&addr2));
182
183        tokio::time::sleep(Duration::from_millis(10)).await;
184        let res = action_rx.try_recv().unwrap_err();
185        assert_eq!(res, TryRecvError::Empty);
186    }
187
188    #[tokio::test]
189    async fn refills_peers_upto_limit() {
190        let address_strings = ["127.0.0.1:123", "127.0.0.2:123", "127.0.0.3:123"];
191        let addresses = address_strings.map(|addr| addr.parse().unwrap());
192
193        let pool = PoolSpawner::new(
194            PoolPeerConfig {
195                addr: NormalizedAddress::with_hardcoded_dns("example.com", 123, addresses.to_vec()),
196                max_peers: 2,
197            },
198            NETWORK_WAIT_PERIOD,
199        );
200        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
201        let (notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
202        tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
203        tokio::time::sleep(Duration::from_millis(10)).await;
204        let res = action_rx.try_recv().unwrap();
205        let params = get_create_params(res);
206        let addr1 = params.addr;
207        tokio::time::sleep(Duration::from_millis(10)).await;
208        let res = action_rx.try_recv().unwrap();
209        let params = get_create_params(res);
210        let addr2 = params.addr;
211        tokio::time::sleep(Duration::from_millis(10)).await;
212
213        notify_tx
214            .send(SystemEvent::peer_removed(
215                params.id,
216                PeerRemovalReason::NetworkIssue,
217            ))
218            .await
219            .unwrap();
220        tokio::time::sleep(Duration::from_millis(10)).await;
221
222        let res = action_rx.try_recv().unwrap();
223        let params = get_create_params(res);
224        let addr3 = params.addr;
225
226        // no duplicates!
227        assert_ne!(addr1, addr2);
228        assert_ne!(addr2, addr3);
229        assert_ne!(addr3, addr1);
230
231        assert!(addresses.contains(&addr3));
232    }
233
234    #[tokio::test]
235    async fn works_if_address_does_not_resolve() {
236        let pool = PoolSpawner::new(
237            PoolPeerConfig {
238                addr: NormalizedAddress::with_hardcoded_dns("does.not.resolve", 123, vec![]),
239                max_peers: 2,
240            },
241            NETWORK_WAIT_PERIOD,
242        );
243        let (action_tx, mut action_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
244        let (_notify_tx, notify_rx) = mpsc::channel(MESSAGE_BUFFER_SIZE);
245        tokio::spawn(async move { pool.run(action_tx, notify_rx).await });
246        tokio::time::sleep(Duration::from_millis(1000)).await;
247        let res = action_rx.try_recv().unwrap_err();
248        assert_eq!(res, TryRecvError::Empty);
249    }
250}