Skip to main content

dns_update/providers/
cloudflare.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{
13    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14    TLSARecord, TlsaCertUsage, TlsaMatching, TlsaSelector,
15    http::{HttpClient, HttpClientBuilder},
16    utils::txt_chunks_to_text,
17};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::{
21    borrow::Cow,
22    net::{Ipv4Addr, Ipv6Addr},
23    time::Duration,
24};
25
26#[derive(Clone)]
27pub struct CloudflareProvider {
28    client: HttpClient,
29    endpoint: Cow<'static, str>,
30}
31
32#[derive(Deserialize, Debug)]
33pub struct IdMap {
34    pub id: String,
35    pub name: String,
36}
37
38#[derive(Serialize, Debug)]
39pub struct Query {
40    name: String,
41    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
42    record_type: Option<&'static str>,
43    #[serde(rename = "match", skip_serializing_if = "Option::is_none")]
44    match_mode: Option<&'static str>,
45}
46
47#[derive(Serialize, Clone, Debug)]
48pub struct CreateDnsRecordParams<'a> {
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub ttl: Option<u32>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub priority: Option<u16>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub proxied: Option<bool>,
55    pub name: &'a str,
56    #[serde(flatten)]
57    pub content: DnsContent,
58}
59
60#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
61#[serde(tag = "type")]
62#[allow(clippy::upper_case_acronyms)]
63pub enum DnsContent {
64    A { content: Ipv4Addr },
65    AAAA { content: Ipv6Addr },
66    CNAME { content: String },
67    NS { content: String },
68    MX { content: String, priority: u16 },
69    TXT { content: String },
70    SRV { data: SrvData },
71    TLSA { data: TlsaData },
72    CAA { data: CaaData },
73}
74
75#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
76pub struct SrvData {
77    pub priority: u16,
78    pub weight: u16,
79    pub port: u16,
80    pub target: String,
81}
82
83#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
84pub struct TlsaData {
85    pub usage: u8,
86    pub selector: u8,
87    pub matching_type: u8,
88    pub certificate: String,
89}
90
91#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
92pub struct CaaData {
93    pub flags: u8,
94    pub tag: String,
95    pub value: String,
96}
97
98#[derive(Deserialize, Debug, Clone)]
99struct ListedRecord {
100    id: String,
101    #[serde(flatten)]
102    content: DnsContent,
103}
104
105#[derive(Deserialize, Serialize, Debug)]
106struct ApiResult<T> {
107    errors: Vec<ApiError>,
108    success: bool,
109    result: T,
110}
111
112const DEFAULT_API_ENDPOINT: &str = "https://api.cloudflare.com/client/v4";
113
114#[derive(Deserialize, Serialize, Debug)]
115pub struct ApiError {
116    pub code: u16,
117    pub message: String,
118}
119
120impl CloudflareProvider {
121    pub(crate) fn new(secret: impl AsRef<str>, timeout: Option<Duration>) -> crate::Result<Self> {
122        let client = HttpClientBuilder::default()
123            .with_header("Authorization", format!("Bearer {}", secret.as_ref()))
124            .with_timeout(timeout)
125            .build();
126
127        Ok(Self {
128            client,
129            endpoint: Cow::Borrowed(DEFAULT_API_ENDPOINT),
130        })
131    }
132
133    #[cfg(test)]
134    pub(crate) fn with_endpoint(self, endpoint: impl Into<Cow<'static, str>>) -> Self {
135        Self {
136            endpoint: endpoint.into(),
137            ..self
138        }
139    }
140
141    async fn obtain_zone_id(&self, origin: impl IntoFqdn<'_>) -> crate::Result<String> {
142        let origin = origin.into_name();
143        let mut candidate: &str = origin.as_ref();
144        loop {
145            let zones = self
146                .client
147                .get(format!(
148                    "{}/zones?{}",
149                    self.endpoint,
150                    Query::name(candidate).serialize()
151                ))
152                .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
153                .await
154                .and_then(|r| r.unwrap_response("list zones"))?;
155            if let Some(zone) = zones.into_iter().find(|zone| zone.name == candidate) {
156                return Ok(zone.id);
157            }
158            match candidate.split_once('.') {
159                Some((_, rest)) if rest.contains('.') => candidate = rest,
160                _ => {
161                    return Err(Error::Api(format!(
162                        "No Cloudflare zone found for {}",
163                        origin.as_ref()
164                    )));
165                }
166            }
167        }
168    }
169
170    pub(crate) async fn set_rrset(
171        &self,
172        name: impl IntoFqdn<'_>,
173        record_type: DnsRecordType,
174        ttl: u32,
175        records: Vec<DnsRecord>,
176        origin: impl IntoFqdn<'_>,
177    ) -> crate::Result<()> {
178        let zone_id = self.obtain_zone_id(origin).await?;
179        let name = name.into_name().into_owned();
180        let desired = build_contents(record_type, records)?;
181        let existing = self.list_at(&zone_id, &name, record_type).await?;
182
183        let mut to_add = Vec::new();
184        let mut existing_unmatched: Vec<ListedRecord> = Vec::new();
185        let mut existing_iter = existing.into_iter();
186        let mut existing_pool: Vec<ListedRecord> = existing_iter.by_ref().collect();
187
188        for content in desired {
189            if let Some(idx) = existing_pool.iter().position(|r| r.content == content) {
190                existing_pool.swap_remove(idx);
191            } else {
192                to_add.push(content);
193            }
194        }
195        existing_unmatched.append(&mut existing_pool);
196
197        for entry in existing_unmatched {
198            self.delete_record(&zone_id, &entry.id).await?;
199        }
200        for content in to_add {
201            self.create_record(&zone_id, &name, ttl, content).await?;
202        }
203        Ok(())
204    }
205
206    pub(crate) async fn add_to_rrset(
207        &self,
208        name: impl IntoFqdn<'_>,
209        record_type: DnsRecordType,
210        ttl: u32,
211        records: Vec<DnsRecord>,
212        origin: impl IntoFqdn<'_>,
213    ) -> crate::Result<()> {
214        if records.is_empty() {
215            return Ok(());
216        }
217        let zone_id = self.obtain_zone_id(origin).await?;
218        let name = name.into_name().into_owned();
219        let desired = build_contents(record_type, records)?;
220        let existing = self.list_at(&zone_id, &name, record_type).await?;
221
222        for content in desired {
223            if existing.iter().any(|r| r.content == content) {
224                continue;
225            }
226            self.create_record(&zone_id, &name, ttl, content).await?;
227        }
228        Ok(())
229    }
230
231    pub(crate) async fn remove_from_rrset(
232        &self,
233        name: impl IntoFqdn<'_>,
234        record_type: DnsRecordType,
235        records: Vec<DnsRecord>,
236        origin: impl IntoFqdn<'_>,
237    ) -> crate::Result<()> {
238        if records.is_empty() {
239            return Ok(());
240        }
241        let zone_id = self.obtain_zone_id(origin).await?;
242        let name = name.into_name().into_owned();
243        let to_remove = build_contents(record_type, records)?;
244        let existing = self.list_at(&zone_id, &name, record_type).await?;
245
246        for content in to_remove {
247            if let Some(entry) = existing.iter().find(|r| r.content == content) {
248                self.delete_record(&zone_id, &entry.id).await?;
249            }
250        }
251        Ok(())
252    }
253
254    pub(crate) async fn list_rrset(
255        &self,
256        name: impl IntoFqdn<'_>,
257        record_type: DnsRecordType,
258        origin: impl IntoFqdn<'_>,
259    ) -> crate::Result<Vec<DnsRecord>> {
260        let zone_id = self.obtain_zone_id(origin).await?;
261        let name = name.into_name().into_owned();
262        let listed = self.list_at(&zone_id, &name, record_type).await?;
263        listed.into_iter().map(|r| r.content.try_into()).collect()
264    }
265
266    #[cfg(test)]
267    pub(crate) async fn list_contents_for_tests(
268        &self,
269        name: impl IntoFqdn<'_>,
270        record_type: DnsRecordType,
271        origin: impl IntoFqdn<'_>,
272    ) -> crate::Result<Vec<DnsContent>> {
273        let zone_id = self.obtain_zone_id(origin).await?;
274        let name = name.into_name().into_owned();
275        let listed = self.list_at(&zone_id, &name, record_type).await?;
276        Ok(listed.into_iter().map(|r| r.content).collect())
277    }
278
279    async fn list_at(
280        &self,
281        zone_id: &str,
282        name: &str,
283        record_type: DnsRecordType,
284    ) -> crate::Result<Vec<ListedRecord>> {
285        let url = format!(
286            "{}/zones/{zone_id}/dns_records?{}&per_page=100",
287            self.endpoint,
288            Query::name_and_type(name, record_type).serialize()
289        );
290        let response: ApiResult<Vec<ListedRecord>> =
291            self.client.get(url).send_with_retry(3).await?;
292        response.unwrap_response("list DNS records")
293    }
294
295    async fn create_record(
296        &self,
297        zone_id: &str,
298        name: &str,
299        ttl: u32,
300        content: DnsContent,
301    ) -> crate::Result<()> {
302        let priority = match &content {
303            DnsContent::MX { priority, .. } => Some(*priority),
304            _ => None,
305        };
306        self.client
307            .post(format!("{}/zones/{zone_id}/dns_records", self.endpoint))
308            .with_body(CreateDnsRecordParams {
309                ttl: Some(ttl),
310                priority,
311                proxied: Some(false),
312                name,
313                content,
314            })?
315            .send_with_retry::<ApiResult<Value>>(3)
316            .await
317            .map(|_| ())
318    }
319
320    async fn delete_record(&self, zone_id: &str, record_id: &str) -> crate::Result<()> {
321        self.client
322            .delete(format!(
323                "{}/zones/{zone_id}/dns_records/{record_id}",
324                self.endpoint
325            ))
326            .send_with_retry::<ApiResult<Value>>(3)
327            .await
328            .map(|_| ())
329    }
330}
331
332fn build_contents(
333    expected_type: DnsRecordType,
334    records: Vec<DnsRecord>,
335) -> crate::Result<Vec<DnsContent>> {
336    let mut out = Vec::with_capacity(records.len());
337    for record in records {
338        if record.as_type() != expected_type {
339            return Err(Error::Api(format!(
340                "RRSet record type mismatch: expected {}, got {}",
341                expected_type.as_str(),
342                record.as_type().as_str(),
343            )));
344        }
345        out.push(record.into());
346    }
347    Ok(out)
348}
349
350impl<T> ApiResult<T> {
351    fn unwrap_response(self, action_name: &str) -> crate::Result<T> {
352        if self.success {
353            Ok(self.result)
354        } else {
355            Err(Error::Api(format!(
356                "Failed to {action_name}: {:?}",
357                self.errors
358            )))
359        }
360    }
361}
362
363impl Query {
364    pub fn name(name: impl Into<String>) -> Self {
365        Self {
366            name: name.into(),
367            record_type: None,
368            match_mode: None,
369        }
370    }
371
372    pub fn name_and_type(name: impl Into<String>, record_type: DnsRecordType) -> Self {
373        Self {
374            name: name.into(),
375            record_type: Some(record_type.as_str()),
376            match_mode: Some("all"),
377        }
378    }
379
380    pub fn serialize(&self) -> String {
381        serde_urlencoded::to_string(self).unwrap()
382    }
383}
384
385impl From<DnsRecord> for DnsContent {
386    fn from(record: DnsRecord) -> Self {
387        match record {
388            DnsRecord::A(content) => DnsContent::A { content },
389            DnsRecord::AAAA(content) => DnsContent::AAAA { content },
390            DnsRecord::CNAME(content) => DnsContent::CNAME { content },
391            DnsRecord::NS(content) => DnsContent::NS { content },
392            DnsRecord::MX(mx) => DnsContent::MX {
393                content: mx.exchange,
394                priority: mx.priority,
395            },
396            DnsRecord::TXT(content) => {
397                let mut out = String::with_capacity(content.len() + 4);
398                txt_chunks_to_text(&mut out, &content, " ");
399                DnsContent::TXT { content: out }
400            }
401            DnsRecord::SRV(srv) => DnsContent::SRV {
402                data: SrvData {
403                    priority: srv.priority,
404                    weight: srv.weight,
405                    port: srv.port,
406                    target: srv.target,
407                },
408            },
409            DnsRecord::TLSA(tlsa) => DnsContent::TLSA {
410                data: TlsaData {
411                    usage: u8::from(tlsa.cert_usage),
412                    selector: u8::from(tlsa.selector),
413                    matching_type: u8::from(tlsa.matching),
414                    certificate: tlsa.cert_data.iter().map(|b| format!("{b:02x}")).collect(),
415                },
416            },
417            DnsRecord::CAA(caa) => {
418                let (flags, tag, value) = caa.decompose();
419                DnsContent::CAA {
420                    data: CaaData { flags, tag, value },
421                }
422            }
423        }
424    }
425}
426
427impl TryFrom<DnsContent> for DnsRecord {
428    type Error = Error;
429
430    fn try_from(content: DnsContent) -> crate::Result<Self> {
431        Ok(match content {
432            DnsContent::A { content } => DnsRecord::A(content),
433            DnsContent::AAAA { content } => DnsRecord::AAAA(content),
434            DnsContent::CNAME { content } => DnsRecord::CNAME(content),
435            DnsContent::NS { content } => DnsRecord::NS(content),
436            DnsContent::MX { content, priority } => DnsRecord::MX(MXRecord {
437                exchange: content,
438                priority,
439            }),
440            DnsContent::TXT { content } => DnsRecord::TXT(unquote_txt(&content)),
441            DnsContent::SRV { data } => DnsRecord::SRV(SRVRecord {
442                priority: data.priority,
443                weight: data.weight,
444                port: data.port,
445                target: data.target,
446            }),
447            DnsContent::TLSA { data } => DnsRecord::TLSA(TLSARecord {
448                cert_usage: tlsa_cert_usage_from_u8(data.usage)?,
449                selector: tlsa_selector_from_u8(data.selector)?,
450                matching: tlsa_matching_from_u8(data.matching_type)?,
451                cert_data: decode_hex(&data.certificate)?,
452            }),
453            DnsContent::CAA { data } => DnsRecord::CAA(build_caa(data)?),
454        })
455    }
456}
457
458fn unquote_txt(content: &str) -> String {
459    let trimmed = content.trim();
460    if !trimmed.starts_with('"') {
461        return trimmed.to_string();
462    }
463    let mut out = String::with_capacity(trimmed.len());
464    let mut bytes = trimmed.as_bytes().iter().copied().peekable();
465    while let Some(b) = bytes.peek().copied() {
466        if b != b'"' {
467            bytes.next();
468            continue;
469        }
470        bytes.next();
471        loop {
472            match bytes.next() {
473                Some(b'"') => break,
474                Some(b'\\') => {
475                    if let Some(next) = bytes.next() {
476                        out.push(next as char);
477                    }
478                }
479                Some(other) => out.push(other as char),
480                None => break,
481            }
482        }
483    }
484    out
485}
486
487fn decode_hex(hex: &str) -> crate::Result<Vec<u8>> {
488    if !hex.len().is_multiple_of(2) {
489        return Err(Error::Parse(format!("invalid hex string: {hex}")));
490    }
491    (0..hex.len())
492        .step_by(2)
493        .map(|i| {
494            u8::from_str_radix(&hex[i..i + 2], 16)
495                .map_err(|e| Error::Parse(format!("invalid hex byte: {e}")))
496        })
497        .collect()
498}
499
500fn tlsa_cert_usage_from_u8(value: u8) -> crate::Result<TlsaCertUsage> {
501    Ok(match value {
502        0 => TlsaCertUsage::PkixTa,
503        1 => TlsaCertUsage::PkixEe,
504        2 => TlsaCertUsage::DaneTa,
505        3 => TlsaCertUsage::DaneEe,
506        255 => TlsaCertUsage::Private,
507        _ => return Err(Error::Parse(format!("unknown TLSA cert usage: {value}"))),
508    })
509}
510
511fn tlsa_selector_from_u8(value: u8) -> crate::Result<TlsaSelector> {
512    Ok(match value {
513        0 => TlsaSelector::Full,
514        1 => TlsaSelector::Spki,
515        255 => TlsaSelector::Private,
516        _ => return Err(Error::Parse(format!("unknown TLSA selector: {value}"))),
517    })
518}
519
520fn tlsa_matching_from_u8(value: u8) -> crate::Result<TlsaMatching> {
521    Ok(match value {
522        0 => TlsaMatching::Raw,
523        1 => TlsaMatching::Sha256,
524        2 => TlsaMatching::Sha512,
525        255 => TlsaMatching::Private,
526        _ => return Err(Error::Parse(format!("unknown TLSA matching: {value}"))),
527    })
528}
529
530fn build_caa(data: CaaData) -> crate::Result<CAARecord> {
531    let issuer_critical = data.flags & 0x80 != 0;
532    match data.tag.as_str() {
533        "issue" => {
534            let (name, options) = parse_caa_value(&data.value);
535            Ok(CAARecord::Issue {
536                issuer_critical,
537                name,
538                options,
539            })
540        }
541        "issuewild" => {
542            let (name, options) = parse_caa_value(&data.value);
543            Ok(CAARecord::IssueWild {
544                issuer_critical,
545                name,
546                options,
547            })
548        }
549        "iodef" => Ok(CAARecord::Iodef {
550            issuer_critical,
551            url: data.value,
552        }),
553        other => Err(Error::Parse(format!("unknown CAA tag: {other}"))),
554    }
555}
556
557fn parse_caa_value(value: &str) -> (Option<String>, Vec<KeyValue>) {
558    let mut parts = value.split(';').map(str::trim);
559    let name_part = parts.next().unwrap_or("").trim().to_string();
560    let name = if name_part.is_empty() {
561        None
562    } else {
563        Some(name_part)
564    };
565    let options = parts
566        .filter(|p| !p.is_empty())
567        .map(|p| match p.split_once('=') {
568            Some((k, v)) => KeyValue {
569                key: k.trim().to_string(),
570                value: v.trim().to_string(),
571            },
572            None => KeyValue {
573                key: p.trim().to_string(),
574                value: String::new(),
575            },
576        })
577        .collect();
578    (name, options)
579}