use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use arcbox_dns::{DEFAULT_TTL, DnsQuery, DnsRecordType};
use tokio::net::UdpSocket;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
const GATEWAY: SocketAddr = SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(10, 0, 2, 1)), 53);
const LOCAL_DOMAIN: &str = "arcbox.local";
const MAX_PACKET: usize = 512;
pub type SandboxRegistry = Arc<std::sync::RwLock<HashMap<String, Ipv4Addr>>>;
static SANDBOX_REGISTRY: std::sync::OnceLock<SandboxRegistry> = std::sync::OnceLock::new();
pub fn sandbox_registry() -> &'static SandboxRegistry {
SANDBOX_REGISTRY.get_or_init(|| Arc::new(std::sync::RwLock::new(HashMap::new())))
}
const FORWARD_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
pub type ContainerRegistry = Arc<RwLock<HashMap<String, Ipv4Addr>>>;
type AliasOwners = Arc<RwLock<HashMap<String, HashMap<String, Ipv4Addr>>>>;
pub struct GuestDnsServer {
containers: ContainerRegistry,
alias_owners: AliasOwners,
sandboxes: SandboxRegistry,
cancel: CancellationToken,
}
impl GuestDnsServer {
pub fn new(cancel: CancellationToken) -> Self {
Self {
containers: Arc::new(RwLock::new(HashMap::new())),
alias_owners: Arc::new(RwLock::new(HashMap::new())),
sandboxes: Arc::clone(sandbox_registry()),
cancel,
}
}
#[allow(dead_code)]
pub fn containers(&self) -> ContainerRegistry {
Arc::clone(&self.containers)
}
#[allow(dead_code)]
pub fn sandboxes(&self) -> SandboxRegistry {
Arc::clone(&self.sandboxes)
}
pub async fn register_container(&self, alias: &str, owner: &str, ip: Ipv4Addr) {
let key = alias.to_lowercase();
let owner_key = owner.to_lowercase();
tracing::debug!(alias = %key, owner = %owner_key, %ip, "dns: register container");
self.alias_owners
.write()
.await
.entry(key.clone())
.or_default()
.insert(owner_key, ip);
self.containers.write().await.insert(key, ip);
}
pub async fn deregister_container(&self, alias: &str, owner: &str) {
let key = alias.to_lowercase();
let owner_key = owner.to_lowercase();
tracing::debug!(alias = %key, owner = %owner_key, "dns: deregister container");
let mut owners = self.alias_owners.write().await;
if let Some(map) = owners.get_mut(&key) {
map.remove(&owner_key);
if map.is_empty() {
owners.remove(&key);
self.containers.write().await.remove(&key);
} else if let Some(&remaining_ip) = map.values().next() {
self.containers.write().await.insert(key, remaining_ip);
}
}
}
#[allow(dead_code)]
pub fn register_sandbox(&self, id: &str, ip: Ipv4Addr) {
let key = id.to_lowercase();
tracing::debug!(id = %key, %ip, "dns: register sandbox");
if let Ok(mut map) = self.sandboxes.write() {
map.insert(key, ip);
}
}
#[allow(dead_code)]
pub fn deregister_sandbox(&self, id: &str) {
let key = id.to_lowercase();
tracing::debug!(id = %key, "dns: deregister sandbox");
if let Ok(mut map) = self.sandboxes.write() {
map.remove(&key);
}
}
pub async fn run(&self) -> anyhow::Result<()> {
let sock = UdpSocket::bind("0.0.0.0:53").await?;
tracing::info!("guest DNS server listening on 0.0.0.0:53");
let mut buf = [0u8; MAX_PACKET];
loop {
tokio::select! {
() = self.cancel.cancelled() => {
tracing::info!("guest DNS server shutting down");
return Ok(());
}
result = sock.recv_from(&mut buf) => {
let (len, peer) = result?;
let data = &buf[..len];
match self.handle_query(data).await {
Ok(response) => {
if let Err(e) = sock.send_to(&response, peer).await {
tracing::warn!(error = %e, "dns: failed to send response");
}
}
Err(e) => {
tracing::debug!(error = %e, "dns: query parse failed, forwarding to gateway");
match self.forward_to_gateway(data).await {
Ok(response) => {
let _ = sock.send_to(&response, peer).await;
}
Err(_) => {
if let Ok(fail) = arcbox_dns::build_servfail(data) {
let _ = sock.send_to(&fail, peer).await;
}
}
}
}
}
}
}
}
}
async fn handle_query(&self, data: &[u8]) -> anyhow::Result<Vec<u8>> {
let query = DnsQuery::parse(data)?;
let name_lower = query.name.to_lowercase();
let is_a_query = query.qtype == DnsRecordType::A;
if let Some(&ip) = self.containers.read().await.get(&name_lower) {
if is_a_query {
return Ok(arcbox_dns::build_response_a(data, ip, DEFAULT_TTL)?);
}
return Ok(arcbox_dns::build_nodata(data)?);
}
let bare_name = name_lower
.strip_suffix(&format!(".{LOCAL_DOMAIN}"))
.unwrap_or(&name_lower);
if bare_name != name_lower {
if let Some(&ip) = self.containers.read().await.get(bare_name) {
if is_a_query {
return Ok(arcbox_dns::build_response_a(data, ip, DEFAULT_TTL)?);
}
return Ok(arcbox_dns::build_nodata(data)?);
}
}
if let Some(ip) = self
.sandboxes
.read()
.ok()
.and_then(|g| g.get(&name_lower).copied())
{
if is_a_query {
return Ok(arcbox_dns::build_response_a(data, ip, DEFAULT_TTL)?);
}
return Ok(arcbox_dns::build_nodata(data)?);
}
if bare_name != name_lower {
if let Some(ip) = self
.sandboxes
.read()
.ok()
.and_then(|g| g.get(bare_name).copied())
{
if is_a_query {
return Ok(arcbox_dns::build_response_a(data, ip, DEFAULT_TTL)?);
}
return Ok(arcbox_dns::build_nodata(data)?);
}
}
if (name_lower == LOCAL_DOMAIN || name_lower.ends_with(&format!(".{LOCAL_DOMAIN}")))
&& matches!(query.qtype, DnsRecordType::A | DnsRecordType::Aaaa)
{
return Ok(arcbox_dns::build_nxdomain(data)?);
}
self.forward_to_gateway(data).await
}
async fn forward_to_gateway(&self, data: &[u8]) -> anyhow::Result<Vec<u8>> {
let sock = UdpSocket::bind("0.0.0.0:0").await?;
sock.send_to(data, GATEWAY).await?;
let mut buf = [0u8; MAX_PACKET];
let len = tokio::time::timeout(FORWARD_TIMEOUT, sock.recv(&mut buf)).await??;
Ok(buf[..len].to_vec())
}
}