Skip to main content

dns_update/
utils.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector, TsigAlgorithm,
15};
16use std::{
17    borrow::Cow,
18    fmt::{self, Display, Formatter},
19    str::FromStr,
20};
21
22pub(crate) fn write_txt_character_strings(output: &mut String, text: &str, separator: &str) {
23    const MAX_CHUNK_BYTES: usize = 255;
24    output.push('"');
25    let mut current_bytes: usize = 0;
26    for ch in text.chars() {
27        let ch_len = ch.len_utf8();
28        if current_bytes > 0 && current_bytes + ch_len > MAX_CHUNK_BYTES {
29            output.push('"');
30            output.push_str(separator);
31            output.push('"');
32            current_bytes = 0;
33        }
34        match ch {
35            '\\' => output.push_str("\\\\"),
36            '"' => output.push_str("\\\""),
37            _ => output.push(ch),
38        }
39        current_bytes += ch_len;
40    }
41    output.push('"');
42}
43
44/// Strip `name` from `origin`, return `return_if_equal` if `name` is the same
45/// as `origin`, or  `@` if `None` given.
46pub(crate) fn strip_origin_from_name(
47    name: &str,
48    origin: &str,
49    return_if_equal: Option<&str>,
50) -> String {
51    let name = name.trim_end_matches('.');
52    let origin = origin.trim_end_matches('.');
53
54    if name == origin {
55        return return_if_equal.unwrap_or("@").to_string();
56    }
57
58    if name.ends_with(&format!(".{}", origin)) {
59        name[..name.len() - origin.len() - 1].to_string()
60    } else {
61        name.to_string()
62    }
63}
64
65impl fmt::Display for TLSARecord {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
67        write!(
68            f,
69            "{} {} {} ",
70            u8::from(self.cert_usage),
71            u8::from(self.selector),
72            u8::from(self.matching),
73        )?;
74
75        for ch in &self.cert_data {
76            write!(f, "{:02x}", ch)?;
77        }
78
79        Ok(())
80    }
81}
82
83impl fmt::Display for KeyValue {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
85        f.write_str(&self.key)?;
86        if !self.value.is_empty() {
87            write!(f, "={}", self.value)?;
88        }
89
90        Ok(())
91    }
92}
93
94impl CAARecord {
95    pub fn decompose(self) -> (u8, String, String) {
96        match self {
97            CAARecord::Issue {
98                issuer_critical,
99                name,
100                options,
101            } => {
102                let flags = if issuer_critical { 128 } else { 0 };
103                let mut value = name.unwrap_or_default();
104                for opt in &options {
105                    use std::fmt::Write;
106                    write!(value, "; {}", opt).unwrap();
107                }
108                (flags, "issue".to_string(), value)
109            }
110            CAARecord::IssueWild {
111                issuer_critical,
112                name,
113                options,
114            } => {
115                let flags = if issuer_critical { 128 } else { 0 };
116                let mut value = name.unwrap_or_default();
117                for opt in &options {
118                    use std::fmt::Write;
119                    write!(value, "; {}", opt).unwrap();
120                }
121                (flags, "issuewild".to_string(), value)
122            }
123            CAARecord::Iodef {
124                issuer_critical,
125                url,
126            } => {
127                let flags = if issuer_critical { 128 } else { 0 };
128                (flags, "iodef".to_string(), url)
129            }
130        }
131    }
132}
133
134impl fmt::Display for CAARecord {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
136        match self {
137            CAARecord::Issue {
138                issuer_critical,
139                name,
140                options,
141            } => {
142                if *issuer_critical {
143                    f.write_str("128 ")?;
144                } else {
145                    f.write_str("0 ")?;
146                }
147                f.write_str("issue ")?;
148                f.write_str("\"")?;
149                if let Some(name) = name {
150                    f.write_str(name)?;
151                }
152                for opt in options {
153                    write!(f, ";{}", opt)?;
154                }
155                f.write_str("\"")?;
156            }
157            CAARecord::IssueWild {
158                issuer_critical,
159                name,
160                options,
161            } => {
162                if *issuer_critical {
163                    f.write_str("128 ")?;
164                } else {
165                    f.write_str("0 ")?;
166                }
167                f.write_str("issuewild ")?;
168                f.write_str("\"")?;
169                if let Some(name) = name {
170                    f.write_str(name)?;
171                }
172                for opt in options {
173                    write!(f, ";{}", opt)?;
174                }
175                f.write_str("\"")?;
176            }
177            CAARecord::Iodef {
178                issuer_critical,
179                url,
180            } => {
181                if *issuer_critical {
182                    f.write_str("128 ")?;
183                } else {
184                    f.write_str("0 ")?;
185                }
186                f.write_str("iodef ")?;
187                f.write_str("\"")?;
188                f.write_str(url)?;
189                f.write_str("\"")?;
190            }
191        }
192        Ok(())
193    }
194}
195
196impl Display for MXRecord {
197    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
198        write!(f, "{} {}", self.priority, self.exchange)
199    }
200}
201
202impl Display for SRVRecord {
203    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
204        write!(
205            f,
206            "{} {} {} {}",
207            self.priority, self.weight, self.port, self.target
208        )
209    }
210}
211
212impl DnsRecord {
213    pub fn as_type(&self) -> DnsRecordType {
214        match self {
215            DnsRecord::A { .. } => DnsRecordType::A,
216            DnsRecord::AAAA { .. } => DnsRecordType::AAAA,
217            DnsRecord::CNAME { .. } => DnsRecordType::CNAME,
218            DnsRecord::NS { .. } => DnsRecordType::NS,
219            DnsRecord::MX { .. } => DnsRecordType::MX,
220            DnsRecord::TXT { .. } => DnsRecordType::TXT,
221            DnsRecord::SRV { .. } => DnsRecordType::SRV,
222            DnsRecord::TLSA { .. } => DnsRecordType::TLSA,
223            DnsRecord::CAA { .. } => DnsRecordType::CAA,
224        }
225    }
226}
227
228impl DnsRecordType {
229    pub fn as_str(&self) -> &'static str {
230        match self {
231            DnsRecordType::A => "A",
232            DnsRecordType::AAAA => "AAAA",
233            DnsRecordType::CNAME => "CNAME",
234            DnsRecordType::NS => "NS",
235            DnsRecordType::MX => "MX",
236            DnsRecordType::TXT => "TXT",
237            DnsRecordType::SRV => "SRV",
238            DnsRecordType::TLSA => "TLSA",
239            DnsRecordType::CAA => "CAA",
240        }
241    }
242}
243
244impl From<TlsaCertUsage> for u8 {
245    fn from(usage: TlsaCertUsage) -> Self {
246        match usage {
247            TlsaCertUsage::PkixTa => 0,
248            TlsaCertUsage::PkixEe => 1,
249            TlsaCertUsage::DaneTa => 2,
250            TlsaCertUsage::DaneEe => 3,
251            TlsaCertUsage::Private => 255,
252        }
253    }
254}
255
256impl From<TlsaSelector> for u8 {
257    fn from(selector: TlsaSelector) -> Self {
258        match selector {
259            TlsaSelector::Full => 0,
260            TlsaSelector::Spki => 1,
261            TlsaSelector::Private => 255,
262        }
263    }
264}
265
266impl From<TlsaMatching> for u8 {
267    fn from(matching: TlsaMatching) -> Self {
268        match matching {
269            TlsaMatching::Raw => 0,
270            TlsaMatching::Sha256 => 1,
271            TlsaMatching::Sha512 => 2,
272            TlsaMatching::Private => 255,
273        }
274    }
275}
276
277impl<'x> IntoFqdn<'x> for &'x str {
278    fn into_fqdn(self) -> Cow<'x, str> {
279        if self.ends_with('.') {
280            Cow::Borrowed(self)
281        } else {
282            Cow::Owned(format!("{}.", self))
283        }
284    }
285
286    fn into_name(self) -> Cow<'x, str> {
287        if let Some(name) = self.strip_suffix('.') {
288            Cow::Borrowed(name)
289        } else {
290            Cow::Borrowed(self)
291        }
292    }
293}
294
295impl<'x> IntoFqdn<'x> for &'x String {
296    fn into_fqdn(self) -> Cow<'x, str> {
297        self.as_str().into_fqdn()
298    }
299
300    fn into_name(self) -> Cow<'x, str> {
301        self.as_str().into_name()
302    }
303}
304
305impl<'x> IntoFqdn<'x> for String {
306    fn into_fqdn(self) -> Cow<'x, str> {
307        if self.ends_with('.') {
308            Cow::Owned(self)
309        } else {
310            Cow::Owned(format!("{}.", self))
311        }
312    }
313
314    fn into_name(self) -> Cow<'x, str> {
315        if let Some(name) = self.strip_suffix('.') {
316            Cow::Owned(name.to_string())
317        } else {
318            Cow::Owned(self)
319        }
320    }
321}
322
323impl FromStr for TsigAlgorithm {
324    type Err = ();
325
326    fn from_str(s: &str) -> std::prelude::v1::Result<Self, Self::Err> {
327        match s {
328            "hmac-md5" => Ok(TsigAlgorithm::HmacMd5),
329            "gss" => Ok(TsigAlgorithm::Gss),
330            "hmac-sha1" => Ok(TsigAlgorithm::HmacSha1),
331            "hmac-sha224" => Ok(TsigAlgorithm::HmacSha224),
332            "hmac-sha256" => Ok(TsigAlgorithm::HmacSha256),
333            "hmac-sha256-128" => Ok(TsigAlgorithm::HmacSha256_128),
334            "hmac-sha384" => Ok(TsigAlgorithm::HmacSha384),
335            "hmac-sha384-192" => Ok(TsigAlgorithm::HmacSha384_192),
336            "hmac-sha512" => Ok(TsigAlgorithm::HmacSha512),
337            "hmac-sha512-256" => Ok(TsigAlgorithm::HmacSha512_256),
338            _ => Err(()),
339        }
340    }
341}
342
343impl std::error::Error for Error {}
344
345impl Display for Error {
346    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
347        match self {
348            Error::Protocol(e) => write!(f, "Protocol error: {}", e),
349            Error::Parse(e) => write!(f, "Parse error: {}", e),
350            Error::Client(e) => write!(f, "Client error: {}", e),
351            Error::Response(e) => write!(f, "Response error: {}", e),
352            Error::Api(e) => write!(f, "API error: {}", e),
353            Error::Serialize(e) => write!(f, "Serialize error: {}", e),
354            Error::Unauthorized => write!(f, "Unauthorized"),
355            Error::NotFound => write!(f, "Not found"),
356            Error::BadRequest => write!(f, "Bad request"),
357        }
358    }
359}
360
361impl Display for DnsRecordType {
362    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
363        write!(f, "{:?}", self)
364    }
365}