1use crate::http::{HttpClient, HttpClientBuilder};
13use crate::utils::{strip_origin_from_name, txt_chunks};
14use crate::{
15 CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, Result, SRVRecord,
16};
17use serde::Deserialize;
18use serde_json::{Value, json};
19use std::net::{Ipv4Addr, Ipv6Addr};
20use std::str::FromStr;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24#[derive(Debug, Clone)]
25pub struct AzureDnsConfig {
26 pub tenant_id: String,
27 pub client_id: String,
28 pub client_secret: String,
29 pub subscription_id: String,
30 pub resource_group: String,
31 pub environment: AzureEnvironment,
32 pub request_timeout: Option<Duration>,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum AzureEnvironment {
37 Public,
38 China,
39 UsGovernment,
40}
41
42impl AzureEnvironment {
43 pub fn from_str_lossy(value: &str) -> Self {
44 match value.to_ascii_lowercase().as_str() {
45 "china" => AzureEnvironment::China,
46 "usgovernment" => AzureEnvironment::UsGovernment,
47 _ => AzureEnvironment::Public,
48 }
49 }
50
51 fn login_host(self) -> &'static str {
52 match self {
53 AzureEnvironment::Public => "https://login.microsoftonline.com",
54 AzureEnvironment::China => "https://login.chinacloudapi.cn",
55 AzureEnvironment::UsGovernment => "https://login.microsoftonline.us",
56 }
57 }
58
59 fn management_host(self) -> &'static str {
60 match self {
61 AzureEnvironment::Public => "https://management.azure.com",
62 AzureEnvironment::China => "https://management.chinacloudapi.cn",
63 AzureEnvironment::UsGovernment => "https://management.usgovcloudapi.net",
64 }
65 }
66
67 fn scope(self) -> &'static str {
68 match self {
69 AzureEnvironment::Public => "https://management.azure.com/.default",
70 AzureEnvironment::China => "https://management.chinacloudapi.cn/.default",
71 AzureEnvironment::UsGovernment => "https://management.usgovcloudapi.net/.default",
72 }
73 }
74}
75
76#[derive(Clone)]
77pub struct AzureDnsProvider {
78 client: HttpClient,
79 config: AzureDnsConfig,
80 token: Arc<Mutex<Option<(String, Instant)>>>,
81 endpoints: AzureEndpoints,
82}
83
84#[derive(Clone)]
85struct AzureEndpoints {
86 login_url: String,
87 management_url: String,
88}
89
90const API_VERSION: &str = "2018-05-01";
91
92impl AzureDnsProvider {
93 pub fn new(config: AzureDnsConfig) -> Result<Self> {
94 let client = HttpClientBuilder::default()
95 .with_timeout(config.request_timeout)
96 .build();
97
98 let endpoints = AzureEndpoints {
99 login_url: config.environment.login_host().to_string(),
100 management_url: config.environment.management_host().to_string(),
101 };
102
103 Ok(Self {
104 client,
105 config,
106 token: Arc::new(Mutex::new(None)),
107 endpoints,
108 })
109 }
110
111 #[cfg(test)]
112 pub(crate) fn with_endpoints(
113 mut self,
114 login_url: impl AsRef<str>,
115 management_url: impl AsRef<str>,
116 ) -> Self {
117 self.endpoints = AzureEndpoints {
118 login_url: login_url.as_ref().trim_end_matches('/').to_string(),
119 management_url: management_url.as_ref().trim_end_matches('/').to_string(),
120 };
121 self
122 }
123
124 #[cfg(test)]
125 pub(crate) fn with_cached_token(self, token: impl Into<String>) -> Self {
126 *self.token.lock().expect("test token lock") =
127 Some((token.into(), Instant::now() + Duration::from_secs(55 * 60)));
128 self
129 }
130
131 async fn ensure_token(&self) -> Result<String> {
132 if let Some((ref token, expiry)) = *self.token_lock()?
133 && Instant::now() < expiry
134 {
135 return Ok(token.clone());
136 }
137
138 let url = format!(
139 "{}/{}/oauth2/v2.0/token",
140 self.endpoints.login_url, self.config.tenant_id
141 );
142 let form = serde_urlencoded::to_string([
143 ("grant_type", "client_credentials"),
144 ("client_id", self.config.client_id.as_str()),
145 ("client_secret", self.config.client_secret.as_str()),
146 ("scope", self.config.environment.scope()),
147 ])
148 .map_err(|e| Error::Api(format!("Failed to encode token request: {e}")))?;
149
150 let token_response: AzureTokenResponse = self
151 .client
152 .post(&url)
153 .with_header("content-type", "application/x-www-form-urlencoded")
154 .with_raw_body(form)
155 .send_with_retry(3)
156 .await?;
157
158 if token_response.access_token.is_empty() {
159 return Err(Error::Api(
160 "Azure token response missing access_token".into(),
161 ));
162 }
163
164 let lifetime = token_response
165 .expires_in
166 .unwrap_or(3600)
167 .saturating_sub(60)
168 .max(60);
169 let expiry = Instant::now() + Duration::from_secs(lifetime);
170 *self.token_lock()? = Some((token_response.access_token.clone(), expiry));
171 Ok(token_response.access_token)
172 }
173
174 pub(crate) async fn set_rrset(
175 &self,
176 name: impl IntoFqdn<'_>,
177 record_type: DnsRecordType,
178 ttl: u32,
179 records: Vec<DnsRecord>,
180 origin: impl IntoFqdn<'_>,
181 ) -> Result<()> {
182 check_record_types(record_type, &records)?;
183 check_cname_singleton(record_type, &records)?;
184 let zone = origin.into_name().to_ascii_lowercase();
185 let fqdn = name.into_name().to_ascii_lowercase();
186 let relative = relative_record_name(&fqdn, &zone);
187 let type_segment = azure_record_type(&record_type)?;
188 let url = self.record_url(&zone, type_segment, &relative);
189 let token = self.ensure_token().await?;
190
191 if records.is_empty() {
192 return self.delete_rrset_url(&url, &token, None).await;
193 }
194
195 self.put_rrset(&url, &token, ttl, record_type, &records, None)
196 .await
197 }
198
199 pub(crate) async fn add_to_rrset(
200 &self,
201 name: impl IntoFqdn<'_>,
202 record_type: DnsRecordType,
203 ttl: u32,
204 records: Vec<DnsRecord>,
205 origin: impl IntoFqdn<'_>,
206 ) -> Result<()> {
207 check_record_types(record_type, &records)?;
208 if records.is_empty() {
209 return Ok(());
210 }
211 check_cname_singleton(record_type, &records)?;
212 let zone = origin.into_name().to_ascii_lowercase();
213 let fqdn = name.into_name().to_ascii_lowercase();
214 let relative = relative_record_name(&fqdn, &zone);
215 let type_segment = azure_record_type(&record_type)?;
216 let url = self.record_url(&zone, type_segment, &relative);
217 let token = self.ensure_token().await?;
218
219 let fetched = self.fetch_rrset(&url, &token).await?;
220 let mut merged = fetched.records;
221 for record in records {
222 if !merged.iter().any(|r| r == &record) {
223 merged.push(record);
224 }
225 }
226 check_cname_singleton(record_type, &merged)?;
227 self.put_rrset(
228 &url,
229 &token,
230 ttl,
231 record_type,
232 &merged,
233 fetched.etag.as_deref(),
234 )
235 .await
236 }
237
238 pub(crate) async fn remove_from_rrset(
239 &self,
240 name: impl IntoFqdn<'_>,
241 record_type: DnsRecordType,
242 records: Vec<DnsRecord>,
243 origin: impl IntoFqdn<'_>,
244 ) -> Result<()> {
245 check_record_types(record_type, &records)?;
246 if records.is_empty() {
247 return Ok(());
248 }
249 let zone = origin.into_name().to_ascii_lowercase();
250 let fqdn = name.into_name().to_ascii_lowercase();
251 let relative = relative_record_name(&fqdn, &zone);
252 let type_segment = azure_record_type(&record_type)?;
253 let url = self.record_url(&zone, type_segment, &relative);
254 let token = self.ensure_token().await?;
255
256 let fetched = match self.fetch_rrset_optional(&url, &token).await? {
257 Some(fetched) => fetched,
258 None => return Ok(()),
259 };
260
261 let remaining: Vec<DnsRecord> = fetched
262 .records
263 .into_iter()
264 .filter(|r| !records.contains(r))
265 .collect();
266
267 if remaining.is_empty() {
268 return self
269 .delete_rrset_url(&url, &token, fetched.etag.as_deref())
270 .await;
271 }
272
273 let ttl = fetched.ttl.unwrap_or(0);
274 self.put_rrset(
275 &url,
276 &token,
277 ttl,
278 record_type,
279 &remaining,
280 fetched.etag.as_deref(),
281 )
282 .await
283 }
284
285 pub(crate) async fn list_rrset(
286 &self,
287 name: impl IntoFqdn<'_>,
288 record_type: DnsRecordType,
289 origin: impl IntoFqdn<'_>,
290 ) -> Result<Vec<DnsRecord>> {
291 let zone = origin.into_name().to_ascii_lowercase();
292 let fqdn = name.into_name().to_ascii_lowercase();
293 let relative = relative_record_name(&fqdn, &zone);
294 let type_segment = azure_record_type(&record_type)?;
295 let url = self.record_url(&zone, type_segment, &relative);
296 let token = self.ensure_token().await?;
297
298 match self.fetch_rrset_optional(&url, &token).await? {
299 Some(fetched) => Ok(fetched.records),
300 None => Ok(Vec::new()),
301 }
302 }
303
304 async fn put_rrset(
305 &self,
306 url: &str,
307 token: &str,
308 ttl: u32,
309 record_type: DnsRecordType,
310 records: &[DnsRecord],
311 if_match: Option<&str>,
312 ) -> Result<()> {
313 let mut properties = serde_json::Map::new();
314 properties.insert("TTL".to_string(), json!(ttl));
315 insert_rrset_payload(&mut properties, record_type, records)?;
316
317 let mut body = serde_json::Map::new();
318 body.insert("properties".to_string(), Value::Object(properties));
319
320 let mut request = self
321 .client
322 .put(url)
323 .with_header("authorization", format!("Bearer {token}"))
324 .with_body(&body)?;
325 if let Some(etag) = if_match {
326 request = request.with_header("if-match", etag);
327 }
328 request.send_with_retry::<Value>(3).await.map(|_| ())
329 }
330
331 async fn delete_rrset_url(&self, url: &str, token: &str, if_match: Option<&str>) -> Result<()> {
332 let mut request = self
333 .client
334 .delete(url)
335 .with_header("authorization", format!("Bearer {token}"));
336 if let Some(etag) = if_match {
337 request = request.with_header("if-match", etag);
338 }
339 request
340 .send_with_retry::<Value>(3)
341 .await
342 .map(|_| ())
343 .or_else(|err| match err {
344 Error::NotFound => Ok(()),
345 err => Err(err),
346 })
347 }
348
349 async fn fetch_rrset(&self, url: &str, token: &str) -> Result<FetchedRrset> {
350 match self.fetch_rrset_optional(url, token).await? {
351 Some(fetched) => Ok(fetched),
352 None => Ok(FetchedRrset::default()),
353 }
354 }
355
356 async fn fetch_rrset_optional(&self, url: &str, token: &str) -> Result<Option<FetchedRrset>> {
357 let value: Value = match self
358 .client
359 .get(url)
360 .with_header("authorization", format!("Bearer {token}"))
361 .send_with_retry(3)
362 .await
363 {
364 Ok(v) => v,
365 Err(Error::NotFound) => return Ok(None),
366 Err(err) => return Err(err),
367 };
368
369 let etag = value
370 .get("etag")
371 .and_then(Value::as_str)
372 .map(str::to_string);
373
374 let ttl = value
375 .get("properties")
376 .and_then(|p| p.get("TTL"))
377 .and_then(Value::as_u64)
378 .map(|v| v as u32);
379
380 let records = parse_rrset_records(&value)?;
381 Ok(Some(FetchedRrset { records, etag, ttl }))
382 }
383
384 fn record_url(&self, zone: &str, type_segment: &str, relative: &str) -> String {
385 format!(
386 "{}/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Network/dnsZones/{}/{}/{}?api-version={}",
387 self.endpoints.management_url,
388 self.config.subscription_id,
389 self.config.resource_group,
390 zone,
391 type_segment,
392 relative,
393 API_VERSION,
394 )
395 }
396
397 fn token_lock(&self) -> Result<std::sync::MutexGuard<'_, Option<(String, Instant)>>> {
398 self.token
399 .lock()
400 .map_err(|_| Error::Client("Azure DNS token cache lock poisoned".into()))
401 }
402}
403
404fn relative_record_name(fqdn: &str, zone: &str) -> String {
405 let stripped = strip_origin_from_name(fqdn, zone, Some("@"));
406 if stripped.is_empty() {
407 "@".to_string()
408 } else {
409 stripped
410 }
411}
412
413fn azure_record_type(rt: &DnsRecordType) -> Result<&'static str> {
414 Ok(match rt {
415 DnsRecordType::A => "A",
416 DnsRecordType::AAAA => "AAAA",
417 DnsRecordType::CNAME => "CNAME",
418 DnsRecordType::MX => "MX",
419 DnsRecordType::NS => "NS",
420 DnsRecordType::TXT => "TXT",
421 DnsRecordType::SRV => "SRV",
422 DnsRecordType::CAA => "CAA",
423 DnsRecordType::TLSA => {
424 return Err(Error::Unsupported(
425 "TLSA records are not supported by Azure DNS".to_string(),
426 ));
427 }
428 })
429}
430
431#[derive(Default)]
432struct FetchedRrset {
433 records: Vec<DnsRecord>,
434 etag: Option<String>,
435 ttl: Option<u32>,
436}
437
438fn check_record_types(expected: DnsRecordType, records: &[DnsRecord]) -> Result<()> {
439 azure_record_type(&expected)?;
440 for r in records {
441 if r.as_type() != expected {
442 return Err(Error::Api(format!(
443 "RRSet record type mismatch: expected {}, got {}",
444 expected.as_str(),
445 r.as_type().as_str(),
446 )));
447 }
448 }
449 Ok(())
450}
451
452fn check_cname_singleton(record_type: DnsRecordType, records: &[DnsRecord]) -> Result<()> {
453 if record_type == DnsRecordType::CNAME && records.len() > 1 {
454 return Err(Error::Api(
455 "CNAME RRSet may contain at most one record".to_string(),
456 ));
457 }
458 Ok(())
459}
460
461fn insert_rrset_payload(
462 props: &mut serde_json::Map<String, Value>,
463 record_type: DnsRecordType,
464 records: &[DnsRecord],
465) -> Result<()> {
466 match record_type {
467 DnsRecordType::A => {
468 let arr: Vec<Value> = records
469 .iter()
470 .map(|r| match r {
471 DnsRecord::A(ip) => json!({"ipv4Address": ip.to_string()}),
472 _ => unreachable!(),
473 })
474 .collect();
475 props.insert("ARecords".to_string(), Value::Array(arr));
476 }
477 DnsRecordType::AAAA => {
478 let arr: Vec<Value> = records
479 .iter()
480 .map(|r| match r {
481 DnsRecord::AAAA(ip) => json!({"ipv6Address": ip.to_string()}),
482 _ => unreachable!(),
483 })
484 .collect();
485 props.insert("AAAARecords".to_string(), Value::Array(arr));
486 }
487 DnsRecordType::CNAME => {
488 if let Some(DnsRecord::CNAME(target)) = records.first() {
489 props.insert(
490 "CNAMERecord".to_string(),
491 json!({"cname": target.trim_end_matches('.')}),
492 );
493 }
494 }
495 DnsRecordType::NS => {
496 let arr: Vec<Value> = records
497 .iter()
498 .map(|r| match r {
499 DnsRecord::NS(target) => {
500 json!({"nsdname": target.trim_end_matches('.')})
501 }
502 _ => unreachable!(),
503 })
504 .collect();
505 props.insert("NSRecords".to_string(), Value::Array(arr));
506 }
507 DnsRecordType::MX => {
508 let arr: Vec<Value> = records
509 .iter()
510 .map(|r| match r {
511 DnsRecord::MX(mx) => json!({
512 "preference": mx.priority,
513 "exchange": mx.exchange.trim_end_matches('.'),
514 }),
515 _ => unreachable!(),
516 })
517 .collect();
518 props.insert("MXRecords".to_string(), Value::Array(arr));
519 }
520 DnsRecordType::TXT => {
521 let arr: Vec<Value> = records
522 .iter()
523 .map(|r| match r {
524 DnsRecord::TXT(text) => json!({"value": txt_chunks(text.clone())}),
525 _ => unreachable!(),
526 })
527 .collect();
528 props.insert("TXTRecords".to_string(), Value::Array(arr));
529 }
530 DnsRecordType::SRV => {
531 let arr: Vec<Value> = records
532 .iter()
533 .map(|r| match r {
534 DnsRecord::SRV(srv) => json!({
535 "priority": srv.priority,
536 "weight": srv.weight,
537 "port": srv.port,
538 "target": srv.target.trim_end_matches('.'),
539 }),
540 _ => unreachable!(),
541 })
542 .collect();
543 props.insert("SRVRecords".to_string(), Value::Array(arr));
544 }
545 DnsRecordType::CAA => {
546 let arr: Vec<Value> = records
547 .iter()
548 .map(|r| match r {
549 DnsRecord::CAA(caa) => {
550 let (flags, tag, value) = caa.clone().decompose();
551 json!({"flags": flags, "tag": tag, "value": value})
552 }
553 _ => unreachable!(),
554 })
555 .collect();
556 props.insert("caaRecords".to_string(), Value::Array(arr));
557 }
558 DnsRecordType::TLSA => {
559 return Err(Error::Unsupported(
560 "TLSA records are not supported by Azure DNS".to_string(),
561 ));
562 }
563 }
564 Ok(())
565}
566
567fn parse_rrset_records(value: &Value) -> Result<Vec<DnsRecord>> {
568 let props = match value.get("properties") {
569 Some(p) => p,
570 None => return Ok(Vec::new()),
571 };
572
573 let mut out = Vec::new();
574
575 if let Some(arr) = props.get("ARecords").and_then(Value::as_array) {
576 for entry in arr {
577 if let Some(addr) = entry.get("ipv4Address").and_then(Value::as_str)
578 && let Ok(ip) = Ipv4Addr::from_str(addr)
579 {
580 out.push(DnsRecord::A(ip));
581 }
582 }
583 }
584 if let Some(arr) = props.get("AAAARecords").and_then(Value::as_array) {
585 for entry in arr {
586 if let Some(addr) = entry.get("ipv6Address").and_then(Value::as_str)
587 && let Ok(ip) = Ipv6Addr::from_str(addr)
588 {
589 out.push(DnsRecord::AAAA(ip));
590 }
591 }
592 }
593 if let Some(obj) = props.get("CNAMERecord")
594 && let Some(target) = obj.get("cname").and_then(Value::as_str)
595 {
596 out.push(DnsRecord::CNAME(target.to_string()));
597 }
598 if let Some(arr) = props.get("NSRecords").and_then(Value::as_array) {
599 for entry in arr {
600 if let Some(target) = entry.get("nsdname").and_then(Value::as_str) {
601 out.push(DnsRecord::NS(target.to_string()));
602 }
603 }
604 }
605 if let Some(arr) = props.get("MXRecords").and_then(Value::as_array) {
606 for entry in arr {
607 let priority = entry.get("preference").and_then(Value::as_u64).unwrap_or(0) as u16;
608 if let Some(exchange) = entry.get("exchange").and_then(Value::as_str) {
609 out.push(DnsRecord::MX(MXRecord {
610 priority,
611 exchange: exchange.to_string(),
612 }));
613 }
614 }
615 }
616 if let Some(arr) = props.get("TXTRecords").and_then(Value::as_array) {
617 for entry in arr {
618 if let Some(values) = entry.get("value").and_then(Value::as_array) {
619 let joined: String = values
620 .iter()
621 .filter_map(Value::as_str)
622 .collect::<Vec<_>>()
623 .concat();
624 out.push(DnsRecord::TXT(joined));
625 }
626 }
627 }
628 if let Some(arr) = props.get("SRVRecords").and_then(Value::as_array) {
629 for entry in arr {
630 let priority = entry.get("priority").and_then(Value::as_u64).unwrap_or(0) as u16;
631 let weight = entry.get("weight").and_then(Value::as_u64).unwrap_or(0) as u16;
632 let port = entry.get("port").and_then(Value::as_u64).unwrap_or(0) as u16;
633 if let Some(target) = entry.get("target").and_then(Value::as_str) {
634 out.push(DnsRecord::SRV(SRVRecord {
635 priority,
636 weight,
637 port,
638 target: target.to_string(),
639 }));
640 }
641 }
642 }
643 if let Some(arr) = props.get("caaRecords").and_then(Value::as_array) {
644 for entry in arr {
645 let flags = entry.get("flags").and_then(Value::as_u64).unwrap_or(0) as u8;
646 let tag = entry.get("tag").and_then(Value::as_str).unwrap_or("");
647 let value = entry
648 .get("value")
649 .and_then(Value::as_str)
650 .unwrap_or("")
651 .to_string();
652 let issuer_critical = flags & 0x80 != 0;
653 let caa = match tag.to_ascii_lowercase().as_str() {
654 "issue" => CAARecord::Issue {
655 issuer_critical,
656 name: if value.is_empty() { None } else { Some(value) },
657 options: Vec::<KeyValue>::new(),
658 },
659 "issuewild" => CAARecord::IssueWild {
660 issuer_critical,
661 name: if value.is_empty() { None } else { Some(value) },
662 options: Vec::<KeyValue>::new(),
663 },
664 "iodef" => CAARecord::Iodef {
665 issuer_critical,
666 url: value,
667 },
668 _ => continue,
669 };
670 out.push(DnsRecord::CAA(caa));
671 }
672 }
673
674 Ok(out)
675}
676
677#[derive(Deserialize)]
678struct AzureTokenResponse {
679 access_token: String,
680 #[serde(default)]
681 expires_in: Option<u64>,
682}