Skip to main content

dns_update/providers/
rfc2136.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::utils::txt_chunks;
13use crate::{
14    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue as DnsKeyValue, MXRecord,
15    SRVRecord, TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector,
16};
17use hickory_net::NetError;
18use hickory_net::client::{Client, ClientHandle};
19use hickory_net::runtime::TokioRuntimeProvider;
20use hickory_net::tcp::TcpClientStream;
21use hickory_net::udp::UdpClientStream;
22use hickory_net::xfer::DnsMultiplexer;
23use hickory_proto::ProtoError;
24use hickory_proto::dnssec::DnsSecError;
25use hickory_proto::op::ResponseCode;
26use hickory_proto::rr::rdata::caa::KeyValue;
27use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
28use hickory_proto::rr::rdata::tsig::TsigAlgorithm;
29use hickory_proto::rr::rdata::{A, AAAA, CAA, CNAME, MX, NS, SRV, TLSA, TXT};
30use hickory_proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType, TSigner};
31use std::net::{AddrParseError, SocketAddr};
32
33#[derive(Clone)]
34pub struct Rfc2136Provider {
35    addr: DnsAddress,
36    signer: Option<TSigner>,
37}
38
39#[derive(Clone, Copy, Debug, PartialEq, Eq)]
40pub enum DnsAddress {
41    Tcp(SocketAddr),
42    Udp(SocketAddr),
43}
44
45impl Rfc2136Provider {
46    pub(crate) fn new_tsig(
47        addr: impl TryInto<DnsAddress>,
48        key_name: impl AsRef<str>,
49        key: impl Into<Vec<u8>>,
50        algorithm: TsigAlgorithm,
51    ) -> crate::Result<Self> {
52        Ok(Rfc2136Provider {
53            addr: addr
54                .try_into()
55                .map_err(|_| Error::Parse("Invalid address".to_string()))?,
56            signer: Some(TSigner::new(
57                key.into(),
58                algorithm,
59                Name::from_ascii(key_name.as_ref())?,
60                60,
61            )?),
62        })
63    }
64
65    async fn connect(&self) -> crate::Result<Client<TokioRuntimeProvider>> {
66        self.connect_inner(self.signer.as_ref()).await
67    }
68
69    async fn connect_unsigned(&self) -> crate::Result<Client<TokioRuntimeProvider>> {
70        self.connect_inner(None).await
71    }
72
73    async fn connect_inner(
74        &self,
75        signer: Option<&TSigner>,
76    ) -> crate::Result<Client<TokioRuntimeProvider>> {
77        match &self.addr {
78            DnsAddress::Udp(addr) => {
79                let mut builder = UdpClientStream::builder(*addr, TokioRuntimeProvider::new());
80                if let Some(signer) = signer {
81                    builder = builder.with_signer(Some(signer.clone()));
82                }
83                let stream = builder.build();
84                let (client, bg) = Client::from_sender(stream);
85                tokio::spawn(bg);
86                Ok(client)
87            }
88            DnsAddress::Tcp(addr) => {
89                let (stream_future, sender) =
90                    TcpClientStream::new(*addr, None, None, TokioRuntimeProvider::new());
91                let stream = stream_future.await?;
92                let mut multiplexer = DnsMultiplexer::new(stream, sender);
93                if let Some(signer) = signer {
94                    multiplexer = multiplexer.with_signer(signer.clone());
95                }
96                let (client, bg) = Client::from_sender(multiplexer);
97                tokio::spawn(bg);
98                Ok(client)
99            }
100        }
101    }
102
103    pub(crate) async fn set_rrset(
104        &self,
105        name: impl IntoFqdn<'_>,
106        record_type: DnsRecordType,
107        ttl: u32,
108        records: Vec<DnsRecord>,
109        origin: impl IntoFqdn<'_>,
110    ) -> crate::Result<()> {
111        let owner = Name::from_str_relaxed(name.into_name().as_ref())?;
112        let zone = Name::from_str_relaxed(origin.into_fqdn().as_ref())?;
113        let rtype: RecordType = record_type.into();
114
115        let mut client = self.connect().await?;
116
117        let delete = Record::update0(owner.clone(), 0, rtype);
118        let result = client.delete_rrset(delete, zone.clone()).await?;
119        if result.response_code != ResponseCode::NoError {
120            return Err(Error::Response(result.response_code.to_string()));
121        }
122
123        if records.is_empty() {
124            return Ok(());
125        }
126
127        let rrset = build_rrset(owner, rtype, ttl, records)?;
128        let result = client.append(rrset, zone, false).await?;
129        if result.response_code != ResponseCode::NoError {
130            return Err(Error::Response(result.response_code.to_string()));
131        }
132        Ok(())
133    }
134
135    pub(crate) async fn add_to_rrset(
136        &self,
137        name: impl IntoFqdn<'_>,
138        record_type: DnsRecordType,
139        ttl: u32,
140        records: Vec<DnsRecord>,
141        origin: impl IntoFqdn<'_>,
142    ) -> crate::Result<()> {
143        if records.is_empty() {
144            return Ok(());
145        }
146        let owner = Name::from_str_relaxed(name.into_name().as_ref())?;
147        let zone = Name::from_str_relaxed(origin.into_fqdn().as_ref())?;
148        let rtype: RecordType = record_type.into();
149        let rrset = build_rrset(owner, rtype, ttl, records)?;
150
151        let mut client = self.connect().await?;
152        let result = client.append(rrset, zone, false).await?;
153        if result.response_code != ResponseCode::NoError {
154            return Err(Error::Response(result.response_code.to_string()));
155        }
156        Ok(())
157    }
158
159    pub(crate) async fn remove_from_rrset(
160        &self,
161        name: impl IntoFqdn<'_>,
162        record_type: DnsRecordType,
163        records: Vec<DnsRecord>,
164        origin: impl IntoFqdn<'_>,
165    ) -> crate::Result<()> {
166        if records.is_empty() {
167            return Ok(());
168        }
169        let owner = Name::from_str_relaxed(name.into_name().as_ref())?;
170        let zone = Name::from_str_relaxed(origin.into_fqdn().as_ref())?;
171        let rtype: RecordType = record_type.into();
172        let rrset = build_rrset(owner, rtype, 0, records)?;
173
174        let mut client = self.connect().await?;
175        let result = client.delete_by_rdata(rrset, zone).await?;
176        if result.response_code != ResponseCode::NoError {
177            return Err(Error::Response(result.response_code.to_string()));
178        }
179        Ok(())
180    }
181
182    pub(crate) async fn list_rrset(
183        &self,
184        name: impl IntoFqdn<'_>,
185        record_type: DnsRecordType,
186        _origin: impl IntoFqdn<'_>,
187    ) -> crate::Result<Vec<DnsRecord>> {
188        let owner = Name::from_str_relaxed(name.into_fqdn().as_ref())?;
189        let rtype: RecordType = record_type.into();
190
191        let mut client = self.connect_unsigned().await?;
192        let response = client.query(owner.clone(), DNSClass::IN, rtype).await?;
193        if response.response_code != ResponseCode::NoError
194            && response.response_code != ResponseCode::NXDomain
195        {
196            return Err(Error::Response(response.response_code.to_string()));
197        }
198
199        let mut out = Vec::new();
200        for record in response.answers.iter() {
201            if record.record_type() != rtype || record.name != owner {
202                continue;
203            }
204            out.push(rdata_to_dns_record(&record.data)?);
205        }
206        Ok(out)
207    }
208}
209
210fn rdata_to_dns_record(data: &RData) -> crate::Result<DnsRecord> {
211    Ok(match data {
212        RData::A(a) => DnsRecord::A(a.0),
213        RData::AAAA(aaaa) => DnsRecord::AAAA(aaaa.0),
214        RData::CNAME(cname) => DnsRecord::CNAME(strip_trailing_dot(&cname.0.to_utf8())),
215        RData::NS(ns) => DnsRecord::NS(strip_trailing_dot(&ns.0.to_utf8())),
216        RData::MX(mx) => DnsRecord::MX(MXRecord {
217            priority: mx.preference,
218            exchange: strip_trailing_dot(&mx.exchange.to_utf8()),
219        }),
220        RData::TXT(txt) => {
221            let combined: String = txt
222                .txt_data
223                .iter()
224                .map(|chunk| String::from_utf8_lossy(chunk).into_owned())
225                .collect();
226            DnsRecord::TXT(combined)
227        }
228        RData::SRV(srv) => DnsRecord::SRV(SRVRecord {
229            priority: srv.priority,
230            weight: srv.weight,
231            port: srv.port,
232            target: strip_trailing_dot(&srv.target.to_utf8()),
233        }),
234        RData::TLSA(tlsa) => DnsRecord::TLSA(TLSARecord {
235            cert_usage: tlsa_cert_usage_from(tlsa.cert_usage)?,
236            selector: tlsa_selector_from(tlsa.selector)?,
237            matching: tlsa_matching_from(tlsa.matching)?,
238            cert_data: tlsa.cert_data.clone(),
239        }),
240        RData::CAA(caa) => DnsRecord::CAA(caa_to_record(caa)?),
241        other => {
242            return Err(Error::Unsupported(format!(
243                "Unsupported RData type for list_rrset: {}",
244                other.record_type()
245            )));
246        }
247    })
248}
249
250fn strip_trailing_dot(s: &str) -> String {
251    s.strip_suffix('.').unwrap_or(s).to_string()
252}
253
254fn caa_to_record(caa: &CAA) -> crate::Result<CAARecord> {
255    let issuer_critical = caa.issuer_critical;
256    let value_text = String::from_utf8_lossy(&caa.value).into_owned();
257    match caa.tag.as_str() {
258        "issue" => {
259            let (name, options) = parse_caa_value(&value_text);
260            Ok(CAARecord::Issue {
261                issuer_critical,
262                name,
263                options,
264            })
265        }
266        "issuewild" => {
267            let (name, options) = parse_caa_value(&value_text);
268            Ok(CAARecord::IssueWild {
269                issuer_critical,
270                name,
271                options,
272            })
273        }
274        "iodef" => Ok(CAARecord::Iodef {
275            issuer_critical,
276            url: value_text,
277        }),
278        other => Err(Error::Unsupported(format!(
279            "Unsupported CAA tag for list_rrset: {other}"
280        ))),
281    }
282}
283
284fn parse_caa_value(value: &str) -> (Option<String>, Vec<DnsKeyValue>) {
285    let mut parts = value.split(';').map(str::trim);
286    let name_part = parts.next().unwrap_or("").trim().to_string();
287    let name = if name_part.is_empty() {
288        None
289    } else {
290        Some(name_part)
291    };
292    let options = parts
293        .filter(|p| !p.is_empty())
294        .map(|p| match p.split_once('=') {
295            Some((k, v)) => DnsKeyValue {
296                key: k.trim().to_string(),
297                value: v.trim().to_string(),
298            },
299            None => DnsKeyValue {
300                key: p.trim().to_string(),
301                value: String::new(),
302            },
303        })
304        .collect();
305    (name, options)
306}
307
308fn tlsa_cert_usage_from(usage: CertUsage) -> crate::Result<TlsaCertUsage> {
309    Ok(match usage {
310        CertUsage::PkixTa => TlsaCertUsage::PkixTa,
311        CertUsage::PkixEe => TlsaCertUsage::PkixEe,
312        CertUsage::DaneTa => TlsaCertUsage::DaneTa,
313        CertUsage::DaneEe => TlsaCertUsage::DaneEe,
314        CertUsage::Private => TlsaCertUsage::Private,
315        other => return Err(Error::Api(format!("Unknown TLSA cert usage: {other:?}"))),
316    })
317}
318
319fn tlsa_selector_from(sel: Selector) -> crate::Result<TlsaSelector> {
320    Ok(match sel {
321        Selector::Full => TlsaSelector::Full,
322        Selector::Spki => TlsaSelector::Spki,
323        Selector::Private => TlsaSelector::Private,
324        other => return Err(Error::Api(format!("Unknown TLSA selector: {other:?}"))),
325    })
326}
327
328fn tlsa_matching_from(m: Matching) -> crate::Result<TlsaMatching> {
329    Ok(match m {
330        Matching::Raw => TlsaMatching::Raw,
331        Matching::Sha256 => TlsaMatching::Sha256,
332        Matching::Sha512 => TlsaMatching::Sha512,
333        Matching::Private => TlsaMatching::Private,
334        other => return Err(Error::Api(format!("Unknown TLSA matching: {other:?}"))),
335    })
336}
337
338fn build_rrset(
339    name: Name,
340    rtype: RecordType,
341    ttl: u32,
342    records: Vec<DnsRecord>,
343) -> crate::Result<RecordSet> {
344    let mut rrset = RecordSet::with_ttl(name, rtype, ttl);
345    for record in records {
346        let (record_type, rdata) = convert_record(record)?;
347        if record_type != rtype {
348            return Err(Error::Api(format!(
349                "RRSet record type mismatch: expected {rtype}, got {record_type}"
350            )));
351        }
352        rrset.add_rdata(rdata);
353    }
354    Ok(rrset)
355}
356
357impl From<DnsRecordType> for RecordType {
358    fn from(record_type: DnsRecordType) -> Self {
359        match record_type {
360            DnsRecordType::A => RecordType::A,
361            DnsRecordType::AAAA => RecordType::AAAA,
362            DnsRecordType::CNAME => RecordType::CNAME,
363            DnsRecordType::NS => RecordType::NS,
364            DnsRecordType::MX => RecordType::MX,
365            DnsRecordType::TXT => RecordType::TXT,
366            DnsRecordType::SRV => RecordType::SRV,
367            DnsRecordType::TLSA => RecordType::TLSA,
368            DnsRecordType::CAA => RecordType::CAA,
369        }
370    }
371}
372
373fn convert_record(record: DnsRecord) -> crate::Result<(RecordType, RData)> {
374    Ok(match record {
375        DnsRecord::A(content) => (RecordType::A, RData::A(A::from(content))),
376        DnsRecord::AAAA(content) => (RecordType::AAAA, RData::AAAA(AAAA::from(content))),
377        DnsRecord::CNAME(content) => (
378            RecordType::CNAME,
379            RData::CNAME(CNAME(Name::from_str_relaxed(content)?)),
380        ),
381        DnsRecord::NS(content) => (
382            RecordType::NS,
383            RData::NS(NS(Name::from_str_relaxed(content)?)),
384        ),
385        DnsRecord::MX(content) => (
386            RecordType::MX,
387            RData::MX(MX::new(
388                content.priority,
389                Name::from_str_relaxed(content.exchange)?,
390            )),
391        ),
392        DnsRecord::TXT(content) => (RecordType::TXT, RData::TXT(TXT::new(txt_chunks(content)))),
393        DnsRecord::SRV(content) => (
394            RecordType::SRV,
395            RData::SRV(SRV::new(
396                content.priority,
397                content.weight,
398                content.port,
399                Name::from_str_relaxed(content.target)?,
400            )),
401        ),
402        DnsRecord::TLSA(content) => (
403            RecordType::TLSA,
404            RData::TLSA(TLSA::new(
405                content.cert_usage.into(),
406                content.selector.into(),
407                content.matching.into(),
408                content.cert_data,
409            )),
410        ),
411        DnsRecord::CAA(caa) => (
412            RecordType::CAA,
413            RData::CAA(match caa {
414                CAARecord::Issue {
415                    issuer_critical,
416                    name,
417                    options,
418                } => CAA::new_issue(
419                    issuer_critical,
420                    name.map(Name::from_str_relaxed).transpose()?,
421                    options
422                        .into_iter()
423                        .map(|kv| KeyValue::new(kv.key, kv.value))
424                        .collect(),
425                ),
426                CAARecord::IssueWild {
427                    issuer_critical,
428                    name,
429                    options,
430                } => CAA::new_issuewild(
431                    issuer_critical,
432                    name.map(Name::from_str_relaxed).transpose()?,
433                    options
434                        .into_iter()
435                        .map(|kv| KeyValue::new(kv.key, kv.value))
436                        .collect(),
437                ),
438                CAARecord::Iodef {
439                    issuer_critical,
440                    url,
441                } => CAA::new_iodef(
442                    issuer_critical,
443                    url.parse()
444                        .map_err(|_| Error::Parse("Invalid URL in CAA record".to_string()))?,
445                ),
446            }),
447        ),
448    })
449}
450
451impl From<TlsaCertUsage> for CertUsage {
452    fn from(usage: TlsaCertUsage) -> Self {
453        match usage {
454            TlsaCertUsage::PkixTa => CertUsage::PkixTa,
455            TlsaCertUsage::PkixEe => CertUsage::PkixEe,
456            TlsaCertUsage::DaneTa => CertUsage::DaneTa,
457            TlsaCertUsage::DaneEe => CertUsage::DaneEe,
458            TlsaCertUsage::Private => CertUsage::Private,
459        }
460    }
461}
462
463impl From<TlsaMatching> for Matching {
464    fn from(matching: TlsaMatching) -> Self {
465        match matching {
466            TlsaMatching::Raw => Matching::Raw,
467            TlsaMatching::Sha256 => Matching::Sha256,
468            TlsaMatching::Sha512 => Matching::Sha512,
469            TlsaMatching::Private => Matching::Private,
470        }
471    }
472}
473
474impl From<TlsaSelector> for Selector {
475    fn from(selector: TlsaSelector) -> Self {
476        match selector {
477            TlsaSelector::Full => Selector::Full,
478            TlsaSelector::Spki => Selector::Spki,
479            TlsaSelector::Private => Selector::Private,
480        }
481    }
482}
483
484impl TryFrom<&str> for DnsAddress {
485    type Error = ();
486
487    fn try_from(url: &str) -> Result<Self, Self::Error> {
488        let (host, is_tcp) = if let Some(host) = url.strip_prefix("udp://") {
489            (host, false)
490        } else if let Some(host) = url.strip_prefix("tcp://") {
491            (host, true)
492        } else {
493            (url, false)
494        };
495        let (host, port) = if let Some(host) = host.strip_prefix('[') {
496            let (host, maybe_port) = host.rsplit_once(']').ok_or(())?;
497
498            (
499                host,
500                maybe_port
501                    .rsplit_once(':')
502                    .map(|(_, port)| port)
503                    .unwrap_or("53"),
504            )
505        } else if let Some((host, port)) = host.rsplit_once(':') {
506            (host, port)
507        } else {
508            (host, "53")
509        };
510
511        let addr = SocketAddr::new(host.parse().map_err(|_| ())?, port.parse().map_err(|_| ())?);
512
513        if is_tcp {
514            Ok(DnsAddress::Tcp(addr))
515        } else {
516            Ok(DnsAddress::Udp(addr))
517        }
518    }
519}
520
521impl TryFrom<&String> for DnsAddress {
522    type Error = ();
523
524    fn try_from(url: &String) -> Result<Self, Self::Error> {
525        DnsAddress::try_from(url.as_str())
526    }
527}
528
529impl TryFrom<String> for DnsAddress {
530    type Error = ();
531
532    fn try_from(url: String) -> Result<Self, Self::Error> {
533        DnsAddress::try_from(url.as_str())
534    }
535}
536
537impl From<crate::TsigAlgorithm> for TsigAlgorithm {
538    fn from(alg: crate::TsigAlgorithm) -> Self {
539        match alg {
540            crate::TsigAlgorithm::HmacMd5 => TsigAlgorithm::HmacMd5,
541            crate::TsigAlgorithm::Gss => TsigAlgorithm::Gss,
542            crate::TsigAlgorithm::HmacSha1 => TsigAlgorithm::HmacSha1,
543            crate::TsigAlgorithm::HmacSha224 => TsigAlgorithm::HmacSha224,
544            crate::TsigAlgorithm::HmacSha256 => TsigAlgorithm::HmacSha256,
545            crate::TsigAlgorithm::HmacSha256_128 => TsigAlgorithm::HmacSha256_128,
546            crate::TsigAlgorithm::HmacSha384 => TsigAlgorithm::HmacSha384,
547            crate::TsigAlgorithm::HmacSha384_192 => TsigAlgorithm::HmacSha384_192,
548            crate::TsigAlgorithm::HmacSha512 => TsigAlgorithm::HmacSha512,
549            crate::TsigAlgorithm::HmacSha512_256 => TsigAlgorithm::HmacSha512_256,
550        }
551    }
552}
553
554impl From<ProtoError> for Error {
555    fn from(e: ProtoError) -> Self {
556        Error::Protocol(e.to_string())
557    }
558}
559
560impl From<AddrParseError> for Error {
561    fn from(e: AddrParseError) -> Self {
562        Error::Parse(e.to_string())
563    }
564}
565
566impl From<NetError> for Error {
567    fn from(e: NetError) -> Self {
568        Error::Client(e.to_string())
569    }
570}
571
572impl From<DnsSecError> for Error {
573    fn from(e: DnsSecError) -> Self {
574        Error::Protocol(e.to_string())
575    }
576}