1use crate::{
13 CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, SRVRecord,
14 http::{HttpClient, HttpClientBuilder},
15 utils::txt_chunks_to_text,
16};
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19
20const DEFAULT_API_ENDPOINT: &str = "https://api.ukfast.io/safedns/v1";
21const LIST_PAGE_SIZE: u32 = 200;
22
23#[derive(Clone)]
24pub struct SafeDnsProvider {
25 client: HttpClient,
26 endpoint: String,
27}
28
29#[derive(Serialize, Debug, Clone)]
30pub struct SafeDnsRecordPayload<'a> {
31 pub name: &'a str,
32 #[serde(rename = "type")]
33 pub record_type: &'a str,
34 pub content: String,
35 pub ttl: u32,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub priority: Option<u16>,
38}
39
40#[derive(Deserialize, Debug, Clone)]
41pub struct SafeDnsRecord {
42 pub id: i64,
43 pub name: String,
44 #[serde(rename = "type")]
45 pub record_type: String,
46 #[serde(default)]
47 pub content: String,
48 #[serde(default)]
49 pub priority: Option<u16>,
50}
51
52#[derive(Deserialize, Debug)]
53pub struct ListRecordsResponse {
54 pub data: Vec<SafeDnsRecord>,
55 #[serde(default)]
56 pub meta: ListMeta,
57}
58
59#[derive(Deserialize, Debug, Default)]
60pub struct ListMeta {
61 #[serde(default)]
62 pub pagination: Pagination,
63}
64
65#[derive(Deserialize, Debug, Default)]
66pub struct Pagination {
67 #[serde(default)]
68 pub total_pages: u32,
69}
70
71#[derive(Deserialize, Debug)]
72pub struct AddRecordResponse {
73 #[allow(dead_code)]
74 pub data: SafeDnsRecord,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct SafeDnsRecordContent {
79 pub record_type: &'static str,
80 pub content: String,
81 pub priority: Option<u16>,
82}
83
84impl SafeDnsProvider {
85 pub(crate) fn new(auth_token: impl AsRef<str>, timeout: Option<Duration>) -> Self {
86 let client = HttpClientBuilder::default()
87 .with_header("Authorization", auth_token.as_ref())
88 .with_timeout(timeout)
89 .build();
90 Self {
91 client,
92 endpoint: DEFAULT_API_ENDPOINT.to_string(),
93 }
94 }
95
96 #[cfg(test)]
97 pub(crate) fn with_endpoint(self, endpoint: impl AsRef<str>) -> Self {
98 Self {
99 endpoint: endpoint.as_ref().to_string(),
100 ..self
101 }
102 }
103
104 pub(crate) async fn set_rrset(
105 &self,
106 name: impl IntoFqdn<'_>,
107 record_type: DnsRecordType,
108 ttl: u32,
109 records: Vec<DnsRecord>,
110 origin: impl IntoFqdn<'_>,
111 ) -> crate::Result<()> {
112 reject_unsupported(record_type)?;
113 let fqdn = name.into_name().into_owned();
114 let zone = origin.into_name().into_owned();
115 let desired = build_contents(record_type, records)?;
116 let existing = self.list_at(&zone, &fqdn, record_type).await?;
117
118 let mut existing_pool: Vec<SafeDnsRecord> = existing;
119 let mut to_add: Vec<SafeDnsRecordContent> = Vec::new();
120
121 for content in desired {
122 if let Some(idx) = existing_pool
123 .iter()
124 .position(|r| record_matches(r, &content))
125 {
126 existing_pool.swap_remove(idx);
127 } else {
128 to_add.push(content);
129 }
130 }
131
132 for entry in existing_pool {
133 self.delete_record(&zone, entry.id).await?;
134 }
135 for content in to_add {
136 self.create_record(&zone, &fqdn, ttl, &content).await?;
137 }
138 Ok(())
139 }
140
141 pub(crate) async fn add_to_rrset(
142 &self,
143 name: impl IntoFqdn<'_>,
144 record_type: DnsRecordType,
145 ttl: u32,
146 records: Vec<DnsRecord>,
147 origin: impl IntoFqdn<'_>,
148 ) -> crate::Result<()> {
149 reject_unsupported(record_type)?;
150 if records.is_empty() {
151 return Ok(());
152 }
153 let fqdn = name.into_name().into_owned();
154 let zone = origin.into_name().into_owned();
155 let desired = build_contents(record_type, records)?;
156 let existing = self.list_at(&zone, &fqdn, record_type).await?;
157
158 for content in desired {
159 if existing.iter().any(|r| record_matches(r, &content)) {
160 continue;
161 }
162 self.create_record(&zone, &fqdn, ttl, &content).await?;
163 }
164 Ok(())
165 }
166
167 pub(crate) async fn remove_from_rrset(
168 &self,
169 name: impl IntoFqdn<'_>,
170 record_type: DnsRecordType,
171 records: Vec<DnsRecord>,
172 origin: impl IntoFqdn<'_>,
173 ) -> crate::Result<()> {
174 reject_unsupported(record_type)?;
175 if records.is_empty() {
176 return Ok(());
177 }
178 let fqdn = name.into_name().into_owned();
179 let zone = origin.into_name().into_owned();
180 let to_remove = build_contents(record_type, records)?;
181 let existing = self.list_at(&zone, &fqdn, record_type).await?;
182
183 for content in to_remove {
184 if let Some(entry) = existing.iter().find(|r| record_matches(r, &content)) {
185 self.delete_record(&zone, entry.id).await?;
186 }
187 }
188 Ok(())
189 }
190
191 pub(crate) async fn list_rrset(
192 &self,
193 name: impl IntoFqdn<'_>,
194 record_type: DnsRecordType,
195 origin: impl IntoFqdn<'_>,
196 ) -> crate::Result<Vec<DnsRecord>> {
197 let fqdn = name.into_name().into_owned();
198 let zone = origin.into_name().into_owned();
199 let listed = self.list_at(&zone, &fqdn, record_type).await?;
200 listed
201 .into_iter()
202 .map(|r| safedns_record_to_dns_record(r, record_type))
203 .collect()
204 }
205
206 async fn list_at(
207 &self,
208 zone: &str,
209 name: &str,
210 record_type: DnsRecordType,
211 ) -> crate::Result<Vec<SafeDnsRecord>> {
212 let type_str = record_type.as_str();
213 let mut out: Vec<SafeDnsRecord> = Vec::new();
214 let mut page: u32 = 1;
215 loop {
216 let url = format!(
217 "{endpoint}/zones/{zone}/records?name:eq={name}&type:eq={type_str}&per_page={LIST_PAGE_SIZE}&page={page}",
218 endpoint = self.endpoint
219 );
220 let response: ListRecordsResponse = self.client.get(url).send_with_retry(3).await?;
221 let total_pages = response.meta.pagination.total_pages;
222 for record in response.data {
223 if record.name == name && record.record_type == type_str {
224 out.push(record);
225 }
226 }
227 if total_pages <= page {
228 break;
229 }
230 page += 1;
231 }
232 Ok(out)
233 }
234
235 async fn create_record(
236 &self,
237 zone: &str,
238 name: &str,
239 ttl: u32,
240 content: &SafeDnsRecordContent,
241 ) -> crate::Result<()> {
242 let body = SafeDnsRecordPayload {
243 name,
244 record_type: content.record_type,
245 content: content.content.clone(),
246 ttl,
247 priority: content.priority,
248 };
249
250 self.client
251 .post(format!(
252 "{endpoint}/zones/{zone}/records",
253 endpoint = self.endpoint
254 ))
255 .with_body(&body)?
256 .send_with_retry::<serde_json::Value>(3)
257 .await
258 .map(|_| ())
259 }
260
261 async fn delete_record(&self, zone: &str, record_id: i64) -> crate::Result<()> {
262 self.client
263 .delete(format!(
264 "{endpoint}/zones/{zone}/records/{record_id}",
265 endpoint = self.endpoint
266 ))
267 .send_with_retry::<serde_json::Value>(3)
268 .await
269 .map(|_| ())
270 }
271}
272
273fn reject_unsupported(record_type: DnsRecordType) -> crate::Result<()> {
274 if record_type == DnsRecordType::TLSA {
275 return Err(Error::Unsupported(
276 "TLSA records are not supported by SafeDNS".to_string(),
277 ));
278 }
279 Ok(())
280}
281
282fn build_contents(
283 expected_type: DnsRecordType,
284 records: Vec<DnsRecord>,
285) -> crate::Result<Vec<SafeDnsRecordContent>> {
286 let mut out = Vec::with_capacity(records.len());
287 for record in records {
288 if record.as_type() != expected_type {
289 return Err(Error::Api(format!(
290 "RRSet record type mismatch: expected {}, got {}",
291 expected_type.as_str(),
292 record.as_type().as_str(),
293 )));
294 }
295 out.push(SafeDnsRecordContent::try_from(record)?);
296 }
297 Ok(out)
298}
299
300fn record_matches(record: &SafeDnsRecord, content: &SafeDnsRecordContent) -> bool {
301 record.record_type == content.record_type
302 && record.content == content.content
303 && record.priority == content.priority
304}
305
306fn safedns_record_to_dns_record(
307 record: SafeDnsRecord,
308 record_type: DnsRecordType,
309) -> crate::Result<DnsRecord> {
310 match record_type {
311 DnsRecordType::A => record
312 .content
313 .parse()
314 .map(DnsRecord::A)
315 .map_err(|e| Error::Parse(format!("invalid A content {}: {e}", record.content))),
316 DnsRecordType::AAAA => record
317 .content
318 .parse()
319 .map(DnsRecord::AAAA)
320 .map_err(|e| Error::Parse(format!("invalid AAAA content {}: {e}", record.content))),
321 DnsRecordType::CNAME => Ok(DnsRecord::CNAME(record.content)),
322 DnsRecordType::NS => Ok(DnsRecord::NS(record.content)),
323 DnsRecordType::MX => Ok(DnsRecord::MX(MXRecord {
324 exchange: record.content,
325 priority: record.priority.unwrap_or(0),
326 })),
327 DnsRecordType::TXT => Ok(DnsRecord::TXT(unquote_txt(&record.content))),
328 DnsRecordType::SRV => parse_srv(&record.content, record.priority.unwrap_or(0)),
329 DnsRecordType::TLSA => Err(Error::Unsupported(
330 "TLSA records are not supported by SafeDNS".to_string(),
331 )),
332 DnsRecordType::CAA => parse_caa(&record.content),
333 }
334}
335
336fn parse_srv(content: &str, priority: u16) -> crate::Result<DnsRecord> {
337 let parts: Vec<&str> = content.split_whitespace().collect();
338 if parts.len() != 3 {
339 return Err(Error::Parse(format!(
340 "invalid SRV content (expected `<weight> <port> <target>`): {content}"
341 )));
342 }
343 let weight: u16 = parts[0]
344 .parse()
345 .map_err(|e| Error::Parse(format!("invalid SRV weight {}: {e}", parts[0])))?;
346 let port: u16 = parts[1]
347 .parse()
348 .map_err(|e| Error::Parse(format!("invalid SRV port {}: {e}", parts[1])))?;
349 Ok(DnsRecord::SRV(SRVRecord {
350 priority,
351 weight,
352 port,
353 target: parts[2].to_string(),
354 }))
355}
356
357fn unquote_txt(content: &str) -> String {
358 let mut out = String::with_capacity(content.len());
359 let mut chars = content.chars().peekable();
360 let mut in_quote = false;
361 while let Some(ch) = chars.next() {
362 match ch {
363 '"' => {
364 in_quote = !in_quote;
365 }
366 '\\' => {
367 if let Some(next) = chars.next() {
368 out.push(next);
369 }
370 }
371 ' ' if !in_quote => {}
372 _ => out.push(ch),
373 }
374 }
375 out
376}
377
378fn parse_caa(content: &str) -> crate::Result<DnsRecord> {
379 let trimmed = content.trim();
380 let (flags_str, rest) = trimmed
381 .split_once(char::is_whitespace)
382 .ok_or_else(|| Error::Parse(format!("invalid CAA content: {content}")))?;
383 let (tag, value_part) = rest
384 .trim_start()
385 .split_once(char::is_whitespace)
386 .ok_or_else(|| Error::Parse(format!("invalid CAA content: {content}")))?;
387 let flags: u8 = flags_str
388 .parse()
389 .map_err(|e| Error::Parse(format!("invalid CAA flags {flags_str}: {e}")))?;
390 let value = value_part
391 .trim()
392 .trim_start_matches('"')
393 .trim_end_matches('"')
394 .to_string();
395 let issuer_critical = flags & 0x80 != 0;
396 match tag.trim() {
397 "issue" => {
398 let (name, options) = parse_caa_value(&value);
399 Ok(DnsRecord::CAA(CAARecord::Issue {
400 issuer_critical,
401 name,
402 options,
403 }))
404 }
405 "issuewild" => {
406 let (name, options) = parse_caa_value(&value);
407 Ok(DnsRecord::CAA(CAARecord::IssueWild {
408 issuer_critical,
409 name,
410 options,
411 }))
412 }
413 "iodef" => Ok(DnsRecord::CAA(CAARecord::Iodef {
414 issuer_critical,
415 url: value,
416 })),
417 other => Err(Error::Parse(format!("unknown CAA tag: {other}"))),
418 }
419}
420
421fn parse_caa_value(value: &str) -> (Option<String>, Vec<KeyValue>) {
422 let mut parts = value.split(';').map(str::trim);
423 let name_part = parts.next().unwrap_or("").trim().to_string();
424 let name = if name_part.is_empty() {
425 None
426 } else {
427 Some(name_part)
428 };
429 let options = parts
430 .filter(|p| !p.is_empty())
431 .map(|p| match p.split_once('=') {
432 Some((k, v)) => KeyValue {
433 key: k.trim().to_string(),
434 value: v.trim().to_string(),
435 },
436 None => KeyValue {
437 key: p.trim().to_string(),
438 value: String::new(),
439 },
440 })
441 .collect();
442 (name, options)
443}
444
445impl TryFrom<DnsRecord> for SafeDnsRecordContent {
446 type Error = Error;
447
448 fn try_from(record: DnsRecord) -> Result<Self, Self::Error> {
449 match record {
450 DnsRecord::A(addr) => Ok(SafeDnsRecordContent {
451 record_type: "A",
452 content: addr.to_string(),
453 priority: None,
454 }),
455 DnsRecord::AAAA(addr) => Ok(SafeDnsRecordContent {
456 record_type: "AAAA",
457 content: addr.to_string(),
458 priority: None,
459 }),
460 DnsRecord::CNAME(target) => Ok(SafeDnsRecordContent {
461 record_type: "CNAME",
462 content: target,
463 priority: None,
464 }),
465 DnsRecord::NS(target) => Ok(SafeDnsRecordContent {
466 record_type: "NS",
467 content: target,
468 priority: None,
469 }),
470 DnsRecord::MX(mx) => Ok(SafeDnsRecordContent {
471 record_type: "MX",
472 content: mx.exchange,
473 priority: Some(mx.priority),
474 }),
475 DnsRecord::TXT(text) => {
476 let mut buf = String::new();
477 txt_chunks_to_text(&mut buf, &text, " ");
478 Ok(SafeDnsRecordContent {
479 record_type: "TXT",
480 content: buf,
481 priority: None,
482 })
483 }
484 DnsRecord::SRV(srv) => Ok(SafeDnsRecordContent {
485 record_type: "SRV",
486 content: format!("{} {} {}", srv.weight, srv.port, srv.target),
487 priority: Some(srv.priority),
488 }),
489 DnsRecord::TLSA(_) => Err(Error::Unsupported(
490 "TLSA records are not supported by SafeDNS".to_string(),
491 )),
492 DnsRecord::CAA(caa) => {
493 let (flags, tag, value) = caa.decompose();
494 Ok(SafeDnsRecordContent {
495 record_type: "CAA",
496 content: format!("{flags} {tag} \"{value}\""),
497 priority: None,
498 })
499 }
500 }
501 }
502}