Skip to main content

dns_update/providers/
cloudflare.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{DnsRecord, Error, IntoFqdn, http::HttpClientBuilder};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::{
16    net::{Ipv4Addr, Ipv6Addr},
17    time::Duration,
18};
19
20#[derive(Clone)]
21pub struct CloudflareProvider {
22    client: HttpClientBuilder,
23}
24
25#[derive(Deserialize, Debug)]
26pub struct IdMap {
27    pub id: String,
28    pub name: String,
29}
30
31#[derive(Serialize, Debug)]
32pub struct Query {
33    name: String,
34}
35
36#[derive(Serialize, Clone, Debug)]
37pub struct CreateDnsRecordParams<'a> {
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub ttl: Option<u32>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub priority: Option<u16>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub proxied: Option<bool>,
44    pub name: &'a str,
45    #[serde(flatten)]
46    pub content: DnsContent,
47}
48
49#[derive(Serialize, Clone, Debug)]
50pub struct UpdateDnsRecordParams<'a> {
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub ttl: Option<u32>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub proxied: Option<bool>,
55    pub name: &'a str,
56    #[serde(flatten)]
57    pub content: DnsContent,
58}
59
60#[derive(Deserialize, Serialize, Clone, Debug)]
61#[serde(tag = "type")]
62#[allow(clippy::upper_case_acronyms)]
63pub enum DnsContent {
64    A { content: Ipv4Addr },
65    AAAA { content: Ipv6Addr },
66    CNAME { content: String },
67    NS { content: String },
68    MX { content: String, priority: u16 },
69    TXT { content: String },
70    SRV { data: SrvData },
71    TLSA { data: TlsaData },
72    CAA { content: String },
73}
74
75#[derive(Deserialize, Serialize, Clone, Debug)]
76pub struct SrvData {
77    pub priority: u16,
78    pub weight: u16,
79    pub port: u16,
80    pub target: String,
81}
82
83#[derive(Deserialize, Serialize, Clone, Debug)]
84pub struct TlsaData {
85    pub usage: u8,
86    pub selector: u8,
87    pub matching_type: u8,
88    pub certificate: String,
89}
90
91#[derive(Deserialize, Serialize, Debug)]
92struct ApiResult<T> {
93    errors: Vec<ApiError>,
94    success: bool,
95    result: T,
96}
97
98#[derive(Deserialize, Serialize, Debug)]
99pub struct ApiError {
100    pub code: u16,
101    pub message: String,
102}
103
104impl CloudflareProvider {
105    pub(crate) fn new(
106        secret: impl AsRef<str>,
107        email: Option<impl AsRef<str>>,
108        timeout: Option<Duration>,
109    ) -> crate::Result<Self> {
110        let client = if let Some(email) = email {
111            HttpClientBuilder::default()
112                .with_header("X-Auth-Email", email.as_ref())
113                .with_header("X-Auth-Key", secret.as_ref())
114        } else {
115            HttpClientBuilder::default()
116                .with_header("Authorization", format!("Bearer {}", secret.as_ref()))
117        }
118        .with_timeout(timeout);
119
120        Ok(Self { client })
121    }
122
123    async fn obtain_zone_id(&self, origin: impl IntoFqdn<'_>) -> crate::Result<String> {
124        let origin = origin.into_name();
125        self.client
126            .get(format!(
127                "https://api.cloudflare.com/client/v4/zones?{}",
128                Query::name(origin.as_ref()).serialize()
129            ))
130            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
131            .await
132            .and_then(|r| r.unwrap_response("list zones"))
133            .and_then(|result| {
134                result
135                    .into_iter()
136                    .find(|zone| zone.name == origin.as_ref())
137                    .map(|zone| zone.id)
138                    .ok_or_else(|| Error::Api(format!("Zone {} not found", origin.as_ref())))
139            })
140    }
141
142    async fn obtain_record_id(
143        &self,
144        zone_id: &str,
145        name: impl IntoFqdn<'_>,
146    ) -> crate::Result<String> {
147        let name = name.into_name();
148        self.client
149            .get(format!(
150                "https://api.cloudflare.com/client/v4/zones/{zone_id}/dns_records?{}",
151                Query::name(name.as_ref()).serialize()
152            ))
153            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
154            .await
155            .and_then(|r| r.unwrap_response("list DNS records"))
156            .and_then(|result| {
157                result
158                    .into_iter()
159                    .find(|record| record.name == name.as_ref())
160                    .map(|record| record.id)
161                    .ok_or_else(|| Error::Api(format!("DNS Record {} not found", name.as_ref())))
162            })
163    }
164
165    pub(crate) async fn create(
166        &self,
167        name: impl IntoFqdn<'_>,
168        record: DnsRecord,
169        ttl: u32,
170        origin: impl IntoFqdn<'_>,
171    ) -> crate::Result<()> {
172        self.client
173            .post(format!(
174                "https://api.cloudflare.com/client/v4/zones/{}/dns_records",
175                self.obtain_zone_id(origin).await?
176            ))
177            .with_body(CreateDnsRecordParams {
178                ttl: ttl.into(),
179                priority: record.priority(),
180                proxied: false.into(),
181                name: name.into_name().as_ref(),
182                content: record.into(),
183            })?
184            .send_with_retry::<ApiResult<Value>>(3)
185            .await
186            .map(|_| ())
187    }
188
189    pub(crate) async fn update(
190        &self,
191        name: impl IntoFqdn<'_>,
192        record: DnsRecord,
193        ttl: u32,
194        origin: impl IntoFqdn<'_>,
195    ) -> crate::Result<()> {
196        let name = name.into_name();
197        self.client
198            .patch(format!(
199                "https://api.cloudflare.com/client/v4/zones/{}/dns_records/{}",
200                self.obtain_zone_id(origin).await?,
201                name.as_ref()
202            ))
203            .with_body(UpdateDnsRecordParams {
204                ttl: ttl.into(),
205                proxied: None,
206                name: name.as_ref(),
207                content: record.into(),
208            })?
209            .send_with_retry::<ApiResult<Value>>(3)
210            .await
211            .map(|_| ())
212    }
213
214    pub(crate) async fn delete(
215        &self,
216        name: impl IntoFqdn<'_>,
217        origin: impl IntoFqdn<'_>,
218    ) -> crate::Result<()> {
219        let zone_id = self.obtain_zone_id(origin).await?;
220        let record_id = self.obtain_record_id(&zone_id, name).await?;
221
222        self.client
223            .delete(format!(
224                "https://api.cloudflare.com/client/v4/zones/{zone_id}/dns_records/{record_id}",
225            ))
226            .send_with_retry::<ApiResult<Value>>(3)
227            .await
228            .map(|_| ())
229    }
230}
231
232impl<T> ApiResult<T> {
233    fn unwrap_response(self, action_name: &str) -> crate::Result<T> {
234        if self.success {
235            Ok(self.result)
236        } else {
237            Err(Error::Api(format!(
238                "Failed to {action_name}: {:?}",
239                self.errors
240            )))
241        }
242    }
243}
244
245impl Query {
246    pub fn name(name: impl Into<String>) -> Self {
247        Self { name: name.into() }
248    }
249
250    pub fn serialize(&self) -> String {
251        serde_urlencoded::to_string(self).unwrap()
252    }
253}
254
255impl From<DnsRecord> for DnsContent {
256    fn from(record: DnsRecord) -> Self {
257        match record {
258            DnsRecord::A(content) => DnsContent::A { content },
259            DnsRecord::AAAA(content) => DnsContent::AAAA { content },
260            DnsRecord::CNAME(content) => DnsContent::CNAME { content },
261            DnsRecord::NS(content) => DnsContent::NS { content },
262            DnsRecord::MX(mx) => DnsContent::MX {
263                content: mx.exchange,
264                priority: mx.priority,
265            },
266            DnsRecord::TXT(content) => DnsContent::TXT { content },
267            DnsRecord::SRV(srv) => DnsContent::SRV {
268                data: SrvData {
269                    priority: srv.priority,
270                    weight: srv.weight,
271                    port: srv.port,
272                    target: srv.target,
273                },
274            },
275            DnsRecord::TLSA(tlsa) => DnsContent::TLSA {
276                data: TlsaData {
277                    usage: u8::from(tlsa.cert_usage),
278                    selector: u8::from(tlsa.selector),
279                    matching_type: u8::from(tlsa.matching),
280                    certificate: tlsa
281                        .cert_data
282                        .iter()
283                        .map(|b| format!("{b:02x}"))
284                        .collect(),
285                },
286            },
287            DnsRecord::CAA(caa) => DnsContent::CAA {
288                content: caa.to_string(),
289            },
290        }
291    }
292}