Skip to main content

dns_update/providers/
safedns.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    http::{HttpClient, HttpClientBuilder},
15    utils::txt_chunks_to_text,
16};
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19
20const DEFAULT_API_ENDPOINT: &str = "https://api.ukfast.io/safedns/v1";
21const LIST_PAGE_SIZE: u32 = 200;
22
23#[derive(Clone)]
24pub struct SafeDnsProvider {
25    client: HttpClient,
26    endpoint: String,
27}
28
29#[derive(Serialize, Debug, Clone)]
30pub struct SafeDnsRecordPayload<'a> {
31    pub name: &'a str,
32    #[serde(rename = "type")]
33    pub record_type: &'a str,
34    pub content: String,
35    pub ttl: u32,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub priority: Option<u16>,
38}
39
40#[derive(Deserialize, Debug, Clone)]
41pub struct SafeDnsRecord {
42    pub id: i64,
43    pub name: String,
44    #[serde(rename = "type")]
45    pub record_type: String,
46    #[serde(default)]
47    pub content: String,
48    #[serde(default)]
49    pub priority: Option<u16>,
50}
51
52#[derive(Deserialize, Debug)]
53pub struct ListRecordsResponse {
54    pub data: Vec<SafeDnsRecord>,
55    #[serde(default)]
56    pub meta: ListMeta,
57}
58
59#[derive(Deserialize, Debug, Default)]
60pub struct ListMeta {
61    #[serde(default)]
62    pub pagination: Pagination,
63}
64
65#[derive(Deserialize, Debug, Default)]
66pub struct Pagination {
67    #[serde(default)]
68    pub total_pages: u32,
69}
70
71#[derive(Deserialize, Debug)]
72pub struct AddRecordResponse {
73    #[allow(dead_code)]
74    pub data: SafeDnsRecord,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct SafeDnsRecordContent {
79    pub record_type: &'static str,
80    pub content: String,
81    pub priority: Option<u16>,
82}
83
84impl SafeDnsProvider {
85    pub(crate) fn new(auth_token: impl AsRef<str>, timeout: Option<Duration>) -> Self {
86        let client = HttpClientBuilder::default()
87            .with_header("Authorization", auth_token.as_ref())
88            .with_timeout(timeout)
89            .build();
90        Self {
91            client,
92            endpoint: DEFAULT_API_ENDPOINT.to_string(),
93        }
94    }
95
96    #[cfg(test)]
97    pub(crate) fn with_endpoint(self, endpoint: impl AsRef<str>) -> Self {
98        Self {
99            endpoint: endpoint.as_ref().to_string(),
100            ..self
101        }
102    }
103
104    pub(crate) async fn set_rrset(
105        &self,
106        name: impl IntoFqdn<'_>,
107        record_type: DnsRecordType,
108        ttl: u32,
109        records: Vec<DnsRecord>,
110        origin: impl IntoFqdn<'_>,
111    ) -> crate::Result<()> {
112        reject_unsupported(record_type)?;
113        let fqdn = name.into_name().into_owned();
114        let zone = origin.into_name().into_owned();
115        let desired = build_contents(record_type, records)?;
116        let existing = self.list_at(&zone, &fqdn, record_type).await?;
117
118        let mut existing_pool: Vec<SafeDnsRecord> = existing;
119        let mut to_add: Vec<SafeDnsRecordContent> = Vec::new();
120
121        for content in desired {
122            if let Some(idx) = existing_pool
123                .iter()
124                .position(|r| record_matches(r, &content))
125            {
126                existing_pool.swap_remove(idx);
127            } else {
128                to_add.push(content);
129            }
130        }
131
132        for entry in existing_pool {
133            self.delete_record(&zone, entry.id).await?;
134        }
135        for content in to_add {
136            self.create_record(&zone, &fqdn, ttl, &content).await?;
137        }
138        Ok(())
139    }
140
141    pub(crate) async fn add_to_rrset(
142        &self,
143        name: impl IntoFqdn<'_>,
144        record_type: DnsRecordType,
145        ttl: u32,
146        records: Vec<DnsRecord>,
147        origin: impl IntoFqdn<'_>,
148    ) -> crate::Result<()> {
149        reject_unsupported(record_type)?;
150        if records.is_empty() {
151            return Ok(());
152        }
153        let fqdn = name.into_name().into_owned();
154        let zone = origin.into_name().into_owned();
155        let desired = build_contents(record_type, records)?;
156        let existing = self.list_at(&zone, &fqdn, record_type).await?;
157
158        for content in desired {
159            if existing.iter().any(|r| record_matches(r, &content)) {
160                continue;
161            }
162            self.create_record(&zone, &fqdn, ttl, &content).await?;
163        }
164        Ok(())
165    }
166
167    pub(crate) async fn remove_from_rrset(
168        &self,
169        name: impl IntoFqdn<'_>,
170        record_type: DnsRecordType,
171        records: Vec<DnsRecord>,
172        origin: impl IntoFqdn<'_>,
173    ) -> crate::Result<()> {
174        reject_unsupported(record_type)?;
175        if records.is_empty() {
176            return Ok(());
177        }
178        let fqdn = name.into_name().into_owned();
179        let zone = origin.into_name().into_owned();
180        let to_remove = build_contents(record_type, records)?;
181        let existing = self.list_at(&zone, &fqdn, record_type).await?;
182
183        for content in to_remove {
184            if let Some(entry) = existing.iter().find(|r| record_matches(r, &content)) {
185                self.delete_record(&zone, entry.id).await?;
186            }
187        }
188        Ok(())
189    }
190
191    pub(crate) async fn list_rrset(
192        &self,
193        name: impl IntoFqdn<'_>,
194        record_type: DnsRecordType,
195        origin: impl IntoFqdn<'_>,
196    ) -> crate::Result<Vec<DnsRecord>> {
197        let fqdn = name.into_name().into_owned();
198        let zone = origin.into_name().into_owned();
199        let listed = self.list_at(&zone, &fqdn, record_type).await?;
200        listed
201            .into_iter()
202            .map(|r| safedns_record_to_dns_record(r, record_type))
203            .collect()
204    }
205
206    async fn list_at(
207        &self,
208        zone: &str,
209        name: &str,
210        record_type: DnsRecordType,
211    ) -> crate::Result<Vec<SafeDnsRecord>> {
212        let type_str = record_type.as_str();
213        let mut out: Vec<SafeDnsRecord> = Vec::new();
214        let mut page: u32 = 1;
215        loop {
216            let url = format!(
217                "{endpoint}/zones/{zone}/records?name:eq={name}&type:eq={type_str}&per_page={LIST_PAGE_SIZE}&page={page}",
218                endpoint = self.endpoint
219            );
220            let response: ListRecordsResponse = self.client.get(url).send_with_retry(3).await?;
221            let total_pages = response.meta.pagination.total_pages;
222            for record in response.data {
223                if record.name == name && record.record_type == type_str {
224                    out.push(record);
225                }
226            }
227            if total_pages <= page {
228                break;
229            }
230            page += 1;
231        }
232        Ok(out)
233    }
234
235    async fn create_record(
236        &self,
237        zone: &str,
238        name: &str,
239        ttl: u32,
240        content: &SafeDnsRecordContent,
241    ) -> crate::Result<()> {
242        let body = SafeDnsRecordPayload {
243            name,
244            record_type: content.record_type,
245            content: content.content.clone(),
246            ttl,
247            priority: content.priority,
248        };
249
250        self.client
251            .post(format!(
252                "{endpoint}/zones/{zone}/records",
253                endpoint = self.endpoint
254            ))
255            .with_body(&body)?
256            .send_with_retry::<serde_json::Value>(3)
257            .await
258            .map(|_| ())
259    }
260
261    async fn delete_record(&self, zone: &str, record_id: i64) -> crate::Result<()> {
262        self.client
263            .delete(format!(
264                "{endpoint}/zones/{zone}/records/{record_id}",
265                endpoint = self.endpoint
266            ))
267            .send_with_retry::<serde_json::Value>(3)
268            .await
269            .map(|_| ())
270    }
271}
272
273fn reject_unsupported(record_type: DnsRecordType) -> crate::Result<()> {
274    if record_type == DnsRecordType::TLSA {
275        return Err(Error::Unsupported(
276            "TLSA records are not supported by SafeDNS".to_string(),
277        ));
278    }
279    Ok(())
280}
281
282fn build_contents(
283    expected_type: DnsRecordType,
284    records: Vec<DnsRecord>,
285) -> crate::Result<Vec<SafeDnsRecordContent>> {
286    let mut out = Vec::with_capacity(records.len());
287    for record in records {
288        if record.as_type() != expected_type {
289            return Err(Error::Api(format!(
290                "RRSet record type mismatch: expected {}, got {}",
291                expected_type.as_str(),
292                record.as_type().as_str(),
293            )));
294        }
295        out.push(SafeDnsRecordContent::try_from(record)?);
296    }
297    Ok(out)
298}
299
300fn record_matches(record: &SafeDnsRecord, content: &SafeDnsRecordContent) -> bool {
301    record.record_type == content.record_type
302        && record.content == content.content
303        && record.priority == content.priority
304}
305
306fn safedns_record_to_dns_record(
307    record: SafeDnsRecord,
308    record_type: DnsRecordType,
309) -> crate::Result<DnsRecord> {
310    match record_type {
311        DnsRecordType::A => record
312            .content
313            .parse()
314            .map(DnsRecord::A)
315            .map_err(|e| Error::Parse(format!("invalid A content {}: {e}", record.content))),
316        DnsRecordType::AAAA => record
317            .content
318            .parse()
319            .map(DnsRecord::AAAA)
320            .map_err(|e| Error::Parse(format!("invalid AAAA content {}: {e}", record.content))),
321        DnsRecordType::CNAME => Ok(DnsRecord::CNAME(record.content)),
322        DnsRecordType::NS => Ok(DnsRecord::NS(record.content)),
323        DnsRecordType::MX => Ok(DnsRecord::MX(MXRecord {
324            exchange: record.content,
325            priority: record.priority.unwrap_or(0),
326        })),
327        DnsRecordType::TXT => Ok(DnsRecord::TXT(unquote_txt(&record.content))),
328        DnsRecordType::SRV => parse_srv(&record.content, record.priority.unwrap_or(0)),
329        DnsRecordType::TLSA => Err(Error::Unsupported(
330            "TLSA records are not supported by SafeDNS".to_string(),
331        )),
332        DnsRecordType::CAA => parse_caa(&record.content),
333    }
334}
335
336fn parse_srv(content: &str, priority: u16) -> crate::Result<DnsRecord> {
337    let parts: Vec<&str> = content.split_whitespace().collect();
338    if parts.len() != 3 {
339        return Err(Error::Parse(format!(
340            "invalid SRV content (expected `<weight> <port> <target>`): {content}"
341        )));
342    }
343    let weight: u16 = parts[0]
344        .parse()
345        .map_err(|e| Error::Parse(format!("invalid SRV weight {}: {e}", parts[0])))?;
346    let port: u16 = parts[1]
347        .parse()
348        .map_err(|e| Error::Parse(format!("invalid SRV port {}: {e}", parts[1])))?;
349    Ok(DnsRecord::SRV(SRVRecord {
350        priority,
351        weight,
352        port,
353        target: parts[2].to_string(),
354    }))
355}
356
357fn unquote_txt(content: &str) -> String {
358    let mut out = String::with_capacity(content.len());
359    let mut chars = content.chars().peekable();
360    let mut in_quote = false;
361    while let Some(ch) = chars.next() {
362        match ch {
363            '"' => {
364                in_quote = !in_quote;
365            }
366            '\\' => {
367                if let Some(next) = chars.next() {
368                    out.push(next);
369                }
370            }
371            ' ' if !in_quote => {}
372            _ => out.push(ch),
373        }
374    }
375    out
376}
377
378fn parse_caa(content: &str) -> crate::Result<DnsRecord> {
379    let trimmed = content.trim();
380    let (flags_str, rest) = trimmed
381        .split_once(char::is_whitespace)
382        .ok_or_else(|| Error::Parse(format!("invalid CAA content: {content}")))?;
383    let (tag, value_part) = rest
384        .trim_start()
385        .split_once(char::is_whitespace)
386        .ok_or_else(|| Error::Parse(format!("invalid CAA content: {content}")))?;
387    let flags: u8 = flags_str
388        .parse()
389        .map_err(|e| Error::Parse(format!("invalid CAA flags {flags_str}: {e}")))?;
390    let value = value_part
391        .trim()
392        .trim_start_matches('"')
393        .trim_end_matches('"')
394        .to_string();
395    let issuer_critical = flags & 0x80 != 0;
396    match tag.trim() {
397        "issue" => {
398            let (name, options) = parse_caa_value(&value);
399            Ok(DnsRecord::CAA(CAARecord::Issue {
400                issuer_critical,
401                name,
402                options,
403            }))
404        }
405        "issuewild" => {
406            let (name, options) = parse_caa_value(&value);
407            Ok(DnsRecord::CAA(CAARecord::IssueWild {
408                issuer_critical,
409                name,
410                options,
411            }))
412        }
413        "iodef" => Ok(DnsRecord::CAA(CAARecord::Iodef {
414            issuer_critical,
415            url: value,
416        })),
417        other => Err(Error::Parse(format!("unknown CAA tag: {other}"))),
418    }
419}
420
421fn parse_caa_value(value: &str) -> (Option<String>, Vec<KeyValue>) {
422    let mut parts = value.split(';').map(str::trim);
423    let name_part = parts.next().unwrap_or("").trim().to_string();
424    let name = if name_part.is_empty() {
425        None
426    } else {
427        Some(name_part)
428    };
429    let options = parts
430        .filter(|p| !p.is_empty())
431        .map(|p| match p.split_once('=') {
432            Some((k, v)) => KeyValue {
433                key: k.trim().to_string(),
434                value: v.trim().to_string(),
435            },
436            None => KeyValue {
437                key: p.trim().to_string(),
438                value: String::new(),
439            },
440        })
441        .collect();
442    (name, options)
443}
444
445impl TryFrom<DnsRecord> for SafeDnsRecordContent {
446    type Error = Error;
447
448    fn try_from(record: DnsRecord) -> Result<Self, Self::Error> {
449        match record {
450            DnsRecord::A(addr) => Ok(SafeDnsRecordContent {
451                record_type: "A",
452                content: addr.to_string(),
453                priority: None,
454            }),
455            DnsRecord::AAAA(addr) => Ok(SafeDnsRecordContent {
456                record_type: "AAAA",
457                content: addr.to_string(),
458                priority: None,
459            }),
460            DnsRecord::CNAME(target) => Ok(SafeDnsRecordContent {
461                record_type: "CNAME",
462                content: target,
463                priority: None,
464            }),
465            DnsRecord::NS(target) => Ok(SafeDnsRecordContent {
466                record_type: "NS",
467                content: target,
468                priority: None,
469            }),
470            DnsRecord::MX(mx) => Ok(SafeDnsRecordContent {
471                record_type: "MX",
472                content: mx.exchange,
473                priority: Some(mx.priority),
474            }),
475            DnsRecord::TXT(text) => {
476                let mut buf = String::new();
477                txt_chunks_to_text(&mut buf, &text, " ");
478                Ok(SafeDnsRecordContent {
479                    record_type: "TXT",
480                    content: buf,
481                    priority: None,
482                })
483            }
484            DnsRecord::SRV(srv) => Ok(SafeDnsRecordContent {
485                record_type: "SRV",
486                content: format!("{} {} {}", srv.weight, srv.port, srv.target),
487                priority: Some(srv.priority),
488            }),
489            DnsRecord::TLSA(_) => Err(Error::Unsupported(
490                "TLSA records are not supported by SafeDNS".to_string(),
491            )),
492            DnsRecord::CAA(caa) => {
493                let (flags, tag, value) = caa.decompose();
494                Ok(SafeDnsRecordContent {
495                    record_type: "CAA",
496                    content: format!("{flags} {tag} \"{value}\""),
497                    priority: None,
498                })
499            }
500        }
501    }
502}