1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use crate::bootstrap::BootstrapHttpsClient;
use crate::cache::Cache;
use crate::error::UpstreamError;
use reqwest::{header, Client};
use std::{net::IpAddr, time::Duration};
use trust_dns_proto::op::message::Message;
#[derive(Clone)]
pub struct UpstreamHttpsClient {
host: String,
port: u16,
https_client: Client,
cache: Cache,
}
impl UpstreamHttpsClient {
pub async fn new(host: String, port: u16) -> Self {
let mut headers = header::HeaderMap::new();
headers.insert(
"Content-Type",
header::HeaderValue::from_static("application/dns-message"),
);
let mut client_builder = Client::builder()
.default_headers(headers)
.https_only(true)
.gzip(true)
.brotli(true)
.timeout(Duration::from_secs(10));
if host.as_str().parse::<IpAddr>().is_err() {
let bootstrap_https_client = BootstrapHttpsClient::new();
let ip_addr = match bootstrap_https_client.bootstrap(host.clone()).await {
Ok(ip_addr) => ip_addr,
Err(_) => panic!("[upstream] failed to bootstrap the DNS-over-HTTPS client"),
};
client_builder = client_builder.resolve(host.as_str(), ip_addr);
}
let https_client = match client_builder.build() {
Ok(https_client) => {
println!("[upstream] connected to https://{}:{}", host, port);
https_client
}
Err(_) => panic!("[upstream] failed to build the HTTPS client"),
};
UpstreamHttpsClient {
host,
port,
https_client,
cache: Cache::new(),
}
}
pub async fn process(&mut self, request_message: Message) -> Result<Message, UpstreamError> {
if let Some(response_message) = self.cache.get(&request_message) {
return Ok(response_message);
}
let raw_request_message = match request_message.to_vec() {
Ok(raw_request_message) => raw_request_message,
Err(error) => {
return Err(error.into());
}
};
let url = format!("https://{}:{}/dns-query", self.host, self.port);
let request = self.https_client.post(url).body(raw_request_message);
let response = match request.send().await {
Ok(response) => response,
Err(error) => {
return Err(error.into());
}
};
let raw_response_message = match response.bytes().await {
Ok(response_bytes) => response_bytes,
Err(error) => {
return Err(error.into());
}
};
let message = match Message::from_vec(&raw_response_message) {
Ok(message) => message,
Err(error) => {
return Err(error.into());
}
};
self.cache.put(message.clone());
Ok(message)
}
}
#[cfg(test)]
mod tests {
use super::BootstrapHttpsClient;
use std::net::{Ipv4Addr, SocketAddr};
#[tokio::test]
async fn test_bootstrap() {
let bootstrap_https_client = BootstrapHttpsClient::new();
let host = String::from("dns.google");
let ip_addr = match bootstrap_https_client.bootstrap(host).await {
Ok(ip_addr) => ip_addr,
Err(_) => panic!("[test] failed to bootstrap the DNS-over-HTTPS service"),
};
let expected_ip_addr = [
SocketAddr::new(Ipv4Addr::new(8, 8, 8, 8).into(), 0),
SocketAddr::new(Ipv4Addr::new(8, 8, 4, 4).into(), 0),
];
assert!(expected_ip_addr.contains(&ip_addr));
}
}