dns_uri/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3
4extern crate alloc;
5
6use alloc::string::{String, ToString};
7use core::fmt;
8use core::net::{AddrParseError, IpAddr, SocketAddr};
9use core::num::ParseIntError;
10use core::str::FromStr;
11
12#[non_exhaustive]
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14/// URI of a DNS server.
15///
16/// This implements parsing from the following string formats:
17///
18/// ```rust
19/// # use dns_uri::Uri;
20/// #
21/// // Regular DNS, IPv4, without port (default port 53)
22/// # let server: Uri = r#"
23/// 8.8.8.8
24/// # "#.trim().parse().unwrap();
25/// # assert!(matches!(server, Uri::Regular { .. }));
26/// // Regular DNS, IPv4, with port
27/// # let server: Uri = r#"
28/// 8.8.8.8:10053
29/// # "#.trim().parse().unwrap();
30/// # assert!(matches!(server, Uri::Regular { .. }));
31/// // Regular DNS, IPv6, without port (default port 53)
32/// # let server: Uri = r#"
33/// [2001:4860:4860::8888]
34/// # "#.trim().parse().unwrap();
35/// # assert!(matches!(server, Uri::Regular { .. }));
36/// // We don't accept bare IPv6 address without blanket.
37/// # assert!(matches!(server, Uri::Regular { .. }));
38/// # let server = r#"
39/// 2001:4860:4860::8888
40/// # "#.trim().parse::<Uri>().unwrap_err();
41/// // Regular DNS, IPv6, with port
42/// # let server: Uri = r#"
43/// [2001:4860:4860::8888]:10053
44/// # "#.trim().parse().unwrap();
45/// # assert!(matches!(server, Uri::Regular { .. }));
46/// // Regular DNS, in URI format.
47/// # let server: Uri = r#"
48/// udp://8.8.8.8
49/// # "#.trim().parse().unwrap();
50/// # let server: Uri = r#"
51/// udp://8.8.8.8:10053
52/// # "#.trim().parse().unwrap();
53/// # assert!(matches!(server, Uri::Regular { .. }));
54/// # let server: Uri = r#"
55/// tcp://[2001:4860:4860::8888]
56/// # "#.trim().parse().unwrap();
57/// # assert!(matches!(server, Uri::Regular { .. }));
58/// # let server: Uri = r#"
59/// tcp://[2001:4860:4860::8888]:10053
60/// # "#.trim().parse().unwrap();
61/// # assert!(matches!(server, Uri::Regular { .. }));
62/// // DNS over TLS.
63/// # let server: Uri = r#"
64/// tls://dns.google
65/// # "#.trim().parse().unwrap();
66/// # assert!(matches!(server, Uri::TLS { .. }));
67/// # let server: Uri = r#"
68/// tls://dns.google:10853
69/// # "#.trim().parse().unwrap();
70/// # assert!(matches!(server, Uri::TLS { .. }));
71/// // DNS over HTTPS, without custom endpoint.
72/// # let server: Uri = r#"
73/// https://dns.google
74/// # "#.trim().parse().unwrap();
75/// # assert!(matches!(server, Uri::HTTPS { .. }));
76/// # let server: Uri = r#"
77/// https://dns.google:8443
78/// # "#.trim().parse().unwrap();
79/// # assert!(matches!(server, Uri::HTTPS { .. }));
80/// // DNS over HTTPS, with custom endpoint.
81/// # let server: Uri = r#"
82/// https://dns.google/dns-query
83/// # "#.trim().parse().unwrap();
84/// # assert!(matches!(server, Uri::HTTPS { .. }));
85/// // For DoH, a root path `/` is also considered as a custom endpoint, please pay attention to this.
86/// # let server: Uri = r#"
87/// https://dns.google/
88/// # "#.trim().parse().unwrap();
89/// // For DoH / DoQ / DoT, you can specify custom SNI via query parameter `sni`.
90/// # let server: Uri = r#"
91/// tls://8.8.8.8?sni=dns.google
92/// # "#.trim().parse().unwrap();
93/// ```
94pub enum Uri {
95    #[non_exhaustive]
96    /// Regular DNS, over UDP or TCP.
97    Regular {
98        /// Server IP address
99        addr: IpAddr,
100
101        /// Server port
102        port: u16,
103
104        /// Prefer TCP over UDP.
105        prefer_tcp: bool,
106    },
107
108    #[non_exhaustive]
109    /// DNS over TLS
110    TLS {
111        /// Server [`Host`]
112        host: Host,
113
114        /// Server port
115        port: u16,
116    },
117
118    #[non_exhaustive]
119    /// DNS over HTTPS
120    HTTPS {
121        /// Server [`Host`]
122        host: Host,
123
124        /// Server port
125        port: u16,
126
127        /// The HTTP endpoint where the DNS `NameServer` provides service. Only
128        /// relevant to DNS-over-HTTPS.
129        custom_http_endpoint: Option<String>,
130
131        /// Force HTTP/3 (aka., DNS over HTTP/3).
132        force_http3: bool,
133    },
134
135    #[non_exhaustive]
136    /// DNS over QUIC
137    QUIC {
138        /// Server [`Host`]
139        host: Host,
140
141        /// Server port
142        port: u16,
143    },
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Hash)]
147/// The parsed host of a DNS server.
148pub enum Host {
149    #[non_exhaustive]
150    /// An IP address.
151    IpAddr {
152        /// The parsed IP address.
153        addr: IpAddr,
154
155        /// A custom server name (SNI)
156        custom_server_name: Option<String>,
157    },
158
159    #[non_exhaustive]
160    /// A domain name.
161    ServerName {
162        /// The parsed domain name.
163        name: String,
164    },
165}
166
167impl Host {
168    #[doc(hidden)]
169    #[must_use]
170    pub fn new_ip_addr(addr: IpAddr, custom_server_name: Option<String>) -> Self {
171        Self::IpAddr {
172            addr,
173            custom_server_name,
174        }
175    }
176
177    #[doc(hidden)]
178    #[must_use]
179    pub fn new_server_name(name: String) -> Self {
180        Self::ServerName { name }
181    }
182}
183
184const DEFAULT_PORT_DNS: u16 = 53;
185const DEFAULT_PORT_DNS_OVER_TLS: u16 = 853;
186const DEFAULT_PORT_DNS_OVER_QUIC: u16 = 853;
187const DEFAULT_PORT_DNS_OVER_HTTPS: u16 = 443;
188
189const SCHEME_UDP: &str = "udp";
190const SCHEME_TCP: &str = "tcp";
191const SCHEME_TLS: &str = "tls";
192const SCHEME_HTTPS: &str = "https";
193const SCHEME_H3: &str = "h3";
194const SCHEME_QUIC: &str = "quic";
195
196impl fmt::Display for Uri {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        match self {
199            Self::Regular { addr, port, prefer_tcp } => {
200                write!(
201                    f,
202                    "{}://{}",
203                    if *prefer_tcp { SCHEME_TCP } else { SCHEME_UDP },
204                    SocketAddr::new(*addr, *port)
205                )
206            }
207            Self::TLS { host, port } => match host {
208                Host::IpAddr {
209                    addr,
210                    custom_server_name: None,
211                } => {
212                    write!(f, "{SCHEME_TLS}://{}", SocketAddr::new(*addr, *port))
213                }
214                Host::IpAddr {
215                    addr,
216                    custom_server_name: Some(custom_server_name),
217                } => write!(
218                    f,
219                    "{SCHEME_TLS}://{}?sni={custom_server_name}",
220                    SocketAddr::new(*addr, *port),
221                ),
222                Host::ServerName { name } => write!(f, "{SCHEME_TLS}://{name}:{port}"),
223            },
224            Self::HTTPS {
225                host,
226                port,
227                custom_http_endpoint,
228                force_http3,
229            } => match host {
230                Host::IpAddr {
231                    addr,
232                    custom_server_name: None,
233                } => {
234                    write!(
235                        f,
236                        "{}://{}{}",
237                        if *force_http3 { SCHEME_H3 } else { SCHEME_HTTPS },
238                        SocketAddr::new(*addr, *port),
239                        custom_http_endpoint.as_deref().unwrap_or(""),
240                    )
241                }
242                Host::IpAddr {
243                    addr,
244                    custom_server_name: Some(custom_server_name),
245                } => write!(
246                    f,
247                    "{}://{}{}?sni={}",
248                    if *force_http3 { SCHEME_H3 } else { SCHEME_HTTPS },
249                    SocketAddr::new(*addr, *port),
250                    custom_http_endpoint.as_deref().unwrap_or(""),
251                    custom_server_name,
252                ),
253                Host::ServerName { name } => write!(
254                    f,
255                    "{}://{}:{}{}",
256                    if *force_http3 { SCHEME_H3 } else { SCHEME_HTTPS },
257                    name,
258                    port,
259                    custom_http_endpoint.as_deref().unwrap_or(""),
260                ),
261            },
262            Self::QUIC { host, port } => match host {
263                Host::IpAddr {
264                    addr,
265                    custom_server_name: None,
266                } => {
267                    write!(f, "{SCHEME_QUIC}://{}", SocketAddr::new(*addr, *port))
268                }
269                Host::IpAddr {
270                    addr,
271                    custom_server_name: Some(custom_server_name),
272                } => write!(
273                    f,
274                    "{SCHEME_QUIC}://{}?sni={custom_server_name}",
275                    SocketAddr::new(*addr, *port),
276                ),
277                Host::ServerName { name } => write!(f, "{SCHEME_QUIC}://{name}:{port}"),
278            },
279        }
280    }
281}
282
283impl FromStr for Uri {
284    type Err = ParseError;
285
286    #[allow(clippy::too_many_lines)]
287    fn from_str(s: &str) -> Result<Self, Self::Err> {
288        if !s.contains('/') {
289            // Try to parse as a plain IP address or socket address.
290
291            let (closing_blacket_index, colon_index) = if s.starts_with('[') {
292                let r @ Some(closing_blacket_index) = memchr::memchr(b']', s.as_bytes()) else {
293                    return Err(ParseError::InvalidIpAddr);
294                };
295
296                (
297                    r,
298                    s.as_bytes().get(closing_blacket_index + 1).and_then(|colon| {
299                        if colon == &b':' {
300                            Some(closing_blacket_index + 1)
301                        } else {
302                            None
303                        }
304                    }),
305                )
306            } else {
307                (None, memchr::memchr(b':', s.as_bytes()))
308            };
309
310            match (closing_blacket_index, colon_index) {
311                // An IPv6 address with port, e.g., [::1]:53
312                (Some(closing_blacket_index), Some(colon)) => {
313                    return Ok(Self::Regular {
314                        addr: IpAddr::V6(s[1..closing_blacket_index].parse()?),
315                        port: s[colon + 1..].parse()?,
316                        prefer_tcp: false,
317                    });
318                }
319                // An IPv4 address with port, e.g., 127.0.0.1:53
320                (None, Some(colon)) => {
321                    return Ok(Self::Regular {
322                        addr: IpAddr::V4(s[..colon].parse()?),
323                        port: s[colon + 1..].parse()?,
324                        prefer_tcp: false,
325                    });
326                }
327                // An IPv6 address (with blanket) without port, e.g., [::1]
328                (Some(closing_blacket_index), None) => {
329                    return Ok(Self::Regular {
330                        addr: IpAddr::V6(s[1..closing_blacket_index].parse()?),
331                        port: DEFAULT_PORT_DNS,
332                        prefer_tcp: false,
333                    });
334                }
335                // An IPv4 address without port, e.g., 127.0.0.1
336                (None, None) => {
337                    return Ok(Self::Regular {
338                        addr: IpAddr::V4(s.parse()?),
339                        port: DEFAULT_PORT_DNS,
340                        prefer_tcp: false,
341                    });
342                }
343            }
344        }
345
346        let uri = fluent_uri::Uri::parse(s)?;
347
348        let authority = uri.authority().ok_or(ParseError::MissingHost)?;
349
350        let server_ip_addr = match authority.host_parsed() {
351            fluent_uri::component::Host::Ipv4(ipv4_addr) => Some(IpAddr::V4(ipv4_addr)),
352            fluent_uri::component::Host::Ipv6(ipv6_addr) => Some(IpAddr::V6(ipv6_addr)),
353            _ => None,
354        };
355        let server_port = authority.port_to_u16()?;
356
357        match uri.scheme().as_str() {
358            r @ (SCHEME_UDP | SCHEME_TCP) => Ok(Self::Regular {
359                addr: server_ip_addr.ok_or(ParseError::InvalidIpAddr)?,
360                port: server_port.unwrap_or(DEFAULT_PORT_DNS),
361                prefer_tcp: r == SCHEME_TCP,
362            }),
363            SCHEME_TLS => Ok(Self::TLS {
364                host: server_ip_addr.map_or_else(
365                    || Host::ServerName {
366                        name: authority.host().to_string(),
367                    },
368                    |addr| Host::IpAddr {
369                        addr,
370                        custom_server_name: custom_sni_from_query(&uri).map(ToString::to_string),
371                    },
372                ),
373                port: server_port.unwrap_or(DEFAULT_PORT_DNS_OVER_TLS),
374            }),
375            r @ (SCHEME_HTTPS | SCHEME_H3) => Ok(Self::HTTPS {
376                host: server_ip_addr.map_or_else(
377                    || Host::ServerName {
378                        name: authority.host().to_string(),
379                    },
380                    |addr| Host::IpAddr {
381                        addr,
382                        custom_server_name: custom_sni_from_query(&uri).map(ToString::to_string),
383                    },
384                ),
385                port: server_port.unwrap_or(DEFAULT_PORT_DNS_OVER_HTTPS),
386                custom_http_endpoint: (!uri.path().is_empty()).then_some(uri.path().to_string()),
387                force_http3: r == SCHEME_H3,
388            }),
389            SCHEME_QUIC => Ok(Self::QUIC {
390                host: server_ip_addr.map_or_else(
391                    || Host::ServerName {
392                        name: authority.host().to_string(),
393                    },
394                    |addr| Host::IpAddr {
395                        addr,
396                        custom_server_name: custom_sni_from_query(&uri).map(ToString::to_string),
397                    },
398                ),
399                port: server_port.unwrap_or(DEFAULT_PORT_DNS_OVER_QUIC),
400            }),
401            _ => Err(ParseError::UnsupportedScheme),
402        }
403    }
404}
405
406#[inline(always)]
407fn custom_sni_from_query<'a>(uri: &fluent_uri::Uri<&'a str>) -> Option<&'a str> {
408    uri.query()
409        .and_then(|query| query.split('&').find_map(|query| query.as_str().strip_prefix("sni=")))
410}
411
412#[derive(Debug)]
413/// Error parsing a DNS server address.
414pub enum ParseError {
415    /// The URI is invalid.
416    InvalidUri(fluent_uri::error::ParseError),
417
418    /// The host is missing.
419    MissingHost,
420
421    /// The IP address is invalid or missing.
422    InvalidIpAddr,
423
424    /// The port is invalid.
425    InvalidPort,
426
427    /// The scheme is unsupported.
428    UnsupportedScheme,
429}
430
431impl fmt::Display for ParseError {
432    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
433        match self {
434            Self::InvalidUri(err) => write!(f, "Invalid URI: {err}"),
435            Self::MissingHost => write!(f, "Missing host"),
436            Self::InvalidIpAddr => write!(f, "Invalid or missing IP address"),
437            Self::InvalidPort => write!(f, "Invalid port"),
438            Self::UnsupportedScheme => write!(f, "Unsupported scheme"),
439        }
440    }
441}
442
443impl From<fluent_uri::error::ParseError> for ParseError {
444    fn from(err: fluent_uri::error::ParseError) -> Self {
445        Self::InvalidUri(err)
446    }
447}
448
449impl From<AddrParseError> for ParseError {
450    fn from(_: AddrParseError) -> Self {
451        Self::InvalidIpAddr
452    }
453}
454
455impl From<ParseIntError> for ParseError {
456    fn from(_: ParseIntError) -> Self {
457        Self::InvalidPort
458    }
459}