Skip to main content

dns_update/providers/
cloudflare.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{DnsRecord, DnsRecordType, Error, IntoFqdn, http::HttpClientBuilder};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::{
16    net::{Ipv4Addr, Ipv6Addr},
17    time::Duration,
18};
19
20#[derive(Clone)]
21pub struct CloudflareProvider {
22    client: HttpClientBuilder,
23}
24
25#[derive(Deserialize, Debug)]
26pub struct IdMap {
27    pub id: String,
28    pub name: String,
29}
30
31#[derive(Serialize, Debug)]
32pub struct Query {
33    name: String,
34    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
35    record_type: Option<&'static str>,
36    #[serde(rename = "match", skip_serializing_if = "Option::is_none")]
37    match_mode: Option<&'static str>,
38}
39
40#[derive(Serialize, Clone, Debug)]
41pub struct CreateDnsRecordParams<'a> {
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub ttl: Option<u32>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub priority: Option<u16>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub proxied: Option<bool>,
48    pub name: &'a str,
49    #[serde(flatten)]
50    pub content: DnsContent,
51}
52
53#[derive(Serialize, Clone, Debug)]
54pub struct UpdateDnsRecordParams<'a> {
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub ttl: Option<u32>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub proxied: Option<bool>,
59    pub name: &'a str,
60    #[serde(flatten)]
61    pub content: DnsContent,
62}
63
64#[derive(Deserialize, Serialize, Clone, Debug)]
65#[serde(tag = "type")]
66#[allow(clippy::upper_case_acronyms)]
67pub enum DnsContent {
68    A { content: Ipv4Addr },
69    AAAA { content: Ipv6Addr },
70    CNAME { content: String },
71    NS { content: String },
72    MX { content: String, priority: u16 },
73    TXT { content: String },
74    SRV { data: SrvData },
75    TLSA { data: TlsaData },
76    CAA { content: String },
77}
78
79#[derive(Deserialize, Serialize, Clone, Debug)]
80pub struct SrvData {
81    pub priority: u16,
82    pub weight: u16,
83    pub port: u16,
84    pub target: String,
85}
86
87#[derive(Deserialize, Serialize, Clone, Debug)]
88pub struct TlsaData {
89    pub usage: u8,
90    pub selector: u8,
91    pub matching_type: u8,
92    pub certificate: String,
93}
94
95#[derive(Deserialize, Serialize, Debug)]
96struct ApiResult<T> {
97    errors: Vec<ApiError>,
98    success: bool,
99    result: T,
100}
101
102#[derive(Deserialize, Serialize, Debug)]
103pub struct ApiError {
104    pub code: u16,
105    pub message: String,
106}
107
108impl CloudflareProvider {
109    pub(crate) fn new(
110        secret: impl AsRef<str>,
111        email: Option<impl AsRef<str>>,
112        timeout: Option<Duration>,
113    ) -> crate::Result<Self> {
114        let client = if let Some(email) = email {
115            HttpClientBuilder::default()
116                .with_header("X-Auth-Email", email.as_ref())
117                .with_header("X-Auth-Key", secret.as_ref())
118        } else {
119            HttpClientBuilder::default()
120                .with_header("Authorization", format!("Bearer {}", secret.as_ref()))
121        }
122        .with_timeout(timeout);
123
124        Ok(Self { client })
125    }
126
127    async fn obtain_zone_id(&self, origin: impl IntoFqdn<'_>) -> crate::Result<String> {
128        let origin = origin.into_name();
129        self.client
130            .get(format!(
131                "https://api.cloudflare.com/client/v4/zones?{}",
132                Query::name(origin.as_ref()).serialize()
133            ))
134            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
135            .await
136            .and_then(|r| r.unwrap_response("list zones"))
137            .and_then(|result| {
138                result
139                    .into_iter()
140                    .find(|zone| zone.name == origin.as_ref())
141                    .map(|zone| zone.id)
142                    .ok_or_else(|| Error::Api(format!("Zone {} not found", origin.as_ref())))
143            })
144    }
145
146    async fn obtain_record_id(
147        &self,
148        zone_id: &str,
149        name: impl IntoFqdn<'_>,
150        record_type: DnsRecordType,
151    ) -> crate::Result<String> {
152        let name = name.into_name();
153        self.client
154            .get(format!(
155                "https://api.cloudflare.com/client/v4/zones/{zone_id}/dns_records?{}",
156                Query::name_and_type(name.as_ref(), record_type).serialize()
157            ))
158            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
159            .await
160            .and_then(|r| r.unwrap_response("list DNS records"))
161            .and_then(|result| {
162                result
163                    .into_iter()
164                    .find(|record| record.name == name.as_ref())
165                    .map(|record| record.id)
166                    .ok_or_else(|| {
167                        Error::Api(format!(
168                            "DNS Record {} of type {} not found",
169                            name.as_ref(),
170                            record_type.as_str()
171                        ))
172                    })
173            })
174    }
175
176    pub(crate) async fn create(
177        &self,
178        name: impl IntoFqdn<'_>,
179        record: DnsRecord,
180        ttl: u32,
181        origin: impl IntoFqdn<'_>,
182    ) -> crate::Result<()> {
183        self.client
184            .post(format!(
185                "https://api.cloudflare.com/client/v4/zones/{}/dns_records",
186                self.obtain_zone_id(origin).await?
187            ))
188            .with_body(CreateDnsRecordParams {
189                ttl: ttl.into(),
190                priority: record.priority(),
191                proxied: false.into(),
192                name: name.into_name().as_ref(),
193                content: record.into(),
194            })?
195            .send_with_retry::<ApiResult<Value>>(3)
196            .await
197            .map(|_| ())
198    }
199
200    pub(crate) async fn update(
201        &self,
202        name: impl IntoFqdn<'_>,
203        record: DnsRecord,
204        ttl: u32,
205        origin: impl IntoFqdn<'_>,
206    ) -> crate::Result<()> {
207        let name = name.into_name();
208        self.client
209            .patch(format!(
210                "https://api.cloudflare.com/client/v4/zones/{}/dns_records/{}",
211                self.obtain_zone_id(origin).await?,
212                name.as_ref()
213            ))
214            .with_body(UpdateDnsRecordParams {
215                ttl: ttl.into(),
216                proxied: None,
217                name: name.as_ref(),
218                content: record.into(),
219            })?
220            .send_with_retry::<ApiResult<Value>>(3)
221            .await
222            .map(|_| ())
223    }
224
225    pub(crate) async fn delete(
226        &self,
227        name: impl IntoFqdn<'_>,
228        origin: impl IntoFqdn<'_>,
229        record_type: DnsRecordType,
230    ) -> crate::Result<()> {
231        let zone_id = self.obtain_zone_id(origin).await?;
232        let record_id = self.obtain_record_id(&zone_id, name, record_type).await?;
233
234        self.client
235            .delete(format!(
236                "https://api.cloudflare.com/client/v4/zones/{zone_id}/dns_records/{record_id}",
237            ))
238            .send_with_retry::<ApiResult<Value>>(3)
239            .await
240            .map(|_| ())
241    }
242}
243
244impl<T> ApiResult<T> {
245    fn unwrap_response(self, action_name: &str) -> crate::Result<T> {
246        if self.success {
247            Ok(self.result)
248        } else {
249            Err(Error::Api(format!(
250                "Failed to {action_name}: {:?}",
251                self.errors
252            )))
253        }
254    }
255}
256
257impl Query {
258    pub fn name(name: impl Into<String>) -> Self {
259        Self {
260            name: name.into(),
261            record_type: None,
262            match_mode: None,
263        }
264    }
265
266    pub fn name_and_type(name: impl Into<String>, record_type: DnsRecordType) -> Self {
267        Self {
268            name: name.into(),
269            record_type: Some(record_type.as_str()),
270            match_mode: Some("all"),
271        }
272    }
273
274    pub fn serialize(&self) -> String {
275        serde_urlencoded::to_string(self).unwrap()
276    }
277}
278
279impl From<DnsRecord> for DnsContent {
280    fn from(record: DnsRecord) -> Self {
281        match record {
282            DnsRecord::A(content) => DnsContent::A { content },
283            DnsRecord::AAAA(content) => DnsContent::AAAA { content },
284            DnsRecord::CNAME(content) => DnsContent::CNAME { content },
285            DnsRecord::NS(content) => DnsContent::NS { content },
286            DnsRecord::MX(mx) => DnsContent::MX {
287                content: mx.exchange,
288                priority: mx.priority,
289            },
290            DnsRecord::TXT(content) => DnsContent::TXT { content },
291            DnsRecord::SRV(srv) => DnsContent::SRV {
292                data: SrvData {
293                    priority: srv.priority,
294                    weight: srv.weight,
295                    port: srv.port,
296                    target: srv.target,
297                },
298            },
299            DnsRecord::TLSA(tlsa) => DnsContent::TLSA {
300                data: TlsaData {
301                    usage: u8::from(tlsa.cert_usage),
302                    selector: u8::from(tlsa.selector),
303                    matching_type: u8::from(tlsa.matching),
304                    certificate: tlsa
305                        .cert_data
306                        .iter()
307                        .map(|b| format!("{b:02x}"))
308                        .collect(),
309                },
310            },
311            DnsRecord::CAA(caa) => DnsContent::CAA {
312                content: caa.to_string(),
313            },
314        }
315    }
316}