Skip to main content

dns_update/providers/
rfc2136.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{CAARecord, DnsRecord, Error, IntoFqdn, TlsaCertUsage, TlsaMatching, TlsaSelector};
13use hickory_client::ClientError;
14use hickory_client::client::{Client, ClientHandle};
15use hickory_client::proto::ProtoError;
16use hickory_client::proto::dnssec::rdata::KEY;
17use hickory_client::proto::dnssec::rdata::tsig::TsigAlgorithm;
18use hickory_client::proto::dnssec::tsig::TSigner;
19use hickory_client::proto::dnssec::{Algorithm, DnsSecError, SigSigner, SigningKey};
20use hickory_client::proto::op::MessageFinalizer;
21use hickory_client::proto::op::ResponseCode;
22use hickory_client::proto::rr::rdata::caa::KeyValue;
23use hickory_client::proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
24use hickory_client::proto::rr::rdata::{A, AAAA, CAA, CNAME, MX, NS, SRV, TLSA, TXT};
25use hickory_client::proto::rr::{DNSClass, Name, RData, Record, RecordType};
26use hickory_client::proto::runtime::TokioRuntimeProvider;
27use hickory_client::proto::tcp::TcpClientStream;
28use hickory_client::proto::udp::UdpClientStream;
29use std::net::{AddrParseError, SocketAddr};
30use std::sync::Arc;
31
32#[derive(Clone)]
33pub struct Rfc2136Provider {
34    addr: DnsAddress,
35    signer: Arc<dyn MessageFinalizer>,
36}
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq)]
39pub enum DnsAddress {
40    Tcp(SocketAddr),
41    Udp(SocketAddr),
42}
43
44impl Rfc2136Provider {
45    pub(crate) fn new_tsig(
46        addr: impl TryInto<DnsAddress>,
47        key_name: impl AsRef<str>,
48        key: impl Into<Vec<u8>>,
49        algorithm: TsigAlgorithm,
50    ) -> crate::Result<Self> {
51        Ok(Rfc2136Provider {
52            addr: addr
53                .try_into()
54                .map_err(|_| Error::Parse("Invalid address".to_string()))?,
55            signer: Arc::new(TSigner::new(
56                key.into(),
57                algorithm,
58                Name::from_ascii(key_name.as_ref())?,
59                60,
60            )?),
61        })
62    }
63
64    pub(crate) fn new_sig0(
65        addr: impl TryInto<DnsAddress>,
66        signer_name: impl AsRef<str>,
67        key: Box<dyn SigningKey>,
68        public_key: impl Into<Vec<u8>>,
69        algorithm: Algorithm,
70    ) -> crate::Result<Self> {
71        let sig0key = KEY::new(
72            Default::default(),
73            Default::default(),
74            Default::default(),
75            Default::default(),
76            algorithm,
77            public_key.into(),
78        );
79
80        let signer = SigSigner::sig0(sig0key, key, Name::from_str_relaxed(signer_name.as_ref())?);
81
82        Ok(Rfc2136Provider {
83            addr: addr
84                .try_into()
85                .map_err(|_| Error::Parse("Invalid address".to_string()))?,
86            signer: Arc::new(signer),
87        })
88    }
89
90    async fn connect(&self) -> crate::Result<Client> {
91        match &self.addr {
92            DnsAddress::Udp(addr) => {
93                let stream = UdpClientStream::builder(*addr, TokioRuntimeProvider::new())
94                    .with_signer(Some(self.signer.clone()))
95                    .build();
96                let (client, bg) = Client::connect(stream).await?;
97                tokio::spawn(bg);
98                Ok(client)
99            }
100            DnsAddress::Tcp(addr) => {
101                let (stream, sender) =
102                    TcpClientStream::new(*addr, None, None, TokioRuntimeProvider::new());
103                let (client, bg) = Client::new(stream, sender, Some(self.signer.clone())).await?;
104                tokio::spawn(bg);
105                Ok(client)
106            }
107        }
108    }
109
110    pub(crate) async fn create(
111        &self,
112        name: impl IntoFqdn<'_>,
113        record: DnsRecord,
114        ttl: u32,
115        origin: impl IntoFqdn<'_>,
116    ) -> crate::Result<()> {
117        let (_rr_type, rdata) = convert_record(record)?;
118        let record = Record::from_rdata(
119            Name::from_str_relaxed(name.into_name().as_ref())?,
120            ttl,
121            rdata,
122        );
123
124        let mut client = self.connect().await?;
125        let result = client
126            .create(record, Name::from_str_relaxed(origin.into_fqdn().as_ref())?)
127            .await?;
128        if result.response_code() == ResponseCode::NoError {
129            Ok(())
130        } else {
131            Err(crate::Error::Response(result.response_code().to_string()))
132        }
133    }
134
135    pub(crate) async fn update(
136        &self,
137        name: impl IntoFqdn<'_>,
138        record: DnsRecord,
139        ttl: u32,
140        origin: impl IntoFqdn<'_>,
141    ) -> crate::Result<()> {
142        let (_rr_type, rdata) = convert_record(record)?;
143        let record = Record::from_rdata(
144            Name::from_str_relaxed(name.into_name().as_ref())?,
145            ttl,
146            rdata,
147        );
148
149        let mut client = self.connect().await?;
150        let result = client
151            .append(
152                record,
153                Name::from_str_relaxed(origin.into_fqdn().as_ref())?,
154                false,
155            )
156            .await?;
157        if result.response_code() == ResponseCode::NoError {
158            Ok(())
159        } else {
160            Err(crate::Error::Response(result.response_code().to_string()))
161        }
162    }
163
164    pub(crate) async fn delete(
165        &self,
166        name: impl IntoFqdn<'_>,
167        origin: impl IntoFqdn<'_>,
168    ) -> crate::Result<()> {
169        let mut client = self.connect().await?;
170        let result = client
171            .delete_all(
172                Name::from_str_relaxed(name.into_name().as_ref())?,
173                Name::from_str_relaxed(origin.into_fqdn().as_ref())?,
174                DNSClass::IN,
175            )
176            .await?;
177        if result.response_code() == ResponseCode::NoError {
178            Ok(())
179        } else {
180            Err(crate::Error::Response(result.response_code().to_string()))
181        }
182    }
183}
184
185fn convert_record(record: DnsRecord) -> crate::Result<(RecordType, RData)> {
186    Ok(match record {
187        DnsRecord::A(content) => (RecordType::A, RData::A(A::from(content))),
188        DnsRecord::AAAA(content) => (RecordType::AAAA, RData::AAAA(AAAA::from(content))),
189        DnsRecord::CNAME(content) => (
190            RecordType::CNAME,
191            RData::CNAME(CNAME(Name::from_str_relaxed(content)?)),
192        ),
193        DnsRecord::NS(content) => (
194            RecordType::NS,
195            RData::NS(NS(Name::from_str_relaxed(content)?)),
196        ),
197        DnsRecord::MX(content) => (
198            RecordType::MX,
199            RData::MX(MX::new(
200                content.priority,
201                Name::from_str_relaxed(content.exchange)?,
202            )),
203        ),
204        DnsRecord::TXT(content) => (RecordType::TXT, RData::TXT(TXT::new(vec![content]))),
205        DnsRecord::SRV(content) => (
206            RecordType::SRV,
207            RData::SRV(SRV::new(
208                content.priority,
209                content.weight,
210                content.port,
211                Name::from_str_relaxed(content.target)?,
212            )),
213        ),
214        DnsRecord::TLSA(content) => (
215            RecordType::TLSA,
216            RData::TLSA(TLSA::new(
217                content.cert_usage.into(),
218                content.selector.into(),
219                content.matching.into(),
220                content.cert_data,
221            )),
222        ),
223        DnsRecord::CAA(caa) => (
224            RecordType::CAA,
225            RData::CAA(match caa {
226                CAARecord::Issue {
227                    issuer_critical,
228                    name,
229                    options,
230                } => CAA::new_issue(
231                    issuer_critical,
232                    name.map(Name::from_str_relaxed).transpose()?,
233                    options
234                        .into_iter()
235                        .map(|kv| KeyValue::new(kv.key, kv.value))
236                        .collect(),
237                ),
238                CAARecord::IssueWild {
239                    issuer_critical,
240                    name,
241                    options,
242                } => CAA::new_issuewild(
243                    issuer_critical,
244                    name.map(Name::from_str_relaxed).transpose()?,
245                    options
246                        .into_iter()
247                        .map(|kv| KeyValue::new(kv.key, kv.value))
248                        .collect(),
249                ),
250                CAARecord::Iodef {
251                    issuer_critical,
252                    url,
253                } => CAA::new_iodef(
254                    issuer_critical,
255                    url.parse()
256                        .map_err(|_| Error::Parse("Invalid URL in CAA record".to_string()))?,
257                ),
258            }),
259        ),
260    })
261}
262
263impl From<TlsaCertUsage> for CertUsage {
264    fn from(usage: TlsaCertUsage) -> Self {
265        match usage {
266            TlsaCertUsage::PkixTa => CertUsage::PkixTa,
267            TlsaCertUsage::PkixEe => CertUsage::PkixEe,
268            TlsaCertUsage::DaneTa => CertUsage::DaneTa,
269            TlsaCertUsage::DaneEe => CertUsage::DaneEe,
270            TlsaCertUsage::Private => CertUsage::Private,
271        }
272    }
273}
274
275impl From<TlsaMatching> for Matching {
276    fn from(matching: TlsaMatching) -> Self {
277        match matching {
278            TlsaMatching::Raw => Matching::Raw,
279            TlsaMatching::Sha256 => Matching::Sha256,
280            TlsaMatching::Sha512 => Matching::Sha512,
281            TlsaMatching::Private => Matching::Private,
282        }
283    }
284}
285
286impl From<TlsaSelector> for Selector {
287    fn from(selector: TlsaSelector) -> Self {
288        match selector {
289            TlsaSelector::Full => Selector::Full,
290            TlsaSelector::Spki => Selector::Spki,
291            TlsaSelector::Private => Selector::Private,
292        }
293    }
294}
295
296impl TryFrom<&str> for DnsAddress {
297    type Error = ();
298
299    fn try_from(url: &str) -> Result<Self, Self::Error> {
300        let (host, is_tcp) = if let Some(host) = url.strip_prefix("udp://") {
301            (host, false)
302        } else if let Some(host) = url.strip_prefix("tcp://") {
303            (host, true)
304        } else {
305            (url, false)
306        };
307        let (host, port) = if let Some(host) = host.strip_prefix('[') {
308            let (host, maybe_port) = host.rsplit_once(']').ok_or(())?;
309
310            (
311                host,
312                maybe_port
313                    .rsplit_once(':')
314                    .map(|(_, port)| port)
315                    .unwrap_or("53"),
316            )
317        } else if let Some((host, port)) = host.rsplit_once(':') {
318            (host, port)
319        } else {
320            (host, "53")
321        };
322
323        let addr = SocketAddr::new(host.parse().map_err(|_| ())?, port.parse().map_err(|_| ())?);
324
325        if is_tcp {
326            Ok(DnsAddress::Tcp(addr))
327        } else {
328            Ok(DnsAddress::Udp(addr))
329        }
330    }
331}
332
333impl TryFrom<&String> for DnsAddress {
334    type Error = ();
335
336    fn try_from(url: &String) -> Result<Self, Self::Error> {
337        DnsAddress::try_from(url.as_str())
338    }
339}
340
341impl TryFrom<String> for DnsAddress {
342    type Error = ();
343
344    fn try_from(url: String) -> Result<Self, Self::Error> {
345        DnsAddress::try_from(url.as_str())
346    }
347}
348
349impl From<crate::TsigAlgorithm> for TsigAlgorithm {
350    fn from(alg: crate::TsigAlgorithm) -> Self {
351        match alg {
352            crate::TsigAlgorithm::HmacMd5 => TsigAlgorithm::HmacMd5,
353            crate::TsigAlgorithm::Gss => TsigAlgorithm::Gss,
354            crate::TsigAlgorithm::HmacSha1 => TsigAlgorithm::HmacSha1,
355            crate::TsigAlgorithm::HmacSha224 => TsigAlgorithm::HmacSha224,
356            crate::TsigAlgorithm::HmacSha256 => TsigAlgorithm::HmacSha256,
357            crate::TsigAlgorithm::HmacSha256_128 => TsigAlgorithm::HmacSha256_128,
358            crate::TsigAlgorithm::HmacSha384 => TsigAlgorithm::HmacSha384,
359            crate::TsigAlgorithm::HmacSha384_192 => TsigAlgorithm::HmacSha384_192,
360            crate::TsigAlgorithm::HmacSha512 => TsigAlgorithm::HmacSha512,
361            crate::TsigAlgorithm::HmacSha512_256 => TsigAlgorithm::HmacSha512_256,
362        }
363    }
364}
365
366impl From<crate::Algorithm> for Algorithm {
367    fn from(alg: crate::Algorithm) -> Self {
368        match alg {
369            crate::Algorithm::RSASHA256 => Algorithm::RSASHA256,
370            crate::Algorithm::RSASHA512 => Algorithm::RSASHA512,
371            crate::Algorithm::ECDSAP256SHA256 => Algorithm::ECDSAP256SHA256,
372            crate::Algorithm::ECDSAP384SHA384 => Algorithm::ECDSAP384SHA384,
373            crate::Algorithm::ED25519 => Algorithm::ED25519,
374        }
375    }
376}
377
378impl From<ProtoError> for Error {
379    fn from(e: ProtoError) -> Self {
380        Error::Protocol(e.to_string())
381    }
382}
383
384impl From<AddrParseError> for Error {
385    fn from(e: AddrParseError) -> Self {
386        Error::Parse(e.to_string())
387    }
388}
389
390impl From<ClientError> for Error {
391    fn from(e: ClientError) -> Self {
392        Error::Client(e.to_string())
393    }
394}
395
396impl From<DnsSecError> for Error {
397    fn from(e: DnsSecError) -> Self {
398        Error::Protocol(e.to_string())
399    }
400}