Skip to main content

dns_update/providers/
cloudflare.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{DnsRecord, DnsRecordType, Error, IntoFqdn, http::HttpClientBuilder};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::{
16    net::{Ipv4Addr, Ipv6Addr},
17    time::Duration,
18};
19
20#[derive(Clone)]
21pub struct CloudflareProvider {
22    client: HttpClientBuilder,
23}
24
25#[derive(Deserialize, Debug)]
26pub struct IdMap {
27    pub id: String,
28    pub name: String,
29}
30
31#[derive(Serialize, Debug)]
32pub struct Query {
33    name: String,
34    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
35    record_type: Option<&'static str>,
36    #[serde(rename = "match", skip_serializing_if = "Option::is_none")]
37    match_mode: Option<&'static str>,
38}
39
40#[derive(Serialize, Clone, Debug)]
41pub struct CreateDnsRecordParams<'a> {
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub ttl: Option<u32>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub priority: Option<u16>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub proxied: Option<bool>,
48    pub name: &'a str,
49    #[serde(flatten)]
50    pub content: DnsContent,
51}
52
53#[derive(Serialize, Clone, Debug)]
54pub struct UpdateDnsRecordParams<'a> {
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub ttl: Option<u32>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub proxied: Option<bool>,
59    pub name: &'a str,
60    #[serde(flatten)]
61    pub content: DnsContent,
62}
63
64#[derive(Deserialize, Serialize, Clone, Debug)]
65#[serde(tag = "type")]
66#[allow(clippy::upper_case_acronyms)]
67pub enum DnsContent {
68    A { content: Ipv4Addr },
69    AAAA { content: Ipv6Addr },
70    CNAME { content: String },
71    NS { content: String },
72    MX { content: String, priority: u16 },
73    TXT { content: String },
74    SRV { data: SrvData },
75    TLSA { data: TlsaData },
76    CAA { data: CaaData },
77}
78
79#[derive(Deserialize, Serialize, Clone, Debug)]
80pub struct SrvData {
81    pub priority: u16,
82    pub weight: u16,
83    pub port: u16,
84    pub target: String,
85}
86
87#[derive(Deserialize, Serialize, Clone, Debug)]
88pub struct TlsaData {
89    pub usage: u8,
90    pub selector: u8,
91    pub matching_type: u8,
92    pub certificate: String,
93}
94
95#[derive(Deserialize, Serialize, Clone, Debug)]
96pub struct CaaData {
97    pub flags: u8,
98    pub tag: String,
99    pub value: String,
100}
101
102#[derive(Deserialize, Serialize, Debug)]
103struct ApiResult<T> {
104    errors: Vec<ApiError>,
105    success: bool,
106    result: T,
107}
108
109#[derive(Deserialize, Serialize, Debug)]
110pub struct ApiError {
111    pub code: u16,
112    pub message: String,
113}
114
115impl CloudflareProvider {
116    pub(crate) fn new(
117        secret: impl AsRef<str>,
118        email: Option<impl AsRef<str>>,
119        timeout: Option<Duration>,
120    ) -> crate::Result<Self> {
121        let client = if let Some(email) = email {
122            HttpClientBuilder::default()
123                .with_header("X-Auth-Email", email.as_ref())
124                .with_header("X-Auth-Key", secret.as_ref())
125        } else {
126            HttpClientBuilder::default()
127                .with_header("Authorization", format!("Bearer {}", secret.as_ref()))
128        }
129        .with_timeout(timeout);
130
131        Ok(Self { client })
132    }
133
134    async fn obtain_zone_id(&self, origin: impl IntoFqdn<'_>) -> crate::Result<String> {
135        let origin = origin.into_name();
136        self.client
137            .get(format!(
138                "https://api.cloudflare.com/client/v4/zones?{}",
139                Query::name(origin.as_ref()).serialize()
140            ))
141            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
142            .await
143            .and_then(|r| r.unwrap_response("list zones"))
144            .and_then(|result| {
145                result
146                    .into_iter()
147                    .find(|zone| zone.name == origin.as_ref())
148                    .map(|zone| zone.id)
149                    .ok_or_else(|| Error::Api(format!("Zone {} not found", origin.as_ref())))
150            })
151    }
152
153    async fn obtain_record_id(
154        &self,
155        zone_id: &str,
156        name: impl IntoFqdn<'_>,
157        record_type: DnsRecordType,
158    ) -> crate::Result<String> {
159        let name = name.into_name();
160        self.client
161            .get(format!(
162                "https://api.cloudflare.com/client/v4/zones/{zone_id}/dns_records?{}",
163                Query::name_and_type(name.as_ref(), record_type).serialize()
164            ))
165            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
166            .await
167            .and_then(|r| r.unwrap_response("list DNS records"))
168            .and_then(|result| {
169                result
170                    .into_iter()
171                    .find(|record| record.name == name.as_ref())
172                    .map(|record| record.id)
173                    .ok_or_else(|| {
174                        Error::Api(format!(
175                            "DNS Record {} of type {} not found",
176                            name.as_ref(),
177                            record_type.as_str()
178                        ))
179                    })
180            })
181    }
182
183    pub(crate) async fn create(
184        &self,
185        name: impl IntoFqdn<'_>,
186        record: DnsRecord,
187        ttl: u32,
188        origin: impl IntoFqdn<'_>,
189    ) -> crate::Result<()> {
190        self.client
191            .post(format!(
192                "https://api.cloudflare.com/client/v4/zones/{}/dns_records",
193                self.obtain_zone_id(origin).await?
194            ))
195            .with_body(CreateDnsRecordParams {
196                ttl: ttl.into(),
197                priority: record.priority(),
198                proxied: false.into(),
199                name: name.into_name().as_ref(),
200                content: record.into(),
201            })?
202            .send_with_retry::<ApiResult<Value>>(3)
203            .await
204            .map(|_| ())
205    }
206
207    pub(crate) async fn update(
208        &self,
209        name: impl IntoFqdn<'_>,
210        record: DnsRecord,
211        ttl: u32,
212        origin: impl IntoFqdn<'_>,
213    ) -> crate::Result<()> {
214        let name = name.into_name();
215        self.client
216            .patch(format!(
217                "https://api.cloudflare.com/client/v4/zones/{}/dns_records/{}",
218                self.obtain_zone_id(origin).await?,
219                name.as_ref()
220            ))
221            .with_body(UpdateDnsRecordParams {
222                ttl: ttl.into(),
223                proxied: None,
224                name: name.as_ref(),
225                content: record.into(),
226            })?
227            .send_with_retry::<ApiResult<Value>>(3)
228            .await
229            .map(|_| ())
230    }
231
232    pub(crate) async fn delete(
233        &self,
234        name: impl IntoFqdn<'_>,
235        origin: impl IntoFqdn<'_>,
236        record_type: DnsRecordType,
237    ) -> crate::Result<()> {
238        let zone_id = self.obtain_zone_id(origin).await?;
239        let record_id = self.obtain_record_id(&zone_id, name, record_type).await?;
240
241        self.client
242            .delete(format!(
243                "https://api.cloudflare.com/client/v4/zones/{zone_id}/dns_records/{record_id}",
244            ))
245            .send_with_retry::<ApiResult<Value>>(3)
246            .await
247            .map(|_| ())
248    }
249}
250
251impl<T> ApiResult<T> {
252    fn unwrap_response(self, action_name: &str) -> crate::Result<T> {
253        if self.success {
254            Ok(self.result)
255        } else {
256            Err(Error::Api(format!(
257                "Failed to {action_name}: {:?}",
258                self.errors
259            )))
260        }
261    }
262}
263
264impl Query {
265    pub fn name(name: impl Into<String>) -> Self {
266        Self {
267            name: name.into(),
268            record_type: None,
269            match_mode: None,
270        }
271    }
272
273    pub fn name_and_type(name: impl Into<String>, record_type: DnsRecordType) -> Self {
274        Self {
275            name: name.into(),
276            record_type: Some(record_type.as_str()),
277            match_mode: Some("all"),
278        }
279    }
280
281    pub fn serialize(&self) -> String {
282        serde_urlencoded::to_string(self).unwrap()
283    }
284}
285
286impl From<DnsRecord> for DnsContent {
287    fn from(record: DnsRecord) -> Self {
288        match record {
289            DnsRecord::A(content) => DnsContent::A { content },
290            DnsRecord::AAAA(content) => DnsContent::AAAA { content },
291            DnsRecord::CNAME(content) => DnsContent::CNAME { content },
292            DnsRecord::NS(content) => DnsContent::NS { content },
293            DnsRecord::MX(mx) => DnsContent::MX {
294                content: mx.exchange,
295                priority: mx.priority,
296            },
297            DnsRecord::TXT(content) => DnsContent::TXT { content },
298            DnsRecord::SRV(srv) => DnsContent::SRV {
299                data: SrvData {
300                    priority: srv.priority,
301                    weight: srv.weight,
302                    port: srv.port,
303                    target: srv.target,
304                },
305            },
306            DnsRecord::TLSA(tlsa) => DnsContent::TLSA {
307                data: TlsaData {
308                    usage: u8::from(tlsa.cert_usage),
309                    selector: u8::from(tlsa.selector),
310                    matching_type: u8::from(tlsa.matching),
311                    certificate: tlsa
312                        .cert_data
313                        .iter()
314                        .map(|b| format!("{b:02x}"))
315                        .collect(),
316                },
317            },
318            DnsRecord::CAA(caa) => {
319                let (flags, tag, value) = caa.decompose();
320                DnsContent::CAA {
321                    data: CaaData { flags, tag, value },
322                }
323            }
324        }
325    }
326}