use core::mem;
use alloc::{
boxed::Box,
string::{String, ToString},
vec::Vec,
};
use domain::{
new::{
base::{
HeaderFlags, MessageItem, QClass, QType, Question, Record,
build::{MessageBuildError, MessageBuilder},
name::{NameCompressor, NameParseError, RevNameBuf},
parse::MessageParser,
wire::{AsBytes, U16},
},
rdata::{RecordData, Txt},
},
utils::dst::UnsizedCopy,
};
use thiserror::Error;
use url::Url;
use crate::coroutine::{DiscoveryCoroutine, DiscoveryCoroutineState, DiscoveryYield};
#[cfg(feature = "cli")]
pub(crate) const DNS_SERVER: &str = "1.1.1.1:53";
pub(crate) const DNS_QUERY_BUF_SIZE: usize = 4 * 1024;
#[derive(Debug, Error)]
pub enum DiscoveryDnsTxtError {
#[error("DNS TXT domain `{1}` is not a valid name")]
InvalidDomain(#[source] NameParseError, String),
#[error("DNS TXT query did not fit in the {DNS_QUERY_BUF_SIZE}-byte buffer")]
QueryTooLarge(#[source] MessageBuildError),
#[error("DNS TXT response could not be parsed")]
InvalidResponse(String),
}
#[derive(Debug, Default)]
enum State {
BuildQuery,
ParseResponse,
#[default]
Done,
}
#[derive(Debug)]
pub struct DiscoveryDnsTxt {
domain: String,
resolver: Url,
state: State,
wants_read: bool,
wants_write: Option<Vec<u8>>,
response: Vec<u8>,
}
impl DiscoveryDnsTxt {
pub fn new(domain: impl ToString, resolver: Url) -> Self {
Self {
domain: domain.to_string(),
resolver,
state: State::BuildQuery,
wants_read: false,
wants_write: None,
response: Vec::new(),
}
}
}
impl DiscoveryCoroutine for DiscoveryDnsTxt {
type Yield = DiscoveryYield;
type Return = Result<Vec<Record<RevNameBuf, Box<Txt>>>, DiscoveryDnsTxtError>;
fn resume(
&mut self,
mut arg: Option<&[u8]>,
) -> DiscoveryCoroutineState<Self::Yield, Self::Return> {
loop {
if let Some(bytes) = self.wants_write.take() {
return DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsWrite {
url: self.resolver.clone(),
bytes,
});
}
if mem::take(&mut self.wants_read) {
return DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsRead {
url: self.resolver.clone(),
});
}
match mem::take(&mut self.state) {
State::BuildQuery => {
let qname = match self.domain.parse::<RevNameBuf>() {
Ok(qname) => qname,
Err(err) => {
let domain = mem::take(&mut self.domain);
return DiscoveryCoroutineState::Complete(Err(
DiscoveryDnsTxtError::InvalidDomain(err, domain),
));
}
};
let mut buf = vec![0u8; DNS_QUERY_BUF_SIZE];
let mut compressor = NameCompressor::default();
let mut builder = MessageBuilder::new(
&mut buf[2..],
&mut compressor,
U16::new(1),
*HeaderFlags::default().set_rd(true),
);
let q = Question {
qname,
qtype: QType::TXT,
qclass: QClass::IN,
};
if let Err(err) = builder.push_question(&q) {
return DiscoveryCoroutineState::Complete(Err(
DiscoveryDnsTxtError::QueryTooLarge(err),
));
}
let msg_len = builder.finish().as_bytes().len();
buf[0..2].copy_from_slice(&(msg_len as u16).to_be_bytes());
buf.truncate(msg_len + 2);
self.wants_write = Some(buf);
self.wants_read = true;
self.state = State::ParseResponse;
}
State::ParseResponse => {
if let Some(bytes) = arg.take() {
self.response.extend_from_slice(bytes);
}
if self.response.len() < 2 {
self.wants_read = true;
self.state = State::ParseResponse;
continue;
}
let body_len =
u16::from_be_bytes([self.response[0], self.response[1]]) as usize;
if self.response.len() < 2 + body_len {
self.wants_read = true;
self.state = State::ParseResponse;
continue;
}
let parser = match MessageParser::new(&self.response[2..2 + body_len]) {
Ok(parser) => parser,
Err(err) => {
return DiscoveryCoroutineState::Complete(Err(
DiscoveryDnsTxtError::InvalidResponse(err.to_string()),
));
}
};
let mut records: Vec<Record<RevNameBuf, Box<Txt>>> = Vec::new();
for item in parser {
let Ok(MessageItem::Answer(record)) = item else {
continue;
};
let RecordData::Txt(txt) = record.rdata else {
continue;
};
records.push(Record {
rname: record.rname,
rtype: record.rtype,
rclass: record.rclass,
ttl: record.ttl,
rdata: txt.unsized_copy_into(),
});
}
return DiscoveryCoroutineState::Complete(Ok(records));
}
State::Done => {
panic!("DiscoveryDnsTxt::resume called after completion")
}
}
}
}
}