https_dns/
bootstrap.rs

1use crate::error::UpstreamError::{self, Bootstrap, Build};
2use crate::utils::build_request_message;
3use http::header::{ACCEPT, CONTENT_TYPE};
4use reqwest::{
5    header::{HeaderMap, HeaderValue},
6    Client,
7};
8use std::{net::SocketAddr, time::Duration};
9use trust_dns_proto::{
10    op::message::Message,
11    rr::{Name, RData, RecordType},
12};
13
14pub struct BootstrapClient {
15    https_client: Client,
16}
17
18impl BootstrapClient {
19    pub fn new() -> Result<Self, UpstreamError> {
20        let mut headers = HeaderMap::new();
21        headers.insert(
22            CONTENT_TYPE,
23            HeaderValue::from_str("application/dns-message").unwrap(),
24        );
25        headers.insert(
26            ACCEPT,
27            HeaderValue::from_str("application/dns-message").unwrap(),
28        );
29
30        let client_builder = Client::builder()
31            .default_headers(headers)
32            .https_only(true)
33            .gzip(true)
34            .brotli(true)
35            .timeout(Duration::from_secs(10));
36
37        let https_client = match client_builder.build() {
38            Ok(https_client) => https_client,
39            Err(_) => return Err(Build),
40        };
41
42        Ok(BootstrapClient { https_client })
43    }
44
45    pub async fn bootstrap(&self, host: &str) -> Result<SocketAddr, UpstreamError> {
46        let request_name = match host.parse::<Name>() {
47            Ok(request_name) => request_name,
48            Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
49        };
50        let request_message = build_request_message(request_name, RecordType::A);
51
52        let raw_request_message = match request_message.to_vec() {
53            Ok(raw_request_message) => raw_request_message,
54            Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
55        };
56
57        let url = "https://1.1.1.1/dns-query";
58        let request = self.https_client.post(url).body(raw_request_message);
59        let response = match request.send().await {
60            Ok(response) => response,
61            Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
62        };
63
64        let raw_response_message = match response.bytes().await {
65            Ok(response_bytes) => response_bytes,
66            Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
67        };
68
69        let response_message = match Message::from_vec(&raw_response_message) {
70            Ok(response_message) => response_message,
71            Err(error) => return Err(Bootstrap(host.to_string(), error.to_string())),
72        };
73
74        if response_message.answers().is_empty() {
75            return Err(Bootstrap(
76                host.to_string(),
77                String::from("the response doesn't contain the answer"),
78            ));
79        }
80        let record = &response_message.answers()[0];
81        let record_data = match record.data() {
82            Some(record_data) => record_data,
83            None => {
84                return Err(Bootstrap(
85                    host.to_string(),
86                    String::from("the response doesn't contain the answer"),
87                ))
88            }
89        };
90
91        match record_data {
92            RData::A(ipv4_address) => Ok(SocketAddr::new((*ipv4_address).into(), 0)),
93            RData::AAAA(ipv6_address) => Ok(SocketAddr::new((*ipv6_address).into(), 0)),
94            _ => Err(Bootstrap(
95                host.to_string(),
96                String::from("unknown record type"),
97            )),
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::BootstrapClient;
105    use std::{
106        collections::HashMap,
107        net::{Ipv4Addr, SocketAddr},
108    };
109
110    #[tokio::test]
111    async fn test_bootstrap() {
112        let bootstrap_client = BootstrapClient::new().unwrap();
113        let bootstrap_result_map = HashMap::from([
114            (
115                "dns.google",
116                vec![
117                    SocketAddr::new(Ipv4Addr::new(8, 8, 8, 8).into(), 0),
118                    SocketAddr::new(Ipv4Addr::new(8, 8, 4, 4).into(), 0),
119                ],
120            ),
121            (
122                "one.one.one.one",
123                vec![
124                    SocketAddr::new(Ipv4Addr::new(1, 1, 1, 1).into(), 0),
125                    SocketAddr::new(Ipv4Addr::new(1, 0, 0, 1).into(), 0),
126                ],
127            ),
128            (
129                "dns.quad9.net",
130                vec![
131                    SocketAddr::new(Ipv4Addr::new(9, 9, 9, 9).into(), 0),
132                    SocketAddr::new(Ipv4Addr::new(149, 112, 112, 112).into(), 0),
133                ],
134            ),
135            (
136                "dns.adguard.com",
137                vec![
138                    SocketAddr::new(Ipv4Addr::new(94, 140, 14, 14).into(), 0),
139                    SocketAddr::new(Ipv4Addr::new(94, 140, 15, 15).into(), 0),
140                ],
141            ),
142        ]);
143
144        for (host, socket_addr_list) in bootstrap_result_map {
145            let result = bootstrap_client.bootstrap(host).await.unwrap();
146            assert!(socket_addr_list.contains(&result));
147        }
148    }
149}