Skip to main content

dns_update/
utils.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector, TsigAlgorithm,
15};
16use std::{
17    borrow::Cow,
18    fmt::{self, Display, Formatter},
19    str::FromStr,
20};
21
22const MAX_CHUNK_BYTES: usize = 255;
23
24pub(crate) fn txt_chunks_to_text(output: &mut String, text: &str, separator: &str) {
25    output.push('"');
26    let mut current_bytes: usize = 0;
27    for ch in text.chars() {
28        let ch_len = ch.len_utf8();
29        if current_bytes > 0 && current_bytes + ch_len > MAX_CHUNK_BYTES {
30            output.push('"');
31            output.push_str(separator);
32            output.push('"');
33            current_bytes = 0;
34        }
35        match ch {
36            '\\' => output.push_str("\\\\"),
37            '"' => output.push_str("\\\""),
38            _ => output.push(ch),
39        }
40        current_bytes += ch_len;
41    }
42    output.push('"');
43}
44
45pub(crate) fn txt_chunks(content: String) -> Vec<String> {
46    if content.len() <= MAX_CHUNK_BYTES {
47        return vec![content];
48    }
49
50    let mut chunks = Vec::new();
51    let mut chunk = String::new();
52
53    for ch in content.chars() {
54        let ch_len = ch.len_utf8();
55        if !chunk.is_empty() && chunk.len() + ch_len > MAX_CHUNK_BYTES {
56            chunks.push(std::mem::take(&mut chunk));
57        }
58        chunk.push(ch);
59    }
60
61    if !chunk.is_empty() {
62        chunks.push(chunk);
63    }
64
65    chunks
66}
67
68/// Strip `name` from `origin`, return `return_if_equal` if `name` is the same
69/// as `origin`, or  `@` if `None` given.
70pub(crate) fn strip_origin_from_name(
71    name: &str,
72    origin: &str,
73    return_if_equal: Option<&str>,
74) -> String {
75    let name = name.trim_end_matches('.');
76    let origin = origin.trim_end_matches('.');
77
78    if name == origin {
79        return return_if_equal.unwrap_or("@").to_string();
80    }
81
82    if name.ends_with(&format!(".{}", origin)) {
83        name[..name.len() - origin.len() - 1].to_string()
84    } else {
85        name.to_string()
86    }
87}
88
89impl fmt::Display for TLSARecord {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
91        write!(
92            f,
93            "{} {} {} ",
94            u8::from(self.cert_usage),
95            u8::from(self.selector),
96            u8::from(self.matching),
97        )?;
98
99        for ch in &self.cert_data {
100            write!(f, "{:02x}", ch)?;
101        }
102
103        Ok(())
104    }
105}
106
107impl fmt::Display for KeyValue {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
109        f.write_str(&self.key)?;
110        if !self.value.is_empty() {
111            write!(f, "={}", self.value)?;
112        }
113
114        Ok(())
115    }
116}
117
118impl CAARecord {
119    pub fn decompose(self) -> (u8, String, String) {
120        match self {
121            CAARecord::Issue {
122                issuer_critical,
123                name,
124                options,
125            } => {
126                let flags = if issuer_critical { 128 } else { 0 };
127                let mut value = name.unwrap_or_default();
128                for opt in &options {
129                    use std::fmt::Write;
130                    write!(value, "; {}", opt).unwrap();
131                }
132                (flags, "issue".to_string(), value)
133            }
134            CAARecord::IssueWild {
135                issuer_critical,
136                name,
137                options,
138            } => {
139                let flags = if issuer_critical { 128 } else { 0 };
140                let mut value = name.unwrap_or_default();
141                for opt in &options {
142                    use std::fmt::Write;
143                    write!(value, "; {}", opt).unwrap();
144                }
145                (flags, "issuewild".to_string(), value)
146            }
147            CAARecord::Iodef {
148                issuer_critical,
149                url,
150            } => {
151                let flags = if issuer_critical { 128 } else { 0 };
152                (flags, "iodef".to_string(), url)
153            }
154        }
155    }
156}
157
158impl fmt::Display for CAARecord {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
160        match self {
161            CAARecord::Issue {
162                issuer_critical,
163                name,
164                options,
165            } => {
166                if *issuer_critical {
167                    f.write_str("128 ")?;
168                } else {
169                    f.write_str("0 ")?;
170                }
171                f.write_str("issue ")?;
172                f.write_str("\"")?;
173                if let Some(name) = name {
174                    f.write_str(name)?;
175                }
176                for opt in options {
177                    write!(f, ";{}", opt)?;
178                }
179                f.write_str("\"")?;
180            }
181            CAARecord::IssueWild {
182                issuer_critical,
183                name,
184                options,
185            } => {
186                if *issuer_critical {
187                    f.write_str("128 ")?;
188                } else {
189                    f.write_str("0 ")?;
190                }
191                f.write_str("issuewild ")?;
192                f.write_str("\"")?;
193                if let Some(name) = name {
194                    f.write_str(name)?;
195                }
196                for opt in options {
197                    write!(f, ";{}", opt)?;
198                }
199                f.write_str("\"")?;
200            }
201            CAARecord::Iodef {
202                issuer_critical,
203                url,
204            } => {
205                if *issuer_critical {
206                    f.write_str("128 ")?;
207                } else {
208                    f.write_str("0 ")?;
209                }
210                f.write_str("iodef ")?;
211                f.write_str("\"")?;
212                f.write_str(url)?;
213                f.write_str("\"")?;
214            }
215        }
216        Ok(())
217    }
218}
219
220impl Display for MXRecord {
221    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
222        write!(f, "{} {}", self.priority, self.exchange)
223    }
224}
225
226impl Display for SRVRecord {
227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228        write!(
229            f,
230            "{} {} {} {}",
231            self.priority, self.weight, self.port, self.target
232        )
233    }
234}
235
236impl DnsRecord {
237    pub fn as_type(&self) -> DnsRecordType {
238        match self {
239            DnsRecord::A { .. } => DnsRecordType::A,
240            DnsRecord::AAAA { .. } => DnsRecordType::AAAA,
241            DnsRecord::CNAME { .. } => DnsRecordType::CNAME,
242            DnsRecord::NS { .. } => DnsRecordType::NS,
243            DnsRecord::MX { .. } => DnsRecordType::MX,
244            DnsRecord::TXT { .. } => DnsRecordType::TXT,
245            DnsRecord::SRV { .. } => DnsRecordType::SRV,
246            DnsRecord::TLSA { .. } => DnsRecordType::TLSA,
247            DnsRecord::CAA { .. } => DnsRecordType::CAA,
248        }
249    }
250}
251
252impl DnsRecordType {
253    pub fn as_str(&self) -> &'static str {
254        match self {
255            DnsRecordType::A => "A",
256            DnsRecordType::AAAA => "AAAA",
257            DnsRecordType::CNAME => "CNAME",
258            DnsRecordType::NS => "NS",
259            DnsRecordType::MX => "MX",
260            DnsRecordType::TXT => "TXT",
261            DnsRecordType::SRV => "SRV",
262            DnsRecordType::TLSA => "TLSA",
263            DnsRecordType::CAA => "CAA",
264        }
265    }
266}
267
268impl From<TlsaCertUsage> for u8 {
269    fn from(usage: TlsaCertUsage) -> Self {
270        match usage {
271            TlsaCertUsage::PkixTa => 0,
272            TlsaCertUsage::PkixEe => 1,
273            TlsaCertUsage::DaneTa => 2,
274            TlsaCertUsage::DaneEe => 3,
275            TlsaCertUsage::Private => 255,
276        }
277    }
278}
279
280impl From<TlsaSelector> for u8 {
281    fn from(selector: TlsaSelector) -> Self {
282        match selector {
283            TlsaSelector::Full => 0,
284            TlsaSelector::Spki => 1,
285            TlsaSelector::Private => 255,
286        }
287    }
288}
289
290impl From<TlsaMatching> for u8 {
291    fn from(matching: TlsaMatching) -> Self {
292        match matching {
293            TlsaMatching::Raw => 0,
294            TlsaMatching::Sha256 => 1,
295            TlsaMatching::Sha512 => 2,
296            TlsaMatching::Private => 255,
297        }
298    }
299}
300
301impl<'x> IntoFqdn<'x> for &'x str {
302    fn into_fqdn(self) -> Cow<'x, str> {
303        if self.ends_with('.') {
304            Cow::Borrowed(self)
305        } else {
306            Cow::Owned(format!("{}.", self))
307        }
308    }
309
310    fn into_name(self) -> Cow<'x, str> {
311        if let Some(name) = self.strip_suffix('.') {
312            Cow::Borrowed(name)
313        } else {
314            Cow::Borrowed(self)
315        }
316    }
317}
318
319impl<'x> IntoFqdn<'x> for &'x String {
320    fn into_fqdn(self) -> Cow<'x, str> {
321        self.as_str().into_fqdn()
322    }
323
324    fn into_name(self) -> Cow<'x, str> {
325        self.as_str().into_name()
326    }
327}
328
329impl<'x> IntoFqdn<'x> for String {
330    fn into_fqdn(self) -> Cow<'x, str> {
331        if self.ends_with('.') {
332            Cow::Owned(self)
333        } else {
334            Cow::Owned(format!("{}.", self))
335        }
336    }
337
338    fn into_name(self) -> Cow<'x, str> {
339        if let Some(name) = self.strip_suffix('.') {
340            Cow::Owned(name.to_string())
341        } else {
342            Cow::Owned(self)
343        }
344    }
345}
346
347impl FromStr for TsigAlgorithm {
348    type Err = ();
349
350    fn from_str(s: &str) -> std::prelude::v1::Result<Self, Self::Err> {
351        match s {
352            "hmac-md5" => Ok(TsigAlgorithm::HmacMd5),
353            "gss" => Ok(TsigAlgorithm::Gss),
354            "hmac-sha1" => Ok(TsigAlgorithm::HmacSha1),
355            "hmac-sha224" => Ok(TsigAlgorithm::HmacSha224),
356            "hmac-sha256" => Ok(TsigAlgorithm::HmacSha256),
357            "hmac-sha256-128" => Ok(TsigAlgorithm::HmacSha256_128),
358            "hmac-sha384" => Ok(TsigAlgorithm::HmacSha384),
359            "hmac-sha384-192" => Ok(TsigAlgorithm::HmacSha384_192),
360            "hmac-sha512" => Ok(TsigAlgorithm::HmacSha512),
361            "hmac-sha512-256" => Ok(TsigAlgorithm::HmacSha512_256),
362            _ => Err(()),
363        }
364    }
365}
366
367impl std::error::Error for Error {}
368
369impl Display for Error {
370    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
371        match self {
372            Error::Protocol(e) => write!(f, "Protocol error: {}", e),
373            Error::Parse(e) => write!(f, "Parse error: {}", e),
374            Error::Client(e) => write!(f, "Client error: {}", e),
375            Error::Response(e) => write!(f, "Response error: {}", e),
376            Error::Api(e) => write!(f, "API error: {}", e),
377            Error::Serialize(e) => write!(f, "Serialize error: {}", e),
378            Error::Unauthorized => write!(f, "Unauthorized"),
379            Error::NotFound => write!(f, "Not found"),
380            Error::BadRequest => write!(f, "Bad request"),
381        }
382    }
383}
384
385impl Display for DnsRecordType {
386    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
387        write!(f, "{:?}", self)
388    }
389}