Skip to main content

dns_update/
utils.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector, TsigAlgorithm,
15};
16use std::{
17    borrow::Cow,
18    fmt::{self, Display, Formatter},
19    str::FromStr,
20};
21
22/// Strip `name` from `origin`, return `return_if_equal` if `name` is the same
23/// as `origin`, or  `@` if `None` given.
24pub(crate) fn strip_origin_from_name(
25    name: &str,
26    origin: &str,
27    return_if_equal: Option<&str>,
28) -> String {
29    let name = name.trim_end_matches('.');
30    let origin = origin.trim_end_matches('.');
31
32    if name == origin {
33        return return_if_equal.unwrap_or("@").to_string();
34    }
35
36    if name.ends_with(&format!(".{}", origin)) {
37        name[..name.len() - origin.len() - 1].to_string()
38    } else {
39        name.to_string()
40    }
41}
42
43impl fmt::Display for TLSARecord {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
45        write!(
46            f,
47            "{} {} {} ",
48            u8::from(self.cert_usage),
49            u8::from(self.selector),
50            u8::from(self.matching),
51        )?;
52
53        for ch in &self.cert_data {
54            write!(f, "{:02x}", ch)?;
55        }
56
57        Ok(())
58    }
59}
60
61impl fmt::Display for KeyValue {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
63        f.write_str(&self.key)?;
64        if !self.value.is_empty() {
65            write!(f, "={}", self.value)?;
66        }
67
68        Ok(())
69    }
70}
71
72impl CAARecord {
73    pub fn decompose(self) -> (u8, String, String) {
74        match self {
75            CAARecord::Issue {
76                issuer_critical,
77                name,
78                options,
79            } => {
80                let flags = if issuer_critical { 128 } else { 0 };
81                let mut value = name.unwrap_or_default();
82                for opt in &options {
83                    use std::fmt::Write;
84                    write!(value, "; {}", opt).unwrap();
85                }
86                (flags, "issue".to_string(), value)
87            }
88            CAARecord::IssueWild {
89                issuer_critical,
90                name,
91                options,
92            } => {
93                let flags = if issuer_critical { 128 } else { 0 };
94                let mut value = name.unwrap_or_default();
95                for opt in &options {
96                    use std::fmt::Write;
97                    write!(value, "; {}", opt).unwrap();
98                }
99                (flags, "issuewild".to_string(), value)
100            }
101            CAARecord::Iodef {
102                issuer_critical,
103                url,
104            } => {
105                let flags = if issuer_critical { 128 } else { 0 };
106                (flags, "iodef".to_string(), url)
107            }
108        }
109    }
110}
111
112impl fmt::Display for CAARecord {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
114        match self {
115            CAARecord::Issue {
116                issuer_critical,
117                name,
118                options,
119            } => {
120                if *issuer_critical {
121                    f.write_str("128 ")?;
122                } else {
123                    f.write_str("0 ")?;
124                }
125                f.write_str("issue ")?;
126                f.write_str("\"")?;
127                if let Some(name) = name {
128                    f.write_str(name)?;
129                }
130                for opt in options {
131                    write!(f, ";{}", opt)?;
132                }
133                f.write_str("\"")?;
134            }
135            CAARecord::IssueWild {
136                issuer_critical,
137                name,
138                options,
139            } => {
140                if *issuer_critical {
141                    f.write_str("128 ")?;
142                } else {
143                    f.write_str("0 ")?;
144                }
145                f.write_str("issuewild ")?;
146                f.write_str("\"")?;
147                if let Some(name) = name {
148                    f.write_str(name)?;
149                }
150                for opt in options {
151                    write!(f, ";{}", opt)?;
152                }
153                f.write_str("\"")?;
154            }
155            CAARecord::Iodef {
156                issuer_critical,
157                url,
158            } => {
159                if *issuer_critical {
160                    f.write_str("128 ")?;
161                } else {
162                    f.write_str("0 ")?;
163                }
164                f.write_str("iodef ")?;
165                f.write_str("\"")?;
166                f.write_str(url)?;
167                f.write_str("\"")?;
168            }
169        }
170        Ok(())
171    }
172}
173
174impl Display for MXRecord {
175    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
176        write!(f, "{} {}", self.priority, self.exchange)
177    }
178}
179
180impl Display for SRVRecord {
181    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
182        write!(
183            f,
184            "{} {} {} {}",
185            self.priority, self.weight, self.port, self.target
186        )
187    }
188}
189
190impl DnsRecord {
191    pub fn as_type(&self) -> DnsRecordType {
192        match self {
193            DnsRecord::A { .. } => DnsRecordType::A,
194            DnsRecord::AAAA { .. } => DnsRecordType::AAAA,
195            DnsRecord::CNAME { .. } => DnsRecordType::CNAME,
196            DnsRecord::NS { .. } => DnsRecordType::NS,
197            DnsRecord::MX { .. } => DnsRecordType::MX,
198            DnsRecord::TXT { .. } => DnsRecordType::TXT,
199            DnsRecord::SRV { .. } => DnsRecordType::SRV,
200            DnsRecord::TLSA { .. } => DnsRecordType::TLSA,
201            DnsRecord::CAA { .. } => DnsRecordType::CAA,
202        }
203    }
204}
205
206impl DnsRecordType {
207    pub fn as_str(&self) -> &'static str {
208        match self {
209            DnsRecordType::A => "A",
210            DnsRecordType::AAAA => "AAAA",
211            DnsRecordType::CNAME => "CNAME",
212            DnsRecordType::NS => "NS",
213            DnsRecordType::MX => "MX",
214            DnsRecordType::TXT => "TXT",
215            DnsRecordType::SRV => "SRV",
216            DnsRecordType::TLSA => "TLSA",
217            DnsRecordType::CAA => "CAA",
218        }
219    }
220}
221
222impl From<TlsaCertUsage> for u8 {
223    fn from(usage: TlsaCertUsage) -> Self {
224        match usage {
225            TlsaCertUsage::PkixTa => 0,
226            TlsaCertUsage::PkixEe => 1,
227            TlsaCertUsage::DaneTa => 2,
228            TlsaCertUsage::DaneEe => 3,
229            TlsaCertUsage::Private => 255,
230        }
231    }
232}
233
234impl From<TlsaSelector> for u8 {
235    fn from(selector: TlsaSelector) -> Self {
236        match selector {
237            TlsaSelector::Full => 0,
238            TlsaSelector::Spki => 1,
239            TlsaSelector::Private => 255,
240        }
241    }
242}
243
244impl From<TlsaMatching> for u8 {
245    fn from(matching: TlsaMatching) -> Self {
246        match matching {
247            TlsaMatching::Raw => 0,
248            TlsaMatching::Sha256 => 1,
249            TlsaMatching::Sha512 => 2,
250            TlsaMatching::Private => 255,
251        }
252    }
253}
254
255impl<'x> IntoFqdn<'x> for &'x str {
256    fn into_fqdn(self) -> Cow<'x, str> {
257        if self.ends_with('.') {
258            Cow::Borrowed(self)
259        } else {
260            Cow::Owned(format!("{}.", self))
261        }
262    }
263
264    fn into_name(self) -> Cow<'x, str> {
265        if let Some(name) = self.strip_suffix('.') {
266            Cow::Borrowed(name)
267        } else {
268            Cow::Borrowed(self)
269        }
270    }
271}
272
273impl<'x> IntoFqdn<'x> for &'x String {
274    fn into_fqdn(self) -> Cow<'x, str> {
275        self.as_str().into_fqdn()
276    }
277
278    fn into_name(self) -> Cow<'x, str> {
279        self.as_str().into_name()
280    }
281}
282
283impl<'x> IntoFqdn<'x> for String {
284    fn into_fqdn(self) -> Cow<'x, str> {
285        if self.ends_with('.') {
286            Cow::Owned(self)
287        } else {
288            Cow::Owned(format!("{}.", self))
289        }
290    }
291
292    fn into_name(self) -> Cow<'x, str> {
293        if let Some(name) = self.strip_suffix('.') {
294            Cow::Owned(name.to_string())
295        } else {
296            Cow::Owned(self)
297        }
298    }
299}
300
301impl FromStr for TsigAlgorithm {
302    type Err = ();
303
304    fn from_str(s: &str) -> std::prelude::v1::Result<Self, Self::Err> {
305        match s {
306            "hmac-md5" => Ok(TsigAlgorithm::HmacMd5),
307            "gss" => Ok(TsigAlgorithm::Gss),
308            "hmac-sha1" => Ok(TsigAlgorithm::HmacSha1),
309            "hmac-sha224" => Ok(TsigAlgorithm::HmacSha224),
310            "hmac-sha256" => Ok(TsigAlgorithm::HmacSha256),
311            "hmac-sha256-128" => Ok(TsigAlgorithm::HmacSha256_128),
312            "hmac-sha384" => Ok(TsigAlgorithm::HmacSha384),
313            "hmac-sha384-192" => Ok(TsigAlgorithm::HmacSha384_192),
314            "hmac-sha512" => Ok(TsigAlgorithm::HmacSha512),
315            "hmac-sha512-256" => Ok(TsigAlgorithm::HmacSha512_256),
316            _ => Err(()),
317        }
318    }
319}
320
321impl std::error::Error for Error {}
322
323impl Display for Error {
324    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
325        match self {
326            Error::Protocol(e) => write!(f, "Protocol error: {}", e),
327            Error::Parse(e) => write!(f, "Parse error: {}", e),
328            Error::Client(e) => write!(f, "Client error: {}", e),
329            Error::Response(e) => write!(f, "Response error: {}", e),
330            Error::Api(e) => write!(f, "API error: {}", e),
331            Error::Serialize(e) => write!(f, "Serialize error: {}", e),
332            Error::Unauthorized => write!(f, "Unauthorized"),
333            Error::NotFound => write!(f, "Not found"),
334            Error::BadRequest => write!(f, "Bad request"),
335        }
336    }
337}
338
339impl Display for DnsRecordType {
340    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
341        write!(f, "{:?}", self)
342    }
343}