1mod endpoint;
5
6pub use endpoint::Endpoint;
7
8use futures_lite::{pin, Stream, StreamExt};
9use genawaiter::sync::Gen;
10
11use crate::PublicKey;
12
13impl crate::Client {
14 pub fn resolve_https_endpoints<'a>(
16 &'a self,
17 qname: &'a str,
18 ) -> impl Stream<Item = Endpoint> + 'a {
19 self.resolve_endpoints(qname, true)
20 }
21
22 pub fn resolve_svcb_endpoints<'a>(
24 &'a self,
25 qname: &'a str,
26 ) -> impl Stream<Item = Endpoint> + 'a {
27 self.resolve_endpoints(qname, false)
28 }
29
30 pub async fn resolve_https_endpoint(
32 &self,
33 qname: &str,
34 ) -> Result<Endpoint, CouldNotResolveEndpoint> {
35 let stream = self.resolve_https_endpoints(qname);
36
37 pin!(stream);
38
39 match stream.next().await {
40 Some(endpoint) => Ok(endpoint),
41 None => {
42 #[cfg(not(target_arch = "wasm32"))]
43 tracing::trace!(?qname, "failed to resolve endpoint");
44 #[cfg(target_arch = "wasm32")]
45 log::trace!("failed to resolve endpoint {qname}");
46
47 Err(CouldNotResolveEndpoint)
48 }
49 }
50 }
51
52 pub async fn resolve_svcb_endpoint(
54 &self,
55 qname: &str,
56 ) -> Result<Endpoint, CouldNotResolveEndpoint> {
57 let stream = self.resolve_https_endpoints(qname);
58
59 pin!(stream);
60
61 match stream.next().await {
62 Some(endpoint) => Ok(endpoint),
63 None => Err(CouldNotResolveEndpoint),
64 }
65 }
66
67 pub fn resolve_endpoints<'a>(
69 &'a self,
70 qname: &'a str,
71 https: bool,
72 ) -> impl Stream<Item = Endpoint> + 'a {
73 Gen::new(|co| async move {
74 let mut depth = 0;
75 let mut stack: Vec<Endpoint> = Vec::new();
76
77 if let Ok(tld) = PublicKey::try_from(qname) {
79 if let Some(signed_packet) = self.resolve(&tld).await {
80 depth += 1;
81 stack.extend(Endpoint::parse(&signed_packet, qname, https));
82 }
83 }
84
85 while let Some(next) = stack.pop() {
86 let current = next.target();
87
88 match PublicKey::try_from(current) {
90 Ok(tld) => match self.resolve(&tld).await {
91 Some(signed_packet) if depth < self.0.max_recursion_depth => {
92 depth += 1;
93 let endpoints = Endpoint::parse(&signed_packet, current, https);
94
95 #[cfg(not(target_arch = "wasm32"))]
96 tracing::trace!(?qname, ?depth, ?endpoints, "resolved endpoints");
97 #[cfg(target_arch = "wasm32")]
98 log::trace!("resolved endpoints qname: {qname}, depth: {depth}, endpoints: {:?}", endpoints);
99
100 stack.extend(endpoints);
101 }
102 _ => break, },
104 Err(_) => co.yield_(next).await,
106 }
107 }
108 })
109 }
110}
111
112#[derive(Debug)]
113pub struct CouldNotResolveEndpoint;
115
116impl std::error::Error for CouldNotResolveEndpoint {}
117
118impl std::fmt::Display for CouldNotResolveEndpoint {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "pkarr could not resolve endpoint")
121 }
122}
123
124#[cfg(all(test, not(target_arch = "wasm32")))]
125mod tests {
126
127 use crate::dns::rdata::SVCB;
128 use crate::mainline::Testnet;
129 use crate::{Client, Keypair};
130 use crate::{PublicKey, SignedPacket};
131
132 use std::future::Future;
133 use std::net::{IpAddr, Ipv4Addr};
134 use std::pin::Pin;
135 use std::str::FromStr;
136 use std::time::Duration;
137
138 fn generate_subtree(
139 client: Client,
140 depth: u8,
141 branching: u8,
142 domain: Option<String>,
143 ips: Vec<IpAddr>,
144 port: Option<u16>,
145 ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
146 Box::pin(async move {
147 let keypair = Keypair::random();
148
149 let mut builder = SignedPacket::builder();
150
151 for _ in 0..branching {
152 let mut svcb = SVCB::new(0, ".".try_into().unwrap());
153
154 if depth == 0 {
155 svcb.priority = 1;
156
157 if let Some(port) = port {
158 svcb.set_port(port);
159 }
160
161 if let Some(target) = &domain {
162 let target: &'static str = Box::leak(target.clone().into_boxed_str());
163 svcb.target = target.try_into().unwrap()
164 }
165
166 for ip in ips.clone() {
167 builder = builder.address(".".try_into().unwrap(), ip, 3600);
168 }
169 } else {
170 let target = generate_subtree(
171 client.clone(),
172 depth - 1,
173 branching,
174 domain.clone(),
175 ips.clone(),
176 port,
177 )
178 .await
179 .to_string();
180 let target: &'static str = Box::leak(target.into_boxed_str());
181 svcb.target = target.try_into().unwrap();
182 };
183
184 builder = builder.https(".".try_into().unwrap(), svcb, 3600);
185 }
186
187 let signed_packet = builder.sign(&keypair).unwrap();
188
189 client.publish(&signed_packet, None).await.unwrap();
190
191 keypair.public_key()
192 })
193 }
194
195 fn generate(
199 client: &Client,
200 depth: u8,
201 branching: u8,
202 domain: Option<String>,
203 ips: Vec<IpAddr>,
204 port: Option<u16>,
205 ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
206 generate_subtree(client.clone(), depth - 1, branching, domain, ips, port)
207 }
208
209 #[tokio::test]
210 async fn direct_endpoint_resolution() {
211 let testnet = Testnet::builder(5).build().unwrap();
212 let client = Client::builder()
213 .no_default_network()
214 .bootstrap(&testnet.bootstrap)
215 .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
216 .build()
217 .unwrap();
218
219 let tld = generate(&client, 1, 1, Some("example.com".to_string()), vec![], None).await;
220
221 let endpoint = client
222 .resolve_https_endpoint(&tld.to_string())
223 .await
224 .unwrap();
225
226 assert_eq!(endpoint.domain(), Some("example.com"));
227 assert_eq!(endpoint.public_key(), &tld);
228 }
229
230 #[tokio::test]
231 async fn resolve_endpoints() {
232 let testnet = Testnet::builder(5).build().unwrap();
233 let client = Client::builder()
234 .no_default_network()
235 .bootstrap(&testnet.bootstrap)
236 .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
237 .request_timeout(Duration::from_millis(200))
238 .build()
239 .unwrap();
240
241 let tld = generate(&client, 3, 3, Some("example.com".to_string()), vec![], None).await;
242
243 let endpoint = client
244 .resolve_https_endpoint(&tld.to_string())
245 .await
246 .unwrap();
247
248 assert_eq!(endpoint.domain(), Some("example.com"));
249 }
250
251 #[tokio::test]
252 async fn empty() {
253 let testnet = Testnet::builder(5).build().unwrap();
254 let client = Client::builder()
255 .no_default_network()
256 .bootstrap(&testnet.bootstrap)
257 .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
258 .request_timeout(Duration::from_millis(20))
259 .build()
260 .unwrap();
261
262 let public_key = Keypair::random().public_key();
263
264 let endpoint = client.resolve_https_endpoint(&public_key.to_string()).await;
265
266 assert!(endpoint.is_err());
267 }
268
269 #[tokio::test]
270 async fn max_recursion_exceeded() {
271 let testnet = Testnet::builder(5).build().unwrap();
272 let client = Client::builder()
273 .no_default_network()
274 .bootstrap(&testnet.bootstrap)
275 .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
276 .request_timeout(Duration::from_millis(100))
277 .max_recursion_depth(3)
278 .build()
279 .unwrap();
280
281 let tld = generate(&client, 4, 3, Some("example.com".to_string()), vec![], None).await;
282
283 let endpoint = client.resolve_https_endpoint(&tld.to_string()).await;
284
285 assert!(endpoint.is_err());
286 }
287
288 #[tokio::test]
289 async fn resolve_addresses() {
290 let testnet = Testnet::builder(5).build().unwrap();
291 let client = Client::builder()
292 .no_default_network()
293 .bootstrap(&testnet.bootstrap)
294 .dht(|b| b.bind_address(Ipv4Addr::LOCALHOST))
295 .request_timeout(Duration::from_millis(200))
296 .build()
297 .unwrap();
298
299 let tld = generate(
300 &client,
301 3,
302 3,
303 None,
304 vec![IpAddr::from_str("0.0.0.10").unwrap()],
305 Some(3000),
306 )
307 .await;
308
309 let endpoint = client
310 .resolve_https_endpoint(&tld.to_string())
311 .await
312 .unwrap();
313
314 assert_eq!(endpoint.target(), ".");
315 assert_eq!(endpoint.domain(), None);
316 assert_eq!(
317 endpoint
318 .to_socket_addrs()
319 .into_iter()
320 .map(|s| s.to_string())
321 .collect::<Vec<String>>(),
322 vec!["0.0.0.10:3000"]
323 );
324 }
325}