1use crate::dht::{
2 errors::DhtError,
3 processor::DhtProcessor,
4 rpc::{DhtMessageClient, DhtRequest, DhtResponse},
5 types::{DhtRecord, NetworkInfo, Peer},
6 DhtConfig, Validator,
7};
8use libp2p::{identity::Keypair, Multiaddr, PeerId};
9use noosphere_common::channel::message_channel;
10use std::time::Duration;
11use tokio;
12
13macro_rules! ensure_response {
14 ($response:expr, $matcher:pat => $statement:expr) => {
15 match $response {
16 $matcher => $statement,
17 _ => Err(DhtError::Error("Unexpected".into())),
18 }
19 };
20}
21
22pub struct DhtNode {
45 config: DhtConfig,
46 client: DhtMessageClient,
47 thread_handle: tokio::task::JoinHandle<Result<(), DhtError>>,
48 peer_id: PeerId,
49}
50
51impl DhtNode {
52 pub fn new<V: Validator + 'static>(
53 keypair: Keypair,
54 config: DhtConfig,
55 validator: Option<V>,
56 ) -> Result<Self, DhtError> {
57 let peer_id = PeerId::from(keypair.public());
58 let channels = message_channel::<DhtRequest, DhtResponse, DhtError>();
59 let thread_handle =
60 DhtProcessor::spawn(&keypair, peer_id, validator, config.clone(), channels.1)?;
61
62 Ok(DhtNode {
63 peer_id,
64 config,
65 client: channels.0,
66 thread_handle,
67 })
68 }
69
70 pub fn config(&self) -> &DhtConfig {
73 &self.config
74 }
75
76 pub fn peer_id(&self) -> &PeerId {
78 &self.peer_id
79 }
80
81 pub async fn addresses(&self) -> Result<Vec<Multiaddr>, DhtError> {
83 let request = DhtRequest::GetAddresses { external: false };
84 let response = self.send_request(request).await?;
85 ensure_response!(response, DhtResponse::GetAddresses(addresses) => Ok(addresses))
86 }
87
88 pub async fn external_addresses(&self) -> Result<Vec<Multiaddr>, DhtError> {
90 let request = DhtRequest::GetAddresses { external: false };
91 let response = self.send_request(request).await?;
92 ensure_response!(response, DhtResponse::GetAddresses(addresses) => Ok(addresses))
93 }
94
95 pub async fn add_peers(&self, peers: Vec<Multiaddr>) -> Result<(), DhtError> {
98 let request = DhtRequest::AddPeers { peers };
99 let response = self.send_request(request).await?;
100 ensure_response!(response, DhtResponse::Success => Ok(()))
101 }
102
103 pub async fn listen(&self, listening_address: Multiaddr) -> Result<Multiaddr, DhtError> {
106 let request = DhtRequest::StartListening {
107 address: listening_address,
108 };
109 let response = self.send_request(request).await?;
110 ensure_response!(response, DhtResponse::Address(addr) => Ok(addr))
111 }
112
113 pub async fn stop_listening(&self) -> Result<(), DhtError> {
115 let request = DhtRequest::StopListening;
116 let response = self.send_request(request).await?;
117 ensure_response!(response, DhtResponse::Success => Ok(()))
118 }
119
120 pub async fn wait_for_peers(&self, requested_peers: usize) -> Result<(), DhtError> {
123 loop {
127 let info = self.network_info().await?;
128 if info.num_peers >= requested_peers {
129 return Ok(());
130 }
131 tokio::time::sleep(Duration::from_secs(1)).await;
132 }
133 }
134
135 pub async fn bootstrap(&self) -> Result<(), DhtError> {
142 let request = DhtRequest::Bootstrap;
143 let response = self.send_request(request).await?;
144 ensure_response!(response, DhtResponse::Success => Ok(()))
145 }
146
147 pub async fn network_info(&self) -> Result<NetworkInfo, DhtError> {
150 let request = DhtRequest::GetNetworkInfo;
151 let response = self.send_request(request).await?;
152 ensure_response!(response, DhtResponse::GetNetworkInfo(info) => Ok(info))
153 }
154
155 pub async fn peers(&self) -> Result<Vec<Peer>, DhtError> {
158 let request = DhtRequest::GetPeers;
159 let response = self.send_request(request).await?;
160 ensure_response!(response, DhtResponse::GetPeers(peers) => Ok(peers))
161 }
162
163 pub async fn put_record(
168 &self,
169 key: &[u8],
170 value: &[u8],
171 quorum: usize,
172 ) -> Result<Vec<u8>, DhtError> {
173 let request = DhtRequest::PutRecord {
174 key: key.to_vec(),
175 value: value.to_vec(),
176 quorum,
177 };
178 let response = self.send_request(request).await?;
179 ensure_response!(response, DhtResponse::PutRecord { key } => Ok(key))
180 }
181
182 pub async fn get_record(&self, key: &[u8]) -> Result<DhtRecord, DhtError> {
187 let request = DhtRequest::GetRecord { key: key.to_vec() };
188 let response = self.send_request(request).await?;
189 ensure_response!(response, DhtResponse::GetRecord(record) => Ok(record))
190 }
191
192 pub async fn start_providing(&self, key: &[u8]) -> Result<(), DhtError> {
196 let request = DhtRequest::StartProviding { key: key.to_vec() };
197 let response = self.send_request(request).await?;
198 ensure_response!(response, DhtResponse::Success => Ok(()))
199 }
200
201 pub async fn get_providers(&self, key: &[u8]) -> Result<Vec<PeerId>, DhtError> {
204 let request = DhtRequest::GetProviders { key: key.to_vec() };
205 let response = self.send_request(request).await?;
206 ensure_response!(response, DhtResponse::GetProviders { providers } => Ok(providers))
207 }
208
209 async fn send_request(&self, request: DhtRequest) -> Result<DhtResponse, DhtError> {
210 self.client
211 .send(request)
212 .await
213 .map_err(DhtError::from)
214 .and_then(|res| res)
215 }
216}
217
218impl Drop for DhtNode {
219 fn drop(&mut self) {
220 self.thread_handle.abort();
221 }
222}
223
224#[cfg(not(target_arch = "wasm32"))]
225#[cfg(test)]
226mod test {
227 use super::*;
228 use std::fmt::Display;
229
230 use crate::dht::{AllowAllValidator, DhtError, DhtNode, NetworkInfo, Validator};
231 use async_trait::async_trait;
232
233 use crate::utils::make_p2p_address;
234 use futures::future::try_join_all;
235 use libp2p::{self, Multiaddr};
236 use std::future::Future;
237 use std::time::Duration;
238
239 const NETWORK_INIT_TIMEOUT_MS: u64 = 10000;
240
241 pub async fn wait_ms(ms: u64) {
242 tokio::time::sleep(Duration::from_millis(ms)).await;
243 }
244
245 async fn await_or_timeout<T>(
246 timeout_ms: u64,
247 future: impl Future<Output = T>,
248 message: String,
249 ) -> T {
250 tokio::select! {
251 _ = wait_ms(timeout_ms) => { panic!("timed out: {}", message); }
252 result = future => { result }
253 }
254 }
255
256 pub async fn swarm_command<'a, TFuture, F, T, E>(
257 nodes: &'a mut [DhtNode],
258 func: F,
259 ) -> Result<Vec<T>, E>
260 where
261 F: FnMut(&'a mut DhtNode) -> TFuture,
262 TFuture: Future<Output = Result<T, E>>,
263 {
264 let futures: Vec<_> = nodes.iter_mut().map(func).collect();
265 try_join_all(futures).await
266 }
267
268 async fn create_network<V: Validator + Clone + 'static>(
271 node_count: usize,
272 validator: Option<V>,
273 ) -> Result<Vec<DhtNode>, anyhow::Error> {
274 let mut bootstrap_addresses: Option<Vec<Multiaddr>> = None;
275 let mut nodes = vec![];
276 for _ in 0..node_count {
277 let node = DhtNode::new(
278 Keypair::generate_ed25519(),
279 Default::default(),
280 validator.clone(),
281 )?;
282
283 if let Some(addresses) = bootstrap_addresses.as_ref() {
284 node.add_peers(addresses.to_owned()).await?;
288 node.listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()).await?;
289 } else {
290 let address = node.listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()).await?;
291 bootstrap_addresses = Some(vec![address]);
292 }
293 nodes.push(node);
294 }
295 Ok(nodes)
296 }
297
298 async fn initialize_network(nodes: &mut Vec<DhtNode>) -> Result<(), anyhow::Error> {
299 let expected_peers = nodes.len() - 1;
300 wait_ms(700).await;
306 swarm_command(nodes, |c| c.bootstrap()).await?;
307
308 await_or_timeout(
310 NETWORK_INIT_TIMEOUT_MS,
311 swarm_command(nodes, |c| c.wait_for_peers(expected_peers)),
312 format!("waiting for {} peers", expected_peers),
313 )
314 .await?;
315 Ok(())
316 }
317
318 fn create_unfiltered_dht_node() -> Result<DhtNode, DhtError> {
319 DhtNode::new::<AllowAllValidator>(
320 Keypair::generate_ed25519(),
321 Default::default(),
322 Some(AllowAllValidator {}),
323 )
324 }
325
326 #[tokio::test]
328 async fn test_dhtnode_base_case() -> Result<(), DhtError> {
329 let node = create_unfiltered_dht_node()?;
330 node.listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()).await?;
331 let info = node.network_info().await?;
332 assert_eq!(
333 info,
334 NetworkInfo {
335 num_connections: 0,
336 num_established: 0,
337 num_peers: 0,
338 num_pending: 0,
339 }
340 );
341
342 if node.bootstrap().await.is_err() {
343 panic!("bootstrap() should succeed, even without peers to bootstrap.");
344 }
345 Ok(())
346 }
347
348 #[tokio::test]
351 async fn test_dhtnode_bootstrap() -> Result<(), DhtError> {
352 let num_nodes = 5;
353 let mut nodes = create_network(num_nodes, Some(AllowAllValidator {})).await?;
354 initialize_network(&mut nodes).await?;
355
356 for info in swarm_command(&mut nodes, |c| c.network_info()).await? {
357 assert_eq!(info.num_peers, num_nodes - 1);
358 assert_eq!(info.num_pending, 0);
362 }
363
364 let info = nodes.first().unwrap().network_info().await?;
365 assert_eq!(info.num_peers, num_nodes - 1);
366 assert_eq!(info.num_pending, 0);
370
371 Ok(())
372 }
373
374 #[tokio::test]
376 async fn test_dhtnode_simple() -> Result<(), DhtError> {
377 let mut nodes = create_network(2, Some(AllowAllValidator {})).await?;
378 initialize_network(&mut nodes).await?;
379 let (node_a, node_b) = (nodes.pop().unwrap(), nodes.pop().unwrap());
380
381 node_a.put_record(b"foo", b"bar", 1).await?;
382 let result = node_b.get_record(b"foo").await?;
383 assert_eq!(result.key, b"foo");
384 assert_eq!(result.value.expect("has value"), b"bar");
385 Ok(())
386 }
387
388 #[tokio::test]
390 async fn test_dhtnode_providers() -> Result<(), DhtError> {
391 let mut nodes = create_network(2, Some(AllowAllValidator {})).await?;
392 initialize_network(&mut nodes).await?;
393 let (node_a, node_b) = (nodes.pop().unwrap(), nodes.pop().unwrap());
394
395 node_a.start_providing(b"foo").await?;
396
397 let providers = node_b.get_providers(b"foo").await?;
398 assert_eq!(providers.len(), 1);
399 assert_eq!(&providers[0], node_a.peer_id());
400 Ok(())
401 }
402
403 #[tokio::test]
404 async fn test_dhtnode_validator() -> Result<(), DhtError> {
405 #[derive(Clone)]
406 struct MyValidator {}
407
408 #[async_trait]
409 impl Validator for MyValidator {
410 async fn validate(&mut self, data: &[u8]) -> bool {
411 data == b"VALID"
412 }
413 }
414
415 impl Display for MyValidator {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 write!(f, "MyValidator")
418 }
419 }
420
421 let mut nodes = create_network(2, Some(MyValidator {})).await?;
422 initialize_network(&mut nodes).await?;
423 let (node_a, node_b) = (nodes.pop().unwrap(), nodes.pop().unwrap());
424 let unfiltered_client = create_unfiltered_dht_node()?;
425 unfiltered_client
426 .add_peers(vec![make_p2p_address(
427 node_a.addresses().await?.pop().unwrap(),
428 node_a.peer_id().to_owned(),
429 )])
430 .await?;
431
432 node_a.put_record(b"foo_1", b"VALID", 1).await?;
433 let result = node_b.get_record(b"foo_1").await?;
434 assert_eq!(
435 result.value.expect("has value"),
436 b"VALID",
437 "validation allows valid records through"
438 );
439
440 assert!(
441 node_a.put_record(b"foo_2", b"INVALID", 1).await.is_err(),
442 "setting a record validates locally"
443 );
444
445 unfiltered_client.put_record(b"foo_3", b"VALID", 1).await?;
447 unfiltered_client
448 .put_record(b"foo_4", b"INVALID", 1)
449 .await?;
450
451 let result = node_b.get_record(b"foo_3").await?;
452 assert_eq!(
453 result.value.expect("has value"),
454 b"VALID",
455 "validation allows valid records through"
456 );
457
458 assert!(
459 node_b.get_record(b"foo_4").await?.value.is_none(),
460 "invalid records are not retrieved from the network"
461 );
462
463 Ok(())
464 }
465}