Skip to main content

dns_update/
utils.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector, TsigAlgorithm,
15};
16use std::{
17    borrow::Cow,
18    fmt::{self, Display, Formatter},
19    str::FromStr,
20};
21
22/// Strip `name` from `origin`, return `return_if_equal` if `name` is the same
23/// as `origin`, or  `@` if `None` given.
24pub(crate) fn strip_origin_from_name(
25    name: &str,
26    origin: &str,
27    return_if_equal: Option<&str>,
28) -> String {
29    let name = name.trim_end_matches('.');
30    let origin = origin.trim_end_matches('.');
31
32    if name == origin {
33        return return_if_equal.unwrap_or("@").to_string();
34    }
35
36    if name.ends_with(&format!(".{}", origin)) {
37        name[..name.len() - origin.len() - 1].to_string()
38    } else {
39        name.to_string()
40    }
41}
42
43impl fmt::Display for TLSARecord {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
45        write!(
46            f,
47            "{} {} {} ",
48            u8::from(self.cert_usage),
49            u8::from(self.selector),
50            u8::from(self.matching),
51        )?;
52
53        for ch in &self.cert_data {
54            write!(f, "{:02x}", ch)?;
55        }
56
57        Ok(())
58    }
59}
60
61impl fmt::Display for KeyValue {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
63        f.write_str(&self.key)?;
64        if !self.value.is_empty() {
65            write!(f, "={}", self.value)?;
66        }
67
68        Ok(())
69    }
70}
71
72impl CAARecord {
73    pub fn decompose(self) -> (u8, String, String) {
74        match self {
75            CAARecord::Issue {
76                issuer_critical,
77                name,
78                options,
79            } => {
80                let flags = if issuer_critical { 128 } else { 0 };
81                let mut value = name.unwrap_or_default();
82                for opt in &options {
83                    use std::fmt::Write;
84                    write!(value, "; {}", opt).unwrap();
85                }
86                (flags, "issue".to_string(), value)
87            }
88            CAARecord::IssueWild {
89                issuer_critical,
90                name,
91                options,
92            } => {
93                let flags = if issuer_critical { 128 } else { 0 };
94                let mut value = name.unwrap_or_default();
95                for opt in &options {
96                    use std::fmt::Write;
97                    write!(value, "; {}", opt).unwrap();
98                }
99                (flags, "issuewild".to_string(), value)
100            }
101            CAARecord::Iodef {
102                issuer_critical,
103                url,
104            } => {
105                let flags = if issuer_critical { 128 } else { 0 };
106                (flags, "iodef".to_string(), url)
107            }
108        }
109    }
110}
111
112impl fmt::Display for CAARecord {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
114        match self {
115            CAARecord::Issue {
116                issuer_critical,
117                name,
118                options,
119            } => {
120                if *issuer_critical {
121                    f.write_str("128 ")?;
122                } else {
123                    f.write_str("0 ")?;
124                }
125                f.write_str("issue ")?;
126                f.write_str("\"")?;
127                if let Some(name) = name {
128                    f.write_str(name)?;
129                }
130                for opt in options {
131                    write!(f, ";{}", opt)?;
132                }
133                f.write_str("\"")?;
134            }
135            CAARecord::IssueWild {
136                issuer_critical,
137                name,
138                options,
139            } => {
140                if *issuer_critical {
141                    f.write_str("128 ")?;
142                } else {
143                    f.write_str("0 ")?;
144                }
145                f.write_str("issuewild ")?;
146                f.write_str("\"")?;
147                if let Some(name) = name {
148                    f.write_str(name)?;
149                }
150                for opt in options {
151                    write!(f, ";{}", opt)?;
152                }
153                f.write_str("\"")?;
154            }
155            CAARecord::Iodef {
156                issuer_critical,
157                url,
158            } => {
159                if *issuer_critical {
160                    f.write_str("128 ")?;
161                } else {
162                    f.write_str("0 ")?;
163                }
164                f.write_str("iodef ")?;
165                f.write_str("\"")?;
166                f.write_str(url)?;
167                f.write_str("\"")?;
168            }
169        }
170        Ok(())
171    }
172}
173
174impl Display for MXRecord {
175    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
176        write!(f, "{} {}", self.priority, self.exchange)
177    }
178}
179
180impl Display for SRVRecord {
181    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
182        write!(
183            f,
184            "{} {} {} {}",
185            self.priority, self.weight, self.port, self.target
186        )
187    }
188}
189
190impl DnsRecord {
191    pub fn as_type(&self) -> DnsRecordType {
192        match self {
193            DnsRecord::A { .. } => DnsRecordType::A,
194            DnsRecord::AAAA { .. } => DnsRecordType::AAAA,
195            DnsRecord::CNAME { .. } => DnsRecordType::CNAME,
196            DnsRecord::NS { .. } => DnsRecordType::NS,
197            DnsRecord::MX { .. } => DnsRecordType::MX,
198            DnsRecord::TXT { .. } => DnsRecordType::TXT,
199            DnsRecord::SRV { .. } => DnsRecordType::SRV,
200            DnsRecord::TLSA { .. } => DnsRecordType::TLSA,
201            DnsRecord::CAA { .. } => DnsRecordType::CAA,
202        }
203    }
204}
205
206impl From<TlsaCertUsage> for u8 {
207    fn from(usage: TlsaCertUsage) -> Self {
208        match usage {
209            TlsaCertUsage::PkixTa => 0,
210            TlsaCertUsage::PkixEe => 1,
211            TlsaCertUsage::DaneTa => 2,
212            TlsaCertUsage::DaneEe => 3,
213            TlsaCertUsage::Private => 255,
214        }
215    }
216}
217
218impl From<TlsaSelector> for u8 {
219    fn from(selector: TlsaSelector) -> Self {
220        match selector {
221            TlsaSelector::Full => 0,
222            TlsaSelector::Spki => 1,
223            TlsaSelector::Private => 255,
224        }
225    }
226}
227
228impl From<TlsaMatching> for u8 {
229    fn from(matching: TlsaMatching) -> Self {
230        match matching {
231            TlsaMatching::Raw => 0,
232            TlsaMatching::Sha256 => 1,
233            TlsaMatching::Sha512 => 2,
234            TlsaMatching::Private => 255,
235        }
236    }
237}
238
239impl<'x> IntoFqdn<'x> for &'x str {
240    fn into_fqdn(self) -> Cow<'x, str> {
241        if self.ends_with('.') {
242            Cow::Borrowed(self)
243        } else {
244            Cow::Owned(format!("{}.", self))
245        }
246    }
247
248    fn into_name(self) -> Cow<'x, str> {
249        if let Some(name) = self.strip_suffix('.') {
250            Cow::Borrowed(name)
251        } else {
252            Cow::Borrowed(self)
253        }
254    }
255}
256
257impl<'x> IntoFqdn<'x> for &'x String {
258    fn into_fqdn(self) -> Cow<'x, str> {
259        self.as_str().into_fqdn()
260    }
261
262    fn into_name(self) -> Cow<'x, str> {
263        self.as_str().into_name()
264    }
265}
266
267impl<'x> IntoFqdn<'x> for String {
268    fn into_fqdn(self) -> Cow<'x, str> {
269        if self.ends_with('.') {
270            Cow::Owned(self)
271        } else {
272            Cow::Owned(format!("{}.", self))
273        }
274    }
275
276    fn into_name(self) -> Cow<'x, str> {
277        if let Some(name) = self.strip_suffix('.') {
278            Cow::Owned(name.to_string())
279        } else {
280            Cow::Owned(self)
281        }
282    }
283}
284
285impl FromStr for TsigAlgorithm {
286    type Err = ();
287
288    fn from_str(s: &str) -> std::prelude::v1::Result<Self, Self::Err> {
289        match s {
290            "hmac-md5" => Ok(TsigAlgorithm::HmacMd5),
291            "gss" => Ok(TsigAlgorithm::Gss),
292            "hmac-sha1" => Ok(TsigAlgorithm::HmacSha1),
293            "hmac-sha224" => Ok(TsigAlgorithm::HmacSha224),
294            "hmac-sha256" => Ok(TsigAlgorithm::HmacSha256),
295            "hmac-sha256-128" => Ok(TsigAlgorithm::HmacSha256_128),
296            "hmac-sha384" => Ok(TsigAlgorithm::HmacSha384),
297            "hmac-sha384-192" => Ok(TsigAlgorithm::HmacSha384_192),
298            "hmac-sha512" => Ok(TsigAlgorithm::HmacSha512),
299            "hmac-sha512-256" => Ok(TsigAlgorithm::HmacSha512_256),
300            _ => Err(()),
301        }
302    }
303}
304
305impl std::error::Error for Error {}
306
307impl Display for Error {
308    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
309        match self {
310            Error::Protocol(e) => write!(f, "Protocol error: {}", e),
311            Error::Parse(e) => write!(f, "Parse error: {}", e),
312            Error::Client(e) => write!(f, "Client error: {}", e),
313            Error::Response(e) => write!(f, "Response error: {}", e),
314            Error::Api(e) => write!(f, "API error: {}", e),
315            Error::Serialize(e) => write!(f, "Serialize error: {}", e),
316            Error::Unauthorized => write!(f, "Unauthorized"),
317            Error::NotFound => write!(f, "Not found"),
318            Error::BadRequest => write!(f, "Bad request"),
319        }
320    }
321}
322
323impl Display for DnsRecordType {
324    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
325        write!(f, "{:?}", self)
326    }
327}