1use std::io;
2use std::net::{IpAddr, SocketAddr};
3use std::ops::Deref;
4use std::str::FromStr;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::time::Duration;
8
9use domain::base::iana::{Rcode, Rtype};
10use domain::base::message::Message;
11use domain::base::message_builder::{AdditionalBuilder, MessageBuilder, StreamTarget};
12use domain::base::name::{Dname, ToDname};
13use domain::base::question::Question;
14use domain::rdata::A;
15use lru_time_cache::LruCache;
16use octseq::array::Array;
17
18const DEFAULT_CACHE_EXPIRE: Duration = Duration::from_secs(10 * 60);
19
20#[cfg(not(feature = "tokio-runtime"))]
21use futures_util::{AsyncReadExt, AsyncWriteExt};
22
23#[cfg(feature = "slings-runtime")]
24use slings::{
25 net::{TcpStream, UdpSocket},
26 time::timeout,
27};
28
29#[cfg(feature = "awak-runtime")]
30use awak::{
31 net::{TcpStream, UdpSocket},
32 time::timeout,
33};
34
35#[cfg(feature = "tokio-runtime")]
36use tokio::{
37 io::{AsyncReadExt, AsyncWriteExt},
38 net::{TcpStream, UdpSocket},
39 time::timeout,
40};
41
42mod conf;
43
44pub use conf::{ResolvConf, ResolvOptions};
45use conf::{ServerConf, Transport};
46
47const RETRY_RANDOM_PORT: usize = 10;
48
49pub struct Resolver {
50 preferred: ServerList,
51 stream: ServerList,
52 options: ResolvOptions,
53 lru_cache: Mutex<LruCache<String, Vec<IpAddr>>>,
54}
55
56impl Resolver {
57 pub fn new() -> Self {
58 Self::from_conf(ResolvConf::default())
59 }
60
61 pub fn from_conf(conf: ResolvConf) -> Self {
62 Resolver {
63 preferred: ServerList::from_conf(&conf, |s| s.transport.is_preferred()),
64 stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
65 options: conf.options,
66 lru_cache: Mutex::new(LruCache::with_expiry_duration(DEFAULT_CACHE_EXPIRE)),
67 }
68 }
69
70 fn options(&self) -> &ResolvOptions {
71 &self.options
72 }
73
74 pub async fn query<N: ToDname, Q: Into<Question<N>>>(&self, question: Q) -> io::Result<Answer> {
75 Query::new(self)?
76 .run(Query::create_message(question.into()))
77 .await
78 }
79
80 fn try_resolve_from_cache(&self, key: &str) -> Option<Vec<IpAddr>> {
81 self.lru_cache.lock().unwrap().get(key).cloned()
82 }
83
84 fn insert_into_cache(&self, key: &str, val: Vec<IpAddr>) {
85 self.lru_cache.lock().unwrap().insert(key.to_string(), val);
86 }
87
88 pub async fn lookup_host<T: AsRef<str>>(&self, host: T) -> io::Result<Vec<IpAddr>> {
89 let host = &host.as_ref();
90 if let Some(v) = self.try_resolve_from_cache(host) {
91 return Ok(v);
92 }
93
94 let qname = &Dname::<Vec<u8>>::from_str(host)
95 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
96 let answer = self.query((&qname, Rtype::A)).await?;
97 let name = answer.canonical_name();
98 let records = answer
99 .answer()
100 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
101 .limit_to::<A>();
102
103 let mut ips = vec![];
104 for record in records.flatten() {
105 if Some(*record.owner()) == name {
106 ips.push(record.data().addr().into());
107 }
108 }
109 self.insert_into_cache(host, ips.clone());
110 Ok(ips)
111 }
112
113 pub async fn query_message(&self, message: QueryMessage) -> io::Result<Answer> {
114 Query::new(self)?.run(message).await
115 }
116}
117
118impl Default for Resolver {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124pub struct Query<'a> {
125 resolver: &'a Resolver,
126 preferred: bool,
127 attempt: usize,
128 counter: ServerListCounter,
129 error: io::Result<Answer>,
130}
131
132impl<'a> Query<'a> {
133 pub fn new(resolver: &'a Resolver) -> io::Result<Self> {
134 let (preferred, counter) = if resolver.options().use_vc || resolver.preferred.is_empty() {
135 if resolver.stream.is_empty() {
136 return Err(io::Error::new(
137 io::ErrorKind::NotFound,
138 "no servers available",
139 ));
140 }
141 (false, resolver.stream.counter(resolver.options().rotate))
142 } else {
143 (true, resolver.preferred.counter(resolver.options().rotate))
144 };
145 Ok(Query {
146 resolver,
147 preferred,
148 attempt: 0,
149 counter,
150 error: Err(io::Error::new(io::ErrorKind::TimedOut, "all timed out")),
151 })
152 }
153
154 pub async fn run(mut self, mut message: QueryMessage) -> io::Result<Answer> {
155 loop {
156 match self.run_query(&mut message).await {
157 Ok(answer) => {
158 if answer.header().rcode() == Rcode::FormErr
159 && self.current_server().does_edns()
160 {
161 self.current_server().disable_edns();
162 continue;
163 } else if answer.header().rcode() == Rcode::ServFail {
164 self.update_error_servfail(answer);
165 } else if answer.header().tc()
166 && self.preferred
167 && !self.resolver.options().ign_tc
168 {
169 if self.switch_to_stream() {
170 continue;
171 } else {
172 return Ok(answer);
173 }
174 } else {
175 return Ok(answer);
176 }
177 }
178 Err(err) => self.update_error(err),
179 }
180 if !self.next_server() {
181 return self.error;
182 }
183 }
184 }
185
186 fn create_message(question: Question<impl ToDname>) -> QueryMessage {
187 let mut message =
188 MessageBuilder::from_target(StreamTarget::new(Default::default()).unwrap()).unwrap();
189 message.header_mut().set_rd(true);
190 let mut message = message.question();
191 message.push(question).unwrap();
192 message.additional()
193 }
194
195 async fn run_query(&mut self, message: &mut QueryMessage) -> io::Result<Answer> {
196 let server = self.current_server();
197 server.prepare_message(message);
198 server.query(message).await
199 }
200
201 fn current_server(&self) -> &ServerInfo {
202 let list = if self.preferred {
203 &self.resolver.preferred
204 } else {
205 &self.resolver.stream
206 };
207 self.counter.info(list)
208 }
209
210 fn update_error(&mut self, err: io::Error) {
211 if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
212 self.error = Err(err)
213 }
214 }
215
216 fn update_error_servfail(&mut self, answer: Answer) {
217 self.error = Ok(answer)
218 }
219
220 fn switch_to_stream(&mut self) -> bool {
221 if !self.preferred {
222 return false;
223 }
224 self.preferred = false;
225 self.attempt = 0;
226 self.counter = self.resolver.stream.counter(self.resolver.options().rotate);
227 true
228 }
229
230 fn next_server(&mut self) -> bool {
231 if self.counter.next() {
232 return true;
233 }
234 self.attempt += 1;
235 if self.attempt >= self.resolver.options().attempts {
236 return false;
237 }
238 self.counter = if self.preferred {
239 self.resolver
240 .preferred
241 .counter(self.resolver.options().rotate)
242 } else {
243 self.resolver.stream.counter(self.resolver.options().rotate)
244 };
245 true
246 }
247}
248
249pub type QueryMessage = AdditionalBuilder<StreamTarget<Array<512>>>;
250
251#[derive(Clone)]
252pub struct Answer {
253 message: Message<Vec<u8>>,
254}
255
256impl Answer {
257 pub fn is_final(&self) -> bool {
258 (self.message.header().rcode() == Rcode::NoError
259 || self.message.header().rcode() == Rcode::NXDomain)
260 && !self.message.header().tc()
261 }
262
263 pub fn is_truncated(&self) -> bool {
264 self.message.header().tc()
265 }
266
267 pub fn into_message(self) -> Message<Vec<u8>> {
268 self.message
269 }
270}
271
272impl From<Message<Vec<u8>>> for Answer {
273 fn from(message: Message<Vec<u8>>) -> Self {
274 Answer { message }
275 }
276}
277
278#[derive(Clone, Debug)]
279struct ServerInfo {
280 conf: ServerConf,
281 edns: Arc<AtomicBool>,
282}
283
284impl ServerInfo {
285 pub fn does_edns(&self) -> bool {
286 self.edns.load(Ordering::Relaxed)
287 }
288
289 pub fn disable_edns(&self) {
290 self.edns.store(false, Ordering::Relaxed);
291 }
292
293 pub fn prepare_message(&self, query: &mut QueryMessage) {
294 query.rewind();
295 if self.does_edns() {
296 query
297 .opt(|opt| {
298 opt.set_udp_payload_size(self.conf.udp_payload_size);
299 Ok(())
300 })
301 .unwrap();
302 }
303 }
304
305 pub async fn query(&self, query: &QueryMessage) -> io::Result<Answer> {
306 let res = match self.conf.transport {
307 Transport::Udp => {
308 timeout(
309 self.conf.request_timeout,
310 Self::udp_query(query, self.conf.addr, self.conf.recv_size),
311 )
312 .await
313 }
314 Transport::Tcp => {
315 timeout(
316 self.conf.request_timeout,
317 Self::tcp_query(query, self.conf.addr),
318 )
319 .await
320 }
321 };
322 match res {
323 Ok(Ok(answer)) => Ok(answer),
324 Ok(Err(err)) => Err(err),
325 Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "request timed out")),
326 }
327 }
328
329 pub async fn tcp_query(query: &QueryMessage, addr: SocketAddr) -> io::Result<Answer> {
330 let sock = &mut TcpStream::connect(&addr).await?;
331 sock.write_all(query.as_target().as_stream_slice()).await?;
332
333 loop {
334 let mut len_buf = [0u8; 2];
335 sock.read_exact(&mut len_buf).await?;
336 let len = u16::from_be_bytes(len_buf) as u64;
337 let mut buf = Vec::new();
338 sock.take(len).read_to_end(&mut buf).await?;
339 if let Ok(answer) = Message::from_octets(buf) {
340 if answer.is_answer(&query.as_message()) {
341 return Ok(answer.into());
342 }
343 } else {
344 return Err(io::Error::new(io::ErrorKind::Other, "short buf"));
345 }
346 }
347 }
348
349 pub async fn udp_query(
350 query: &QueryMessage,
351 addr: SocketAddr,
352 recv_size: usize,
353 ) -> io::Result<Answer> {
354 let sock = Self::udp_bind(addr.is_ipv4()).await?;
355 #[cfg(not(feature = "awak-runtime"))]
356 sock.connect(addr).await?;
357 #[cfg(feature = "awak-runtime")]
358 sock.connect(addr)?;
359 let sent = sock.send(query.as_target().as_dgram_slice()).await?;
360 if sent != query.as_target().as_dgram_slice().len() {
361 return Err(io::Error::new(io::ErrorKind::Other, "short UDP send"));
362 }
363 loop {
364 let mut buf = vec![0; recv_size];
365 let len = sock.recv(&mut buf).await?;
366 buf.truncate(len);
367 let answer = match Message::from_octets(buf) {
368 Ok(answer) => answer,
369 Err(_) => continue,
370 };
371 if !answer.is_answer(&query.as_message()) {
372 continue;
373 }
374 return Ok(answer.into());
375 }
376 }
377
378 async fn udp_bind(v4: bool) -> io::Result<UdpSocket> {
379 let mut i = 0;
380 loop {
381 let local: SocketAddr = if v4 {
382 ([0u8; 4], 0).into()
383 } else {
384 ([0u16; 8], 0).into()
385 };
386 #[cfg(feature = "tokio-runtime")]
387 let binder = UdpSocket::bind(&local).await;
388 #[cfg(not(feature = "tokio-runtime"))]
389 let binder = UdpSocket::bind(local);
390 match binder {
391 Ok(sock) => return Ok(sock),
392 Err(err) => {
393 if i == RETRY_RANDOM_PORT {
394 return Err(err);
395 } else {
396 i += 1
397 }
398 }
399 }
400 }
401 }
402}
403
404impl From<ServerConf> for ServerInfo {
405 fn from(conf: ServerConf) -> Self {
406 ServerInfo {
407 conf,
408 edns: Arc::new(AtomicBool::new(true)),
409 }
410 }
411}
412
413impl<'a> From<&'a ServerConf> for ServerInfo {
414 fn from(conf: &'a ServerConf) -> Self {
415 conf.clone().into()
416 }
417}
418
419#[derive(Clone, Debug)]
420struct ServerList {
421 servers: Vec<ServerInfo>,
422 start: Arc<AtomicUsize>,
423}
424
425impl ServerList {
426 pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
427 where
428 F: Fn(&ServerConf) -> bool,
429 {
430 ServerList {
431 servers: {
432 conf.servers
433 .iter()
434 .filter(|f| filter(f))
435 .map(Into::into)
436 .collect()
437 },
438 start: Arc::new(AtomicUsize::new(0)),
439 }
440 }
441
442 pub fn is_empty(&self) -> bool {
443 self.servers.is_empty()
444 }
445
446 pub fn counter(&self, rotate: bool) -> ServerListCounter {
447 let res = ServerListCounter::new(self);
448 if rotate {
449 self.rotate()
450 }
451 res
452 }
453
454 pub fn iter(&self) -> ServerListIter {
455 ServerListIter::new(self)
456 }
457
458 pub fn rotate(&self) {
459 self.start.fetch_add(1, Ordering::SeqCst);
460 }
461}
462
463impl<'a> IntoIterator for &'a ServerList {
464 type Item = &'a ServerInfo;
465 type IntoIter = ServerListIter<'a>;
466
467 fn into_iter(self) -> Self::IntoIter {
468 self.iter()
469 }
470}
471
472impl Deref for ServerList {
473 type Target = [ServerInfo];
474
475 fn deref(&self) -> &Self::Target {
476 self.servers.as_ref()
477 }
478}
479
480#[derive(Clone, Debug)]
481struct ServerListCounter {
482 cur: usize,
483 end: usize,
484}
485
486impl ServerListCounter {
487 fn new(list: &ServerList) -> Self {
488 if list.servers.is_empty() {
489 return ServerListCounter { cur: 0, end: 0 };
490 }
491
492 let start = list.start.load(Ordering::Relaxed) % list.servers.len();
493 ServerListCounter {
494 cur: start,
495 end: start + list.servers.len(),
496 }
497 }
498
499 #[allow(clippy::should_implement_trait)]
500 pub fn next(&mut self) -> bool {
501 let next = self.cur + 1;
502 if next < self.end {
503 self.cur = next;
504 true
505 } else {
506 false
507 }
508 }
509
510 pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
511 &list[self.cur % list.servers.len()]
512 }
513}
514
515#[derive(Clone, Debug)]
516struct ServerListIter<'a> {
517 servers: &'a ServerList,
518 counter: ServerListCounter,
519}
520
521impl<'a> ServerListIter<'a> {
522 fn new(list: &'a ServerList) -> Self {
523 ServerListIter {
524 servers: list,
525 counter: ServerListCounter::new(list),
526 }
527 }
528}
529
530impl<'a> Iterator for ServerListIter<'a> {
531 type Item = &'a ServerInfo;
532
533 fn next(&mut self) -> Option<Self::Item> {
534 if self.counter.next() {
535 Some(self.counter.info(self.servers))
536 } else {
537 None
538 }
539 }
540}
541
542impl Deref for Answer {
543 type Target = Message<Vec<u8>>;
544
545 fn deref(&self) -> &Self::Target {
546 &self.message
547 }
548}
549
550impl AsRef<Message<Vec<u8>>> for Answer {
551 fn as_ref(&self) -> &Message<Vec<u8>> {
552 &self.message
553 }
554}