tdns_cli/
util.rs

1use std::{
2    fmt, fs,
3    net::{IpAddr, SocketAddr},
4    num::ParseIntError,
5    str::FromStr,
6};
7
8use failure::format_err;
9use futures::{
10    future::{self, Either},
11    Future,
12};
13use trust_dns::{
14    client::ClientHandle,
15    op::DnsResponse,
16    proto::{error::ProtoError, op::query::Query, xfer::DnsHandle},
17    rr::{self, Record, RecordType},
18};
19
20#[derive(Debug, Clone)]
21pub struct CommaSeparated<T>(Vec<T>);
22
23impl<T> CommaSeparated<T> {
24    pub fn into_vec(self) -> Vec<T> {
25        self.0
26    }
27}
28
29impl<T: Clone> CommaSeparated<T> {
30    pub fn to_vec(&self) -> Vec<T> {
31        self.0.clone()
32    }
33}
34
35impl<T: FromStr> FromStr for CommaSeparated<T> {
36    type Err = T::Err;
37
38    fn from_str(s: &str) -> Result<Self, Self::Err> {
39        Ok(CommaSeparated(
40            s.split(',')
41                .map(|part| part.parse())
42                .collect::<Result<_, _>>()?,
43        ))
44    }
45}
46
47/// A potential unresolved host name, with an optional port number.
48#[derive(Debug, Clone)]
49pub enum SocketName {
50    HostName(rr::Name, Option<u16>),
51    SocketAddr(SocketAddr),
52    IpAddr(IpAddr),
53}
54
55impl SocketName {
56    pub fn resolve(
57        &self,
58        resolver: impl DnsHandle,
59        default_port: u16,
60    ) -> impl Future<Item = SocketAddr, Error = failure::Error> {
61        match self {
62            SocketName::HostName(name, port) => {
63                let port = port.unwrap_or(default_port);
64                Either::A(
65                    resolve_ip(resolver, name.clone()).map(move |ip| SocketAddr::new(ip, port)),
66                )
67            }
68            SocketName::IpAddr(addr) => Either::B(future::ok(SocketAddr::new(*addr, default_port))),
69            SocketName::SocketAddr(addr) => Either::B(future::ok(*addr)),
70        }
71    }
72}
73
74impl FromStr for SocketName {
75    type Err = ParseSocketNameError;
76
77    fn from_str(s: &str) -> Result<Self, Self::Err> {
78        s.parse()
79            .map(SocketName::SocketAddr)
80            .or_else(|_| s.parse().map(SocketName::IpAddr))
81            .or_else(|_| {
82                let parts: Vec<_> = s.split(':').collect();
83                match parts.len() {
84                    1 => Ok(SocketName::HostName(
85                        parts[0].parse().map_err(ParseSocketNameError::Name)?,
86                        None,
87                    )),
88                    2 => Ok(SocketName::HostName(
89                        parts[0].parse().map_err(ParseSocketNameError::Name)?,
90                        Some(parts[1].parse().map_err(ParseSocketNameError::Port)?),
91                    )),
92                    _ => Err(ParseSocketNameError::Invalid),
93                }
94            })
95    }
96}
97
98impl fmt::Display for ParseSocketNameError {
99    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100        use ParseSocketNameError::*;
101        match self {
102            Invalid => write!(
103                f,
104                "invalid socket name, expected IP, IP:PORT, HOST, or HOST:PORT"
105            ),
106            Name(e) => write!(f, "invalid host name: {}", e),
107            Port(e) => write!(f, "invalid port: {}", e),
108        }
109    }
110}
111
112impl std::error::Error for ParseSocketNameError {}
113
114#[derive(Debug)]
115pub enum ParseSocketNameError {
116    Invalid,
117    Name(ProtoError),
118    Port(ParseIntError),
119}
120
121pub fn get_system_resolver() -> Option<SocketAddr> {
122    use resolv_conf::{Config, ScopedIp};
123    let resolv_conf = fs::read("/etc/resolv.conf").ok()?;
124    let config = Config::parse(&resolv_conf).ok()?;
125    config.nameservers.iter().find_map(|scoped| match scoped {
126        ScopedIp::V4(v4) => Some(SocketAddr::new(v4.clone().into(), 53)),
127        ScopedIp::V6(v6, _) => Some(SocketAddr::new(v6.clone().into(), 53)),
128    })
129}
130
131pub fn dns_query(
132    mut recursor: impl ClientHandle,
133    query: Query,
134) -> impl Future<Item = DnsResponse, Error = failure::Error> {
135    use future::Loop;
136    const MAX_TRIES: usize = 3;
137    future::loop_fn(0, move |count| {
138        let run_query = recursor.lookup(query.clone(), Default::default());
139        let name = query.name().clone();
140        run_query.then(move |result| match result {
141            Ok(addrs) => future::ok(Loop::Break(addrs)),
142            Err(_) if count < MAX_TRIES => future::ok(Loop::Continue(count + 1)),
143            Err(e) => future::err(format_err!(
144                "could not resolve server name '{}' (max retries reached): {}",
145                name,
146                e
147            )),
148        })
149    })
150}
151
152pub fn query_ip_addr(
153    recursor: impl ClientHandle,
154    name: rr::Name,
155) -> impl Future<Item = Vec<IpAddr>, Error = failure::Error> + 'static {
156    // FIXME: IPv6
157    dns_query(recursor, Query::query(name, RecordType::A)).map(|response| {
158        response
159            .answers()
160            .iter()
161            .filter_map(|r| r.rdata().to_ip_addr())
162            .collect()
163    })
164}
165
166pub fn get_ns_records<R>(
167    recursor: R,
168    domain: rr::Name,
169) -> impl Future<Item = Vec<Record>, Error = failure::Error>
170where
171    R: ClientHandle,
172{
173    dns_query(recursor, Query::query(domain, RecordType::NS))
174        .map(|response| response.answers().to_vec())
175}
176
177pub fn resolve_ip(
178    recursor: impl ClientHandle,
179    server_name: rr::Name,
180) -> impl Future<Item = IpAddr, Error = failure::Error> {
181    query_ip_addr(recursor.clone(), server_name.clone()).and_then(move |addrs| {
182        // TODO: handle multiple addresses
183        if let Some(addr) = addrs.first().cloned() {
184            Ok(addr)
185        } else {
186            Err(format_err!(
187                "could not resolve server '{}': no addresses found",
188                server_name
189            ))
190        }
191    })
192}