Skip to main content

dns_update/providers/
cloudflare.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::{DnsRecord, DnsRecordType, Error, IntoFqdn, http::HttpClientBuilder};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::{
16    borrow::Cow,
17    net::{Ipv4Addr, Ipv6Addr},
18    time::Duration,
19};
20
21#[derive(Clone)]
22pub struct CloudflareProvider {
23    client: HttpClientBuilder,
24    endpoint: Cow<'static, str>,
25}
26
27#[derive(Deserialize, Debug)]
28pub struct IdMap {
29    pub id: String,
30    pub name: String,
31}
32
33#[derive(Serialize, Debug)]
34pub struct Query {
35    name: String,
36    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
37    record_type: Option<&'static str>,
38    #[serde(rename = "match", skip_serializing_if = "Option::is_none")]
39    match_mode: Option<&'static str>,
40}
41
42#[derive(Serialize, Clone, Debug)]
43pub struct CreateDnsRecordParams<'a> {
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub ttl: Option<u32>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub priority: Option<u16>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub proxied: Option<bool>,
50    pub name: &'a str,
51    #[serde(flatten)]
52    pub content: DnsContent,
53}
54
55#[derive(Serialize, Clone, Debug)]
56pub struct UpdateDnsRecordParams<'a> {
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub ttl: Option<u32>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub proxied: Option<bool>,
61    pub name: &'a str,
62    #[serde(flatten)]
63    pub content: DnsContent,
64}
65
66#[derive(Deserialize, Serialize, Clone, Debug)]
67#[serde(tag = "type")]
68#[allow(clippy::upper_case_acronyms)]
69pub enum DnsContent {
70    A { content: Ipv4Addr },
71    AAAA { content: Ipv6Addr },
72    CNAME { content: String },
73    NS { content: String },
74    MX { content: String, priority: u16 },
75    TXT { content: String },
76    SRV { data: SrvData },
77    TLSA { data: TlsaData },
78    CAA { data: CaaData },
79}
80
81#[derive(Deserialize, Serialize, Clone, Debug)]
82pub struct SrvData {
83    pub priority: u16,
84    pub weight: u16,
85    pub port: u16,
86    pub target: String,
87}
88
89#[derive(Deserialize, Serialize, Clone, Debug)]
90pub struct TlsaData {
91    pub usage: u8,
92    pub selector: u8,
93    pub matching_type: u8,
94    pub certificate: String,
95}
96
97#[derive(Deserialize, Serialize, Clone, Debug)]
98pub struct CaaData {
99    pub flags: u8,
100    pub tag: String,
101    pub value: String,
102}
103
104#[derive(Deserialize, Serialize, Debug)]
105struct ApiResult<T> {
106    errors: Vec<ApiError>,
107    success: bool,
108    result: T,
109}
110
111const DEFAULT_API_ENDPOINT: &str = "https://api.cloudflare.com/client/v4";
112
113#[derive(Deserialize, Serialize, Debug)]
114pub struct ApiError {
115    pub code: u16,
116    pub message: String,
117}
118
119impl CloudflareProvider {
120    pub(crate) fn new(secret: impl AsRef<str>, timeout: Option<Duration>) -> crate::Result<Self> {
121        let client = HttpClientBuilder::default()
122            .with_header("Authorization", format!("Bearer {}", secret.as_ref()))
123            .with_timeout(timeout);
124
125        Ok(Self {
126            client,
127            endpoint: Cow::Borrowed(DEFAULT_API_ENDPOINT),
128        })
129    }
130
131    #[cfg(test)]
132    pub(crate) fn with_endpoint(self, endpoint: impl Into<Cow<'static, str>>) -> Self {
133        Self {
134            endpoint: endpoint.into(),
135            ..self
136        }
137    }
138
139    async fn obtain_zone_id(&self, origin: impl IntoFqdn<'_>) -> crate::Result<String> {
140        let origin = origin.into_name();
141        let mut candidate: &str = origin.as_ref();
142        loop {
143            let zones = self
144                .client
145                .get(format!(
146                    "{}/zones?{}",
147                    self.endpoint,
148                    Query::name(candidate).serialize()
149                ))
150                .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
151                .await
152                .and_then(|r| r.unwrap_response("list zones"))?;
153            if let Some(zone) = zones.into_iter().find(|zone| zone.name == candidate) {
154                return Ok(zone.id);
155            }
156            match candidate.split_once('.') {
157                Some((_, rest)) if rest.contains('.') => candidate = rest,
158                _ => {
159                    return Err(Error::Api(format!(
160                        "No Cloudflare zone found for {}",
161                        origin.as_ref()
162                    )));
163                }
164            }
165        }
166    }
167
168    async fn obtain_record_id(
169        &self,
170        zone_id: &str,
171        name: impl IntoFqdn<'_>,
172        record_type: DnsRecordType,
173    ) -> crate::Result<String> {
174        let name = name.into_name();
175        self.client
176            .get(format!(
177                "{}/zones/{zone_id}/dns_records?{}",
178                self.endpoint,
179                Query::name_and_type(name.as_ref(), record_type).serialize()
180            ))
181            .send_with_retry::<ApiResult<Vec<IdMap>>>(3)
182            .await
183            .and_then(|r| r.unwrap_response("list DNS records"))
184            .and_then(|result| {
185                result
186                    .into_iter()
187                    .find(|record| record.name == name.as_ref())
188                    .map(|record| record.id)
189                    .ok_or_else(|| {
190                        Error::Api(format!(
191                            "DNS Record {} of type {} not found",
192                            name.as_ref(),
193                            record_type.as_str()
194                        ))
195                    })
196            })
197    }
198
199    pub(crate) async fn create(
200        &self,
201        name: impl IntoFqdn<'_>,
202        record: DnsRecord,
203        ttl: u32,
204        origin: impl IntoFqdn<'_>,
205    ) -> crate::Result<()> {
206        self.client
207            .post(format!(
208                "{}/zones/{}/dns_records",
209                self.endpoint,
210                self.obtain_zone_id(origin).await?
211            ))
212            .with_body(CreateDnsRecordParams {
213                ttl: ttl.into(),
214                priority: record.priority(),
215                proxied: false.into(),
216                name: name.into_name().as_ref(),
217                content: record.into(),
218            })?
219            .send_with_retry::<ApiResult<Value>>(3)
220            .await
221            .map(|_| ())
222    }
223
224    pub(crate) async fn update(
225        &self,
226        name: impl IntoFqdn<'_>,
227        record: DnsRecord,
228        ttl: u32,
229        origin: impl IntoFqdn<'_>,
230    ) -> crate::Result<()> {
231        let zone_id = self.obtain_zone_id(origin).await?;
232        let name = name.into_name();
233        let record_id = self
234            .obtain_record_id(&zone_id, name.as_ref(), record.as_type())
235            .await?;
236        self.client
237            .patch(format!(
238                "{}/zones/{zone_id}/dns_records/{record_id}",
239                self.endpoint,
240            ))
241            .with_body(UpdateDnsRecordParams {
242                ttl: ttl.into(),
243                proxied: None,
244                name: name.as_ref(),
245                content: record.into(),
246            })?
247            .send_with_retry::<ApiResult<Value>>(3)
248            .await
249            .map(|_| ())
250    }
251
252    pub(crate) async fn delete(
253        &self,
254        name: impl IntoFqdn<'_>,
255        origin: impl IntoFqdn<'_>,
256        record_type: DnsRecordType,
257    ) -> crate::Result<()> {
258        let zone_id = self.obtain_zone_id(origin).await?;
259        let record_id = self.obtain_record_id(&zone_id, name, record_type).await?;
260
261        self.client
262            .delete(format!(
263                "{}/zones/{zone_id}/dns_records/{record_id}",
264                self.endpoint,
265            ))
266            .send_with_retry::<ApiResult<Value>>(3)
267            .await
268            .map(|_| ())
269    }
270}
271
272impl<T> ApiResult<T> {
273    fn unwrap_response(self, action_name: &str) -> crate::Result<T> {
274        if self.success {
275            Ok(self.result)
276        } else {
277            Err(Error::Api(format!(
278                "Failed to {action_name}: {:?}",
279                self.errors
280            )))
281        }
282    }
283}
284
285impl Query {
286    pub fn name(name: impl Into<String>) -> Self {
287        Self {
288            name: name.into(),
289            record_type: None,
290            match_mode: None,
291        }
292    }
293
294    pub fn name_and_type(name: impl Into<String>, record_type: DnsRecordType) -> Self {
295        Self {
296            name: name.into(),
297            record_type: Some(record_type.as_str()),
298            match_mode: Some("all"),
299        }
300    }
301
302    pub fn serialize(&self) -> String {
303        serde_urlencoded::to_string(self).unwrap()
304    }
305}
306
307impl From<DnsRecord> for DnsContent {
308    fn from(record: DnsRecord) -> Self {
309        match record {
310            DnsRecord::A(content) => DnsContent::A { content },
311            DnsRecord::AAAA(content) => DnsContent::AAAA { content },
312            DnsRecord::CNAME(content) => DnsContent::CNAME { content },
313            DnsRecord::NS(content) => DnsContent::NS { content },
314            DnsRecord::MX(mx) => DnsContent::MX {
315                content: mx.exchange,
316                priority: mx.priority,
317            },
318            DnsRecord::TXT(content) => DnsContent::TXT {
319                content: format!("\"{}\"", content.replace('\"', "\\\"")),
320            },
321            DnsRecord::SRV(srv) => DnsContent::SRV {
322                data: SrvData {
323                    priority: srv.priority,
324                    weight: srv.weight,
325                    port: srv.port,
326                    target: srv.target,
327                },
328            },
329            DnsRecord::TLSA(tlsa) => DnsContent::TLSA {
330                data: TlsaData {
331                    usage: u8::from(tlsa.cert_usage),
332                    selector: u8::from(tlsa.selector),
333                    matching_type: u8::from(tlsa.matching),
334                    certificate: tlsa.cert_data.iter().map(|b| format!("{b:02x}")).collect(),
335                },
336            },
337            DnsRecord::CAA(caa) => {
338                let (flags, tag, value) = caa.decompose();
339                DnsContent::CAA {
340                    data: CaaData { flags, tag, value },
341                }
342            }
343        }
344    }
345}