Skip to main content

dns_update/providers/
porkbun.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector,
15    http::{HttpClient, HttpClientBuilder},
16    utils::strip_origin_from_name,
17};
18use serde::{Deserialize, Deserializer, Serialize};
19use std::{
20    net::{Ipv4Addr, Ipv6Addr},
21    time::Duration,
22};
23
24#[derive(Clone)]
25pub struct PorkBunProvider {
26    client: HttpClient,
27    api_key: String,
28    secret_api_key: String,
29    endpoint: String,
30}
31
32#[derive(Serialize, Debug)]
33pub struct AuthParams<'a> {
34    pub secretapikey: &'a str,
35    pub apikey: &'a str,
36}
37
38#[derive(Serialize, Debug)]
39pub struct DnsRecordParams<'a> {
40    #[serde(flatten)]
41    pub auth: AuthParams<'a>,
42    pub name: &'a str,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub ttl: Option<u32>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub notes: Option<&'a str>,
47    #[serde(flatten)]
48    content: RecordData,
49}
50
51#[derive(Deserialize, Debug)]
52pub struct ApiResponse {
53    pub status: String,
54    pub message: Option<String>,
55}
56
57#[derive(Deserialize, Debug)]
58struct RetrieveResponse {
59    status: String,
60    #[serde(default)]
61    message: Option<String>,
62    #[serde(default)]
63    records: Vec<ListedRecord>,
64}
65
66#[derive(Deserialize, Debug, Clone)]
67struct ListedRecord {
68    id: String,
69    #[serde(rename = "type")]
70    record_type: String,
71    content: String,
72    #[serde(default, deserialize_with = "deserialize_opt_u16_from_string")]
73    prio: Option<u16>,
74}
75
76#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
77#[serde(tag = "type")]
78#[allow(clippy::upper_case_acronyms)]
79pub enum RecordData {
80    A { content: Ipv4Addr },
81    MX { content: String, prio: u16 },
82    CNAME { content: String },
83    ALIAS { content: String },
84    TXT { content: String },
85    NS { content: String },
86    AAAA { content: Ipv6Addr },
87    SRV { content: String, prio: u16 },
88    TLSA { content: String },
89    CAA { content: String },
90    HTTPS { content: String },
91    SVCB { content: String },
92    SSHFP { content: String },
93}
94
95const DEFAULT_API_ENDPOINT: &str = "https://api.porkbun.com/api/json/v3";
96
97impl PorkBunProvider {
98    pub(crate) fn new(
99        api_key: impl AsRef<str>,
100        secret_api_key: impl AsRef<str>,
101        timeout: Option<Duration>,
102    ) -> Self {
103        let client = HttpClientBuilder::default().with_timeout(timeout).build();
104
105        Self {
106            client,
107            api_key: api_key.as_ref().to_string(),
108            secret_api_key: secret_api_key.as_ref().to_string(),
109            endpoint: DEFAULT_API_ENDPOINT.to_string(),
110        }
111    }
112
113    #[cfg(test)]
114    pub(crate) fn with_endpoint(self, endpoint: impl AsRef<str>) -> Self {
115        Self {
116            endpoint: endpoint.as_ref().to_string(),
117            ..self
118        }
119    }
120
121    pub(crate) async fn set_rrset(
122        &self,
123        name: impl IntoFqdn<'_>,
124        record_type: DnsRecordType,
125        ttl: u32,
126        records: Vec<DnsRecord>,
127        origin: impl IntoFqdn<'_>,
128    ) -> crate::Result<()> {
129        let name = name.into_name().into_owned();
130        let domain = origin.into_name().into_owned();
131        let subdomain = strip_origin_from_name(&name, &domain, Some(""));
132
133        if records.is_empty() {
134            return self
135                .delete_by_name_type(&domain, record_type.as_str(), &subdomain)
136                .await;
137        }
138
139        let desired = build_record_data(record_type, records)?;
140
141        let existing = self
142            .retrieve_by_name_type(&domain, record_type.as_str(), &subdomain)
143            .await?;
144
145        let mut existing_pool = existing;
146        let mut to_add: Vec<RecordData> = Vec::new();
147
148        for data in desired {
149            if let Some(idx) = existing_pool.iter().position(|r| listed_matches(r, &data)) {
150                existing_pool.swap_remove(idx);
151            } else {
152                to_add.push(data);
153            }
154        }
155
156        for entry in existing_pool {
157            self.delete_record(&domain, &entry.id).await?;
158        }
159        for data in to_add {
160            self.create_record(&domain, &subdomain, ttl, data).await?;
161        }
162        Ok(())
163    }
164
165    pub(crate) async fn add_to_rrset(
166        &self,
167        name: impl IntoFqdn<'_>,
168        record_type: DnsRecordType,
169        ttl: u32,
170        records: Vec<DnsRecord>,
171        origin: impl IntoFqdn<'_>,
172    ) -> crate::Result<()> {
173        if records.is_empty() {
174            return Ok(());
175        }
176        let name = name.into_name().into_owned();
177        let domain = origin.into_name().into_owned();
178        let subdomain = strip_origin_from_name(&name, &domain, Some(""));
179        let desired = build_record_data(record_type, records)?;
180        let existing = self
181            .retrieve_by_name_type(&domain, record_type.as_str(), &subdomain)
182            .await?;
183
184        for data in desired {
185            if existing.iter().any(|r| listed_matches(r, &data)) {
186                continue;
187            }
188            self.create_record(&domain, &subdomain, ttl, data).await?;
189        }
190        Ok(())
191    }
192
193    pub(crate) async fn remove_from_rrset(
194        &self,
195        name: impl IntoFqdn<'_>,
196        record_type: DnsRecordType,
197        records: Vec<DnsRecord>,
198        origin: impl IntoFqdn<'_>,
199    ) -> crate::Result<()> {
200        if records.is_empty() {
201            return Ok(());
202        }
203        let name = name.into_name().into_owned();
204        let domain = origin.into_name().into_owned();
205        let subdomain = strip_origin_from_name(&name, &domain, Some(""));
206        let to_remove = build_record_data(record_type, records)?;
207        let existing = self
208            .retrieve_by_name_type(&domain, record_type.as_str(), &subdomain)
209            .await?;
210
211        for data in to_remove {
212            if let Some(entry) = existing.iter().find(|r| listed_matches(r, &data)) {
213                self.delete_record(&domain, &entry.id).await?;
214            }
215        }
216        Ok(())
217    }
218
219    pub(crate) async fn list_rrset(
220        &self,
221        name: impl IntoFqdn<'_>,
222        record_type: DnsRecordType,
223        origin: impl IntoFqdn<'_>,
224    ) -> crate::Result<Vec<DnsRecord>> {
225        let name = name.into_name().into_owned();
226        let domain = origin.into_name().into_owned();
227        let subdomain = strip_origin_from_name(&name, &domain, Some(""));
228        let listed = self
229            .retrieve_by_name_type(&domain, record_type.as_str(), &subdomain)
230            .await?;
231        listed
232            .into_iter()
233            .map(|r| listed_to_dns_record(r, record_type))
234            .collect()
235    }
236
237    fn auth(&self) -> AuthParams<'_> {
238        AuthParams {
239            secretapikey: &self.secret_api_key,
240            apikey: &self.api_key,
241        }
242    }
243
244    async fn retrieve_by_name_type(
245        &self,
246        domain: &str,
247        record_type: &str,
248        subdomain: &str,
249    ) -> crate::Result<Vec<ListedRecord>> {
250        let url = retrieve_by_name_type_url(&self.endpoint, domain, record_type, subdomain);
251        let response: RetrieveResponse = self
252            .client
253            .post(url)
254            .with_body(self.auth())?
255            .send_with_retry(3)
256            .await?;
257        if response.status == "SUCCESS" {
258            Ok(response
259                .records
260                .into_iter()
261                .filter(|r| r.record_type.eq_ignore_ascii_case(record_type))
262                .collect())
263        } else {
264            Err(Error::Api(response.status_message()))
265        }
266    }
267
268    async fn create_record(
269        &self,
270        domain: &str,
271        subdomain: &str,
272        ttl: u32,
273        content: RecordData,
274    ) -> crate::Result<()> {
275        self.client
276            .post(format!(
277                "{endpoint}/dns/create/{domain}",
278                endpoint = self.endpoint,
279            ))
280            .with_body(DnsRecordParams {
281                auth: self.auth(),
282                name: subdomain,
283                ttl: Some(ttl),
284                notes: None,
285                content,
286            })?
287            .send_with_retry::<ApiResponse>(3)
288            .await?
289            .into_result()
290    }
291
292    async fn delete_record(&self, domain: &str, record_id: &str) -> crate::Result<()> {
293        self.client
294            .post(format!(
295                "{endpoint}/dns/delete/{domain}/{record_id}",
296                endpoint = self.endpoint,
297            ))
298            .with_body(self.auth())?
299            .send_with_retry::<ApiResponse>(3)
300            .await?
301            .into_result()
302    }
303
304    async fn delete_by_name_type(
305        &self,
306        domain: &str,
307        record_type: &str,
308        subdomain: &str,
309    ) -> crate::Result<()> {
310        self.client
311            .post(delete_by_name_type_url(
312                &self.endpoint,
313                domain,
314                record_type,
315                subdomain,
316            ))
317            .with_body(self.auth())?
318            .send_with_retry::<ApiResponse>(3)
319            .await?
320            .into_result()
321    }
322
323}
324
325fn retrieve_by_name_type_url(
326    endpoint: &str,
327    domain: &str,
328    record_type: &str,
329    subdomain: &str,
330) -> String {
331    if subdomain.is_empty() {
332        format!("{endpoint}/dns/retrieveByNameType/{domain}/{record_type}")
333    } else {
334        format!("{endpoint}/dns/retrieveByNameType/{domain}/{record_type}/{subdomain}")
335    }
336}
337
338fn delete_by_name_type_url(
339    endpoint: &str,
340    domain: &str,
341    record_type: &str,
342    subdomain: &str,
343) -> String {
344    if subdomain.is_empty() {
345        format!("{endpoint}/dns/deleteByNameType/{domain}/{record_type}")
346    } else {
347        format!("{endpoint}/dns/deleteByNameType/{domain}/{record_type}/{subdomain}")
348    }
349}
350
351fn build_record_data(
352    expected_type: DnsRecordType,
353    records: Vec<DnsRecord>,
354) -> crate::Result<Vec<RecordData>> {
355    let mut out = Vec::with_capacity(records.len());
356    for record in records {
357        if record.as_type() != expected_type {
358            return Err(Error::Api(format!(
359                "RRSet record type mismatch: expected {}, got {}",
360                expected_type.as_str(),
361                record.as_type().as_str(),
362            )));
363        }
364        out.push(record.into());
365    }
366    Ok(out)
367}
368
369fn listed_matches(listed: &ListedRecord, data: &RecordData) -> bool {
370    if !listed.record_type.eq_ignore_ascii_case(data.variant_name()) {
371        return false;
372    }
373    let (expected_content, expected_prio) = data.as_content_prio();
374    if listed.content.trim_end_matches('.') != expected_content.trim_end_matches('.') {
375        return false;
376    }
377    match expected_prio {
378        Some(p) => listed.prio == Some(p),
379        None => true,
380    }
381}
382
383fn listed_to_dns_record(
384    listed: ListedRecord,
385    record_type: DnsRecordType,
386) -> crate::Result<DnsRecord> {
387    let content = listed.content;
388    let prio = listed.prio;
389    Ok(match record_type {
390        DnsRecordType::A => DnsRecord::A(
391            content
392                .parse()
393                .map_err(|e| Error::Parse(format!("invalid A record content {content}: {e}")))?,
394        ),
395        DnsRecordType::AAAA => DnsRecord::AAAA(
396            content
397                .parse()
398                .map_err(|e| Error::Parse(format!("invalid AAAA record content {content}: {e}")))?,
399        ),
400        DnsRecordType::CNAME => DnsRecord::CNAME(content.trim_end_matches('.').to_string()),
401        DnsRecordType::NS => DnsRecord::NS(content.trim_end_matches('.').to_string()),
402        DnsRecordType::MX => DnsRecord::MX(MXRecord {
403            exchange: content.trim_end_matches('.').to_string(),
404            priority: prio.unwrap_or(0),
405        }),
406        DnsRecordType::TXT => DnsRecord::TXT(content),
407        DnsRecordType::SRV => DnsRecord::SRV(parse_srv_content(&content, prio.unwrap_or(0))?),
408        DnsRecordType::TLSA => DnsRecord::TLSA(parse_tlsa_content(&content)?),
409        DnsRecordType::CAA => DnsRecord::CAA(parse_caa_content(&content)?),
410    })
411}
412
413fn parse_srv_content(content: &str, priority: u16) -> crate::Result<SRVRecord> {
414    let parts: Vec<&str> = content.split_whitespace().collect();
415    if parts.len() != 3 {
416        return Err(Error::Parse(format!(
417            "invalid SRV content {content}: expected 'weight port target'"
418        )));
419    }
420    let weight: u16 = parts[0]
421        .parse()
422        .map_err(|e| Error::Parse(format!("invalid SRV weight {}: {e}", parts[0])))?;
423    let port: u16 = parts[1]
424        .parse()
425        .map_err(|e| Error::Parse(format!("invalid SRV port {}: {e}", parts[1])))?;
426    Ok(SRVRecord {
427        priority,
428        weight,
429        port,
430        target: parts[2].trim_end_matches('.').to_string(),
431    })
432}
433
434fn parse_tlsa_content(content: &str) -> crate::Result<TLSARecord> {
435    let parts: Vec<&str> = content.split_whitespace().collect();
436    if parts.len() != 4 {
437        return Err(Error::Parse(format!(
438            "invalid TLSA content {content}: expected 'usage selector matching hex'"
439        )));
440    }
441    let usage: u8 = parts[0]
442        .parse()
443        .map_err(|e| Error::Parse(format!("invalid TLSA usage: {e}")))?;
444    let selector: u8 = parts[1]
445        .parse()
446        .map_err(|e| Error::Parse(format!("invalid TLSA selector: {e}")))?;
447    let matching: u8 = parts[2]
448        .parse()
449        .map_err(|e| Error::Parse(format!("invalid TLSA matching: {e}")))?;
450    Ok(TLSARecord {
451        cert_usage: tlsa_cert_usage_from_u8(usage)?,
452        selector: tlsa_selector_from_u8(selector)?,
453        matching: tlsa_matching_from_u8(matching)?,
454        cert_data: decode_hex(parts[3])?,
455    })
456}
457
458fn parse_caa_content(content: &str) -> crate::Result<CAARecord> {
459    let trimmed = content.trim();
460    let (flags_str, rest) = trimmed
461        .split_once(char::is_whitespace)
462        .ok_or_else(|| Error::Parse(format!("invalid CAA content {content}: missing tag")))?;
463    let (tag, raw_value) = rest
464        .trim_start()
465        .split_once(char::is_whitespace)
466        .ok_or_else(|| Error::Parse(format!("invalid CAA content {content}: missing value")))?;
467    let flags: u8 = flags_str
468        .parse()
469        .map_err(|e| Error::Parse(format!("invalid CAA flags {flags_str}: {e}")))?;
470    let value = raw_value.trim().trim_matches('"').to_string();
471    build_caa(flags, tag.to_string(), value)
472}
473
474fn build_caa(flags: u8, tag: String, value: String) -> crate::Result<CAARecord> {
475    let issuer_critical = flags & 0x80 != 0;
476    match tag.as_str() {
477        "issue" => {
478            let (name, options) = parse_caa_value(&value);
479            Ok(CAARecord::Issue {
480                issuer_critical,
481                name,
482                options,
483            })
484        }
485        "issuewild" => {
486            let (name, options) = parse_caa_value(&value);
487            Ok(CAARecord::IssueWild {
488                issuer_critical,
489                name,
490                options,
491            })
492        }
493        "iodef" => Ok(CAARecord::Iodef {
494            issuer_critical,
495            url: value,
496        }),
497        other => Err(Error::Parse(format!("unknown CAA tag: {other}"))),
498    }
499}
500
501fn parse_caa_value(value: &str) -> (Option<String>, Vec<KeyValue>) {
502    let mut parts = value.split(';').map(str::trim);
503    let name_part = parts.next().unwrap_or("").trim().to_string();
504    let name = if name_part.is_empty() {
505        None
506    } else {
507        Some(name_part)
508    };
509    let options = parts
510        .filter(|p| !p.is_empty())
511        .map(|p| match p.split_once('=') {
512            Some((k, v)) => KeyValue {
513                key: k.trim().to_string(),
514                value: v.trim().to_string(),
515            },
516            None => KeyValue {
517                key: p.trim().to_string(),
518                value: String::new(),
519            },
520        })
521        .collect();
522    (name, options)
523}
524
525fn decode_hex(hex: &str) -> crate::Result<Vec<u8>> {
526    if !hex.len().is_multiple_of(2) {
527        return Err(Error::Parse(format!("invalid hex string: {hex}")));
528    }
529    (0..hex.len())
530        .step_by(2)
531        .map(|i| {
532            u8::from_str_radix(&hex[i..i + 2], 16)
533                .map_err(|e| Error::Parse(format!("invalid hex byte: {e}")))
534        })
535        .collect()
536}
537
538fn tlsa_cert_usage_from_u8(value: u8) -> crate::Result<TlsaCertUsage> {
539    Ok(match value {
540        0 => TlsaCertUsage::PkixTa,
541        1 => TlsaCertUsage::PkixEe,
542        2 => TlsaCertUsage::DaneTa,
543        3 => TlsaCertUsage::DaneEe,
544        255 => TlsaCertUsage::Private,
545        _ => return Err(Error::Parse(format!("unknown TLSA cert usage: {value}"))),
546    })
547}
548
549fn tlsa_selector_from_u8(value: u8) -> crate::Result<TlsaSelector> {
550    Ok(match value {
551        0 => TlsaSelector::Full,
552        1 => TlsaSelector::Spki,
553        255 => TlsaSelector::Private,
554        _ => return Err(Error::Parse(format!("unknown TLSA selector: {value}"))),
555    })
556}
557
558fn tlsa_matching_from_u8(value: u8) -> crate::Result<TlsaMatching> {
559    Ok(match value {
560        0 => TlsaMatching::Raw,
561        1 => TlsaMatching::Sha256,
562        2 => TlsaMatching::Sha512,
563        255 => TlsaMatching::Private,
564        _ => return Err(Error::Parse(format!("unknown TLSA matching: {value}"))),
565    })
566}
567
568fn deserialize_opt_u16_from_string<'de, D>(deserializer: D) -> Result<Option<u16>, D::Error>
569where
570    D: Deserializer<'de>,
571{
572    #[derive(Deserialize)]
573    #[serde(untagged)]
574    enum Either {
575        Str(String),
576        Num(u16),
577        None,
578    }
579    match Option::<Either>::deserialize(deserializer)? {
580        None | Some(Either::None) => Ok(None),
581        Some(Either::Num(n)) => Ok(Some(n)),
582        Some(Either::Str(s)) => {
583            if s.is_empty() {
584                Ok(None)
585            } else {
586                s.parse::<u16>().map(Some).map_err(serde::de::Error::custom)
587            }
588        }
589    }
590}
591
592impl ApiResponse {
593    fn into_result(self) -> crate::Result<()> {
594        if self.status == "SUCCESS" {
595            Ok(())
596        } else {
597            Err(Error::Api(self.message.unwrap_or(self.status)))
598        }
599    }
600}
601
602impl RetrieveResponse {
603    fn status_message(self) -> String {
604        self.message.unwrap_or(self.status)
605    }
606}
607
608impl RecordData {
609    pub fn variant_name(&self) -> &'static str {
610        match self {
611            RecordData::A { .. } => "A",
612            RecordData::MX { .. } => "MX",
613            RecordData::CNAME { .. } => "CNAME",
614            RecordData::ALIAS { .. } => "ALIAS",
615            RecordData::TXT { .. } => "TXT",
616            RecordData::NS { .. } => "NS",
617            RecordData::AAAA { .. } => "AAAA",
618            RecordData::SRV { .. } => "SRV",
619            RecordData::TLSA { .. } => "TLSA",
620            RecordData::CAA { .. } => "CAA",
621            RecordData::HTTPS { .. } => "HTTPS",
622            RecordData::SVCB { .. } => "SVCB",
623            RecordData::SSHFP { .. } => "SSHFP",
624        }
625    }
626
627    fn as_content_prio(&self) -> (String, Option<u16>) {
628        match self {
629            RecordData::A { content } => (content.to_string(), None),
630            RecordData::AAAA { content } => (content.to_string(), None),
631            RecordData::CNAME { content }
632            | RecordData::ALIAS { content }
633            | RecordData::NS { content }
634            | RecordData::TXT { content }
635            | RecordData::TLSA { content }
636            | RecordData::CAA { content }
637            | RecordData::HTTPS { content }
638            | RecordData::SVCB { content }
639            | RecordData::SSHFP { content } => (content.clone(), None),
640            RecordData::MX { content, prio } => (content.clone(), Some(*prio)),
641            RecordData::SRV { content, prio } => (content.clone(), Some(*prio)),
642        }
643    }
644
645}
646
647fn strip_trailing_dot(value: String) -> String {
648    if value.ends_with('.') {
649        value.trim_end_matches('.').to_string()
650    } else {
651        value
652    }
653}
654
655impl From<DnsRecord> for RecordData {
656    fn from(record: DnsRecord) -> Self {
657        match record {
658            DnsRecord::A(content) => RecordData::A { content },
659            DnsRecord::AAAA(content) => RecordData::AAAA { content },
660            DnsRecord::CNAME(content) => RecordData::CNAME {
661                content: strip_trailing_dot(content),
662            },
663            DnsRecord::NS(content) => RecordData::NS {
664                content: strip_trailing_dot(content),
665            },
666            DnsRecord::MX(mx) => RecordData::MX {
667                content: strip_trailing_dot(mx.exchange),
668                prio: mx.priority,
669            },
670            DnsRecord::TXT(content) => RecordData::TXT { content },
671            DnsRecord::SRV(srv) => RecordData::SRV {
672                content: format!(
673                    "{} {} {}",
674                    srv.weight,
675                    srv.port,
676                    strip_trailing_dot(srv.target)
677                ),
678                prio: srv.priority,
679            },
680            DnsRecord::TLSA(tlsa) => RecordData::TLSA {
681                content: tlsa.to_string(),
682            },
683            DnsRecord::CAA(caa) => RecordData::CAA {
684                content: caa.to_string(),
685            },
686        }
687    }
688}