1mod endpoint;
5
6pub use endpoint::Endpoint;
7
8use futures_lite::{pin, Stream, StreamExt};
9use genawaiter::sync::Gen;
10
11use crate::PublicKey;
12
13impl crate::Client {
14 pub fn resolve_https_endpoints<'a>(
16 &'a self,
17 qname: &'a str,
18 ) -> impl Stream<Item = Endpoint> + 'a {
19 self.resolve_endpoints(qname, true)
20 }
21
22 pub fn resolve_svcb_endpoints<'a>(
24 &'a self,
25 qname: &'a str,
26 ) -> impl Stream<Item = Endpoint> + 'a {
27 self.resolve_endpoints(qname, false)
28 }
29
30 pub async fn resolve_https_endpoint(
32 &self,
33 qname: &str,
34 ) -> Result<Endpoint, CouldNotResolveEndpoint> {
35 let stream = self.resolve_https_endpoints(qname);
36
37 pin!(stream);
38
39 match stream.next().await {
40 Some(endpoint) => Ok(endpoint),
41 None => {
42 #[cfg(not(target_arch = "wasm32"))]
43 tracing::trace!(?qname, "failed to resolve endpoint");
44 #[cfg(target_arch = "wasm32")]
45 log::trace!("failed to resolve endpoint {qname}");
46
47 Err(CouldNotResolveEndpoint)
48 }
49 }
50 }
51
52 pub async fn resolve_svcb_endpoint(
54 &self,
55 qname: &str,
56 ) -> Result<Endpoint, CouldNotResolveEndpoint> {
57 let stream = self.resolve_https_endpoints(qname);
58
59 pin!(stream);
60
61 match stream.next().await {
62 Some(endpoint) => Ok(endpoint),
63 None => Err(CouldNotResolveEndpoint),
64 }
65 }
66
67 pub fn resolve_endpoints<'a>(
69 &'a self,
70 qname: &'a str,
71 https: bool,
72 ) -> impl Stream<Item = Endpoint> + 'a {
73 Gen::new(|co| async move {
74 let mut depth = 0;
75 let mut stack: Vec<Endpoint> = Vec::new();
76
77 if let Ok(tld) = PublicKey::try_from(qname) {
79 if let Some(signed_packet) = self.resolve(&tld).await {
80 depth += 1;
81 stack.extend(Endpoint::parse(&signed_packet, qname, https));
82 }
83 }
84
85 while let Some(next) = stack.pop() {
86 let current = next.target();
87
88 match PublicKey::try_from(current) {
90 Ok(tld) => match self.resolve(&tld).await {
91 Some(signed_packet) if depth < self.0.max_recursion_depth => {
92 depth += 1;
93 let endpoints = Endpoint::parse(&signed_packet, current, https);
94
95 #[cfg(not(target_arch = "wasm32"))]
96 tracing::trace!(?qname, ?depth, ?endpoints, "resolved endpoints");
97 #[cfg(target_arch = "wasm32")]
98 log::trace!("resolved endpoints qname: {qname}, depth: {depth}, endpoints: {:?}", endpoints);
99
100 stack.extend(endpoints);
101 }
102 _ => break, },
104 Err(_) => co.yield_(next).await,
106 }
107 }
108 })
109 }
110}
111
112#[derive(Debug)]
113pub struct CouldNotResolveEndpoint;
115
116impl std::error::Error for CouldNotResolveEndpoint {}
117
118impl std::fmt::Display for CouldNotResolveEndpoint {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "pkarr could not resolve endpoint")
121 }
122}
123
124#[cfg(all(test, not(target_arch = "wasm32")))]
125mod tests {
126
127 use crate::dns::rdata::SVCB;
128 use crate::{Client, Keypair};
129 use crate::{PublicKey, SignedPacket};
130
131 use std::future::Future;
132 use std::net::IpAddr;
133 use std::pin::Pin;
134 use std::str::FromStr;
135 use std::time::Duration;
136
137 use mainline::Testnet;
138
139 fn generate_subtree(
140 client: Client,
141 depth: u8,
142 branching: u8,
143 domain: Option<String>,
144 ips: Vec<IpAddr>,
145 port: Option<u16>,
146 ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
147 Box::pin(async move {
148 let keypair = Keypair::random();
149
150 let mut builder = SignedPacket::builder();
151
152 for _ in 0..branching {
153 let mut svcb = SVCB::new(0, ".".try_into().unwrap());
154
155 if depth == 0 {
156 svcb.priority = 1;
157
158 if let Some(port) = port {
159 svcb.set_port(port);
160 }
161
162 if let Some(target) = &domain {
163 let target: &'static str = Box::leak(target.clone().into_boxed_str());
164 svcb.target = target.try_into().unwrap()
165 }
166
167 for ip in ips.clone() {
168 builder = builder.address(".".try_into().unwrap(), ip, 3600);
169 }
170 } else {
171 let target = generate_subtree(
172 client.clone(),
173 depth - 1,
174 branching,
175 domain.clone(),
176 ips.clone(),
177 port,
178 )
179 .await
180 .to_string();
181 let target: &'static str = Box::leak(target.into_boxed_str());
182 svcb.target = target.try_into().unwrap();
183 };
184
185 builder = builder.https(".".try_into().unwrap(), svcb, 3600);
186 }
187
188 let signed_packet = builder.sign(&keypair).unwrap();
189
190 client.publish(&signed_packet, None).await.unwrap();
191
192 keypair.public_key()
193 })
194 }
195
196 fn generate(
200 client: &Client,
201 depth: u8,
202 branching: u8,
203 domain: Option<String>,
204 ips: Vec<IpAddr>,
205 port: Option<u16>,
206 ) -> Pin<Box<dyn Future<Output = PublicKey>>> {
207 generate_subtree(client.clone(), depth - 1, branching, domain, ips, port)
208 }
209
210 #[tokio::test]
211 async fn direct_endpoint_resolution() {
212 let testnet = Testnet::new(3).unwrap();
213 let client = Client::builder()
214 .no_default_network()
215 .bootstrap(&testnet.bootstrap)
216 .build()
217 .unwrap();
218
219 let tld = generate(&client, 1, 1, Some("example.com".to_string()), vec![], None).await;
220
221 let endpoint = client
222 .resolve_https_endpoint(&tld.to_string())
223 .await
224 .unwrap();
225
226 assert_eq!(endpoint.domain(), Some("example.com"));
227 assert_eq!(endpoint.public_key(), &tld);
228 }
229
230 #[tokio::test]
231 async fn resolve_endpoints() {
232 let testnet = Testnet::new(3).unwrap();
233 let client = Client::builder()
234 .no_default_network()
235 .bootstrap(&testnet.bootstrap)
236 .build()
237 .unwrap();
238
239 let tld = generate(&client, 3, 3, Some("example.com".to_string()), vec![], None).await;
240
241 let endpoint = client
242 .resolve_https_endpoint(&tld.to_string())
243 .await
244 .unwrap();
245
246 assert_eq!(endpoint.domain(), Some("example.com"));
247 }
248
249 #[tokio::test]
250 async fn empty() {
251 let testnet = Testnet::new(3).unwrap();
252 let client = Client::builder()
253 .no_default_network()
254 .bootstrap(&testnet.bootstrap)
255 .request_timeout(Duration::from_millis(20))
256 .build()
257 .unwrap();
258
259 let pubky = Keypair::random().public_key();
260
261 let endpoint = client.resolve_https_endpoint(&pubky.to_string()).await;
262
263 assert!(endpoint.is_err());
264 }
265
266 #[tokio::test]
267 async fn max_recursion_exceeded() {
268 let testnet = Testnet::new(3).unwrap();
269 let client = Client::builder()
270 .no_default_network()
271 .bootstrap(&testnet.bootstrap)
272 .max_recursion_depth(3)
273 .build()
274 .unwrap();
275
276 let tld = generate(&client, 4, 3, Some("example.com".to_string()), vec![], None).await;
277
278 let endpoint = client.resolve_https_endpoint(&tld.to_string()).await;
279
280 assert!(endpoint.is_err());
281 }
282
283 #[tokio::test]
284 async fn resolve_addresses() {
285 let testnet = Testnet::new(3).unwrap();
286 let client = Client::builder()
287 .no_default_network()
288 .bootstrap(&testnet.bootstrap)
289 .build()
290 .unwrap();
291
292 let tld = generate(
293 &client,
294 3,
295 3,
296 None,
297 vec![IpAddr::from_str("0.0.0.10").unwrap()],
298 Some(3000),
299 )
300 .await;
301
302 let endpoint = client
303 .resolve_https_endpoint(&tld.to_string())
304 .await
305 .unwrap();
306
307 assert_eq!(endpoint.target(), ".");
308 assert_eq!(endpoint.domain(), None);
309 assert_eq!(
310 endpoint
311 .to_socket_addrs()
312 .into_iter()
313 .map(|s| s.to_string())
314 .collect::<Vec<String>>(),
315 vec!["0.0.0.10:3000"]
316 );
317 }
318}