Skip to main content

dns_update/providers/
luadns.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector,
15    http::{HttpClient, HttpClientBuilder},
16    utils::txt_chunks_to_text,
17};
18use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
19use serde::{Deserialize, Serialize};
20use std::time::Duration;
21
22const DEFAULT_API_ENDPOINT: &str = "https://api.luadns.com";
23const PAGE_SIZE: u32 = 500;
24
25#[derive(Clone)]
26pub struct LuaDnsProvider {
27    client: HttpClient,
28    endpoint: String,
29}
30
31#[derive(Deserialize, Debug, Clone)]
32pub struct LuaZone {
33    pub id: i64,
34    pub name: String,
35}
36
37#[derive(Serialize, Deserialize, Debug, Clone)]
38pub struct LuaRecord {
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub id: Option<i64>,
41    pub name: String,
42    #[serde(rename = "type")]
43    pub rr_type: String,
44    pub content: String,
45    pub ttl: u32,
46    #[serde(default, skip_serializing_if = "Option::is_none")]
47    pub zone_id: Option<i64>,
48}
49
50impl LuaDnsProvider {
51    pub(crate) fn new(
52        api_username: impl AsRef<str>,
53        api_token: impl AsRef<str>,
54        timeout: Option<Duration>,
55    ) -> Self {
56        let raw = format!("{}:{}", api_username.as_ref(), api_token.as_ref());
57        let encoded = B64.encode(raw);
58        let client = HttpClientBuilder::default()
59            .with_header("Authorization", format!("Basic {encoded}"))
60            .with_header("Accept", "application/json")
61            .with_timeout(timeout)
62            .build();
63        Self {
64            client,
65            endpoint: DEFAULT_API_ENDPOINT.to_string(),
66        }
67    }
68
69    #[cfg(test)]
70    pub(crate) fn with_endpoint(self, endpoint: impl AsRef<str>) -> Self {
71        Self {
72            endpoint: endpoint.as_ref().to_string(),
73            ..self
74        }
75    }
76
77    async fn list_zones(&self) -> crate::Result<Vec<LuaZone>> {
78        let mut out: Vec<LuaZone> = Vec::new();
79        let mut page: u32 = 1;
80        loop {
81            let chunk: Vec<LuaZone> = self
82                .client
83                .get(format!(
84                    "{}/v1/zones?limit={PAGE_SIZE}&page={page}",
85                    self.endpoint
86                ))
87                .send_with_retry(3)
88                .await?;
89            let received = chunk.len();
90            out.extend(chunk);
91            if (received as u32) < PAGE_SIZE {
92                break;
93            }
94            page += 1;
95        }
96        Ok(out)
97    }
98
99    async fn find_zone(&self, origin: &str) -> crate::Result<LuaZone> {
100        let zones = self.list_zones().await?;
101        zones
102            .into_iter()
103            .find(|z| z.name.eq_ignore_ascii_case(origin))
104            .ok_or_else(|| Error::Api(format!("LuaDNS zone {origin} not found")))
105    }
106
107    async fn list_records(&self, zone_id: i64) -> crate::Result<Vec<LuaRecord>> {
108        let mut out: Vec<LuaRecord> = Vec::new();
109        let mut page: u32 = 1;
110        loop {
111            let chunk: Vec<LuaRecord> = self
112                .client
113                .get(format!(
114                    "{}/v1/zones/{zone_id}/records?limit={PAGE_SIZE}&page={page}",
115                    self.endpoint
116                ))
117                .send_with_retry(3)
118                .await?;
119            let received = chunk.len();
120            out.extend(chunk);
121            if (received as u32) < PAGE_SIZE {
122                break;
123            }
124            page += 1;
125        }
126        Ok(out)
127    }
128
129    async fn list_at(
130        &self,
131        zone_id: i64,
132        fqdn: &str,
133        record_type: DnsRecordType,
134    ) -> crate::Result<Vec<LuaRecord>> {
135        let target = fqdn.trim_end_matches('.');
136        let rr_type = record_type.as_str();
137        let records = self.list_records(zone_id).await?;
138        Ok(records
139            .into_iter()
140            .filter(|r| r.rr_type == rr_type && r.name.trim_end_matches('.') == target)
141            .collect())
142    }
143
144    pub(crate) async fn set_rrset(
145        &self,
146        name: impl IntoFqdn<'_>,
147        record_type: DnsRecordType,
148        ttl: u32,
149        records: Vec<DnsRecord>,
150        origin: impl IntoFqdn<'_>,
151    ) -> crate::Result<()> {
152        check_record_types(record_type, &records)?;
153        let origin_name = origin.into_name().to_string();
154        let zone = self.find_zone(&origin_name).await?;
155        let fqdn = name.into_fqdn().to_string();
156        let desired = build_contents(record_type, records)?;
157        let mut existing_pool = self.list_at(zone.id, &fqdn, record_type).await?;
158
159        let mut to_add = Vec::new();
160        for content in desired {
161            if let Some(idx) = existing_pool.iter().position(|r| r.content == content) {
162                existing_pool.swap_remove(idx);
163            } else {
164                to_add.push(content);
165            }
166        }
167
168        for entry in existing_pool {
169            if let Some(id) = entry.id {
170                self.delete_record(zone.id, id).await?;
171            }
172        }
173        for content in to_add {
174            self.create_raw(zone.id, &fqdn, record_type, ttl, content)
175                .await?;
176        }
177        Ok(())
178    }
179
180    pub(crate) async fn add_to_rrset(
181        &self,
182        name: impl IntoFqdn<'_>,
183        record_type: DnsRecordType,
184        ttl: u32,
185        records: Vec<DnsRecord>,
186        origin: impl IntoFqdn<'_>,
187    ) -> crate::Result<()> {
188        check_record_types(record_type, &records)?;
189        if records.is_empty() {
190            return Ok(());
191        }
192        let origin_name = origin.into_name().to_string();
193        let zone = self.find_zone(&origin_name).await?;
194        let fqdn = name.into_fqdn().to_string();
195        let desired = build_contents(record_type, records)?;
196        let existing = self.list_at(zone.id, &fqdn, record_type).await?;
197
198        for content in desired {
199            if existing.iter().any(|r| r.content == content) {
200                continue;
201            }
202            self.create_raw(zone.id, &fqdn, record_type, ttl, content)
203                .await?;
204        }
205        Ok(())
206    }
207
208    pub(crate) async fn remove_from_rrset(
209        &self,
210        name: impl IntoFqdn<'_>,
211        record_type: DnsRecordType,
212        records: Vec<DnsRecord>,
213        origin: impl IntoFqdn<'_>,
214    ) -> crate::Result<()> {
215        check_record_types(record_type, &records)?;
216        if records.is_empty() {
217            return Ok(());
218        }
219        let origin_name = origin.into_name().to_string();
220        let zone = self.find_zone(&origin_name).await?;
221        let fqdn = name.into_fqdn().to_string();
222        let to_remove = build_contents(record_type, records)?;
223        let existing = self.list_at(zone.id, &fqdn, record_type).await?;
224
225        for content in to_remove {
226            if let Some(entry) = existing.iter().find(|r| r.content == content)
227                && let Some(id) = entry.id
228            {
229                self.delete_record(zone.id, id).await?;
230            }
231        }
232        Ok(())
233    }
234
235    pub(crate) async fn list_rrset(
236        &self,
237        name: impl IntoFqdn<'_>,
238        record_type: DnsRecordType,
239        origin: impl IntoFqdn<'_>,
240    ) -> crate::Result<Vec<DnsRecord>> {
241        let origin_name = origin.into_name().to_string();
242        let zone = self.find_zone(&origin_name).await?;
243        let fqdn = name.into_fqdn().to_string();
244        let listed = self.list_at(zone.id, &fqdn, record_type).await?;
245        listed
246            .into_iter()
247            .map(|r| parse_content(record_type, &r.content))
248            .collect()
249    }
250
251    async fn create_raw(
252        &self,
253        zone_id: i64,
254        fqdn: &str,
255        record_type: DnsRecordType,
256        ttl: u32,
257        content: String,
258    ) -> crate::Result<()> {
259        let body = LuaRecord {
260            id: None,
261            name: fqdn.to_string(),
262            rr_type: record_type.as_str().to_string(),
263            content,
264            ttl,
265            zone_id: None,
266        };
267        self.client
268            .post(format!("{}/v1/zones/{zone_id}/records", self.endpoint))
269            .with_body(&body)?
270            .send_with_retry::<LuaRecord>(3)
271            .await
272            .map(|_| ())
273    }
274
275    async fn delete_record(&self, zone_id: i64, record_id: i64) -> crate::Result<()> {
276        self.client
277            .delete(format!(
278                "{}/v1/zones/{zone_id}/records/{record_id}",
279                self.endpoint
280            ))
281            .send_raw()
282            .await
283            .map(|_| ())
284    }
285}
286
287fn ensure_dot(name: String) -> String {
288    if name.ends_with('.') {
289        name
290    } else {
291        format!("{name}.")
292    }
293}
294
295fn render_content(record: DnsRecord) -> String {
296    match record {
297        DnsRecord::A(addr) => addr.to_string(),
298        DnsRecord::AAAA(addr) => addr.to_string(),
299        DnsRecord::CNAME(content) => ensure_dot(content),
300        DnsRecord::NS(content) => ensure_dot(content),
301        DnsRecord::TXT(content) => {
302            let mut out = String::with_capacity(content.len() + 4);
303            txt_chunks_to_text(&mut out, &content, " ");
304            out
305        }
306        DnsRecord::MX(mx) => format!("{} {}", mx.priority, ensure_dot(mx.exchange)),
307        DnsRecord::SRV(srv) => format!(
308            "{} {} {} {}",
309            srv.priority,
310            srv.weight,
311            srv.port,
312            ensure_dot(srv.target)
313        ),
314        DnsRecord::TLSA(tlsa) => tlsa.to_string(),
315        DnsRecord::CAA(caa) => caa.to_string(),
316    }
317}
318
319fn build_contents(
320    expected_type: DnsRecordType,
321    records: Vec<DnsRecord>,
322) -> crate::Result<Vec<String>> {
323    let mut out = Vec::with_capacity(records.len());
324    for record in records {
325        if record.as_type() != expected_type {
326            return Err(Error::Api(format!(
327                "RRSet record type mismatch: expected {}, got {}",
328                expected_type.as_str(),
329                record.as_type().as_str(),
330            )));
331        }
332        out.push(render_content(record));
333    }
334    Ok(out)
335}
336
337fn check_record_types(expected: DnsRecordType, records: &[DnsRecord]) -> crate::Result<()> {
338    for r in records {
339        if r.as_type() != expected {
340            return Err(Error::Api(format!(
341                "RRSet record type mismatch: expected {}, got {}",
342                expected.as_str(),
343                r.as_type().as_str(),
344            )));
345        }
346    }
347    Ok(())
348}
349
350fn parse_content(record_type: DnsRecordType, content: &str) -> crate::Result<DnsRecord> {
351    Ok(match record_type {
352        DnsRecordType::A => DnsRecord::A(
353            content
354                .parse()
355                .map_err(|e| Error::Parse(format!("invalid A value '{content}': {e}")))?,
356        ),
357        DnsRecordType::AAAA => DnsRecord::AAAA(
358            content
359                .parse()
360                .map_err(|e| Error::Parse(format!("invalid AAAA value '{content}': {e}")))?,
361        ),
362        DnsRecordType::CNAME => DnsRecord::CNAME(strip_trailing_dot(content)),
363        DnsRecordType::NS => DnsRecord::NS(strip_trailing_dot(content)),
364        DnsRecordType::MX => parse_mx(content)?,
365        DnsRecordType::TXT => DnsRecord::TXT(parse_txt(content)),
366        DnsRecordType::SRV => parse_srv(content)?,
367        DnsRecordType::TLSA => parse_tlsa(content)?,
368        DnsRecordType::CAA => parse_caa(content)?,
369    })
370}
371
372fn parse_mx(value: &str) -> crate::Result<DnsRecord> {
373    let mut parts = value.splitn(2, char::is_whitespace);
374    let priority = parts
375        .next()
376        .ok_or_else(|| Error::Parse(format!("invalid MX value '{value}'")))?
377        .parse()
378        .map_err(|e| Error::Parse(format!("invalid MX priority in '{value}': {e}")))?;
379    let exchange = parts
380        .next()
381        .ok_or_else(|| Error::Parse(format!("invalid MX value '{value}'")))?
382        .trim();
383    Ok(DnsRecord::MX(MXRecord {
384        priority,
385        exchange: strip_trailing_dot(exchange),
386    }))
387}
388
389fn parse_srv(value: &str) -> crate::Result<DnsRecord> {
390    let mut parts = value.split_whitespace();
391    let priority = parts
392        .next()
393        .ok_or_else(|| Error::Parse(format!("invalid SRV value '{value}'")))?
394        .parse()
395        .map_err(|e| Error::Parse(format!("invalid SRV priority in '{value}': {e}")))?;
396    let weight = parts
397        .next()
398        .ok_or_else(|| Error::Parse(format!("invalid SRV value '{value}'")))?
399        .parse()
400        .map_err(|e| Error::Parse(format!("invalid SRV weight in '{value}': {e}")))?;
401    let port = parts
402        .next()
403        .ok_or_else(|| Error::Parse(format!("invalid SRV value '{value}'")))?
404        .parse()
405        .map_err(|e| Error::Parse(format!("invalid SRV port in '{value}': {e}")))?;
406    let target = parts
407        .next()
408        .ok_or_else(|| Error::Parse(format!("invalid SRV value '{value}'")))?;
409    Ok(DnsRecord::SRV(SRVRecord {
410        priority,
411        weight,
412        port,
413        target: strip_trailing_dot(target),
414    }))
415}
416
417fn parse_txt(value: &str) -> String {
418    let trimmed = value.trim();
419    let mut out = String::with_capacity(trimmed.len());
420    let mut bytes = trimmed.bytes().peekable();
421    while let Some(&b) = bytes.peek() {
422        if b != b'"' {
423            bytes.next();
424            continue;
425        }
426        bytes.next();
427        loop {
428            match bytes.next() {
429                Some(b'"') => break,
430                Some(b'\\') => {
431                    if let Some(next) = bytes.next() {
432                        out.push(next as char);
433                    }
434                }
435                Some(other) => out.push(other as char),
436                None => break,
437            }
438        }
439    }
440    if out.is_empty() && !trimmed.is_empty() && !trimmed.starts_with('"') {
441        return trimmed.to_string();
442    }
443    out
444}
445
446fn parse_tlsa(content: &str) -> crate::Result<DnsRecord> {
447    let mut parts = content.split_whitespace();
448    let usage: u8 = parts
449        .next()
450        .ok_or_else(|| Error::Parse(format!("invalid TLSA record: {content}")))?
451        .parse()
452        .map_err(|e| Error::Parse(format!("invalid TLSA usage: {e}")))?;
453    let selector: u8 = parts
454        .next()
455        .ok_or_else(|| Error::Parse(format!("invalid TLSA record: {content}")))?
456        .parse()
457        .map_err(|e| Error::Parse(format!("invalid TLSA selector: {e}")))?;
458    let matching: u8 = parts
459        .next()
460        .ok_or_else(|| Error::Parse(format!("invalid TLSA record: {content}")))?
461        .parse()
462        .map_err(|e| Error::Parse(format!("invalid TLSA matching: {e}")))?;
463    let hex: String = parts.collect::<Vec<_>>().join("");
464    Ok(DnsRecord::TLSA(TLSARecord {
465        cert_usage: tlsa_cert_usage_from_u8(usage)?,
466        selector: tlsa_selector_from_u8(selector)?,
467        matching: tlsa_matching_from_u8(matching)?,
468        cert_data: decode_hex(&hex)?,
469    }))
470}
471
472fn parse_caa(value: &str) -> crate::Result<DnsRecord> {
473    let mut parts = value.splitn(3, char::is_whitespace);
474    let flags: u8 = parts
475        .next()
476        .ok_or_else(|| Error::Parse(format!("invalid CAA value '{value}'")))?
477        .parse()
478        .map_err(|e| Error::Parse(format!("invalid CAA flags in '{value}': {e}")))?;
479    let tag = parts
480        .next()
481        .ok_or_else(|| Error::Parse(format!("invalid CAA value '{value}'")))?
482        .to_ascii_lowercase();
483    let raw_value = parts
484        .next()
485        .ok_or_else(|| Error::Parse(format!("invalid CAA value '{value}'")))?
486        .trim();
487    let unquoted = raw_value
488        .strip_prefix('"')
489        .and_then(|s| s.strip_suffix('"'))
490        .map(|s| s.replace("\\\"", "\""))
491        .unwrap_or_else(|| raw_value.to_string());
492
493    let issuer_critical = flags & 0x80 != 0;
494    match tag.as_str() {
495        "issue" => {
496            let (name, options) = parse_caa_kv(&unquoted);
497            Ok(DnsRecord::CAA(CAARecord::Issue {
498                issuer_critical,
499                name,
500                options,
501            }))
502        }
503        "issuewild" => {
504            let (name, options) = parse_caa_kv(&unquoted);
505            Ok(DnsRecord::CAA(CAARecord::IssueWild {
506                issuer_critical,
507                name,
508                options,
509            }))
510        }
511        "iodef" => Ok(DnsRecord::CAA(CAARecord::Iodef {
512            issuer_critical,
513            url: unquoted,
514        })),
515        other => Err(Error::Parse(format!("unknown CAA tag: {other}"))),
516    }
517}
518
519fn parse_caa_kv(value: &str) -> (Option<String>, Vec<KeyValue>) {
520    let mut parts = value.split(';').map(str::trim);
521    let name_part = parts.next().unwrap_or("").trim().to_string();
522    let name = if name_part.is_empty() {
523        None
524    } else {
525        Some(name_part)
526    };
527    let options = parts
528        .filter(|p| !p.is_empty())
529        .map(|p| match p.split_once('=') {
530            Some((k, v)) => KeyValue {
531                key: k.trim().to_string(),
532                value: v.trim().to_string(),
533            },
534            None => KeyValue {
535                key: p.trim().to_string(),
536                value: String::new(),
537            },
538        })
539        .collect();
540    (name, options)
541}
542
543fn strip_trailing_dot(value: &str) -> String {
544    value.strip_suffix('.').unwrap_or(value).to_string()
545}
546
547fn decode_hex(hex: &str) -> crate::Result<Vec<u8>> {
548    if !hex.len().is_multiple_of(2) {
549        return Err(Error::Parse(format!("invalid hex string: {hex}")));
550    }
551    (0..hex.len())
552        .step_by(2)
553        .map(|i| {
554            u8::from_str_radix(&hex[i..i + 2], 16)
555                .map_err(|e| Error::Parse(format!("invalid hex byte: {e}")))
556        })
557        .collect()
558}
559
560fn tlsa_cert_usage_from_u8(value: u8) -> crate::Result<TlsaCertUsage> {
561    Ok(match value {
562        0 => TlsaCertUsage::PkixTa,
563        1 => TlsaCertUsage::PkixEe,
564        2 => TlsaCertUsage::DaneTa,
565        3 => TlsaCertUsage::DaneEe,
566        255 => TlsaCertUsage::Private,
567        _ => return Err(Error::Parse(format!("unknown TLSA cert usage: {value}"))),
568    })
569}
570
571fn tlsa_selector_from_u8(value: u8) -> crate::Result<TlsaSelector> {
572    Ok(match value {
573        0 => TlsaSelector::Full,
574        1 => TlsaSelector::Spki,
575        255 => TlsaSelector::Private,
576        _ => return Err(Error::Parse(format!("unknown TLSA selector: {value}"))),
577    })
578}
579
580fn tlsa_matching_from_u8(value: u8) -> crate::Result<TlsaMatching> {
581    Ok(match value {
582        0 => TlsaMatching::Raw,
583        1 => TlsaMatching::Sha256,
584        2 => TlsaMatching::Sha512,
585        255 => TlsaMatching::Private,
586        _ => return Err(Error::Parse(format!("unknown TLSA matching: {value}"))),
587    })
588}