iroh_net/
test_utils.rs

1//! Internal utilities to support testing.
2use std::net::Ipv4Addr;
3
4use anyhow::Result;
5pub use dns_and_pkarr_servers::DnsPkarrServer;
6pub use dns_server::create_dns_resolver;
7use tokio::sync::oneshot;
8
9use crate::{
10    defaults::DEFAULT_STUN_PORT,
11    relay::{
12        server::{CertConfig, RelayConfig, Server, ServerConfig, StunConfig, TlsConfig},
13        RelayMap, RelayNode, RelayUrl,
14    },
15};
16
17/// A drop guard to clean up test infrastructure.
18///
19/// After dropping the test infrastructure will asynchronously shutdown and release its
20/// resources.
21// Nightly sees the sender as dead code currently, but we only rely on Drop of the
22// sender.
23#[derive(Debug)]
24#[allow(dead_code)]
25pub struct CleanupDropGuard(pub(crate) oneshot::Sender<()>);
26
27/// Runs a relay server with STUN enabled suitable for tests.
28///
29/// The returned `Url` is the url of the relay server in the returned [`RelayMap`].
30/// When dropped, the returned [`Server`] does will stop running.
31pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, Server)> {
32    run_relay_server_with(Some(StunConfig {
33        bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
34    }))
35    .await
36}
37
38/// Runs a relay server.
39///
40/// `stun` can be set to `None` to disable stun, or set to `Some` `StunConfig`,
41/// to enable stun on a specific socket.
42///
43/// The return value is similar to [`run_relay_server`].
44pub async fn run_relay_server_with(
45    stun: Option<StunConfig>,
46) -> Result<(RelayMap, RelayUrl, Server)> {
47    let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
48    let rustls_cert = rustls::pki_types::CertificateDer::from(cert.serialize_der().unwrap());
49    let private_key =
50        rustls::pki_types::PrivatePkcs8KeyDer::from(cert.get_key_pair().serialize_der());
51    let private_key = rustls::pki_types::PrivateKeyDer::from(private_key);
52
53    let config = ServerConfig {
54        relay: Some(RelayConfig {
55            http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
56            tls: Some(TlsConfig {
57                cert: CertConfig::<(), ()>::Manual {
58                    private_key,
59                    certs: vec![rustls_cert],
60                },
61                https_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
62            }),
63            limits: Default::default(),
64        }),
65        stun,
66        #[cfg(feature = "metrics")]
67        metrics_addr: None,
68    };
69    let server = Server::spawn(config).await.unwrap();
70    let url: RelayUrl = format!("https://localhost:{}", server.https_addr().unwrap().port())
71        .parse()
72        .unwrap();
73    let m = RelayMap::from_nodes([RelayNode {
74        url: url.clone(),
75        stun_only: false,
76        stun_port: server.stun_addr().map_or(DEFAULT_STUN_PORT, |s| s.port()),
77    }])
78    .unwrap();
79    Ok((m, url, server))
80}
81
82pub(crate) mod dns_and_pkarr_servers {
83    use std::{net::SocketAddr, time::Duration};
84
85    use anyhow::Result;
86    use iroh_base::key::{NodeId, SecretKey};
87    use url::Url;
88
89    use super::{create_dns_resolver, CleanupDropGuard};
90    use crate::{
91        discovery::{dns::DnsDiscovery, pkarr::PkarrPublisher, ConcurrentDiscovery},
92        dns::DnsResolver,
93        test_utils::{
94            dns_server::run_dns_server, pkarr_dns_state::State, pkarr_relay::run_pkarr_relay,
95        },
96    };
97
98    /// Handle and drop guard for test DNS and Pkarr servers.
99    ///
100    /// Once the struct is dropped the servers will shut down.
101    #[derive(Debug)]
102    pub struct DnsPkarrServer {
103        /// The node origin domain.
104        pub node_origin: String,
105        /// The shared state of the DNS and Pkarr servers.
106        state: State,
107        /// The socket address of the DNS server.
108        pub nameserver: SocketAddr,
109        /// The HTTP URL of the Pkarr server.
110        pub pkarr_url: Url,
111        _dns_drop_guard: CleanupDropGuard,
112        _pkarr_drop_guard: CleanupDropGuard,
113    }
114
115    impl DnsPkarrServer {
116        /// Run DNS and Pkarr servers on localhost.
117        pub async fn run() -> anyhow::Result<Self> {
118            Self::run_with_origin("dns.iroh.test".to_string()).await
119        }
120
121        /// Run DNS and Pkarr servers on localhost with the specified `node_origin` domain.
122        pub async fn run_with_origin(node_origin: String) -> anyhow::Result<Self> {
123            let state = State::new(node_origin.clone());
124            let (nameserver, dns_drop_guard) = run_dns_server(state.clone()).await?;
125            let (pkarr_url, pkarr_drop_guard) = run_pkarr_relay(state.clone()).await?;
126            Ok(Self {
127                node_origin,
128                nameserver,
129                pkarr_url,
130                state,
131                _dns_drop_guard: dns_drop_guard,
132                _pkarr_drop_guard: pkarr_drop_guard,
133            })
134        }
135
136        /// Create a [`ConcurrentDiscovery`] with [`DnsDiscovery`] and [`PkarrPublisher`]
137        /// configured to use the test servers.
138        pub fn discovery(&self, secret_key: SecretKey) -> Box<ConcurrentDiscovery> {
139            Box::new(ConcurrentDiscovery::from_services(vec![
140                // Enable DNS discovery by default
141                Box::new(DnsDiscovery::new(self.node_origin.clone())),
142                // Enable pkarr publishing by default
143                Box::new(PkarrPublisher::new(secret_key, self.pkarr_url.clone())),
144            ]))
145        }
146
147        /// Create a [`DnsResolver`] configured to use the test DNS server.
148        pub fn dns_resolver(&self) -> DnsResolver {
149            create_dns_resolver(self.nameserver).expect("failed to create DNS resolver")
150        }
151
152        /// Wait until a Pkarr announce for a node is published to the server.
153        ///
154        /// If `timeout` elapses an error is returned.
155        pub async fn on_node(&self, node_id: &NodeId, timeout: Duration) -> Result<()> {
156            self.state.on_node(node_id, timeout).await
157        }
158    }
159}
160
161pub(crate) mod dns_server {
162    use std::{
163        future::Future,
164        net::{Ipv4Addr, SocketAddr},
165    };
166
167    use anyhow::{ensure, Result};
168    use futures_lite::future::Boxed as BoxFuture;
169    use hickory_proto::{
170        op::{header::MessageType, Message},
171        serialize::binary::BinDecodable,
172    };
173    use hickory_resolver::{config::NameServerConfig, TokioAsyncResolver};
174    use tokio::{net::UdpSocket, sync::oneshot};
175    use tracing::{debug, error, warn};
176
177    use super::CleanupDropGuard;
178
179    /// Trait used by [`run_dns_server`] for answering DNS queries.
180    pub trait QueryHandler: Send + Sync + 'static {
181        fn resolve(
182            &self,
183            query: &Message,
184            reply: &mut Message,
185        ) -> impl Future<Output = Result<()>> + Send;
186    }
187
188    pub type QueryHandlerFunction =
189        Box<dyn Fn(&Message, &mut Message) -> BoxFuture<Result<()>> + Send + Sync + 'static>;
190
191    impl QueryHandler for QueryHandlerFunction {
192        fn resolve(
193            &self,
194            query: &Message,
195            reply: &mut Message,
196        ) -> impl Future<Output = Result<()>> + Send {
197            (self)(query, reply)
198        }
199    }
200
201    /// Run a DNS server.
202    ///
203    /// Must pass a [`QueryHandler`] that answers queries.
204    pub async fn run_dns_server(
205        resolver: impl QueryHandler,
206    ) -> Result<(SocketAddr, CleanupDropGuard)> {
207        let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
208        let socket = UdpSocket::bind(bind_addr).await?;
209        let bound_addr = socket.local_addr()?;
210        let s = TestDnsServer { socket, resolver };
211        let (tx, mut rx) = oneshot::channel();
212        tokio::task::spawn(async move {
213            tokio::select! {
214                _ = &mut rx => {
215                    debug!("shutting down dns server");
216                }
217                res = s.run() => {
218                    if let Err(e) = res {
219                        error!("error running dns server {e:?}");
220                    }
221                }
222            }
223        });
224        Ok((bound_addr, CleanupDropGuard(tx)))
225    }
226
227    /// Create a DNS resolver with a single nameserver.
228    pub fn create_dns_resolver(nameserver: SocketAddr) -> Result<TokioAsyncResolver> {
229        let mut config = hickory_resolver::config::ResolverConfig::new();
230        let nameserver_config =
231            NameServerConfig::new(nameserver, hickory_resolver::config::Protocol::Udp);
232        config.add_name_server(nameserver_config);
233        let resolver = hickory_resolver::AsyncResolver::tokio(config, Default::default());
234        Ok(resolver)
235    }
236
237    struct TestDnsServer<R> {
238        resolver: R,
239        socket: UdpSocket,
240    }
241
242    impl<R: QueryHandler> TestDnsServer<R> {
243        async fn run(self) -> Result<()> {
244            let mut buf = [0; 1450];
245            loop {
246                let res = self.socket.recv_from(&mut buf).await;
247                let (len, from) = res?;
248                if let Err(err) = self.handle_datagram(from, &buf[..len]).await {
249                    warn!(?err, %from, "failed to handle incoming datagram");
250                }
251            }
252        }
253
254        async fn handle_datagram(&self, from: SocketAddr, buf: &[u8]) -> Result<()> {
255            let packet = Message::from_bytes(buf)?;
256            debug!(queries = ?packet.queries(), %from, "received query");
257            let mut reply = packet.clone();
258            reply.set_message_type(MessageType::Response);
259            self.resolver.resolve(&packet, &mut reply).await?;
260            debug!(?reply, %from, "send reply");
261            let buf = reply.to_vec()?;
262            let len = self.socket.send_to(&buf, from).await?;
263            ensure!(len == buf.len(), "failed to send complete packet");
264            Ok(())
265        }
266    }
267}
268
269pub(crate) mod pkarr_relay {
270    use std::{
271        future::IntoFuture,
272        net::{Ipv4Addr, SocketAddr},
273    };
274
275    use anyhow::Result;
276    use axum::{
277        extract::{Path, State},
278        response::IntoResponse,
279        routing::put,
280        Router,
281    };
282    use bytes::Bytes;
283    use tokio::sync::oneshot;
284    use tracing::{debug, error, warn};
285    use url::Url;
286
287    use super::CleanupDropGuard;
288    use crate::test_utils::pkarr_dns_state::State as AppState;
289
290    pub async fn run_pkarr_relay(state: AppState) -> Result<(Url, CleanupDropGuard)> {
291        let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0));
292        let app = Router::new()
293            .route("/pkarr/:key", put(pkarr_put))
294            .with_state(state);
295        let listener = tokio::net::TcpListener::bind(bind_addr).await?;
296        let bound_addr = listener.local_addr()?;
297        let url: Url = format!("http://{bound_addr}/pkarr")
298            .parse()
299            .expect("valid url");
300
301        let (tx, mut rx) = oneshot::channel();
302        tokio::spawn(async move {
303            let serve = axum::serve(listener, app);
304            tokio::select! {
305                _ = &mut rx => {
306                    debug!("shutting down pkarr server");
307                }
308                res = serve.into_future() => {
309                    if let Err(e) = res {
310                        error!("pkarr server error: {e:?}");
311                    }
312                }
313            }
314        });
315        Ok((url, CleanupDropGuard(tx)))
316    }
317
318    async fn pkarr_put(
319        State(state): State<AppState>,
320        Path(key): Path<String>,
321        body: Bytes,
322    ) -> Result<impl IntoResponse, AppError> {
323        let key = pkarr::PublicKey::try_from(key.as_str())?;
324        let signed_packet = pkarr::SignedPacket::from_relay_payload(&key, &body)?;
325        let _updated = state.upsert(signed_packet)?;
326        Ok(http::StatusCode::NO_CONTENT)
327    }
328
329    #[derive(Debug)]
330    struct AppError(anyhow::Error);
331    impl<T: Into<anyhow::Error>> From<T> for AppError {
332        fn from(value: T) -> Self {
333            Self(value.into())
334        }
335    }
336    impl IntoResponse for AppError {
337        fn into_response(self) -> axum::response::Response {
338            warn!(err = ?self, "request failed");
339            (http::StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
340        }
341    }
342}
343
344pub(crate) mod pkarr_dns_state {
345    use std::{
346        collections::{hash_map, HashMap},
347        future::Future,
348        ops::Deref,
349        sync::Arc,
350        time::Duration,
351    };
352
353    use anyhow::{bail, Result};
354    use parking_lot::{Mutex, MutexGuard};
355    use pkarr::SignedPacket;
356
357    use crate::{
358        dns::node_info::{node_id_from_hickory_name, NodeInfo},
359        test_utils::dns_server::QueryHandler,
360        NodeId,
361    };
362
363    #[derive(Debug, Clone)]
364    pub struct State {
365        packets: Arc<Mutex<HashMap<NodeId, SignedPacket>>>,
366        origin: String,
367        notify: Arc<tokio::sync::Notify>,
368    }
369
370    impl State {
371        pub fn new(origin: String) -> Self {
372            Self {
373                packets: Default::default(),
374                origin,
375                notify: Arc::new(tokio::sync::Notify::new()),
376            }
377        }
378
379        pub fn on_update(&self) -> tokio::sync::futures::Notified<'_> {
380            self.notify.notified()
381        }
382
383        pub async fn on_node(&self, node: &NodeId, timeout: Duration) -> Result<()> {
384            let timeout = tokio::time::sleep(timeout);
385            tokio::pin!(timeout);
386            while self.get(node).is_none() {
387                tokio::select! {
388                    _ = &mut timeout => bail!("timeout"),
389                    _ = self.on_update() => {}
390                }
391            }
392            Ok(())
393        }
394
395        pub fn upsert(&self, signed_packet: SignedPacket) -> anyhow::Result<bool> {
396            let node_id = NodeId::from_bytes(&signed_packet.public_key().to_bytes())?;
397            let mut map = self.packets.lock();
398            let updated = match map.entry(node_id) {
399                hash_map::Entry::Vacant(e) => {
400                    e.insert(signed_packet);
401                    true
402                }
403                hash_map::Entry::Occupied(mut e) => {
404                    if signed_packet.more_recent_than(e.get()) {
405                        e.insert(signed_packet);
406                        true
407                    } else {
408                        false
409                    }
410                }
411            };
412            if updated {
413                self.notify.notify_waiters();
414            }
415            Ok(updated)
416        }
417
418        /// Returns a mutex guard, do not hold over await points
419        pub fn get(&self, node_id: &NodeId) -> Option<impl Deref<Target = SignedPacket> + '_> {
420            let map = self.packets.lock();
421            if map.contains_key(node_id) {
422                let guard = MutexGuard::map(map, |state| state.get_mut(node_id).unwrap());
423                Some(guard)
424            } else {
425                None
426            }
427        }
428
429        pub fn resolve_dns(
430            &self,
431            query: &hickory_proto::op::Message,
432            reply: &mut hickory_proto::op::Message,
433            ttl: u32,
434        ) -> Result<()> {
435            for query in query.queries() {
436                let Some(node_id) = node_id_from_hickory_name(query.name()) else {
437                    continue;
438                };
439                let packet = self.get(&node_id);
440                let Some(packet) = packet.as_ref() else {
441                    continue;
442                };
443                let node_info = NodeInfo::from_pkarr_signed_packet(packet)?;
444                for record in node_info.to_hickory_records(&self.origin, ttl)? {
445                    reply.add_answer(record);
446                }
447            }
448            Ok(())
449        }
450    }
451
452    impl QueryHandler for State {
453        fn resolve(
454            &self,
455            query: &hickory_proto::op::Message,
456            reply: &mut hickory_proto::op::Message,
457        ) -> impl Future<Output = Result<()>> + Send {
458            const TTL: u32 = 30;
459            let res = self.resolve_dns(query, reply, TTL);
460            std::future::ready(res)
461        }
462    }
463}