pkarr_client/extra/endpoints/
mod.rs1mod endpoint;
5
6pub use endpoint::Endpoint;
7
8use futures_lite::{pin, Stream, StreamExt};
9use genawaiter::sync::Gen;
10
11use crate::PublicKey;
12
13impl crate::Client {
14 pub fn resolve_https_endpoints<'a>(
16 &'a self,
17 qname: &'a str,
18 ) -> impl Stream<Item = Endpoint> + 'a {
19 self.resolve_endpoints(qname, true)
20 }
21
22 pub fn resolve_svcb_endpoints<'a>(
24 &'a self,
25 qname: &'a str,
26 ) -> impl Stream<Item = Endpoint> + 'a {
27 self.resolve_endpoints(qname, false)
28 }
29
30 pub async fn resolve_https_endpoint(
32 &self,
33 qname: &str,
34 ) -> Result<Endpoint, CouldNotResolveEndpoint> {
35 let stream = self.resolve_https_endpoints(qname);
36
37 pin!(stream);
38
39 match stream.next().await {
40 Some(endpoint) => Ok(endpoint),
41 None => {
42 #[cfg(not(target_arch = "wasm32"))]
43 tracing::trace!(?qname, "failed to resolve endpoint");
44 #[cfg(target_arch = "wasm32")]
45 log::trace!("failed to resolve endpoint {qname}");
46
47 Err(CouldNotResolveEndpoint)
48 }
49 }
50 }
51
52 pub async fn resolve_svcb_endpoint(
54 &self,
55 qname: &str,
56 ) -> Result<Endpoint, CouldNotResolveEndpoint> {
57 let stream = self.resolve_https_endpoints(qname);
58
59 pin!(stream);
60
61 match stream.next().await {
62 Some(endpoint) => Ok(endpoint),
63 None => Err(CouldNotResolveEndpoint),
64 }
65 }
66
67 pub fn resolve_endpoints<'a>(
69 &'a self,
70 qname: &'a str,
71 https: bool,
72 ) -> impl Stream<Item = Endpoint> + 'a {
73 Gen::new(|co| async move {
74 let mut depth = 0;
75 let mut stack: Vec<Endpoint> = Vec::new();
76
77 if let Ok(tld) = PublicKey::try_from(qname) {
79 if let Some(signed_packet) = self.resolve(&tld).await {
80 depth += 1;
81 stack.extend(Endpoint::parse(&signed_packet, qname, https));
82 }
83 }
84
85 while let Some(next) = stack.pop() {
86 let current = next.target();
87
88 match PublicKey::try_from(current) {
90 Ok(tld) => match self.resolve(&tld).await {
91 Some(signed_packet) if depth < self.0.max_recursion_depth => {
92 depth += 1;
93 let endpoints = Endpoint::parse(&signed_packet, current, https);
94
95 #[cfg(not(target_arch = "wasm32"))]
96 tracing::trace!(?qname, ?depth, ?endpoints, "resolved endpoints");
97 #[cfg(target_arch = "wasm32")]
98 log::trace!("resolved endpoints qname: {qname}, depth: {depth}, endpoints: {:?}", endpoints);
99
100 stack.extend(endpoints);
101 }
102 _ => break, },
104 Err(_) => co.yield_(next).await,
106 }
107 }
108 })
109 }
110}
111
112#[derive(Debug)]
113pub struct CouldNotResolveEndpoint;
115
116impl std::error::Error for CouldNotResolveEndpoint {}
117
118impl std::fmt::Display for CouldNotResolveEndpoint {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "pkarr could not resolve endpoint")
121 }
122}
123
124#[cfg(all(test, not(target_arch = "wasm32")))]
125mod tests {
126
127 use crate::dns::rdata::SVCB;
128 use crate::mainline::Testnet;
129 use crate::{Client, Keypair};
130 use crate::{PublicKey, SignedPacket};
131
132 use std::future::Future;
133 use std::net::IpAddr;
134 use std::pin::Pin;
135 use std::str::FromStr;
136
137 fn generate_subtree(
138 client: Client,
139 depth: u8,
140 branching: u8,
141 domain: Option<String>,
142 ips: Vec<IpAddr>,
143 port: Option<u16>,
144 ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
145 Box::pin(async move {
146 let keypair = Keypair::random();
147
148 let mut builder = SignedPacket::builder();
149
150 for _ in 0..branching {
151 let mut svcb = SVCB::new(0, ".".try_into().unwrap());
152
153 if depth == 0 {
154 svcb.priority = 1;
155
156 if let Some(port) = port {
157 svcb.set_port(port);
158 }
159
160 if let Some(target) = &domain {
161 let target: &'static str = Box::leak(target.clone().into_boxed_str());
162 svcb.target = target.try_into().unwrap()
163 }
164
165 for ip in ips.clone() {
166 builder = builder.address(".".try_into().unwrap(), ip, 3600);
167 }
168 } else {
169 let target = generate_subtree(
170 client.clone(),
171 depth - 1,
172 branching,
173 domain.clone(),
174 ips.clone(),
175 port,
176 )
177 .await
178 .to_string();
179 let target: &'static str = Box::leak(target.into_boxed_str());
180 svcb.target = target.try_into().unwrap();
181 };
182
183 builder = builder.https(".".try_into().unwrap(), svcb, 3600);
184 }
185
186 let signed_packet = builder.sign(&keypair).unwrap();
187
188 client.publish(&signed_packet, None).await.unwrap();
189
190 keypair.public_key()
191 })
192 }
193
194 fn generate(
198 client: &Client,
199 depth: u8,
200 branching: u8,
201 domain: Option<String>,
202 ips: Vec<IpAddr>,
203 port: Option<u16>,
204 ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
205 generate_subtree(client.clone(), depth - 1, branching, domain, ips, port)
206 }
207
208 #[tokio::test]
209 async fn direct_endpoint_resolution() {
210 let testnet = Testnet::new_async(5).await.unwrap();
211 let client = Client::builder()
212 .no_default_network()
213 .bootstrap(&testnet.bootstrap)
214 .build()
215 .unwrap();
216
217 let tld = generate(&client, 1, 1, Some("example.com".to_string()), vec![], None).await;
218
219 let endpoint = client
220 .resolve_https_endpoint(&tld.to_string())
221 .await
222 .unwrap();
223
224 assert_eq!(endpoint.domain(), Some("example.com"));
225 assert_eq!(endpoint.public_key(), &tld);
226 }
227
228 #[tokio::test]
229 async fn resolve_endpoints() {
230 let testnet = Testnet::new_async(5).await.unwrap();
231 let client = Client::builder()
232 .no_default_network()
233 .bootstrap(&testnet.bootstrap)
234 .build()
235 .unwrap();
236
237 let tld = generate(&client, 3, 3, Some("example.com".to_string()), vec![], None).await;
238
239 let endpoint = client
240 .resolve_https_endpoint(&tld.to_string())
241 .await
242 .unwrap();
243
244 assert_eq!(endpoint.domain(), Some("example.com"));
245 }
246
247 #[tokio::test]
248 async fn empty() {
249 let testnet = Testnet::new_async(5).await.unwrap();
250 let client = Client::builder()
251 .no_default_network()
252 .bootstrap(&testnet.bootstrap)
253 .build()
254 .unwrap();
255
256 let public_key = Keypair::random().public_key();
257
258 let endpoint = client.resolve_https_endpoint(&public_key.to_string()).await;
259
260 assert!(endpoint.is_err());
261 }
262
263 #[tokio::test]
264 async fn max_recursion_exceeded() {
265 let testnet = Testnet::new_async(5).await.unwrap();
266 let client = Client::builder()
267 .no_default_network()
268 .bootstrap(&testnet.bootstrap)
269 .max_recursion_depth(3)
270 .build()
271 .unwrap();
272
273 let tld = generate(&client, 4, 3, Some("example.com".to_string()), vec![], None).await;
274
275 let endpoint = client.resolve_https_endpoint(&tld.to_string()).await;
276
277 assert!(endpoint.is_err());
278 }
279
280 #[tokio::test]
281 async fn resolve_addresses() {
282 let testnet = Testnet::new_async(5).await.unwrap();
283 let client = Client::builder()
284 .no_default_network()
285 .bootstrap(&testnet.bootstrap)
286 .build()
287 .unwrap();
288
289 let tld = generate(
290 &client,
291 3,
292 3,
293 None,
294 vec![IpAddr::from_str("0.0.0.10").unwrap()],
295 Some(3000),
296 )
297 .await;
298
299 let endpoint = client
300 .resolve_https_endpoint(&tld.to_string())
301 .await
302 .unwrap();
303
304 assert_eq!(endpoint.target(), ".");
305 assert_eq!(endpoint.domain(), None);
306 assert_eq!(
307 endpoint
308 .to_socket_addrs()
309 .into_iter()
310 .map(|s| s.to_string())
311 .collect::<Vec<String>>(),
312 vec!["0.0.0.10:3000"]
313 );
314 }
315}