1use crate::error::UpstreamError::{self, Bootstrap, Build};
2use crate::utils::build_request_message;
3use http::header::{ACCEPT, CONTENT_TYPE};
4use reqwest::{
5 header::{HeaderMap, HeaderValue},
6 Client,
7};
8use std::{net::SocketAddr, time::Duration};
9use trust_dns_proto::{
10 op::message::Message,
11 rr::{Name, RData, RecordType},
12};
13
14pub struct BootstrapClient {
15 https_client: Client,
16}
17
18impl BootstrapClient {
19 pub fn new() -> Result<Self, UpstreamError> {
20 let mut headers = HeaderMap::new();
21 headers.insert(
22 CONTENT_TYPE,
23 HeaderValue::from_str("application/dns-message").unwrap(),
24 );
25 headers.insert(
26 ACCEPT,
27 HeaderValue::from_str("application/dns-message").unwrap(),
28 );
29
30 let client_builder = Client::builder()
31 .default_headers(headers)
32 .https_only(true)
33 .gzip(true)
34 .brotli(true)
35 .timeout(Duration::from_secs(10));
36
37 let https_client = match client_builder.build() {
38 Ok(https_client) => https_client,
39 Err(_) => return Err(Build),
40 };
41
42 Ok(BootstrapClient { https_client })
43 }
44
45 pub async fn bootstrap(&self, host: &str) -> Result<SocketAddr, UpstreamError> {
46 let request_name = match host.parse::<Name>() {
47 Ok(request_name) => request_name,
48 Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
49 };
50 let request_message = build_request_message(request_name, RecordType::A);
51
52 let raw_request_message = match request_message.to_vec() {
53 Ok(raw_request_message) => raw_request_message,
54 Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
55 };
56
57 let url = "https://1.1.1.1/dns-query";
58 let request = self.https_client.post(url).body(raw_request_message);
59 let response = match request.send().await {
60 Ok(response) => response,
61 Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
62 };
63
64 let raw_response_message = match response.bytes().await {
65 Ok(response_bytes) => response_bytes,
66 Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
67 };
68
69 let response_message = match Message::from_vec(&raw_response_message) {
70 Ok(response_message) => response_message,
71 Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
72 };
73
74 if response_message.answers().is_empty() {
75 return Err(Bootstrap(
76 host.to_string(),
77 String::from("the response doesn't contain the answer"),
78 ));
79 }
80 let record = &response_message.answers()[0];
81 let record_data = match record.data() {
82 Some(record_data) => record_data,
83 None => {
84 return Err(Bootstrap(
85 host.to_string(),
86 String::from("the response doesn't contain the answer"),
87 ))
88 }
89 };
90
91 match record_data {
92 RData::A(ipv4_address) => Ok(SocketAddr::new((*ipv4_address).into(), 0)),
93 RData::AAAA(ipv6_address) => Ok(SocketAddr::new((*ipv6_address).into(), 0)),
94 _ => Err(Bootstrap(
95 host.to_string(),
96 String::from("unknown record type"),
97 )),
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::BootstrapClient;
105 use std::{
106 collections::HashMap,
107 net::{Ipv4Addr, SocketAddr},
108 };
109
110 #[tokio::test]
111 async fn test_bootstrap() {
112 let bootstrap_client = BootstrapClient::new().unwrap();
113 let bootstrap_result_map = HashMap::from([
114 (
115 "dns.google",
116 vec![
117 SocketAddr::new(Ipv4Addr::new(8, 8, 8, 8).into(), 0),
118 SocketAddr::new(Ipv4Addr::new(8, 8, 4, 4).into(), 0),
119 ],
120 ),
121 (
122 "one.one.one.one",
123 vec![
124 SocketAddr::new(Ipv4Addr::new(1, 1, 1, 1).into(), 0),
125 SocketAddr::new(Ipv4Addr::new(1, 0, 0, 1).into(), 0),
126 ],
127 ),
128 (
129 "dns.quad9.net",
130 vec![
131 SocketAddr::new(Ipv4Addr::new(9, 9, 9, 9).into(), 0),
132 SocketAddr::new(Ipv4Addr::new(149, 112, 112, 112).into(), 0),
133 ],
134 ),
135 (
136 "dns.adguard.com",
137 vec![
138 SocketAddr::new(Ipv4Addr::new(94, 140, 14, 14).into(), 0),
139 SocketAddr::new(Ipv4Addr::new(94, 140, 15, 15).into(), 0),
140 ],
141 ),
142 ]);
143
144 for (host, socket_addr_list) in bootstrap_result_map {
145 let result = bootstrap_client.bootstrap(host).await.unwrap();
146 assert!(socket_addr_list.contains(&result));
147 }
148 }
149}