microsandbox_network/dns/nameserver/
parse.rs1use std::fmt;
8use std::io;
9use std::net::{IpAddr, SocketAddr};
10use std::str::FromStr;
11
12use serde::{Deserialize, Serialize};
13
14const DEFAULT_DNS_PORT: u16 = 53;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum Nameserver {
23 Addr(SocketAddr),
25 Host {
27 host: String,
29 port: u16,
31 },
32}
33
34impl Nameserver {
35 pub async fn resolve(&self) -> io::Result<SocketAddr> {
40 match self {
41 Self::Addr(sa) => Ok(*sa),
42 Self::Host { host, port } => tokio::net::lookup_host((host.as_str(), *port))
43 .await?
44 .next()
45 .ok_or_else(|| {
46 io::Error::new(
47 io::ErrorKind::NotFound,
48 format!("no addresses resolved for {host}:{port}"),
49 )
50 }),
51 }
52 }
53}
54
55impl fmt::Display for Nameserver {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 Self::Addr(sa) => write!(f, "{sa}"),
59 Self::Host { host, port } => write!(f, "{host}:{port}"),
60 }
61 }
62}
63
64#[derive(Debug, thiserror::Error)]
66#[error("invalid nameserver {0:?}; expected IP, IP:PORT, HOST, or HOST:PORT")]
67pub struct ParseNameserverError(pub String);
68
69impl FromStr for Nameserver {
79 type Err = ParseNameserverError;
80
81 fn from_str(input: &str) -> Result<Self, Self::Err> {
82 let s = input.trim();
83 if s.is_empty() {
84 return Err(ParseNameserverError(input.to_owned()));
85 }
86
87 if let Ok(sa) = s.parse::<SocketAddr>() {
89 return Ok(Self::Addr(sa));
90 }
91
92 if let Ok(ip) = s.parse::<IpAddr>() {
94 return Ok(Self::Addr(SocketAddr::new(ip, DEFAULT_DNS_PORT)));
95 }
96
97 if let Some((host, port)) = s.rsplit_once(':')
102 && !host.is_empty()
103 && !host.contains(':')
104 && host.parse::<IpAddr>().is_err()
105 && let Ok(port) = port.parse::<u16>()
106 {
107 return Ok(Self::Host {
108 host: host.to_owned(),
109 port,
110 });
111 }
112
113 if !s.contains(char::is_whitespace) && !s.contains(':') {
116 return Ok(Self::Host {
117 host: s.to_owned(),
118 port: DEFAULT_DNS_PORT,
119 });
120 }
121
122 Err(ParseNameserverError(input.to_owned()))
123 }
124}
125
126impl Serialize for Nameserver {
129 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
130 s.collect_str(self)
131 }
132}
133
134impl<'de> Deserialize<'de> for Nameserver {
135 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
136 let s = String::deserialize(d)?;
137 s.parse().map_err(serde::de::Error::custom)
138 }
139}
140
141impl From<SocketAddr> for Nameserver {
143 fn from(sa: SocketAddr) -> Self {
144 Self::Addr(sa)
145 }
146}
147
148impl From<IpAddr> for Nameserver {
149 fn from(ip: IpAddr) -> Self {
150 Self::Addr(SocketAddr::new(ip, DEFAULT_DNS_PORT))
151 }
152}
153
154#[cfg(test)]
159mod tests {
160 use super::*;
161
162 fn addr(s: &str) -> Nameserver {
163 Nameserver::Addr(s.parse().unwrap())
164 }
165
166 fn host(host: &str, port: u16) -> Nameserver {
167 Nameserver::Host {
168 host: host.to_owned(),
169 port,
170 }
171 }
172
173 #[test]
174 fn parses_ipv4_bare() {
175 assert_eq!("1.1.1.1".parse::<Nameserver>().unwrap(), addr("1.1.1.1:53"));
176 }
177
178 #[test]
179 fn parses_ipv4_with_port() {
180 assert_eq!(
181 "8.8.8.8:5353".parse::<Nameserver>().unwrap(),
182 addr("8.8.8.8:5353")
183 );
184 }
185
186 #[test]
187 fn parses_ipv6_bare() {
188 assert_eq!(
189 "2606:4700:4700::1111".parse::<Nameserver>().unwrap(),
190 addr("[2606:4700:4700::1111]:53")
191 );
192 }
193
194 #[test]
195 fn parses_ipv6_bracketed_with_port() {
196 assert_eq!(
197 "[2606:4700:4700::1111]:53".parse::<Nameserver>().unwrap(),
198 addr("[2606:4700:4700::1111]:53")
199 );
200 }
201
202 #[test]
203 fn parses_hostname_bare() {
204 assert_eq!(
205 "dns.google".parse::<Nameserver>().unwrap(),
206 host("dns.google", 53)
207 );
208 }
209
210 #[test]
211 fn parses_hostname_with_port() {
212 assert_eq!(
213 "dns.google:53".parse::<Nameserver>().unwrap(),
214 host("dns.google", 53)
215 );
216 assert_eq!(
217 "my-dns.corp.internal:5353".parse::<Nameserver>().unwrap(),
218 host("my-dns.corp.internal", 5353)
219 );
220 }
221
222 #[test]
223 fn trims_whitespace() {
224 assert_eq!(
225 " 1.1.1.1 ".parse::<Nameserver>().unwrap(),
226 addr("1.1.1.1:53")
227 );
228 }
229
230 #[test]
231 fn rejects_empty() {
232 assert!("".parse::<Nameserver>().is_err());
233 assert!(" ".parse::<Nameserver>().is_err());
234 }
235
236 #[test]
237 fn rejects_embedded_whitespace() {
238 assert!("dns google".parse::<Nameserver>().is_err());
239 }
240
241 #[test]
242 fn rejects_bad_port() {
243 assert!("dns.google:notaport".parse::<Nameserver>().is_err());
244 assert!("1.1.1.1:99999".parse::<Nameserver>().is_err());
245 }
246
247 #[test]
248 fn display_roundtrip() {
249 for s in ["1.1.1.1:53", "[2606:4700:4700::1111]:53", "dns.google:53"] {
250 let ns: Nameserver = s.parse().unwrap();
251 assert_eq!(ns.to_string(), s);
252 }
253 }
254
255 #[test]
256 fn display_feeds_back_into_parse() {
257 for s in ["1.1.1.1", "dns.google", "dns.google:53"] {
258 let ns: Nameserver = s.parse().unwrap();
259 let reparsed: Nameserver = ns.to_string().parse().unwrap();
261 assert_eq!(ns, reparsed);
262 }
263 }
264}