use std::io::{Read, Write};
use std::net::{Ipv4Addr, Shutdown, SocketAddr, SocketAddrV4, TcpListener, TcpStream, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use anyhow::{Context, Result};
pub struct DnsHandle {
stop: Arc<AtomicBool>,
joins: Vec<JoinHandle<Result<()>>>,
}
impl DnsHandle {
pub fn start(bind_ip: Ipv4Addr, upstream_ip: Ipv4Addr) -> Result<Self> {
let bind_addr = SocketAddrV4::new(bind_ip, 53);
let upstream_addr = SocketAddrV4::new(upstream_ip, 53);
let udp = UdpSocket::bind(bind_addr)
.with_context(|| format!("failed to bind UDP DNS forwarder on {bind_addr}"))?;
udp.set_read_timeout(Some(Duration::from_millis(250)))
.context("failed to set UDP DNS forwarder read timeout")?;
let tcp = TcpListener::bind(bind_addr)
.with_context(|| format!("failed to bind TCP DNS forwarder on {bind_addr}"))?;
tcp.set_nonblocking(true)
.context("failed to set TCP DNS forwarder nonblocking")?;
let stop = Arc::new(AtomicBool::new(false));
let udp_stop = Arc::clone(&stop);
let udp_join = thread::spawn(move || udp_loop(udp, upstream_addr, udp_stop));
let tcp_stop = Arc::clone(&stop);
let tcp_join = thread::spawn(move || tcp_loop(tcp, upstream_addr, tcp_stop));
Ok(Self {
stop,
joins: vec![udp_join, tcp_join],
})
}
fn stop_and_join(&mut self) {
self.stop.store(true, Ordering::Relaxed);
while let Some(join) = self.joins.pop() {
let _ = join.join();
}
}
}
impl Drop for DnsHandle {
fn drop(&mut self) {
self.stop_and_join();
}
}
fn udp_loop(socket: UdpSocket, upstream_addr: SocketAddrV4, stop: Arc<AtomicBool>) -> Result<()> {
let mut buf = [0_u8; 4096];
while !stop.load(Ordering::Relaxed) {
match socket.recv_from(&mut buf) {
Ok((n, peer)) => {
let SocketAddr::V4(peer) = peer else {
continue;
};
let response = forward_udp_query(&buf[..n], upstream_addr)?;
socket
.send_to(&response, peer)
.with_context(|| format!("failed to return UDP DNS response to {peer}"))?;
}
Err(err)
if err.kind() == std::io::ErrorKind::WouldBlock
|| err.kind() == std::io::ErrorKind::TimedOut => {}
Err(err) => return Err(err).context("UDP DNS forwarder recv_from failed"),
}
}
Ok(())
}
fn forward_udp_query(query: &[u8], upstream_addr: SocketAddrV4) -> Result<Vec<u8>> {
let upstream = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
.context("failed to bind UDP upstream socket for DNS forwarder")?;
upstream
.set_read_timeout(Some(Duration::from_secs(3)))
.context("failed to set UDP upstream read timeout")?;
upstream
.connect(upstream_addr)
.with_context(|| format!("failed to connect UDP DNS upstream {upstream_addr}"))?;
upstream
.send(query)
.with_context(|| format!("failed to send UDP DNS query to {upstream_addr}"))?;
let mut buf = [0_u8; 4096];
let n = upstream
.recv(&mut buf)
.with_context(|| format!("failed to receive UDP DNS response from {upstream_addr}"))?;
Ok(buf[..n].to_vec())
}
fn tcp_loop(
listener: TcpListener,
upstream_addr: SocketAddrV4,
stop: Arc<AtomicBool>,
) -> Result<()> {
while !stop.load(Ordering::Relaxed) {
match listener.accept() {
Ok((client, _)) => {
thread::spawn(move || {
let _ = handle_tcp_connection(client, upstream_addr);
});
}
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(100));
}
Err(err) => return Err(err).context("TCP DNS forwarder accept failed"),
}
}
Ok(())
}
fn handle_tcp_connection(mut client: TcpStream, upstream_addr: SocketAddrV4) -> Result<()> {
let mut upstream = TcpStream::connect(upstream_addr)
.with_context(|| format!("failed to connect TCP DNS upstream {upstream_addr}"))?;
relay_bidirectional(&mut client, &mut upstream)?;
Ok(())
}
fn relay_bidirectional(left: &mut TcpStream, right: &mut TcpStream) -> Result<()> {
let mut left_reader = left
.try_clone()
.context("failed to clone DNS client stream")?;
let mut left_writer = left
.try_clone()
.context("failed to clone DNS client writer")?;
let mut right_reader = right
.try_clone()
.context("failed to clone DNS upstream stream")?;
let mut right_writer = right
.try_clone()
.context("failed to clone DNS upstream writer")?;
let left_to_right = thread::spawn(move || -> std::io::Result<u64> {
let copied = std::io::copy(&mut left_reader, &mut right_writer)?;
let _ = right_writer.shutdown(Shutdown::Write);
Ok(copied)
});
let right_to_left = thread::spawn(move || -> std::io::Result<u64> {
let copied = std::io::copy(&mut right_reader, &mut left_writer)?;
let _ = left_writer.shutdown(Shutdown::Write);
Ok(copied)
});
let _ = left_to_right
.join()
.map_err(|_| anyhow::anyhow!("DNS TCP client->upstream relay thread panicked"))?
.context("DNS TCP client->upstream relay failed")?;
let _ = right_to_left
.join()
.map_err(|_| anyhow::anyhow!("DNS TCP upstream->client relay thread panicked"))?
.context("DNS TCP upstream->client relay failed")?;
Ok(())
}