Skip to main content

pkarr/extra/endpoints/
mod.rs

1//! implementation of [Endpoints](https://github.com/pubky/pkarr/blob/main/design/endpoints.md) spec.
2//!
3
4mod endpoint;
5
6pub use endpoint::Endpoint;
7
8use futures_lite::{pin, Stream, StreamExt};
9use genawaiter::sync::Gen;
10
11use crate::PublicKey;
12
13impl crate::Client {
14    /// Returns an async stream of [HTTPS][crate::dns::rdata::RData::HTTPS] [Endpoint]s
15    pub fn resolve_https_endpoints<'a>(
16        &'a self,
17        qname: &'a str,
18    ) -> impl Stream<Item = Endpoint> + 'a {
19        self.resolve_endpoints(qname, true)
20    }
21
22    /// Returns an async stream of [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint]s
23    pub fn resolve_svcb_endpoints<'a>(
24        &'a self,
25        qname: &'a str,
26    ) -> impl Stream<Item = Endpoint> + 'a {
27        self.resolve_endpoints(qname, false)
28    }
29
30    /// Helper method that returns the first [HTTPS][crate::dns::rdata::RData::HTTPS] [Endpoint] in the Async stream from [Self::resolve_https_endpoints]
31    pub async fn resolve_https_endpoint(
32        &self,
33        qname: &str,
34    ) -> Result<Endpoint, CouldNotResolveEndpoint> {
35        let stream = self.resolve_https_endpoints(qname);
36
37        pin!(stream);
38
39        match stream.next().await {
40            Some(endpoint) => Ok(endpoint),
41            None => {
42                #[cfg(not(target_arch = "wasm32"))]
43                tracing::trace!(?qname, "failed to resolve endpoint");
44                #[cfg(target_arch = "wasm32")]
45                log::trace!("failed to resolve endpoint {qname}");
46
47                Err(CouldNotResolveEndpoint)
48            }
49        }
50    }
51
52    /// Helper method that returns the first [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint] in the Async stream from [Self::resolve_svcb_endpoints]
53    pub async fn resolve_svcb_endpoint(
54        &self,
55        qname: &str,
56    ) -> Result<Endpoint, CouldNotResolveEndpoint> {
57        let stream = self.resolve_https_endpoints(qname);
58
59        pin!(stream);
60
61        match stream.next().await {
62            Some(endpoint) => Ok(endpoint),
63            None => Err(CouldNotResolveEndpoint),
64        }
65    }
66
67    /// Returns an async stream of either [HTTPS][crate::dns::rdata::RData::HTTPS] or [SVCB][crate::dns::rdata::RData::SVCB] [Endpoint]s
68    pub fn resolve_endpoints<'a>(
69        &'a self,
70        qname: &'a str,
71        https: bool,
72    ) -> impl Stream<Item = Endpoint> + 'a {
73        Gen::new(|co| async move {
74            let mut depth = 0;
75            let mut stack: Vec<Endpoint> = Vec::new();
76
77            // Initialize the stack with endpoints from the starting domain.
78            if let Ok(tld) = PublicKey::try_from(qname) {
79                if let Some(signed_packet) = self.resolve(&tld).await {
80                    depth += 1;
81                    stack.extend(Endpoint::parse(&signed_packet, qname, https));
82                }
83            }
84
85            while let Some(next) = stack.pop() {
86                let current = next.target();
87
88                // Attempt to resolve the domain as a public key.
89                match PublicKey::try_from(current) {
90                    Ok(tld) => match self.resolve(&tld).await {
91                        Some(signed_packet) if depth < self.0.max_recursion_depth => {
92                            depth += 1;
93                            let endpoints = Endpoint::parse(&signed_packet, current, https);
94
95                            #[cfg(not(target_arch = "wasm32"))]
96                            tracing::trace!(?qname, ?depth, ?endpoints, "resolved endpoints");
97                            #[cfg(target_arch = "wasm32")]
98                            log::trace!("resolved endpoints qname: {qname}, depth: {depth}, endpoints: {:?}", endpoints);
99
100                            stack.extend(endpoints);
101                        }
102                        _ => break, // Stop on resolution failure or recursion depth exceeded.
103                    },
104                    // Yield if the domain is not pointing to another Pkarr TLD domain.
105                    Err(_) => co.yield_(next).await,
106                }
107            }
108        })
109    }
110}
111
112#[derive(Debug)]
113/// pkarr could not resolve endpoint
114pub struct CouldNotResolveEndpoint;
115
116impl std::error::Error for CouldNotResolveEndpoint {}
117
118impl std::fmt::Display for CouldNotResolveEndpoint {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "pkarr could not resolve endpoint")
121    }
122}
123
124#[cfg(all(test, not(target_arch = "wasm32")))]
125mod tests {
126
127    use crate::dns::rdata::SVCB;
128    use crate::mainline::Testnet;
129    use crate::{Client, Keypair};
130    use crate::{PublicKey, SignedPacket};
131
132    use std::future::Future;
133    use std::net::{IpAddr, Ipv4Addr};
134    use std::pin::Pin;
135    use std::str::FromStr;
136    use std::time::Duration;
137
138    fn generate_subtree(
139        client: Client,
140        depth: u8,
141        branching: u8,
142        domain: Option<String>,
143        ips: Vec<IpAddr>,
144        port: Option<u16>,
145    ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
146        Box::pin(async move {
147            let keypair = Keypair::random();
148
149            let mut builder = SignedPacket::builder();
150
151            for _ in 0..branching {
152                let mut svcb = SVCB::new(0, ".".try_into().unwrap());
153
154                if depth == 0 {
155                    svcb.priority = 1;
156
157                    if let Some(port) = port {
158                        svcb.set_port(port);
159                    }
160
161                    if let Some(target) = &domain {
162                        let target: &'static str = Box::leak(target.clone().into_boxed_str());
163                        svcb.target = target.try_into().unwrap()
164                    }
165
166                    for ip in ips.clone() {
167                        builder = builder.address(".".try_into().unwrap(), ip, 3600);
168                    }
169                } else {
170                    let target = generate_subtree(
171                        client.clone(),
172                        depth - 1,
173                        branching,
174                        domain.clone(),
175                        ips.clone(),
176                        port,
177                    )
178                    .await
179                    .to_string();
180                    let target: &'static str = Box::leak(target.into_boxed_str());
181                    svcb.target = target.try_into().unwrap();
182                };
183
184                builder = builder.https(".".try_into().unwrap(), svcb, 3600);
185            }
186
187            let signed_packet = builder.sign(&keypair).unwrap();
188
189            client.publish(&signed_packet, None).await.unwrap();
190
191            keypair.public_key()
192        })
193    }
194
195    /// depth of (3): A -> B -> C
196    /// branch of (2): A -> B0,  A ->  B1
197    /// domain, ips, and port are all at the end (C, or B1)
198    fn generate(
199        client: &Client,
200        depth: u8,
201        branching: u8,
202        domain: Option<String>,
203        ips: Vec<IpAddr>,
204        port: Option<u16>,
205    ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
206        generate_subtree(client.clone(), depth - 1, branching, domain, ips, port)
207    }
208
209    #[tokio::test]
210    async fn direct_endpoint_resolution() {
211        let testnet = Testnet::builder(5).build().unwrap();
212        let client = Client::builder()
213            .no_default_network()
214            .bootstrap(&testnet.bootstrap)
215            .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
216            .build()
217            .unwrap();
218
219        let tld = generate(&client, 1, 1, Some("example.com".to_string()), vec![], None).await;
220
221        let endpoint = client
222            .resolve_https_endpoint(&tld.to_string())
223            .await
224            .unwrap();
225
226        assert_eq!(endpoint.domain(), Some("example.com"));
227        assert_eq!(endpoint.public_key(), &tld);
228    }
229
230    #[tokio::test]
231    async fn resolve_endpoints() {
232        let testnet = Testnet::builder(5).build().unwrap();
233        let client = Client::builder()
234            .no_default_network()
235            .bootstrap(&testnet.bootstrap)
236            .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
237            .request_timeout(Duration::from_millis(200))
238            .build()
239            .unwrap();
240
241        let tld = generate(&client, 3, 3, Some("example.com".to_string()), vec![], None).await;
242
243        let endpoint = client
244            .resolve_https_endpoint(&tld.to_string())
245            .await
246            .unwrap();
247
248        assert_eq!(endpoint.domain(), Some("example.com"));
249    }
250
251    #[tokio::test]
252    async fn empty() {
253        let testnet = Testnet::builder(5).build().unwrap();
254        let client = Client::builder()
255            .no_default_network()
256            .bootstrap(&testnet.bootstrap)
257            .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
258            .request_timeout(Duration::from_millis(20))
259            .build()
260            .unwrap();
261
262        let public_key = Keypair::random().public_key();
263
264        let endpoint = client.resolve_https_endpoint(&public_key.to_string()).await;
265
266        assert!(endpoint.is_err());
267    }
268
269    #[tokio::test]
270    async fn max_recursion_exceeded() {
271        let testnet = Testnet::builder(5).build().unwrap();
272        let client = Client::builder()
273            .no_default_network()
274            .bootstrap(&testnet.bootstrap)
275            .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
276            .request_timeout(Duration::from_millis(100))
277            .max_recursion_depth(3)
278            .build()
279            .unwrap();
280
281        let tld = generate(&client, 4, 3, Some("example.com".to_string()), vec![], None).await;
282
283        let endpoint = client.resolve_https_endpoint(&tld.to_string()).await;
284
285        assert!(endpoint.is_err());
286    }
287
288    #[tokio::test]
289    async fn resolve_addresses() {
290        let testnet = Testnet::builder(5).build().unwrap();
291        let client = Client::builder()
292            .no_default_network()
293            .bootstrap(&testnet.bootstrap)
294            .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
295            .request_timeout(Duration::from_millis(200))
296            .build()
297            .unwrap();
298
299        let tld = generate(
300            &client,
301            3,
302            3,
303            None,
304            vec![IpAddr::from_str("0.0.0.10").unwrap()],
305            Some(3000),
306        )
307        .await;
308
309        let endpoint = client
310            .resolve_https_endpoint(&tld.to_string())
311            .await
312            .unwrap();
313
314        assert_eq!(endpoint.target(), ".");
315        assert_eq!(endpoint.domain(), None);
316        assert_eq!(
317            endpoint
318                .to_socket_addrs()
319                .into_iter()
320                .map(|s| s.to_string())
321                .collect::<Vec<String>>(),
322            vec!["0.0.0.10:3000"]
323        );
324    }
325}