1use crate::error::{Error, Result};
19use crate::netstack::{NetStack, TcpConnection};
20use parking_lot::Mutex;
21use std::collections::HashMap;
22use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsConnector;
28
29#[derive(Debug, Clone)]
31pub struct DohServerConfig {
32 pub hostname: String,
34 pub ips: Vec<Ipv4Addr>,
37}
38
39impl DohServerConfig {
40 pub fn new(hostname: impl Into<String>, ips: Vec<Ipv4Addr>) -> Self {
42 Self {
43 hostname: hostname.into(),
44 ips,
45 }
46 }
47
48 pub fn cloudflare() -> Self {
50 Self {
51 hostname: "1dot1dot1dot1.cloudflare-dns.com".into(),
52 ips: vec![Ipv4Addr::new(1, 1, 1, 1), Ipv4Addr::new(1, 0, 0, 1)],
53 }
54 }
55
56 pub fn google() -> Self {
58 Self {
59 hostname: "dns.google".into(),
60 ips: vec![Ipv4Addr::new(8, 8, 8, 8), Ipv4Addr::new(8, 8, 4, 4)],
61 }
62 }
63
64 pub fn quad9() -> Self {
66 Self {
67 hostname: "dns.quad9.net".into(),
68 ips: vec![Ipv4Addr::new(9, 9, 9, 9), Ipv4Addr::new(149, 112, 112, 112)],
69 }
70 }
71
72 pub fn adguard() -> Self {
74 Self {
75 hostname: "dns.adguard-dns.com".into(),
76 ips: vec![Ipv4Addr::new(94, 140, 14, 14), Ipv4Addr::new(94, 140, 15, 15)],
77 }
78 }
79
80 pub fn nextdns(config_id: &str) -> Self {
83 Self {
84 hostname: format!("{}.dns.nextdns.io", config_id),
85 ips: vec![Ipv4Addr::new(45, 90, 28, 0), Ipv4Addr::new(45, 90, 30, 0)],
86 }
87 }
88}
89
90impl Default for DohServerConfig {
91 fn default() -> Self {
92 Self::cloudflare()
93 }
94}
95
96#[derive(Debug, Clone)]
101pub struct DnsConfig {
102 pub pre_connection: DohServerConfig,
105 pub post_connection: DohServerConfig,
108}
109
110impl DnsConfig {
111 pub fn new(server: DohServerConfig) -> Self {
113 Self {
114 pre_connection: server.clone(),
115 post_connection: server,
116 }
117 }
118
119 pub fn with_different_servers(pre_connection: DohServerConfig, post_connection: DohServerConfig) -> Self {
121 Self {
122 pre_connection,
123 post_connection,
124 }
125 }
126
127 pub fn cloudflare() -> Self {
129 Self::new(DohServerConfig::cloudflare())
130 }
131
132 pub fn google() -> Self {
134 Self::new(DohServerConfig::google())
135 }
136
137 pub fn quad9() -> Self {
139 Self::new(DohServerConfig::quad9())
140 }
141}
142
143impl Default for DnsConfig {
144 fn default() -> Self {
145 Self::cloudflare()
146 }
147}
148
149#[derive(Clone)]
151struct CacheEntry {
152 addresses: Vec<Ipv4Addr>,
153 expires_at: Instant,
154}
155
156#[derive(Clone)]
158enum Transport {
159 Direct,
161 Tunnel(Arc<NetStack>),
163}
164
165pub struct DohResolver {
183 transport: Transport,
184 tls_connector: TlsConnector,
185 cache: Mutex<HashMap<String, CacheEntry>>,
187 cache_ttl: Duration,
189 server_config: DohServerConfig,
191}
192
193impl DohResolver {
194 pub fn new_tunneled(netstack: Arc<NetStack>) -> Self {
196 Self::new_tunneled_with_config(netstack, DohServerConfig::default())
197 }
198
199 pub fn new_tunneled_with_config(netstack: Arc<NetStack>, config: DohServerConfig) -> Self {
201 Self::new_with_transport(Transport::Tunnel(netstack), config)
202 }
203
204 pub fn new_direct() -> Self {
207 Self::new_direct_with_config(DohServerConfig::default())
208 }
209
210 pub fn new_direct_with_config(config: DohServerConfig) -> Self {
213 Self::new_with_transport(Transport::Direct, config)
214 }
215
216 fn new_with_transport(transport: Transport, server_config: DohServerConfig) -> Self {
218 let _ = rustls::crypto::ring::default_provider().install_default();
220
221 let root_store =
223 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
224
225 let tls_config = rustls::ClientConfig::builder()
226 .with_root_certificates(root_store)
227 .with_no_client_auth();
228
229 let tls_connector = TlsConnector::from(Arc::new(tls_config));
230
231 Self {
232 transport,
233 tls_connector,
234 cache: Mutex::new(HashMap::new()),
235 cache_ttl: Duration::from_secs(300), server_config,
237 }
238 }
239
240 pub fn server_config(&self) -> &DohServerConfig {
242 &self.server_config
243 }
244
245 pub async fn resolve(&self, hostname: &str) -> Result<Vec<Ipv4Addr>> {
247 if let Ok(ip) = hostname.parse::<Ipv4Addr>() {
249 return Ok(vec![ip]);
250 }
251
252 {
254 let cache = self.cache.lock();
255 if let Some(entry) = cache.get(hostname) {
256 if entry.expires_at > Instant::now() {
257 log::debug!("DNS cache hit for {}", hostname);
258 return Ok(entry.addresses.clone());
259 }
260 }
261 }
262
263 let mode = match &self.transport {
264 Transport::Direct => "direct",
265 Transport::Tunnel(_) => "tunneled",
266 };
267 log::info!("Resolving {} via DoH ({})", hostname, mode);
268
269 let mut last_error = None;
271 for doh_ip in &self.server_config.ips {
272 match self.query_doh(*doh_ip, hostname).await {
273 Ok(addrs) => {
274 {
276 let mut cache = self.cache.lock();
277 cache.insert(
278 hostname.to_string(),
279 CacheEntry {
280 addresses: addrs.clone(),
281 expires_at: Instant::now() + self.cache_ttl,
282 },
283 );
284 }
285 return Ok(addrs);
286 }
287 Err(e) => {
288 log::warn!("DoH query to {} failed: {}", doh_ip, e);
289 last_error = Some(e);
290 }
291 }
292 }
293
294 Err(last_error.unwrap_or(Error::DnsAllServersFailed))
295 }
296
297 pub async fn resolve_addr(&self, hostname: &str, port: u16) -> Result<SocketAddr> {
299 let addrs = self.resolve(hostname).await?;
300 let ip = addrs
301 .into_iter()
302 .next()
303 .ok_or_else(|| Error::DnsNoRecords(hostname.to_string()))?;
304 Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
305 }
306
307 async fn query_doh(&self, doh_ip: Ipv4Addr, hostname: &str) -> Result<Vec<Ipv4Addr>> {
309 let addr = SocketAddr::V4(SocketAddrV4::new(doh_ip, 443));
310
311 let dns_query = build_dns_query(hostname)?;
313
314 let http_request = format!(
316 "POST /dns-query HTTP/1.1\r\n\
317 Host: {}\r\n\
318 Content-Type: application/dns-message\r\n\
319 Accept: application/dns-message\r\n\
320 Content-Length: {}\r\n\
321 Connection: close\r\n\
322 \r\n",
323 self.server_config.hostname,
324 dns_query.len()
325 );
326
327 let response = match &self.transport {
329 Transport::Direct => {
330 self.query_direct(addr, &http_request, &dns_query).await?
331 }
332 Transport::Tunnel(netstack) => {
333 self.query_tunneled(netstack.clone(), addr, &http_request, &dns_query)
334 .await?
335 }
336 };
337
338 log::debug!("Received {} bytes from DoH server", response.len());
339
340 parse_doh_response(&response, hostname)
342 }
343
344 async fn query_direct(
346 &self,
347 addr: SocketAddr,
348 http_request: &str,
349 dns_query: &[u8],
350 ) -> Result<Vec<u8>> {
351 let tcp_stream = TcpStream::connect(addr).await?;
353
354 let server_name = rustls::pki_types::ServerName::try_from(self.server_config.hostname.clone())
356 .map_err(|e| Error::TlsHandshake(format!("Invalid server name: {}", e)))?;
357
358 log::debug!("Starting TLS handshake with DoH server {} (direct)", addr);
359 let mut tls_stream = self
360 .tls_connector
361 .connect(server_name, tcp_stream)
362 .await
363 .map_err(|e| Error::TlsHandshake(e.to_string()))?;
364
365 log::debug!("TLS handshake completed, sending DNS query");
366
367 tls_stream.write_all(http_request.as_bytes()).await?;
369 tls_stream.write_all(dns_query).await?;
370 tls_stream.flush().await?;
371
372 log::debug!("DNS query sent, waiting for response");
373
374 let mut response = Vec::new();
376 let mut buf = [0u8; 4096];
377 loop {
378 match tls_stream.read(&mut buf).await {
379 Ok(0) => break,
380 Ok(n) => {
381 response.extend_from_slice(&buf[..n]);
382 if response.len() > 4 && response_complete(&response) {
383 break;
384 }
385 }
386 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
387 Err(e) => return Err(e.into()),
388 }
389 }
390
391 Ok(response)
392 }
393
394 async fn query_tunneled(
396 &self,
397 netstack: Arc<NetStack>,
398 addr: SocketAddr,
399 http_request: &str,
400 dns_query: &[u8],
401 ) -> Result<Vec<u8>> {
402 let tcp_conn = TcpConnection::connect(netstack, addr).await?;
404
405 let tcp_stream = TunnelTcpStream {
406 conn: Arc::new(tcp_conn),
407 };
408
409 let server_name = rustls::pki_types::ServerName::try_from(self.server_config.hostname.clone())
411 .map_err(|e| Error::TlsHandshake(format!("Invalid server name: {}", e)))?;
412
413 log::debug!(
414 "Starting TLS handshake with DoH server {} (tunneled)",
415 addr
416 );
417 let mut tls_stream = self
418 .tls_connector
419 .connect(server_name, tcp_stream)
420 .await
421 .map_err(|e| Error::TlsHandshake(e.to_string()))?;
422
423 log::debug!("TLS handshake completed, sending DNS query");
424
425 tls_stream.write_all(http_request.as_bytes()).await?;
427 tls_stream.write_all(dns_query).await?;
428 tls_stream.flush().await?;
429
430 log::debug!("DNS query sent, waiting for response");
431
432 let mut response = Vec::new();
434 let mut buf = [0u8; 4096];
435 loop {
436 match tls_stream.read(&mut buf).await {
437 Ok(0) => break,
438 Ok(n) => {
439 response.extend_from_slice(&buf[..n]);
440 if response.len() > 4 && response_complete(&response) {
441 break;
442 }
443 }
444 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
445 Err(e) => return Err(e.into()),
446 }
447 }
448
449 Ok(response)
450 }
451}
452
453fn response_complete(data: &[u8]) -> bool {
455 if let Some(header_end) = find_header_end(data) {
457 let headers = &data[..header_end];
459 if let Some(content_length) = parse_content_length(headers) {
460 let body_start = header_end + 4; let body_len = data.len().saturating_sub(body_start);
462 return body_len >= content_length;
463 }
464 return true;
466 }
467 false
468}
469
470fn find_header_end(data: &[u8]) -> Option<usize> {
472 for i in 0..data.len().saturating_sub(3) {
473 if &data[i..i + 4] == b"\r\n\r\n" {
474 return Some(i);
475 }
476 }
477 None
478}
479
480fn parse_content_length(headers: &[u8]) -> Option<usize> {
482 let headers_str = std::str::from_utf8(headers).ok()?;
483 for line in headers_str.lines() {
484 if line.to_lowercase().starts_with("content-length:") {
485 let value = line.split(':').nth(1)?.trim();
486 return value.parse().ok();
487 }
488 }
489 None
490}
491
492fn build_dns_query(hostname: &str) -> Result<Vec<u8>> {
494 let mut query = Vec::new();
495
496 let id: u16 = rand::random();
498 query.extend_from_slice(&id.to_be_bytes());
499
500 query.extend_from_slice(&[0x01, 0x00]); query.extend_from_slice(&[0x00, 0x01]);
505 query.extend_from_slice(&[0x00, 0x00]);
507 query.extend_from_slice(&[0x00, 0x00]);
509 query.extend_from_slice(&[0x00, 0x00]);
511
512 for label in hostname.split('.') {
515 if label.len() > 63 {
516 return Err(Error::DnsLabelTooLong(label.to_string()));
517 }
518 query.push(label.len() as u8);
519 query.extend_from_slice(label.as_bytes());
520 }
521 query.push(0); query.extend_from_slice(&[0x00, 0x01]);
525 query.extend_from_slice(&[0x00, 0x01]);
527
528 Ok(query)
529}
530
531fn parse_doh_response(response: &[u8], hostname: &str) -> Result<Vec<Ipv4Addr>> {
533 let header_end = find_header_end(response)
535 .ok_or_else(|| Error::InvalidHttpResponse("no header end found".into()))?;
536
537 let body_start = header_end + 4;
538 if body_start >= response.len() {
539 return Err(Error::InvalidHttpResponse("empty body".into()));
540 }
541
542 let headers =
544 std::str::from_utf8(&response[..header_end]).map_err(|_| Error::InvalidHttpResponse("invalid headers".into()))?;
545
546 let status_line = headers.lines().next().unwrap_or("");
547 if !status_line.contains("200") {
548 return Err(Error::DohServerError(status_line.to_string()));
549 }
550
551 let dns_response = &response[body_start..];
552 parse_dns_response(dns_response, hostname)
553}
554
555fn parse_dns_response(data: &[u8], hostname: &str) -> Result<Vec<Ipv4Addr>> {
557 if data.len() < 12 {
558 return Err(Error::DnsResponseTooShort);
559 }
560
561 let flags = u16::from_be_bytes([data[2], data[3]]);
563 let rcode = flags & 0x000F;
564
565 if rcode != 0 {
566 return Err(Error::DnsError(rcode));
567 }
568
569 let ancount = u16::from_be_bytes([data[6], data[7]]) as usize;
570 if ancount == 0 {
571 return Err(Error::DnsNoRecords(hostname.to_string()));
572 }
573
574 log::debug!("DNS response has {} answers", ancount);
575
576 let mut pos = 12;
578
579 let qdcount = u16::from_be_bytes([data[4], data[5]]) as usize;
581 for _ in 0..qdcount {
582 pos = skip_dns_name(data, pos)?;
583 pos += 4; }
585
586 let mut addresses = Vec::new();
588 for _ in 0..ancount {
589 if pos >= data.len() {
590 break;
591 }
592
593 pos = skip_dns_name(data, pos)?;
595
596 if pos + 10 > data.len() {
597 break;
598 }
599
600 let rtype = u16::from_be_bytes([data[pos], data[pos + 1]]);
601 let _rclass = u16::from_be_bytes([data[pos + 2], data[pos + 3]]);
602 let _ttl = u32::from_be_bytes([data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]]);
603 let rdlength = u16::from_be_bytes([data[pos + 8], data[pos + 9]]) as usize;
604
605 pos += 10;
606
607 if pos + rdlength > data.len() {
608 break;
609 }
610
611 if rtype == 1 && rdlength == 4 {
613 let ip = Ipv4Addr::new(data[pos], data[pos + 1], data[pos + 2], data[pos + 3]);
614 log::debug!("Resolved {} -> {}", hostname, ip);
615 addresses.push(ip);
616 }
617
618 pos += rdlength;
619 }
620
621 if addresses.is_empty() {
622 return Err(Error::DnsNoRecords(hostname.to_string()));
623 }
624
625 Ok(addresses)
626}
627
628fn skip_dns_name(data: &[u8], mut pos: usize) -> Result<usize> {
630 loop {
631 if pos >= data.len() {
632 return Err(Error::DnsNameTooLong);
633 }
634
635 let len = data[pos] as usize;
636
637 if len & 0xC0 == 0xC0 {
639 return Ok(pos + 2);
641 }
642
643 if len == 0 {
645 return Ok(pos + 1);
646 }
647
648 pos += 1 + len;
650 }
651}
652
653pub(crate) struct TunnelTcpStream {
655 conn: Arc<TcpConnection>,
656}
657
658impl tokio::io::AsyncRead for TunnelTcpStream {
659 fn poll_read(
660 self: std::pin::Pin<&mut Self>,
661 cx: &mut std::task::Context<'_>,
662 buf: &mut tokio::io::ReadBuf<'_>,
663 ) -> std::task::Poll<std::io::Result<()>> {
664 let conn = self.conn.clone();
665 let unfilled = buf.initialize_unfilled();
666
667 conn.netstack.poll();
668
669 if conn.netstack.can_recv(conn.handle) {
670 match conn.netstack.recv(conn.handle, unfilled) {
671 Ok(n) if n > 0 => {
672 buf.advance(n);
673 return std::task::Poll::Ready(Ok(()));
674 }
675 Ok(_) => {}
676 Err(e) => {
677 return std::task::Poll::Ready(Err(std::io::Error::new(
678 std::io::ErrorKind::Other,
679 e.to_string(),
680 )));
681 }
682 }
683 }
684
685 if !conn.netstack.may_recv(conn.handle) {
686 return std::task::Poll::Ready(Ok(()));
687 }
688
689 let waker = cx.waker().clone();
690 tokio::spawn(async move {
691 tokio::time::sleep(Duration::from_millis(1)).await;
692 waker.wake();
693 });
694
695 std::task::Poll::Pending
696 }
697}
698
699impl tokio::io::AsyncWrite for TunnelTcpStream {
700 fn poll_write(
701 self: std::pin::Pin<&mut Self>,
702 cx: &mut std::task::Context<'_>,
703 buf: &[u8],
704 ) -> std::task::Poll<std::io::Result<usize>> {
705 let conn = self.conn.clone();
706
707 conn.netstack.poll();
708
709 if conn.netstack.can_send(conn.handle) {
710 match conn.netstack.send(conn.handle, buf) {
711 Ok(n) => {
712 conn.netstack.poll();
713 return std::task::Poll::Ready(Ok(n));
714 }
715 Err(e) => {
716 return std::task::Poll::Ready(Err(std::io::Error::new(
717 std::io::ErrorKind::Other,
718 e.to_string(),
719 )));
720 }
721 }
722 }
723
724 if !conn.netstack.may_send(conn.handle) {
725 return std::task::Poll::Ready(Err(std::io::Error::new(
726 std::io::ErrorKind::BrokenPipe,
727 "Connection closed",
728 )));
729 }
730
731 let waker = cx.waker().clone();
732 tokio::spawn(async move {
733 tokio::time::sleep(Duration::from_millis(1)).await;
734 waker.wake();
735 });
736
737 std::task::Poll::Pending
738 }
739
740 fn poll_flush(
741 self: std::pin::Pin<&mut Self>,
742 _cx: &mut std::task::Context<'_>,
743 ) -> std::task::Poll<std::io::Result<()>> {
744 self.conn.netstack.poll();
745 std::task::Poll::Ready(Ok(()))
746 }
747
748 fn poll_shutdown(
749 self: std::pin::Pin<&mut Self>,
750 _cx: &mut std::task::Context<'_>,
751 ) -> std::task::Poll<std::io::Result<()>> {
752 self.conn.shutdown();
753 self.conn.netstack.poll();
754 std::task::Poll::Ready(Ok(()))
755 }
756}