use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
use anyhow::{Result, bail};
use core::{
mem::{self, MaybeUninit},
pin::Pin,
ptr,
};
use crate::{
dns::query::Query,
entropy::Entropy,
net::{IpAddr, Ipv4Addr, Ipv6Addr, Socket, SocketAddr},
runtime::Runtime,
};
mod private {
pub trait Sealed {}
}
macro_rules! impl_upstream {
($($n:expr),*) => {
$(
impl private::Sealed for [SocketAddr; $n] {}
impl UpstreamSized for [SocketAddr; $n] {}
)*
};
}
impl_upstream!(1, 2, 3, 4, 5, 6, 7, 8);
pub struct Upstream(usize, [MaybeUninit<SocketAddr>; 8]);
impl Upstream {
fn get_addr(&self, alt: usize) -> Option<SocketAddr> {
if alt < self.0 {
Some(unsafe { self.1[alt].assume_init() })
} else {
None
}
}
}
pub trait UpstreamSized: private::Sealed {}
impl Upstream {
const fn create<const N: usize>(addrs: [SocketAddr; N]) -> Upstream
where
[SocketAddr; N]: UpstreamSized,
{
let mut data = [MaybeUninit::uninit(); 8];
unsafe {
ptr::write(
(&mut data as *mut _ as *mut _),
(&addrs as *const _ as *const [MaybeUninit<SocketAddr>; N]).read(),
)
};
Self(N, data)
}
}
pub const GOOGLE: Upstream = Upstream::create([
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 53),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)),
53,
),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8844)),
53,
),
]);
pub const CLOUDFLARE: Upstream = Upstream::create([
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 0, 0, 1)), 53),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111)),
53,
),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1001)),
53,
),
]);
#[cfg(not(any(target_os = "linux")))]
compile_error!("unsupported platform, check back soon or make an issue/PR");
pub fn platform_resolver<R: Runtime>() -> Result<Arc<dyn Resolver<R>>> {
let mut upstreams = Vec::new();
upstreams.push(GOOGLE);
upstreams.push(CLOUDFLARE);
let config = Config { upstreams };
#[cfg(target_os = "linux")]
Ok(Arc::new(linux::Resolver::<R>::create(config)?))
}
pub struct Config {
upstreams: Vec<Upstream>,
}
pub struct Request(Vec<u8>);
impl Request {
pub fn build(entropy: &mut impl Entropy, query: Query) -> Request {
let mut packet = Vec::new();
packet.extend(entropy.entropy().unwrap().to_be_bytes()); packet.extend(&0x0100u16.to_be_bytes()); packet.extend(&1u16.to_be_bytes()); packet.extend(&0u16.to_be_bytes()); packet.extend(&0u16.to_be_bytes()); packet.extend(&0u16.to_be_bytes());
for label in query.domain.split('.') {
packet.push(label.len() as u8);
packet.extend(label.as_bytes());
}
packet.push(0);
packet.extend(&(query.ty as u16).to_be_bytes()); packet.extend(&(Class::IN as u16).to_be_bytes());
Request(packet)
}
}
#[cfg(target_os = "linux")]
mod linux {
use core::{marker::PhantomData, pin::Pin};
use crate::{
dns::{Config, Parse, Query, Request, Response},
entropy::Time,
net::Socket,
runtime::Runtime,
};
use alloc::{
boxed::Box,
string::{String, ToString},
sync::Arc,
vec,
};
use anyhow::{Result, bail};
use fast::sync::{Mutex, Waiter};
use futures::{AsyncReadExt, AsyncWriteExt, FutureExt};
pub struct Resolver<R: Runtime>(Arc<Mutex<Inner>>, PhantomData<R>);
pub struct Inner {
socket: Pin<Box<dyn Socket + Send>>,
}
pub struct Lookup {}
impl Future for Lookup {
type Output = Result<Response>;
fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
todo!()
}
}
impl<R: Runtime> super::Resolver<R> for Resolver<R> {
fn create(config: Config) -> Result<Self>
where
Self: Sized,
{
let mut range = (0..).into_iter();
let mut alt = 0;
try {
loop {
let Some(i) = range.next() else {
unreachable!();
};
match R::Socket::connect(
config.upstreams[i % config.upstreams.len()]
.get_addr(alt)
.ok_or(anyhow::Error::msg("all options exhausted"))?,
) {
Ok(socket) => {
break Self(Arc::new(Mutex::new(Inner { socket })), PhantomData);
}
_ => (),
}
if i == config.upstreams.len() - 1 {
alt += 1;
}
}
}
}
fn resolve(
&self,
query: Query,
) -> core::pin::Pin<alloc::boxed::Box<dyn Future<Output = Result<Response>>>> {
let Self(mutex, _) = self;
let mutex = mutex.clone();
Box::pin(R::spawn(async move || {
let packet = Request::build(&mut Time, query);
let mut guard = mutex.lock(Box::pin(Waiter::default())).await;
guard.socket.write_all(&packet.0).await;
let mut buf = vec![0u8; 512];
let Ok(count) = guard.socket.read(&mut buf).await else {
bail!("failed to read authority response");
};
buf.truncate(count);
Ok(Response::parse(&buf))
}))
}
}
}
pub use query::*;
mod query {
use core::borrow::Borrow;
use alloc::{
borrow::ToOwned,
string::{String, ToString},
sync::Arc,
};
pub struct Query {
pub(crate) domain: String,
pub(crate) ty: Type,
}
#[repr(u16)]
pub enum Type {
A = 1, NS = 2, CNAME = 5, SOA = 6, PTR = 12, MX = 15, TXT = 16, AAAA = 28, SRV = 33, ANY = 255, }
#[repr(u16)]
pub enum Class {
IN = 1,
}
pub fn a(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::A,
}
}
pub fn aaaa(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::AAAA,
}
}
pub fn cname(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::CNAME,
}
}
pub fn mx(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::MX,
}
}
pub fn ns(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::NS,
}
}
pub fn ptr(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::PTR,
}
}
pub fn soa(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::SOA,
}
}
pub fn srv(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::SRV,
}
}
pub fn txt(domain: &str) -> Query {
Query {
domain: domain.into(),
ty: Type::TXT,
}
}
}
pub trait Resolver<R: Runtime> {
fn create(config: Config) -> Result<Self>
where
Self: Sized;
fn resolve(&self, query: Query) -> Pin<Box<dyn Future<Output = Result<Response>>>>;
}
pub trait Authority {
fn create<T: Socket>(config: Config, socket: impl Socket) -> Self
where
Self: Sized;
fn answer(&self, query: Query) -> Pin<Box<dyn Future<Output = Result<Response>>>>;
}
pub struct Response {
pub answers: Vec<Answer>,
pub authorities: Vec<Answer>,
pub additionals: Vec<Answer>,
}
impl Parse for Response {
fn parse(response: &[u8]) -> Result<Self>
where
Self: Sized,
{
if response.len() < 12 {
bail!("invalid response");
}
let flags = u16::from_be_bytes([response[2], response[3]]);
let rcode = flags & 0x000F;
if rcode == 3 {
bail!("name not found");
} else if rcode != 0 {
bail!("server failure");
}
let question_count = u16::from_be_bytes([response[4], response[5]]);
let answer_count = u16::from_be_bytes([response[6], response[7]]);
let authority_count = u16::from_be_bytes([response[8], response[9]]);
let additional_count = u16::from_be_bytes([response[10], response[11]]);
let mut offset = 12;
let mut query_type = None;
for _ in 0..question_count {
let (name_end, _name) = util::parse_name(response, offset)?;
offset = name_end;
if offset + 4 > response.len() {
bail!("invalid response");
}
let qtype = u16::from_be_bytes([response[offset], response[offset + 1]]);
let _qclass = u16::from_be_bytes([response[offset + 2], response[offset + 3]]);
offset += 4;
if query_type.is_none() {
query_type = Some(qtype);
}
}
let mut answers = Vec::new();
for _ in 0..answer_count {
if let Ok(answer) = parse_match_record(response, &mut offset) {
answers.push(answer);
}
}
let mut authorities = Vec::new();
for _ in 0..authority_count {
if let Ok(answer) = parse_match_record(response, &mut offset) {
authorities.push(answer);
}
}
let mut additionals = Vec::new();
for _ in 0..additional_count {
if let Ok(answer) = parse_match_record(response, &mut offset) {
additionals.push(answer);
}
}
Ok(Response {
answers,
authorities,
additionals,
})
}
}
fn parse_match_record(response: &[u8], offset: &mut usize) -> Result<Answer> {
let rtype = u16::from_be_bytes([response[*offset], response[*offset + 1]]);
*offset += 10;
Ok(match unsafe { mem::transmute::<_, Type>(rtype) } {
Type::A => Answer::A(Record::<Ipv4Addr>::parse(response)?),
Type::NS => Answer::AAAA(Record::<Ipv6Addr>::parse(response)?),
Type::CNAME => todo!(),
Type::SOA => todo!(),
Type::PTR => todo!(),
Type::MX => todo!(),
Type::TXT => todo!(),
Type::AAAA => todo!(),
Type::SRV => todo!(),
Type::ANY => todo!(),
})
}
pub enum Answer {
A(Record<Ipv4Addr>),
AAAA(Record<Ipv6Addr>),
CNAME(Record<String>),
MX(Record<(u16, String)>),
TXT(Record<Vec<String>>),
PTR(Record<String>),
NS(Record<String>),
SOA(Record<StartOfAuthority>),
}
pub struct StartOfAuthority {
master: String,
responsible: String,
serial: u32,
refresh: u32,
retry: u32,
expire: u32,
minimum: u32,
}
pub struct Record<T> {
name: String,
ttl: u32,
data: T,
}
mod util {
use alloc::string::String;
use anyhow::Result;
use anyhow::bail;
pub(crate) fn skip_name(data: &[u8], mut offset: usize) -> Result<usize> {
loop {
if offset >= data.len() {
bail!("invalid response");
}
let len = data[offset];
if len == 0 {
return Ok(offset + 1);
}
if (len & 0xC0) == 0xC0 {
return Ok(offset + 2);
}
offset += 1 + len as usize;
}
}
pub fn parse_name(data: &[u8], mut offset: usize) -> Result<(usize, String)> {
let mut name = String::new();
let mut jumps = 0;
let max_jumps = 5;
let original_offset = offset;
let mut jumped = false;
loop {
if jumps > max_jumps {
bail!("invalid response");
}
if offset >= data.len() {
bail!("invalid response");
}
let len = data[offset];
if len == 0 {
offset += 1;
break;
}
if (len & 0xC0) == 0xC0 {
if offset + 1 >= data.len() {
bail!("invalid response");
}
if !jumped {
jumped = true;
}
let ptr = (((len & 0x3F) as u16) << 8) | (data[offset + 1] as u16);
offset = ptr as usize;
jumps += 1;
continue;
}
offset += 1;
if offset + len as usize > data.len() {
bail!("invalid response");
}
if !name.is_empty() {
name.push('.');
}
let Ok(x) = core::str::from_utf8(&data[offset..offset + len as usize]) else {
bail!("failed to parse")
};
name.push_str(x);
offset += len as usize;
}
let final_offset = if jumped { original_offset + 2 } else { offset };
Ok((final_offset, name))
}
}
pub trait Parse {
fn parse(response: &[u8]) -> Result<Self>
where
Self: Sized;
}
impl Parse for Record<Ipv4Addr> {
fn parse(response: &[u8]) -> Result<Self> {
if response.len() < 12 {
bail!("invalid response");
}
let flags = u16::from_be_bytes([response[2], response[3]]);
let rcode = flags & 0x000F;
if rcode == 3 {
bail!("name not found");
} else if rcode != 0 {
bail!("server failure");
}
let question_count = u16::from_be_bytes([response[4], response[5]]);
let answer_count = u16::from_be_bytes([response[6], response[7]]);
if answer_count == 0 {
bail!("no answers");
}
let mut offset = 12;
for _ in 0..question_count {
offset = util::skip_name(response, offset)?;
offset += 4; }
let mut found_name = String::new();
let mut found_ttl = 0u32;
let mut found_addr = None;
for _ in 0..answer_count {
let (name_end, name) = util::parse_name(response, offset)?;
offset = name_end;
if offset + 10 > response.len() {
bail!("invalid response");
}
let rtype = u16::from_be_bytes([response[offset], response[offset + 1]]);
let rclass = u16::from_be_bytes([response[offset + 2], response[offset + 3]]);
let ttl = u32::from_be_bytes([
response[offset + 4],
response[offset + 5],
response[offset + 6],
response[offset + 7],
]);
let rdlength = u16::from_be_bytes([response[offset + 8], response[offset + 9]]);
offset += 10;
if offset + rdlength as usize > response.len() {
bail!("invalid response");
}
if rtype == Type::A as u16 && rclass == Class::IN as u16 && rdlength == 4 {
if found_addr.is_none() {
found_name = name;
found_ttl = ttl;
found_addr = Some(Ipv4Addr::new(
response[offset],
response[offset + 1],
response[offset + 2],
response[offset + 3],
));
}
}
offset += rdlength as usize;
}
match found_addr {
Some(addr) => Ok(Record {
name: found_name,
ttl: found_ttl,
data: addr,
}),
None => bail!("no A records found"),
}
}
}
impl Parse for Record<Ipv6Addr> {
fn parse(response: &[u8]) -> Result<Self> {
if response.len() < 12 {
bail!("invalid response");
}
let flags = u16::from_be_bytes([response[2], response[3]]);
let rcode = flags & 0x000F;
if rcode == 3 {
bail!("name not found");
} else if rcode != 0 {
bail!("server failure");
}
let question_count = u16::from_be_bytes([response[4], response[5]]);
let answer_count = u16::from_be_bytes([response[6], response[7]]);
if answer_count == 0 {
bail!("no answers");
}
let mut offset = 12;
for _ in 0..question_count {
offset = util::skip_name(response, offset)?;
offset += 4; }
let mut found_name = String::new();
let mut found_ttl = 0u32;
let mut found_addr = None;
for _ in 0..answer_count {
let (name_end, name) = util::parse_name(response, offset)?;
offset = name_end;
if offset + 10 > response.len() {
bail!("invalid response");
}
let rtype = u16::from_be_bytes([response[offset], response[offset + 1]]);
let rclass = u16::from_be_bytes([response[offset + 2], response[offset + 3]]);
let ttl = u32::from_be_bytes([
response[offset + 4],
response[offset + 5],
response[offset + 6],
response[offset + 7],
]);
let rdlength = u16::from_be_bytes([response[offset + 8], response[offset + 9]]);
offset += 10;
if offset + rdlength as usize > response.len() {
bail!("invalid response");
}
if rtype == Type::AAAA as u16 && rclass == Class::IN as u16 && rdlength == 16 {
if found_addr.is_none() {
found_name = name;
found_ttl = ttl;
let mut octets = [0u8; 16];
octets.copy_from_slice(&response[offset..offset + 16]);
found_addr = Some(Ipv6Addr(unsafe { mem::transmute(octets) }));
}
}
offset += rdlength as usize;
}
match found_addr {
Some(addr) => Ok(Record {
name: found_name,
ttl: found_ttl,
data: addr,
}),
None => bail!("no AAAA records found"),
}
}
}