Skip to main content

pkarr_client/extra/endpoints/
mod.rs

1//! implementation of [Endpoints](https://pkarr.org/endpoints) spec.
2//!
3
4mod endpoint;
5
6pub use endpoint::Endpoint;
7
8use futures_lite::{pin, Stream, StreamExt};
9use genawaiter::sync::Gen;
10
11use crate::PublicKey;
12
13impl crate::Client {
14    /// Returns an async stream of [HTTPS][crate::dns::rdata::RData::HTTPS] [Endpoint]s
15    pub fn resolve_https_endpoints<'a>(
16        &'a self,
17        qname: &'a str,
18    ) -> impl Stream<Item = Endpoint> + 'a {
19        self.resolve_endpoints(qname, true)
20    }
21
22    /// Returns an async stream of [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint]s
23    pub fn resolve_svcb_endpoints<'a>(
24        &'a self,
25        qname: &'a str,
26    ) -> impl Stream<Item = Endpoint> + 'a {
27        self.resolve_endpoints(qname, false)
28    }
29
30    /// Helper method that returns the first [HTTPS][crate::dns::rdata::RData::HTTPS] [Endpoint] in the Async stream from [Self::resolve_https_endpoints]
31    pub async fn resolve_https_endpoint(
32        &self,
33        qname: &str,
34    ) -> Result<Endpoint, CouldNotResolveEndpoint> {
35        let stream = self.resolve_https_endpoints(qname);
36
37        pin!(stream);
38
39        match stream.next().await {
40            Some(endpoint) => Ok(endpoint),
41            None => {
42                #[cfg(not(target_arch = "wasm32"))]
43                tracing::trace!(?qname, "failed to resolve endpoint");
44                #[cfg(target_arch = "wasm32")]
45                log::trace!("failed to resolve endpoint {qname}");
46
47                Err(CouldNotResolveEndpoint)
48            }
49        }
50    }
51
52    /// Helper method that returns the first [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint] in the Async stream from [Self::resolve_svcb_endpoints]
53    pub async fn resolve_svcb_endpoint(
54        &self,
55        qname: &str,
56    ) -> Result<Endpoint, CouldNotResolveEndpoint> {
57        let stream = self.resolve_https_endpoints(qname);
58
59        pin!(stream);
60
61        match stream.next().await {
62            Some(endpoint) => Ok(endpoint),
63            None => Err(CouldNotResolveEndpoint),
64        }
65    }
66
67    /// Returns an async stream of either [HTTPS][crate::dns::rdata::RData::HTTPS] or [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint]s
68    pub fn resolve_endpoints<'a>(
69        &'a self,
70        qname: &'a str,
71        https: bool,
72    ) -> impl Stream<Item = Endpoint> + 'a {
73        Gen::new(|co| async move {
74            let mut depth = 0;
75            let mut stack: Vec<Endpoint> = Vec::new();
76
77            // Initialize the stack with endpoints from the starting domain.
78            if let Ok(tld) = PublicKey::try_from(qname) {
79                if let Some(signed_packet) = self.resolve(&tld).await {
80                    depth += 1;
81                    stack.extend(Endpoint::parse(&signed_packet, qname, https));
82                }
83            }
84
85            while let Some(next) = stack.pop() {
86                let current = next.target();
87
88                // Attempt to resolve the domain as a public key.
89                match PublicKey::try_from(current) {
90                    Ok(tld) => match self.resolve(&tld).await {
91                        Some(signed_packet) if depth < self.0.max_recursion_depth => {
92                            depth += 1;
93                            let endpoints = Endpoint::parse(&signed_packet, current, https);
94
95                            #[cfg(not(target_arch = "wasm32"))]
96                            tracing::trace!(?qname, ?depth, ?endpoints, "resolved endpoints");
97                            #[cfg(target_arch = "wasm32")]
98                            log::trace!("resolved endpoints qname: {qname}, depth: {depth}, endpoints: {:?}", endpoints);
99
100                            stack.extend(endpoints);
101                        }
102                        _ => break, // Stop on resolution failure or recursion depth exceeded.
103                    },
104                    // Yield if the domain is not pointing to another Pkarr TLD domain.
105                    Err(_) => co.yield_(next).await,
106                }
107            }
108        })
109    }
110}
111
112#[derive(Debug)]
113/// pkarr could not resolve endpoint
114pub struct CouldNotResolveEndpoint;
115
116impl std::error::Error for CouldNotResolveEndpoint {}
117
118impl std::fmt::Display for CouldNotResolveEndpoint {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "pkarr could not resolve endpoint")
121    }
122}
123
124#[cfg(all(test, not(target_arch = "wasm32")))]
125mod tests {
126
127    use crate::dns::rdata::SVCB;
128    use crate::mainline::Testnet;
129    use crate::{Client, Keypair};
130    use crate::{PublicKey, SignedPacket};
131
132    use std::future::Future;
133    use std::net::IpAddr;
134    use std::pin::Pin;
135    use std::str::FromStr;
136
137    fn generate_subtree(
138        client: Client,
139        depth: u8,
140        branching: u8,
141        domain: Option<String>,
142        ips: Vec<IpAddr>,
143        port: Option<u16>,
144    ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
145        Box::pin(async move {
146            let keypair = Keypair::random();
147
148            let mut builder = SignedPacket::builder();
149
150            for _ in 0..branching {
151                let mut svcb = SVCB::new(0, ".".try_into().unwrap());
152
153                if depth == 0 {
154                    svcb.priority = 1;
155
156                    if let Some(port) = port {
157                        svcb.set_port(port);
158                    }
159
160                    if let Some(target) = &domain {
161                        let target: &'static str = Box::leak(target.clone().into_boxed_str());
162                        svcb.target = target.try_into().unwrap()
163                    }
164
165                    for ip in ips.clone() {
166                        builder = builder.address(".".try_into().unwrap(), ip, 3600);
167                    }
168                } else {
169                    let target = generate_subtree(
170                        client.clone(),
171                        depth - 1,
172                        branching,
173                        domain.clone(),
174                        ips.clone(),
175                        port,
176                    )
177                    .await
178                    .to_string();
179                    let target: &'static str = Box::leak(target.into_boxed_str());
180                    svcb.target = target.try_into().unwrap();
181                };
182
183                builder = builder.https(".".try_into().unwrap(), svcb, 3600);
184            }
185
186            let signed_packet = builder.sign(&keypair).unwrap();
187
188            client.publish(&signed_packet, None).await.unwrap();
189
190            keypair.public_key()
191        })
192    }
193
194    /// depth of (3): A -> B -> C
195    /// branch of (2): A -> B0,  A ->  B1
196    /// domain, ips, and port are all at the end (C, or B1)
197    fn generate(
198        client: &Client,
199        depth: u8,
200        branching: u8,
201        domain: Option<String>,
202        ips: Vec<IpAddr>,
203        port: Option<u16>,
204    ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
205        generate_subtree(client.clone(), depth - 1, branching, domain, ips, port)
206    }
207
208    #[tokio::test]
209    async fn direct_endpoint_resolution() {
210        let testnet = Testnet::new_async(5).await.unwrap();
211        let client = Client::builder()
212            .no_default_network()
213            .bootstrap(&testnet.bootstrap)
214            .build()
215            .unwrap();
216
217        let tld = generate(&client, 1, 1, Some("example.com".to_string()), vec![], None).await;
218
219        let endpoint = client
220            .resolve_https_endpoint(&tld.to_string())
221            .await
222            .unwrap();
223
224        assert_eq!(endpoint.domain(), Some("example.com"));
225        assert_eq!(endpoint.public_key(), &tld);
226    }
227
228    #[tokio::test]
229    async fn resolve_endpoints() {
230        let testnet = Testnet::new_async(5).await.unwrap();
231        let client = Client::builder()
232            .no_default_network()
233            .bootstrap(&testnet.bootstrap)
234            .build()
235            .unwrap();
236
237        let tld = generate(&client, 3, 3, Some("example.com".to_string()), vec![], None).await;
238
239        let endpoint = client
240            .resolve_https_endpoint(&tld.to_string())
241            .await
242            .unwrap();
243
244        assert_eq!(endpoint.domain(), Some("example.com"));
245    }
246
247    #[tokio::test]
248    async fn empty() {
249        let testnet = Testnet::new_async(5).await.unwrap();
250        let client = Client::builder()
251            .no_default_network()
252            .bootstrap(&testnet.bootstrap)
253            .build()
254            .unwrap();
255
256        let public_key = Keypair::random().public_key();
257
258        let endpoint = client.resolve_https_endpoint(&public_key.to_string()).await;
259
260        assert!(endpoint.is_err());
261    }
262
263    #[tokio::test]
264    async fn max_recursion_exceeded() {
265        let testnet = Testnet::new_async(5).await.unwrap();
266        let client = Client::builder()
267            .no_default_network()
268            .bootstrap(&testnet.bootstrap)
269            .max_recursion_depth(3)
270            .build()
271            .unwrap();
272
273        let tld = generate(&client, 4, 3, Some("example.com".to_string()), vec![], None).await;
274
275        let endpoint = client.resolve_https_endpoint(&tld.to_string()).await;
276
277        assert!(endpoint.is_err());
278    }
279
280    #[tokio::test]
281    async fn resolve_addresses() {
282        let testnet = Testnet::new_async(5).await.unwrap();
283        let client = Client::builder()
284            .no_default_network()
285            .bootstrap(&testnet.bootstrap)
286            .build()
287            .unwrap();
288
289        let tld = generate(
290            &client,
291            3,
292            3,
293            None,
294            vec![IpAddr::from_str("0.0.0.10").unwrap()],
295            Some(3000),
296        )
297        .await;
298
299        let endpoint = client
300            .resolve_https_endpoint(&tld.to_string())
301            .await
302            .unwrap();
303
304        assert_eq!(endpoint.target(), ".");
305        assert_eq!(endpoint.domain(), None);
306        assert_eq!(
307            endpoint
308                .to_socket_addrs()
309                .into_iter()
310                .map(|s| s.to_string())
311                .collect::<Vec<String>>(),
312            vec!["0.0.0.10:3000"]
313        );
314    }
315}