pkarr/extra/endpoints/
mod.rs

1//! implementation of [Endpoints](https://pkarr.org/endpoints) spec.
2//!
3
4mod endpoint;
5
6pub use endpoint::Endpoint;
7
8use futures_lite::{pin, Stream, StreamExt};
9use genawaiter::sync::Gen;
10
11use crate::PublicKey;
12
13impl crate::Client {
14    /// Returns an async stream of [HTTPS][crate::dns::rdata::RData::HTTPS] [Endpoint]s
15    pub fn resolve_https_endpoints<'a>(
16        &'a self,
17        qname: &'a str,
18    ) -> impl Stream<Item = Endpoint> + 'a {
19        self.resolve_endpoints(qname, true)
20    }
21
22    /// Returns an async stream of [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint]s
23    pub fn resolve_svcb_endpoints<'a>(
24        &'a self,
25        qname: &'a str,
26    ) -> impl Stream<Item = Endpoint> + 'a {
27        self.resolve_endpoints(qname, false)
28    }
29
30    /// Helper method that returns the first [HTTPS][crate::dns::rdata::RData::HTTPS] [Endpoint] in the Async stream from [Self::resolve_https_endpoints]
31    pub async fn resolve_https_endpoint(
32        &self,
33        qname: &str,
34    ) -> Result<Endpoint, CouldNotResolveEndpoint> {
35        let stream = self.resolve_https_endpoints(qname);
36
37        pin!(stream);
38
39        match stream.next().await {
40            Some(endpoint) => Ok(endpoint),
41            None => {
42                #[cfg(not(target_arch = "wasm32"))]
43                tracing::trace!(?qname, "failed to resolve endpoint");
44                #[cfg(target_arch = "wasm32")]
45                log::trace!("failed to resolve endpoint {qname}");
46
47                Err(CouldNotResolveEndpoint)
48            }
49        }
50    }
51
52    /// Helper method that returns the first [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint] in the Async stream from [Self::resolve_svcb_endpoints]
53    pub async fn resolve_svcb_endpoint(
54        &self,
55        qname: &str,
56    ) -> Result<Endpoint, CouldNotResolveEndpoint> {
57        let stream = self.resolve_https_endpoints(qname);
58
59        pin!(stream);
60
61        match stream.next().await {
62            Some(endpoint) => Ok(endpoint),
63            None => Err(CouldNotResolveEndpoint),
64        }
65    }
66
67    /// Returns an async stream of either [HTTPS][crate::dns::rdata::RData::HTTPS] or [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint]s
68    pub fn resolve_endpoints<'a>(
69        &'a self,
70        qname: &'a str,
71        https: bool,
72    ) -> impl Stream<Item = Endpoint> + 'a {
73        Gen::new(|co| async move {
74            let mut depth = 0;
75            let mut stack: Vec<Endpoint> = Vec::new();
76
77            // Initialize the stack with endpoints from the starting domain.
78            if let Ok(tld) = PublicKey::try_from(qname) {
79                if let Some(signed_packet) = self.resolve(&tld).await {
80                    depth += 1;
81                    stack.extend(Endpoint::parse(&signed_packet, qname, https));
82                }
83            }
84
85            while let Some(next) = stack.pop() {
86                let current = next.target();
87
88                // Attempt to resolve the domain as a public key.
89                match PublicKey::try_from(current) {
90                    Ok(tld) => match self.resolve(&tld).await {
91                        Some(signed_packet) if depth < self.0.max_recursion_depth => {
92                            depth += 1;
93                            let endpoints = Endpoint::parse(&signed_packet, current, https);
94
95                            #[cfg(not(target_arch = "wasm32"))]
96                            tracing::trace!(?qname, ?depth, ?endpoints, "resolved endpoints");
97                            #[cfg(target_arch = "wasm32")]
98                            log::trace!("resolved endpoints qname: {qname}, depth: {depth}, endpoints: {:?}", endpoints);
99
100                            stack.extend(endpoints);
101                        }
102                        _ => break, // Stop on resolution failure or recursion depth exceeded.
103                    },
104                    // Yield if the domain is not pointing to another Pkarr TLD domain.
105                    Err(_) => co.yield_(next).await,
106                }
107            }
108        })
109    }
110}
111
112#[derive(Debug)]
113/// pkarr could not resolve endpoint
114pub struct CouldNotResolveEndpoint;
115
116impl std::error::Error for CouldNotResolveEndpoint {}
117
118impl std::fmt::Display for CouldNotResolveEndpoint {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "pkarr could not resolve endpoint")
121    }
122}
123
124#[cfg(all(test, not(target_arch = "wasm32")))]
125mod tests {
126
127    use crate::dns::rdata::SVCB;
128    use crate::{Client, Keypair};
129    use crate::{PublicKey, SignedPacket};
130
131    use std::future::Future;
132    use std::net::IpAddr;
133    use std::pin::Pin;
134    use std::str::FromStr;
135    use std::time::Duration;
136
137    use mainline::Testnet;
138
139    fn generate_subtree(
140        client: Client,
141        depth: u8,
142        branching: u8,
143        domain: Option<String>,
144        ips: Vec<IpAddr>,
145        port: Option<u16>,
146    ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
147        Box::pin(async move {
148            let keypair = Keypair::random();
149
150            let mut builder = SignedPacket::builder();
151
152            for _ in 0..branching {
153                let mut svcb = SVCB::new(0, ".".try_into().unwrap());
154
155                if depth == 0 {
156                    svcb.priority = 1;
157
158                    if let Some(port) = port {
159                        svcb.set_port(port);
160                    }
161
162                    if let Some(target) = &domain {
163                        let target: &'static str = Box::leak(target.clone().into_boxed_str());
164                        svcb.target = target.try_into().unwrap()
165                    }
166
167                    for ip in ips.clone() {
168                        builder = builder.address(".".try_into().unwrap(), ip, 3600);
169                    }
170                } else {
171                    let target = generate_subtree(
172                        client.clone(),
173                        depth - 1,
174                        branching,
175                        domain.clone(),
176                        ips.clone(),
177                        port,
178                    )
179                    .await
180                    .to_string();
181                    let target: &'static str = Box::leak(target.into_boxed_str());
182                    svcb.target = target.try_into().unwrap();
183                };
184
185                builder = builder.https(".".try_into().unwrap(), svcb, 3600);
186            }
187
188            let signed_packet = builder.sign(&keypair).unwrap();
189
190            client.publish(&signed_packet, None).await.unwrap();
191
192            keypair.public_key()
193        })
194    }
195
196    /// depth of (3): A -> B -> C
197    /// branch of (2): A -> B0,  A ->  B1
198    /// domain, ips, and port are all at the end (C, or B1)
199    fn generate(
200        client: &Client,
201        depth: u8,
202        branching: u8,
203        domain: Option<String>,
204        ips: Vec<IpAddr>,
205        port: Option<u16>,
206    ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
207        generate_subtree(client.clone(), depth - 1, branching, domain, ips, port)
208    }
209
210    #[tokio::test]
211    async fn direct_endpoint_resolution() {
212        let testnet = Testnet::new(3).unwrap();
213        let client = Client::builder()
214            .no_default_network()
215            .bootstrap(&testnet.bootstrap)
216            .build()
217            .unwrap();
218
219        let tld = generate(&client, 1, 1, Some("example.com".to_string()), vec![], None).await;
220
221        let endpoint = client
222            .resolve_https_endpoint(&tld.to_string())
223            .await
224            .unwrap();
225
226        assert_eq!(endpoint.domain(), Some("example.com"));
227        assert_eq!(endpoint.public_key(), &tld);
228    }
229
230    #[tokio::test]
231    async fn resolve_endpoints() {
232        let testnet = Testnet::new(3).unwrap();
233        let client = Client::builder()
234            .no_default_network()
235            .bootstrap(&testnet.bootstrap)
236            .build()
237            .unwrap();
238
239        let tld = generate(&client, 3, 3, Some("example.com".to_string()), vec![], None).await;
240
241        let endpoint = client
242            .resolve_https_endpoint(&tld.to_string())
243            .await
244            .unwrap();
245
246        assert_eq!(endpoint.domain(), Some("example.com"));
247    }
248
249    #[tokio::test]
250    async fn empty() {
251        let testnet = Testnet::new(3).unwrap();
252        let client = Client::builder()
253            .no_default_network()
254            .bootstrap(&testnet.bootstrap)
255            .request_timeout(Duration::from_millis(20))
256            .build()
257            .unwrap();
258
259        let pubky = Keypair::random().public_key();
260
261        let endpoint = client.resolve_https_endpoint(&pubky.to_string()).await;
262
263        assert!(endpoint.is_err());
264    }
265
266    #[tokio::test]
267    async fn max_recursion_exceeded() {
268        let testnet = Testnet::new(3).unwrap();
269        let client = Client::builder()
270            .no_default_network()
271            .bootstrap(&testnet.bootstrap)
272            .max_recursion_depth(3)
273            .build()
274            .unwrap();
275
276        let tld = generate(&client, 4, 3, Some("example.com".to_string()), vec![], None).await;
277
278        let endpoint = client.resolve_https_endpoint(&tld.to_string()).await;
279
280        assert!(endpoint.is_err());
281    }
282
283    #[tokio::test]
284    async fn resolve_addresses() {
285        let testnet = Testnet::new(3).unwrap();
286        let client = Client::builder()
287            .no_default_network()
288            .bootstrap(&testnet.bootstrap)
289            .build()
290            .unwrap();
291
292        let tld = generate(
293            &client,
294            3,
295            3,
296            None,
297            vec![IpAddr::from_str("0.0.0.10").unwrap()],
298            Some(3000),
299        )
300        .await;
301
302        let endpoint = client
303            .resolve_https_endpoint(&tld.to_string())
304            .await
305            .unwrap();
306
307        assert_eq!(endpoint.target(), ".");
308        assert_eq!(endpoint.domain(), None);
309        assert_eq!(
310            endpoint
311                .to_socket_addrs()
312                .into_iter()
313                .map(|s| s.to_string())
314                .collect::<Vec<String>>(),
315            vec!["0.0.0.10:3000"]
316        );
317    }
318}