1use std::{
2 fmt, fs,
3 net::{IpAddr, SocketAddr},
4 num::ParseIntError,
5 str::FromStr,
6};
7
8use failure::format_err;
9use futures::{
10 future::{self, Either},
11 Future,
12};
13use trust_dns::{
14 client::ClientHandle,
15 op::DnsResponse,
16 proto::{error::ProtoError, op::query::Query, xfer::DnsHandle},
17 rr::{self, Record, RecordType},
18};
19
20#[derive(Debug, Clone)]
21pub struct CommaSeparated<T>(Vec<T>);
22
23impl<T> CommaSeparated<T> {
24 pub fn into_vec(self) -> Vec<T> {
25 self.0
26 }
27}
28
29impl<T: Clone> CommaSeparated<T> {
30 pub fn to_vec(&self) -> Vec<T> {
31 self.0.clone()
32 }
33}
34
35impl<T: FromStr> FromStr for CommaSeparated<T> {
36 type Err = T::Err;
37
38 fn from_str(s: &str) -> Result<Self, Self::Err> {
39 Ok(CommaSeparated(
40 s.split(',')
41 .map(|part| part.parse())
42 .collect::<Result<_, _>>()?,
43 ))
44 }
45}
46
47#[derive(Debug, Clone)]
49pub enum SocketName {
50 HostName(rr::Name, Option<u16>),
51 SocketAddr(SocketAddr),
52 IpAddr(IpAddr),
53}
54
55impl SocketName {
56 pub fn resolve(
57 &self,
58 resolver: impl DnsHandle,
59 default_port: u16,
60 ) -> impl Future<Item = SocketAddr, Error = failure::Error> {
61 match self {
62 SocketName::HostName(name, port) => {
63 let port = port.unwrap_or(default_port);
64 Either::A(
65 resolve_ip(resolver, name.clone()).map(move |ip| SocketAddr::new(ip, port)),
66 )
67 }
68 SocketName::IpAddr(addr) => Either::B(future::ok(SocketAddr::new(*addr, default_port))),
69 SocketName::SocketAddr(addr) => Either::B(future::ok(*addr)),
70 }
71 }
72}
73
74impl FromStr for SocketName {
75 type Err = ParseSocketNameError;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 s.parse()
79 .map(SocketName::SocketAddr)
80 .or_else(|_| s.parse().map(SocketName::IpAddr))
81 .or_else(|_| {
82 let parts: Vec<_> = s.split(':').collect();
83 match parts.len() {
84 1 => Ok(SocketName::HostName(
85 parts[0].parse().map_err(ParseSocketNameError::Name)?,
86 None,
87 )),
88 2 => Ok(SocketName::HostName(
89 parts[0].parse().map_err(ParseSocketNameError::Name)?,
90 Some(parts[1].parse().map_err(ParseSocketNameError::Port)?),
91 )),
92 _ => Err(ParseSocketNameError::Invalid),
93 }
94 })
95 }
96}
97
98impl fmt::Display for ParseSocketNameError {
99 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100 use ParseSocketNameError::*;
101 match self {
102 Invalid => write!(
103 f,
104 "invalid socket name, expected IP, IP:PORT, HOST, or HOST:PORT"
105 ),
106 Name(e) => write!(f, "invalid host name: {}", e),
107 Port(e) => write!(f, "invalid port: {}", e),
108 }
109 }
110}
111
112impl std::error::Error for ParseSocketNameError {}
113
114#[derive(Debug)]
115pub enum ParseSocketNameError {
116 Invalid,
117 Name(ProtoError),
118 Port(ParseIntError),
119}
120
121pub fn get_system_resolver() -> Option<SocketAddr> {
122 use resolv_conf::{Config, ScopedIp};
123 let resolv_conf = fs::read("/etc/resolv.conf").ok()?;
124 let config = Config::parse(&resolv_conf).ok()?;
125 config.nameservers.iter().find_map(|scoped| match scoped {
126 ScopedIp::V4(v4) => Some(SocketAddr::new(v4.clone().into(), 53)),
127 ScopedIp::V6(v6, _) => Some(SocketAddr::new(v6.clone().into(), 53)),
128 })
129}
130
131pub fn dns_query(
132 mut recursor: impl ClientHandle,
133 query: Query,
134) -> impl Future<Item = DnsResponse, Error = failure::Error> {
135 use future::Loop;
136 const MAX_TRIES: usize = 3;
137 future::loop_fn(0, move |count| {
138 let run_query = recursor.lookup(query.clone(), Default::default());
139 let name = query.name().clone();
140 run_query.then(move |result| match result {
141 Ok(addrs) => future::ok(Loop::Break(addrs)),
142 Err(_) if count < MAX_TRIES => future::ok(Loop::Continue(count + 1)),
143 Err(e) => future::err(format_err!(
144 "could not resolve server name '{}' (max retries reached): {}",
145 name,
146 e
147 )),
148 })
149 })
150}
151
152pub fn query_ip_addr(
153 recursor: impl ClientHandle,
154 name: rr::Name,
155) -> impl Future<Item = Vec<IpAddr>, Error = failure::Error> + 'static {
156 dns_query(recursor, Query::query(name, RecordType::A)).map(|response| {
158 response
159 .answers()
160 .iter()
161 .filter_map(|r| r.rdata().to_ip_addr())
162 .collect()
163 })
164}
165
166pub fn get_ns_records<R>(
167 recursor: R,
168 domain: rr::Name,
169) -> impl Future<Item = Vec<Record>, Error = failure::Error>
170where
171 R: ClientHandle,
172{
173 dns_query(recursor, Query::query(domain, RecordType::NS))
174 .map(|response| response.answers().to_vec())
175}
176
177pub fn resolve_ip(
178 recursor: impl ClientHandle,
179 server_name: rr::Name,
180) -> impl Future<Item = IpAddr, Error = failure::Error> {
181 query_ip_addr(recursor.clone(), server_name.clone()).and_then(move |addrs| {
182 if let Some(addr) = addrs.first().cloned() {
184 Ok(addr)
185 } else {
186 Err(format_err!(
187 "could not resolve server '{}': no addresses found",
188 server_name
189 ))
190 }
191 })
192}