1use std::io;
2use std::net::{IpAddr, SocketAddr};
3use std::ops::Deref;
4use std::str::FromStr;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::sync::{Arc, Mutex};
7use std::time::Duration;
8
9use domain::base::iana::{Rcode, Rtype};
10use domain::base::message::Message;
11use domain::base::message_builder::{AdditionalBuilder, MessageBuilder, StreamTarget};
12use domain::base::name::{Name, ToName};
13use domain::base::question::Question;
14use domain::rdata::A;
15use lru_time_cache::LruCache;
16
17const DEFAULT_CACHE_EXPIRE: Duration = Duration::from_secs(10 * 60);
18
19cfg_if::cfg_if! {
20 if #[cfg(feature = "slings-runtime")] {
21 use slings::{
22 net::{TcpStream, UdpSocket},
23 time::timeout,
24 };
25 use futures_util::{AsyncReadExt, AsyncWriteExt};
26 }
27 else if #[cfg(feature = "awak-runtime")] {
28 use awak::{
29 net::{TcpStream, UdpSocket},
30 time::timeout,
31 };
32 use futures_util::{AsyncReadExt, AsyncWriteExt};
33 }
34 else if #[cfg(feature = "tokio-runtime")] {
35 use tokio::{
36 io::{AsyncReadExt, AsyncWriteExt},
37 net::{TcpStream, UdpSocket},
38 time::timeout,
39 };
40 }
41}
42
43mod conf;
44
45pub use conf::{ResolvConf, ResolvOptions};
46use conf::{ServerConf, Transport};
47
48const RETRY_RANDOM_PORT: usize = 10;
49
50pub struct Resolver {
51 preferred: ServerList,
52 stream: ServerList,
53 options: ResolvOptions,
54 lru_cache: Mutex<LruCache<String, Vec<IpAddr>>>,
55}
56
57impl Resolver {
58 pub fn new() -> Self {
59 Self::from_conf(ResolvConf::default())
60 }
61
62 pub fn from_conf(conf: ResolvConf) -> Self {
63 Resolver {
64 preferred: ServerList::from_conf(&conf, |s| s.transport.is_preferred()),
65 stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
66 options: conf.options,
67 lru_cache: Mutex::new(LruCache::with_expiry_duration(DEFAULT_CACHE_EXPIRE)),
68 }
69 }
70
71 fn options(&self) -> &ResolvOptions {
72 &self.options
73 }
74
75 pub async fn query<N: ToName, Q: Into<Question<N>>>(&self, question: Q) -> io::Result<Answer> {
76 Query::new(self)?
77 .run(Query::create_message(question.into()))
78 .await
79 }
80
81 fn try_resolve_from_cache(&self, key: &str) -> Option<Vec<IpAddr>> {
82 self.lru_cache.lock().unwrap().get(key).cloned()
83 }
84
85 fn insert_into_cache(&self, key: &str, val: Vec<IpAddr>) {
86 self.lru_cache.lock().unwrap().insert(key.to_string(), val);
87 }
88
89 pub async fn lookup_host<T: AsRef<str>>(&self, host: T) -> io::Result<Vec<IpAddr>> {
90 let host = &host.as_ref();
91 if let Some(v) = self.try_resolve_from_cache(host) {
92 return Ok(v);
93 }
94
95 let qname = &Name::<Vec<u8>>::from_str(host)
96 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
97 let answer = self.query((&qname, Rtype::A)).await?;
98 let name = answer.canonical_name();
99 let records = answer
100 .answer()
101 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
102 .limit_to::<A>();
103
104 let mut ips = vec![];
105 for record in records.flatten() {
106 if Some(*record.owner()) == name {
107 ips.push(record.data().addr().into());
108 }
109 }
110 self.insert_into_cache(host, ips.clone());
111 Ok(ips)
112 }
113
114 pub async fn query_message(&self, message: QueryMessage) -> io::Result<Answer> {
115 Query::new(self)?.run(message).await
116 }
117}
118
119impl Default for Resolver {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125pub struct Query<'a> {
126 resolver: &'a Resolver,
127 preferred: bool,
128 attempt: usize,
129 counter: ServerListCounter,
130 error: io::Result<Answer>,
131}
132
133impl<'a> Query<'a> {
134 pub fn new(resolver: &'a Resolver) -> io::Result<Self> {
135 let (preferred, counter) = if resolver.options().use_vc || resolver.preferred.is_empty() {
136 if resolver.stream.is_empty() {
137 return Err(io::Error::new(
138 io::ErrorKind::NotFound,
139 "no servers available",
140 ));
141 }
142 (false, resolver.stream.counter(resolver.options().rotate))
143 } else {
144 (true, resolver.preferred.counter(resolver.options().rotate))
145 };
146 Ok(Query {
147 resolver,
148 preferred,
149 attempt: 0,
150 counter,
151 error: Err(io::Error::new(io::ErrorKind::TimedOut, "all timed out")),
152 })
153 }
154
155 pub async fn run(mut self, mut message: QueryMessage) -> io::Result<Answer> {
156 loop {
157 match self.run_query(&mut message).await {
158 Ok(answer) => {
159 if answer.header().rcode() == Rcode::FORMERR
160 && self.current_server().does_edns()
161 {
162 self.current_server().disable_edns();
163 continue;
164 } else if answer.header().rcode() == Rcode::SERVFAIL {
165 self.update_error_servfail(answer);
166 } else if answer.header().tc()
167 && self.preferred
168 && !self.resolver.options().ign_tc
169 {
170 if self.switch_to_stream() {
171 continue;
172 } else {
173 return Ok(answer);
174 }
175 } else {
176 return Ok(answer);
177 }
178 }
179 Err(err) => self.update_error(err),
180 }
181 if !self.next_server() {
182 return self.error;
183 }
184 }
185 }
186
187 fn create_message(question: Question<impl ToName>) -> QueryMessage {
188 let mut message =
189 MessageBuilder::from_target(StreamTarget::new_vec()).unwrap();
190 message.header_mut().set_rd(true);
191 let mut message = message.question();
192 message.push(question).unwrap();
193 message.additional()
194 }
195
196 async fn run_query(&mut self, message: &mut QueryMessage) -> io::Result<Answer> {
197 let server = self.current_server();
198 server.prepare_message(message);
199 server.query(message).await
200 }
201
202 fn current_server(&self) -> &ServerInfo {
203 let list = if self.preferred {
204 &self.resolver.preferred
205 } else {
206 &self.resolver.stream
207 };
208 self.counter.info(list)
209 }
210
211 fn update_error(&mut self, err: io::Error) {
212 if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
213 self.error = Err(err)
214 }
215 }
216
217 fn update_error_servfail(&mut self, answer: Answer) {
218 self.error = Ok(answer)
219 }
220
221 fn switch_to_stream(&mut self) -> bool {
222 if !self.preferred {
223 return false;
224 }
225 self.preferred = false;
226 self.attempt = 0;
227 self.counter = self.resolver.stream.counter(self.resolver.options().rotate);
228 true
229 }
230
231 fn next_server(&mut self) -> bool {
232 if self.counter.next() {
233 return true;
234 }
235 self.attempt += 1;
236 if self.attempt >= self.resolver.options().attempts {
237 return false;
238 }
239 self.counter = if self.preferred {
240 self.resolver
241 .preferred
242 .counter(self.resolver.options().rotate)
243 } else {
244 self.resolver.stream.counter(self.resolver.options().rotate)
245 };
246 true
247 }
248}
249
250pub type QueryMessage = AdditionalBuilder<StreamTarget<Vec<u8>>>;
251
252#[derive(Clone)]
253pub struct Answer {
254 message: Message<Vec<u8>>,
255}
256
257impl Answer {
258 pub fn is_final(&self) -> bool {
259 (self.message.header().rcode() == Rcode::NOERROR
260 || self.message.header().rcode() == Rcode::NXDOMAIN)
261 && !self.message.header().tc()
262 }
263
264 pub fn is_truncated(&self) -> bool {
265 self.message.header().tc()
266 }
267
268 pub fn into_message(self) -> Message<Vec<u8>> {
269 self.message
270 }
271}
272
273impl From<Message<Vec<u8>>> for Answer {
274 fn from(message: Message<Vec<u8>>) -> Self {
275 Answer { message }
276 }
277}
278
279#[derive(Clone, Debug)]
280struct ServerInfo {
281 conf: ServerConf,
282 edns: Arc<AtomicBool>,
283}
284
285impl ServerInfo {
286 pub fn does_edns(&self) -> bool {
287 self.edns.load(Ordering::Relaxed)
288 }
289
290 pub fn disable_edns(&self) {
291 self.edns.store(false, Ordering::Relaxed);
292 }
293
294 pub fn prepare_message(&self, query: &mut QueryMessage) {
295 query.rewind();
296 if self.does_edns() {
297 query
298 .opt(|opt| {
299 opt.set_udp_payload_size(self.conf.udp_payload_size);
300 Ok(())
301 })
302 .unwrap();
303 }
304 }
305
306 pub async fn query(&self, query: &QueryMessage) -> io::Result<Answer> {
307 let res = match self.conf.transport {
308 Transport::Udp => {
309 timeout(
310 self.conf.request_timeout,
311 Self::udp_query(query, self.conf.addr, self.conf.recv_size),
312 )
313 .await
314 }
315 Transport::Tcp => {
316 timeout(
317 self.conf.request_timeout,
318 Self::tcp_query(query, self.conf.addr),
319 )
320 .await
321 }
322 };
323 match res {
324 Ok(Ok(answer)) => Ok(answer),
325 Ok(Err(err)) => Err(err),
326 Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "request timed out")),
327 }
328 }
329
330 pub async fn tcp_query(query: &QueryMessage, addr: SocketAddr) -> io::Result<Answer> {
331 let sock = &mut TcpStream::connect(&addr).await?;
332 sock.write_all(query.as_target().as_stream_slice()).await?;
333
334 loop {
335 let mut len_buf = [0u8; 2];
336 sock.read_exact(&mut len_buf).await?;
337 let len = u16::from_be_bytes(len_buf) as u64;
338 let mut buf = Vec::new();
339 sock.take(len).read_to_end(&mut buf).await?;
340 if let Ok(answer) = Message::from_octets(buf) {
341 if answer.is_answer(&query.as_message()) {
342 return Ok(answer.into());
343 }
344 } else {
345 return Err(io::Error::new(io::ErrorKind::Other, "short buf"));
346 }
347 }
348 }
349
350 pub async fn udp_query(
351 query: &QueryMessage,
352 addr: SocketAddr,
353 recv_size: usize,
354 ) -> io::Result<Answer> {
355 let sock = Self::udp_bind(addr.is_ipv4()).await?;
356 #[cfg(not(feature = "awak-runtime"))]
357 sock.connect(addr).await?;
358 #[cfg(feature = "awak-runtime")]
359 sock.connect(addr)?;
360 let sent = sock.send(query.as_target().as_dgram_slice()).await?;
361 if sent != query.as_target().as_dgram_slice().len() {
362 return Err(io::Error::new(io::ErrorKind::Other, "short UDP send"));
363 }
364 loop {
365 let mut buf = vec![0; recv_size];
366 let len = sock.recv(&mut buf).await?;
367 buf.truncate(len);
368 let answer = match Message::from_octets(buf) {
369 Ok(answer) => answer,
370 Err(_) => continue,
371 };
372 if !answer.is_answer(&query.as_message()) {
373 continue;
374 }
375 return Ok(answer.into());
376 }
377 }
378
379 async fn udp_bind(v4: bool) -> io::Result<UdpSocket> {
380 let mut i = 0;
381 loop {
382 let local: SocketAddr = if v4 {
383 ([0u8; 4], 0).into()
384 } else {
385 ([0u16; 8], 0).into()
386 };
387 #[cfg(feature = "tokio-runtime")]
388 let binder = UdpSocket::bind(&local).await;
389 #[cfg(not(feature = "tokio-runtime"))]
390 let binder = UdpSocket::bind(local);
391 match binder {
392 Ok(sock) => return Ok(sock),
393 Err(err) => {
394 if i == RETRY_RANDOM_PORT {
395 return Err(err);
396 } else {
397 i += 1
398 }
399 }
400 }
401 }
402 }
403}
404
405impl From<ServerConf> for ServerInfo {
406 fn from(conf: ServerConf) -> Self {
407 ServerInfo {
408 conf,
409 edns: Arc::new(AtomicBool::new(true)),
410 }
411 }
412}
413
414impl<'a> From<&'a ServerConf> for ServerInfo {
415 fn from(conf: &'a ServerConf) -> Self {
416 conf.clone().into()
417 }
418}
419
420#[derive(Clone, Debug)]
421struct ServerList {
422 servers: Vec<ServerInfo>,
423 start: Arc<AtomicUsize>,
424}
425
426impl ServerList {
427 pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
428 where
429 F: Fn(&ServerConf) -> bool,
430 {
431 ServerList {
432 servers: {
433 conf.servers
434 .iter()
435 .filter(|f| filter(f))
436 .map(Into::into)
437 .collect()
438 },
439 start: Arc::new(AtomicUsize::new(0)),
440 }
441 }
442
443 pub fn is_empty(&self) -> bool {
444 self.servers.is_empty()
445 }
446
447 pub fn counter(&self, rotate: bool) -> ServerListCounter {
448 let res = ServerListCounter::new(self);
449 if rotate {
450 self.rotate()
451 }
452 res
453 }
454
455 pub fn iter(&self) -> ServerListIter<'_> {
456 ServerListIter::new(self)
457 }
458
459 pub fn rotate(&self) {
460 self.start.fetch_add(1, Ordering::SeqCst);
461 }
462}
463
464impl<'a> IntoIterator for &'a ServerList {
465 type Item = &'a ServerInfo;
466 type IntoIter = ServerListIter<'a>;
467
468 fn into_iter(self) -> Self::IntoIter {
469 self.iter()
470 }
471}
472
473impl Deref for ServerList {
474 type Target = [ServerInfo];
475
476 fn deref(&self) -> &Self::Target {
477 self.servers.as_ref()
478 }
479}
480
481#[derive(Clone, Debug)]
482struct ServerListCounter {
483 cur: usize,
484 end: usize,
485}
486
487impl ServerListCounter {
488 fn new(list: &ServerList) -> Self {
489 if list.servers.is_empty() {
490 return ServerListCounter { cur: 0, end: 0 };
491 }
492
493 let start = list.start.load(Ordering::Relaxed) % list.servers.len();
494 ServerListCounter {
495 cur: start,
496 end: start + list.servers.len(),
497 }
498 }
499
500 #[allow(clippy::should_implement_trait)]
501 pub fn next(&mut self) -> bool {
502 let next = self.cur + 1;
503 if next < self.end {
504 self.cur = next;
505 true
506 } else {
507 false
508 }
509 }
510
511 pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
512 &list[self.cur % list.servers.len()]
513 }
514}
515
516#[derive(Clone, Debug)]
517struct ServerListIter<'a> {
518 servers: &'a ServerList,
519 counter: ServerListCounter,
520}
521
522impl<'a> ServerListIter<'a> {
523 fn new(list: &'a ServerList) -> Self {
524 ServerListIter {
525 servers: list,
526 counter: ServerListCounter::new(list),
527 }
528 }
529}
530
531impl<'a> Iterator for ServerListIter<'a> {
532 type Item = &'a ServerInfo;
533
534 fn next(&mut self) -> Option<Self::Item> {
535 if self.counter.next() {
536 Some(self.counter.info(self.servers))
537 } else {
538 None
539 }
540 }
541}
542
543impl Deref for Answer {
544 type Target = Message<Vec<u8>>;
545
546 fn deref(&self) -> &Self::Target {
547 &self.message
548 }
549}
550
551impl AsRef<Message<Vec<u8>>> for Answer {
552 fn as_ref(&self) -> &Message<Vec<u8>> {
553 &self.message
554 }
555}