Skip to main content

dns_update/
utils.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector, TsigAlgorithm,
15};
16use std::{
17    borrow::Cow,
18    fmt::{self, Display, Formatter},
19    str::FromStr,
20};
21
22const MAX_CHUNK_BYTES: usize = 255;
23
24pub(crate) fn txt_chunks_to_text(output: &mut String, text: &str, separator: &str) {
25    output.push('"');
26    let mut current_bytes: usize = 0;
27    for ch in text.chars() {
28        let ch_len = ch.len_utf8();
29        if current_bytes > 0 && current_bytes + ch_len > MAX_CHUNK_BYTES {
30            output.push('"');
31            output.push_str(separator);
32            output.push('"');
33            current_bytes = 0;
34        }
35        match ch {
36            '\\' => output.push_str("\\\\"),
37            '"' => output.push_str("\\\""),
38            _ => output.push(ch),
39        }
40        current_bytes += ch_len;
41    }
42    output.push('"');
43}
44
45pub(crate) fn txt_chunks(content: String) -> Vec<String> {
46    if content.len() <= MAX_CHUNK_BYTES {
47        return vec![content];
48    }
49
50    let mut chunks = Vec::new();
51    let mut chunk = String::new();
52
53    for ch in content.chars() {
54        let ch_len = ch.len_utf8();
55        if !chunk.is_empty() && chunk.len() + ch_len > MAX_CHUNK_BYTES {
56            chunks.push(std::mem::take(&mut chunk));
57        }
58        chunk.push(ch);
59    }
60
61    if !chunk.is_empty() {
62        chunks.push(chunk);
63    }
64
65    chunks
66}
67
68/// Strip `name` from `origin`, return `return_if_equal` if `name` is the same
69/// as `origin`, or  `@` if `None` given.
70pub(crate) fn strip_origin_from_name(
71    name: &str,
72    origin: &str,
73    return_if_equal: Option<&str>,
74) -> String {
75    let name = name.trim_end_matches('.');
76    let origin = origin.trim_end_matches('.');
77
78    if name == origin {
79        return return_if_equal.unwrap_or("@").to_string();
80    }
81
82    if name.ends_with(&format!(".{}", origin)) {
83        name[..name.len() - origin.len() - 1].to_string()
84    } else {
85        name.to_string()
86    }
87}
88
89impl fmt::Display for TLSARecord {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
91        write!(
92            f,
93            "{} {} {} ",
94            u8::from(self.cert_usage),
95            u8::from(self.selector),
96            u8::from(self.matching),
97        )?;
98
99        for ch in &self.cert_data {
100            write!(f, "{:02x}", ch)?;
101        }
102
103        Ok(())
104    }
105}
106
107impl fmt::Display for KeyValue {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
109        f.write_str(&self.key)?;
110        if !self.value.is_empty() {
111            write!(f, "={}", self.value)?;
112        }
113
114        Ok(())
115    }
116}
117
118impl CAARecord {
119    pub fn decompose(self) -> (u8, String, String) {
120        match self {
121            CAARecord::Issue {
122                issuer_critical,
123                name,
124                options,
125            } => {
126                let flags = if issuer_critical { 128 } else { 0 };
127                let mut value = name.unwrap_or_default();
128                for opt in &options {
129                    use std::fmt::Write;
130                    write!(value, "; {}", opt).unwrap();
131                }
132                (flags, "issue".to_string(), value)
133            }
134            CAARecord::IssueWild {
135                issuer_critical,
136                name,
137                options,
138            } => {
139                let flags = if issuer_critical { 128 } else { 0 };
140                let mut value = name.unwrap_or_default();
141                for opt in &options {
142                    use std::fmt::Write;
143                    write!(value, "; {}", opt).unwrap();
144                }
145                (flags, "issuewild".to_string(), value)
146            }
147            CAARecord::Iodef {
148                issuer_critical,
149                url,
150            } => {
151                let flags = if issuer_critical { 128 } else { 0 };
152                (flags, "iodef".to_string(), url)
153            }
154        }
155    }
156}
157
158impl fmt::Display for CAARecord {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
160        match self {
161            CAARecord::Issue {
162                issuer_critical,
163                name,
164                options,
165            } => {
166                if *issuer_critical {
167                    f.write_str("128 ")?;
168                } else {
169                    f.write_str("0 ")?;
170                }
171                f.write_str("issue ")?;
172                f.write_str("\"")?;
173                if let Some(name) = name {
174                    f.write_str(name)?;
175                }
176                for opt in options {
177                    write!(f, ";{}", opt)?;
178                }
179                f.write_str("\"")?;
180            }
181            CAARecord::IssueWild {
182                issuer_critical,
183                name,
184                options,
185            } => {
186                if *issuer_critical {
187                    f.write_str("128 ")?;
188                } else {
189                    f.write_str("0 ")?;
190                }
191                f.write_str("issuewild ")?;
192                f.write_str("\"")?;
193                if let Some(name) = name {
194                    f.write_str(name)?;
195                }
196                for opt in options {
197                    write!(f, ";{}", opt)?;
198                }
199                f.write_str("\"")?;
200            }
201            CAARecord::Iodef {
202                issuer_critical,
203                url,
204            } => {
205                if *issuer_critical {
206                    f.write_str("128 ")?;
207                } else {
208                    f.write_str("0 ")?;
209                }
210                f.write_str("iodef ")?;
211                f.write_str("\"")?;
212                f.write_str(url)?;
213                f.write_str("\"")?;
214            }
215        }
216        Ok(())
217    }
218}
219
220impl Display for MXRecord {
221    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
222        write!(f, "{} {}", self.priority, self.exchange)
223    }
224}
225
226impl Display for SRVRecord {
227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228        write!(
229            f,
230            "{} {} {} {}",
231            self.priority, self.weight, self.port, self.target
232        )
233    }
234}
235
236impl DnsRecord {
237    pub fn as_type(&self) -> DnsRecordType {
238        match self {
239            DnsRecord::A { .. } => DnsRecordType::A,
240            DnsRecord::AAAA { .. } => DnsRecordType::AAAA,
241            DnsRecord::CNAME { .. } => DnsRecordType::CNAME,
242            DnsRecord::NS { .. } => DnsRecordType::NS,
243            DnsRecord::MX { .. } => DnsRecordType::MX,
244            DnsRecord::TXT { .. } => DnsRecordType::TXT,
245            DnsRecord::SRV { .. } => DnsRecordType::SRV,
246            DnsRecord::TLSA { .. } => DnsRecordType::TLSA,
247            DnsRecord::CAA { .. } => DnsRecordType::CAA,
248        }
249    }
250}
251
252impl Display for DnsRecord {
253    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254        match self {
255            DnsRecord::A(addr) => Display::fmt(addr, f),
256            DnsRecord::AAAA(addr) => Display::fmt(addr, f),
257            DnsRecord::CNAME(name) => f.write_str(name),
258            DnsRecord::NS(name) => f.write_str(name),
259            DnsRecord::MX(record) => Display::fmt(record, f),
260            DnsRecord::TXT(text) => f.write_str(text),
261            DnsRecord::SRV(record) => Display::fmt(record, f),
262            DnsRecord::TLSA(record) => Display::fmt(record, f),
263            DnsRecord::CAA(record) => Display::fmt(record, f),
264        }
265    }
266}
267
268impl DnsRecordType {
269    pub fn as_str(&self) -> &'static str {
270        match self {
271            DnsRecordType::A => "A",
272            DnsRecordType::AAAA => "AAAA",
273            DnsRecordType::CNAME => "CNAME",
274            DnsRecordType::NS => "NS",
275            DnsRecordType::MX => "MX",
276            DnsRecordType::TXT => "TXT",
277            DnsRecordType::SRV => "SRV",
278            DnsRecordType::TLSA => "TLSA",
279            DnsRecordType::CAA => "CAA",
280        }
281    }
282}
283
284impl From<TlsaCertUsage> for u8 {
285    fn from(usage: TlsaCertUsage) -> Self {
286        match usage {
287            TlsaCertUsage::PkixTa => 0,
288            TlsaCertUsage::PkixEe => 1,
289            TlsaCertUsage::DaneTa => 2,
290            TlsaCertUsage::DaneEe => 3,
291            TlsaCertUsage::Private => 255,
292        }
293    }
294}
295
296impl From<TlsaSelector> for u8 {
297    fn from(selector: TlsaSelector) -> Self {
298        match selector {
299            TlsaSelector::Full => 0,
300            TlsaSelector::Spki => 1,
301            TlsaSelector::Private => 255,
302        }
303    }
304}
305
306impl From<TlsaMatching> for u8 {
307    fn from(matching: TlsaMatching) -> Self {
308        match matching {
309            TlsaMatching::Raw => 0,
310            TlsaMatching::Sha256 => 1,
311            TlsaMatching::Sha512 => 2,
312            TlsaMatching::Private => 255,
313        }
314    }
315}
316
317impl<'x> IntoFqdn<'x> for &'x str {
318    fn into_fqdn(self) -> Cow<'x, str> {
319        if self.ends_with('.') {
320            Cow::Borrowed(self)
321        } else {
322            Cow::Owned(format!("{}.", self))
323        }
324    }
325
326    fn into_name(self) -> Cow<'x, str> {
327        if let Some(name) = self.strip_suffix('.') {
328            Cow::Borrowed(name)
329        } else {
330            Cow::Borrowed(self)
331        }
332    }
333}
334
335impl<'x> IntoFqdn<'x> for &'x String {
336    fn into_fqdn(self) -> Cow<'x, str> {
337        self.as_str().into_fqdn()
338    }
339
340    fn into_name(self) -> Cow<'x, str> {
341        self.as_str().into_name()
342    }
343}
344
345impl<'x> IntoFqdn<'x> for String {
346    fn into_fqdn(self) -> Cow<'x, str> {
347        if self.ends_with('.') {
348            Cow::Owned(self)
349        } else {
350            Cow::Owned(format!("{}.", self))
351        }
352    }
353
354    fn into_name(self) -> Cow<'x, str> {
355        if let Some(name) = self.strip_suffix('.') {
356            Cow::Owned(name.to_string())
357        } else {
358            Cow::Owned(self)
359        }
360    }
361}
362
363impl FromStr for TsigAlgorithm {
364    type Err = ();
365
366    fn from_str(s: &str) -> std::prelude::v1::Result<Self, Self::Err> {
367        match s {
368            "hmac-md5" => Ok(TsigAlgorithm::HmacMd5),
369            "gss" => Ok(TsigAlgorithm::Gss),
370            "hmac-sha1" => Ok(TsigAlgorithm::HmacSha1),
371            "hmac-sha224" => Ok(TsigAlgorithm::HmacSha224),
372            "hmac-sha256" => Ok(TsigAlgorithm::HmacSha256),
373            "hmac-sha256-128" => Ok(TsigAlgorithm::HmacSha256_128),
374            "hmac-sha384" => Ok(TsigAlgorithm::HmacSha384),
375            "hmac-sha384-192" => Ok(TsigAlgorithm::HmacSha384_192),
376            "hmac-sha512" => Ok(TsigAlgorithm::HmacSha512),
377            "hmac-sha512-256" => Ok(TsigAlgorithm::HmacSha512_256),
378            _ => Err(()),
379        }
380    }
381}
382
383impl std::error::Error for Error {}
384
385impl Display for Error {
386    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
387        match self {
388            Error::Protocol(e) => write!(f, "Protocol error: {}", e),
389            Error::Parse(e) => write!(f, "Parse error: {}", e),
390            Error::Client(e) => write!(f, "Client error: {}", e),
391            Error::Response(e) => write!(f, "Response error: {}", e),
392            Error::Api(e) => write!(f, "API error: {}", e),
393            Error::Serialize(e) => write!(f, "Serialize error: {}", e),
394            Error::Unauthorized => write!(f, "Unauthorized"),
395            Error::NotFound => write!(f, "Not found"),
396            Error::BadRequest => write!(f, "Bad request"),
397        }
398    }
399}
400
401impl Display for DnsRecordType {
402    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
403        write!(f, "{:?}", self)
404    }
405}