use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc;
use bytes::Bytes;
use smoltcp::iface::SocketSet;
use smoltcp::socket::udp;
use smoltcp::storage::PacketMetadata;
use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint};
use tokio::sync::mpsc;
use super::common::config::NormalizedDnsConfig;
use super::forwarder::{DnsForwarder, DnsForwarderHandle};
use super::proxies::udp::UdpProxy;
use crate::config::DnsConfig;
use crate::policy::NetworkPolicy;
use crate::shared::SharedState;
use crate::stack::GatewayIps;
const DNS_PORT: u16 = 53;
const DNS_MAX_SIZE: usize = 4096;
const DNS_SOCKET_PACKET_SLOTS: usize = 16;
const CHANNEL_CAPACITY: usize = 64;
pub(crate) struct DnsInterceptor {
socket_handle: smoltcp::iface::SocketHandle,
query_tx: mpsc::Sender<DnsQuery>,
response_rx: mpsc::Receiver<DnsResponse>,
}
pub(crate) struct DnsQuery {
pub(super) data: Bytes,
pub(super) source: IpEndpoint,
pub(super) original_dst: Option<IpAddress>,
}
pub(crate) struct DnsResponse {
pub(crate) data: Bytes,
pub(super) dest: IpEndpoint,
pub(super) source_addr: Option<IpAddress>,
}
impl DnsInterceptor {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
sockets: &mut SocketSet<'_>,
dns_config: DnsConfig,
shared: Arc<SharedState>,
tokio_handle: &tokio::runtime::Handle,
gateway_ips: Arc<HashSet<IpAddr>>,
network_policy: Arc<NetworkPolicy>,
gateway: GatewayIps,
) -> (Self, DnsForwarderHandle) {
let rx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
let rx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
let tx_meta = vec![PacketMetadata::EMPTY; DNS_SOCKET_PACKET_SLOTS];
let tx_payload = vec![0u8; DNS_MAX_SIZE * DNS_SOCKET_PACKET_SLOTS];
let mut socket = udp::Socket::new(
udp::PacketBuffer::new(rx_meta, rx_payload),
udp::PacketBuffer::new(tx_meta, tx_payload),
);
socket
.bind(IpListenEndpoint {
addr: None,
port: DNS_PORT,
})
.expect("failed to bind DNS socket to port 53");
let socket_handle = sockets.add(socket);
let (query_tx, query_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (response_tx, response_rx) = mpsc::channel(CHANNEL_CAPACITY);
let normalized = Arc::new(NormalizedDnsConfig::from_config(dns_config));
let forwarder_handle = DnsForwarder::spawn(
tokio_handle,
normalized,
gateway_ips,
network_policy,
shared.clone(),
gateway,
);
UdpProxy::spawn(
tokio_handle,
query_rx,
response_tx,
forwarder_handle.clone(),
shared,
);
(
Self {
socket_handle,
query_tx,
response_rx,
},
forwarder_handle,
)
}
pub(crate) fn process(&mut self, sockets: &mut SocketSet<'_>) {
let socket = sockets.get_mut::<udp::Socket>(self.socket_handle);
let mut buf = [0u8; DNS_MAX_SIZE];
while socket.can_recv() {
match socket.recv_slice(&mut buf) {
Ok((n, meta)) => {
let query = DnsQuery {
data: Bytes::copy_from_slice(&buf[..n]),
source: meta.endpoint,
original_dst: meta.local_address,
};
if self.query_tx.try_send(query).is_err() {
tracing::debug!("DNS query channel full, dropping query");
}
}
Err(_) => break,
}
}
while socket.can_send() {
match self.response_rx.try_recv() {
Ok(response) => {
let mut meta = udp::UdpMetadata::from(response.dest);
meta.local_address = response.source_addr;
let _ = socket.send_slice(&response.data, meta);
}
Err(_) => break,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use smoltcp::iface::{Config, Interface, SocketSet};
use smoltcp::time::Instant;
use smoltcp::wire::{HardwareAddress, IpCidr, Ipv4Address, Ipv6Address};
use crate::device::SmoltcpDevice;
#[test]
fn udp_socket_bind_accepts_ipv6_endpoint() {
let mut socket = udp::Socket::new(
udp::PacketBuffer::new(
vec![smoltcp::storage::PacketMetadata::EMPTY; 4],
vec![0u8; 1024],
),
udp::PacketBuffer::new(
vec![smoltcp::storage::PacketMetadata::EMPTY; 4],
vec![0u8; 1024],
),
);
socket
.bind(IpListenEndpoint {
addr: None,
port: DNS_PORT,
})
.expect("bind addr:None succeeds");
let v6: Ipv6Address = "fd42::1".parse().unwrap();
let v6_dest = IpEndpoint {
addr: IpAddress::Ipv6(v6),
port: 12345,
};
let mut meta = udp::UdpMetadata::from(v6_dest);
meta.local_address = Some(IpAddress::Ipv6("2606:4700:4700::1111".parse().unwrap()));
socket
.send_slice(b"v6 reply payload", meta)
.expect("v6 send accepted by socket bound addr:None");
let v4_dest = IpEndpoint {
addr: IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 2)),
port: 12345,
};
let mut meta_v4 = udp::UdpMetadata::from(v4_dest);
meta_v4.local_address = Some(IpAddress::Ipv4(Ipv4Address::new(1, 1, 1, 1)));
socket
.send_slice(b"v4 reply payload", meta_v4)
.expect("v4 send accepted by socket bound addr:None");
}
#[test]
fn ipv6_udp_dns_packet_is_captured_with_local_address() {
let shared = Arc::new(SharedState::new(8));
let mtu = 1500;
let mut device = SmoltcpDevice::new(shared.clone(), mtu);
let gateway_v6: Ipv6Address = "fd42::1".parse().unwrap();
let guest_v6: Ipv6Address = "fd42::2".parse().unwrap();
let resolver_v6: Ipv6Address = "2606:4700:4700::1111".parse().unwrap();
let hw_addr =
HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([0x02, 0, 0, 0, 0, 1]));
let mut iface = Interface::new(Config::new(hw_addr), &mut device, Instant::from_millis(0));
iface.update_ip_addrs(|addrs| {
addrs
.push(IpCidr::new(IpAddress::Ipv6(gateway_v6), 64))
.unwrap();
});
iface
.routes_mut()
.add_default_ipv6_route(gateway_v6)
.unwrap();
iface.set_any_ip(true);
let mut sockets = SocketSet::new(vec![]);
let mut socket = udp::Socket::new(
udp::PacketBuffer::new(
vec![smoltcp::storage::PacketMetadata::EMPTY; 4],
vec![0u8; 1024],
),
udp::PacketBuffer::new(
vec![smoltcp::storage::PacketMetadata::EMPTY; 4],
vec![0u8; 1024],
),
);
socket
.bind(IpListenEndpoint {
addr: None,
port: DNS_PORT,
})
.unwrap();
let handle = sockets.add(socket);
let payload = [0xDE, 0xAD, 0xBE, 0xEF];
let mut frame = build_ipv6_udp_frame(
[0x02, 0, 0, 0, 0, 2], [0x02, 0, 0, 0, 0, 1], guest_v6,
resolver_v6,
33333,
DNS_PORT,
&payload,
);
shared.tx_ring.push(std::mem::take(&mut frame)).unwrap();
let _ = device.stage_next_frame().expect("frame staged");
let _ = iface.poll(Instant::from_millis(0), &mut device, &mut sockets);
let socket = sockets.get_mut::<udp::Socket>(handle);
let mut buf = [0u8; 1024];
let (n, meta) = socket.recv_slice(&mut buf).expect("v6 DNS packet captured");
assert_eq!(&buf[..n], &payload);
assert_eq!(
meta.local_address,
Some(IpAddress::Ipv6(resolver_v6)),
"interceptor sees the original v6 destination, not the gateway IP"
);
}
fn build_ipv6_udp_frame(
src_mac: [u8; 6],
dst_mac: [u8; 6],
src_ip: Ipv6Address,
dst_ip: Ipv6Address,
src_port: u16,
dst_port: u16,
payload: &[u8],
) -> Vec<u8> {
use smoltcp::phy::ChecksumCapabilities;
use smoltcp::wire::{
EthernetAddress, EthernetFrame, EthernetProtocol, EthernetRepr, IpProtocol, Ipv6Packet,
Ipv6Repr, UdpPacket, UdpRepr,
};
let udp_repr = UdpRepr { src_port, dst_port };
let ipv6_repr = Ipv6Repr {
src_addr: src_ip,
dst_addr: dst_ip,
next_header: IpProtocol::Udp,
payload_len: 8 + payload.len(),
hop_limit: 64,
};
let ipv6_hdr_len = 40;
let mut frame = vec![0u8; 14 + ipv6_hdr_len + 8 + payload.len()];
EthernetRepr {
src_addr: EthernetAddress(src_mac),
dst_addr: EthernetAddress(dst_mac),
ethertype: EthernetProtocol::Ipv6,
}
.emit(&mut EthernetFrame::new_unchecked(&mut frame));
ipv6_repr.emit(&mut Ipv6Packet::new_unchecked(
&mut frame[14..14 + ipv6_hdr_len],
));
udp_repr.emit(
&mut UdpPacket::new_unchecked(&mut frame[14 + ipv6_hdr_len..]),
&IpAddress::Ipv6(src_ip),
&IpAddress::Ipv6(dst_ip),
payload.len(),
|buf| buf.copy_from_slice(payload),
&ChecksumCapabilities::default(),
);
frame
}
}