Skip to main content

dns_update/providers/
azuredns.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use crate::http::{HttpClient, HttpClientBuilder};
13use crate::utils::{strip_origin_from_name, txt_chunks};
14use crate::{
15    CAARecord, DnsRecord, DnsRecordType, Error, IntoFqdn, KeyValue, MXRecord, Result, SRVRecord,
16};
17use serde::Deserialize;
18use serde_json::{Value, json};
19use std::net::{Ipv4Addr, Ipv6Addr};
20use std::str::FromStr;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24#[derive(Debug, Clone)]
25pub struct AzureDnsConfig {
26    pub tenant_id: String,
27    pub client_id: String,
28    pub client_secret: String,
29    pub subscription_id: String,
30    pub resource_group: String,
31    pub environment: AzureEnvironment,
32    pub request_timeout: Option<Duration>,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum AzureEnvironment {
37    Public,
38    China,
39    UsGovernment,
40}
41
42impl AzureEnvironment {
43    pub fn from_str_lossy(value: &str) -> Self {
44        match value.to_ascii_lowercase().as_str() {
45            "china" => AzureEnvironment::China,
46            "usgovernment" => AzureEnvironment::UsGovernment,
47            _ => AzureEnvironment::Public,
48        }
49    }
50
51    fn login_host(self) -> &'static str {
52        match self {
53            AzureEnvironment::Public => "https://login.microsoftonline.com",
54            AzureEnvironment::China => "https://login.chinacloudapi.cn",
55            AzureEnvironment::UsGovernment => "https://login.microsoftonline.us",
56        }
57    }
58
59    fn management_host(self) -> &'static str {
60        match self {
61            AzureEnvironment::Public => "https://management.azure.com",
62            AzureEnvironment::China => "https://management.chinacloudapi.cn",
63            AzureEnvironment::UsGovernment => "https://management.usgovcloudapi.net",
64        }
65    }
66
67    fn scope(self) -> &'static str {
68        match self {
69            AzureEnvironment::Public => "https://management.azure.com/.default",
70            AzureEnvironment::China => "https://management.chinacloudapi.cn/.default",
71            AzureEnvironment::UsGovernment => "https://management.usgovcloudapi.net/.default",
72        }
73    }
74}
75
76#[derive(Clone)]
77pub struct AzureDnsProvider {
78    client: HttpClient,
79    config: AzureDnsConfig,
80    token: Arc<Mutex<Option<(String, Instant)>>>,
81    endpoints: AzureEndpoints,
82}
83
84#[derive(Clone)]
85struct AzureEndpoints {
86    login_url: String,
87    management_url: String,
88}
89
90const API_VERSION: &str = "2018-05-01";
91
92impl AzureDnsProvider {
93    pub fn new(config: AzureDnsConfig) -> Result<Self> {
94        let client = HttpClientBuilder::default()
95            .with_timeout(config.request_timeout)
96            .build();
97
98        let endpoints = AzureEndpoints {
99            login_url: config.environment.login_host().to_string(),
100            management_url: config.environment.management_host().to_string(),
101        };
102
103        Ok(Self {
104            client,
105            config,
106            token: Arc::new(Mutex::new(None)),
107            endpoints,
108        })
109    }
110
111    #[cfg(test)]
112    pub(crate) fn with_endpoints(
113        mut self,
114        login_url: impl AsRef<str>,
115        management_url: impl AsRef<str>,
116    ) -> Self {
117        self.endpoints = AzureEndpoints {
118            login_url: login_url.as_ref().trim_end_matches('/').to_string(),
119            management_url: management_url.as_ref().trim_end_matches('/').to_string(),
120        };
121        self
122    }
123
124    #[cfg(test)]
125    pub(crate) fn with_cached_token(self, token: impl Into<String>) -> Self {
126        *self.token.lock().expect("test token lock") =
127            Some((token.into(), Instant::now() + Duration::from_secs(55 * 60)));
128        self
129    }
130
131    async fn ensure_token(&self) -> Result<String> {
132        if let Some((ref token, expiry)) = *self.token_lock()?
133            && Instant::now() < expiry
134        {
135            return Ok(token.clone());
136        }
137
138        let url = format!(
139            "{}/{}/oauth2/v2.0/token",
140            self.endpoints.login_url, self.config.tenant_id
141        );
142        let form = serde_urlencoded::to_string([
143            ("grant_type", "client_credentials"),
144            ("client_id", self.config.client_id.as_str()),
145            ("client_secret", self.config.client_secret.as_str()),
146            ("scope", self.config.environment.scope()),
147        ])
148        .map_err(|e| Error::Api(format!("Failed to encode token request: {e}")))?;
149
150        let token_response: AzureTokenResponse = self
151            .client
152            .post(&url)
153            .with_header("content-type", "application/x-www-form-urlencoded")
154            .with_raw_body(form)
155            .send_with_retry(3)
156            .await?;
157
158        if token_response.access_token.is_empty() {
159            return Err(Error::Api(
160                "Azure token response missing access_token".into(),
161            ));
162        }
163
164        let lifetime = token_response
165            .expires_in
166            .unwrap_or(3600)
167            .saturating_sub(60)
168            .max(60);
169        let expiry = Instant::now() + Duration::from_secs(lifetime);
170        *self.token_lock()? = Some((token_response.access_token.clone(), expiry));
171        Ok(token_response.access_token)
172    }
173
174    pub(crate) async fn set_rrset(
175        &self,
176        name: impl IntoFqdn<'_>,
177        record_type: DnsRecordType,
178        ttl: u32,
179        records: Vec<DnsRecord>,
180        origin: impl IntoFqdn<'_>,
181    ) -> Result<()> {
182        check_record_types(record_type, &records)?;
183        check_cname_singleton(record_type, &records)?;
184        let zone = origin.into_name().to_ascii_lowercase();
185        let fqdn = name.into_name().to_ascii_lowercase();
186        let relative = relative_record_name(&fqdn, &zone);
187        let type_segment = azure_record_type(&record_type)?;
188        let url = self.record_url(&zone, type_segment, &relative);
189        let token = self.ensure_token().await?;
190
191        if records.is_empty() {
192            return self.delete_rrset_url(&url, &token, None).await;
193        }
194
195        self.put_rrset(&url, &token, ttl, record_type, &records, None)
196            .await
197    }
198
199    pub(crate) async fn add_to_rrset(
200        &self,
201        name: impl IntoFqdn<'_>,
202        record_type: DnsRecordType,
203        ttl: u32,
204        records: Vec<DnsRecord>,
205        origin: impl IntoFqdn<'_>,
206    ) -> Result<()> {
207        check_record_types(record_type, &records)?;
208        if records.is_empty() {
209            return Ok(());
210        }
211        check_cname_singleton(record_type, &records)?;
212        let zone = origin.into_name().to_ascii_lowercase();
213        let fqdn = name.into_name().to_ascii_lowercase();
214        let relative = relative_record_name(&fqdn, &zone);
215        let type_segment = azure_record_type(&record_type)?;
216        let url = self.record_url(&zone, type_segment, &relative);
217        let token = self.ensure_token().await?;
218
219        let fetched = self.fetch_rrset(&url, &token).await?;
220        let mut merged = fetched.records;
221        for record in records {
222            if !merged.iter().any(|r| r == &record) {
223                merged.push(record);
224            }
225        }
226        check_cname_singleton(record_type, &merged)?;
227        self.put_rrset(
228            &url,
229            &token,
230            ttl,
231            record_type,
232            &merged,
233            fetched.etag.as_deref(),
234        )
235        .await
236    }
237
238    pub(crate) async fn remove_from_rrset(
239        &self,
240        name: impl IntoFqdn<'_>,
241        record_type: DnsRecordType,
242        records: Vec<DnsRecord>,
243        origin: impl IntoFqdn<'_>,
244    ) -> Result<()> {
245        check_record_types(record_type, &records)?;
246        if records.is_empty() {
247            return Ok(());
248        }
249        let zone = origin.into_name().to_ascii_lowercase();
250        let fqdn = name.into_name().to_ascii_lowercase();
251        let relative = relative_record_name(&fqdn, &zone);
252        let type_segment = azure_record_type(&record_type)?;
253        let url = self.record_url(&zone, type_segment, &relative);
254        let token = self.ensure_token().await?;
255
256        let fetched = match self.fetch_rrset_optional(&url, &token).await? {
257            Some(fetched) => fetched,
258            None => return Ok(()),
259        };
260
261        let remaining: Vec<DnsRecord> = fetched
262            .records
263            .into_iter()
264            .filter(|r| !records.contains(r))
265            .collect();
266
267        if remaining.is_empty() {
268            return self
269                .delete_rrset_url(&url, &token, fetched.etag.as_deref())
270                .await;
271        }
272
273        let ttl = fetched.ttl.unwrap_or(0);
274        self.put_rrset(
275            &url,
276            &token,
277            ttl,
278            record_type,
279            &remaining,
280            fetched.etag.as_deref(),
281        )
282        .await
283    }
284
285    pub(crate) async fn list_rrset(
286        &self,
287        name: impl IntoFqdn<'_>,
288        record_type: DnsRecordType,
289        origin: impl IntoFqdn<'_>,
290    ) -> Result<Vec<DnsRecord>> {
291        let zone = origin.into_name().to_ascii_lowercase();
292        let fqdn = name.into_name().to_ascii_lowercase();
293        let relative = relative_record_name(&fqdn, &zone);
294        let type_segment = azure_record_type(&record_type)?;
295        let url = self.record_url(&zone, type_segment, &relative);
296        let token = self.ensure_token().await?;
297
298        match self.fetch_rrset_optional(&url, &token).await? {
299            Some(fetched) => Ok(fetched.records),
300            None => Ok(Vec::new()),
301        }
302    }
303
304    async fn put_rrset(
305        &self,
306        url: &str,
307        token: &str,
308        ttl: u32,
309        record_type: DnsRecordType,
310        records: &[DnsRecord],
311        if_match: Option<&str>,
312    ) -> Result<()> {
313        let mut properties = serde_json::Map::new();
314        properties.insert("TTL".to_string(), json!(ttl));
315        insert_rrset_payload(&mut properties, record_type, records)?;
316
317        let mut body = serde_json::Map::new();
318        body.insert("properties".to_string(), Value::Object(properties));
319
320        let mut request = self
321            .client
322            .put(url)
323            .with_header("authorization", format!("Bearer {token}"))
324            .with_body(&body)?;
325        if let Some(etag) = if_match {
326            request = request.with_header("if-match", etag);
327        }
328        request.send_with_retry::<Value>(3).await.map(|_| ())
329    }
330
331    async fn delete_rrset_url(&self, url: &str, token: &str, if_match: Option<&str>) -> Result<()> {
332        let mut request = self
333            .client
334            .delete(url)
335            .with_header("authorization", format!("Bearer {token}"));
336        if let Some(etag) = if_match {
337            request = request.with_header("if-match", etag);
338        }
339        request
340            .send_with_retry::<Value>(3)
341            .await
342            .map(|_| ())
343            .or_else(|err| match err {
344                Error::NotFound => Ok(()),
345                err => Err(err),
346            })
347    }
348
349    async fn fetch_rrset(&self, url: &str, token: &str) -> Result<FetchedRrset> {
350        match self.fetch_rrset_optional(url, token).await? {
351            Some(fetched) => Ok(fetched),
352            None => Ok(FetchedRrset::default()),
353        }
354    }
355
356    async fn fetch_rrset_optional(&self, url: &str, token: &str) -> Result<Option<FetchedRrset>> {
357        let value: Value = match self
358            .client
359            .get(url)
360            .with_header("authorization", format!("Bearer {token}"))
361            .send_with_retry(3)
362            .await
363        {
364            Ok(v) => v,
365            Err(Error::NotFound) => return Ok(None),
366            Err(err) => return Err(err),
367        };
368
369        let etag = value
370            .get("etag")
371            .and_then(Value::as_str)
372            .map(str::to_string);
373
374        let ttl = value
375            .get("properties")
376            .and_then(|p| p.get("TTL"))
377            .and_then(Value::as_u64)
378            .map(|v| v as u32);
379
380        let records = parse_rrset_records(&value)?;
381        Ok(Some(FetchedRrset { records, etag, ttl }))
382    }
383
384    fn record_url(&self, zone: &str, type_segment: &str, relative: &str) -> String {
385        format!(
386            "{}/subscriptions/{}/resourceGroups/{}/providers/Microsoft.Network/dnsZones/{}/{}/{}?api-version={}",
387            self.endpoints.management_url,
388            self.config.subscription_id,
389            self.config.resource_group,
390            zone,
391            type_segment,
392            relative,
393            API_VERSION,
394        )
395    }
396
397    fn token_lock(&self) -> Result<std::sync::MutexGuard<'_, Option<(String, Instant)>>> {
398        self.token
399            .lock()
400            .map_err(|_| Error::Client("Azure DNS token cache lock poisoned".into()))
401    }
402}
403
404fn relative_record_name(fqdn: &str, zone: &str) -> String {
405    let stripped = strip_origin_from_name(fqdn, zone, Some("@"));
406    if stripped.is_empty() {
407        "@".to_string()
408    } else {
409        stripped
410    }
411}
412
413fn azure_record_type(rt: &DnsRecordType) -> Result<&'static str> {
414    Ok(match rt {
415        DnsRecordType::A => "A",
416        DnsRecordType::AAAA => "AAAA",
417        DnsRecordType::CNAME => "CNAME",
418        DnsRecordType::MX => "MX",
419        DnsRecordType::NS => "NS",
420        DnsRecordType::TXT => "TXT",
421        DnsRecordType::SRV => "SRV",
422        DnsRecordType::CAA => "CAA",
423        DnsRecordType::TLSA => {
424            return Err(Error::Unsupported(
425                "TLSA records are not supported by Azure DNS".to_string(),
426            ));
427        }
428    })
429}
430
431#[derive(Default)]
432struct FetchedRrset {
433    records: Vec<DnsRecord>,
434    etag: Option<String>,
435    ttl: Option<u32>,
436}
437
438fn check_record_types(expected: DnsRecordType, records: &[DnsRecord]) -> Result<()> {
439    azure_record_type(&expected)?;
440    for r in records {
441        if r.as_type() != expected {
442            return Err(Error::Api(format!(
443                "RRSet record type mismatch: expected {}, got {}",
444                expected.as_str(),
445                r.as_type().as_str(),
446            )));
447        }
448    }
449    Ok(())
450}
451
452fn check_cname_singleton(record_type: DnsRecordType, records: &[DnsRecord]) -> Result<()> {
453    if record_type == DnsRecordType::CNAME && records.len() > 1 {
454        return Err(Error::Api(
455            "CNAME RRSet may contain at most one record".to_string(),
456        ));
457    }
458    Ok(())
459}
460
461fn insert_rrset_payload(
462    props: &mut serde_json::Map<String, Value>,
463    record_type: DnsRecordType,
464    records: &[DnsRecord],
465) -> Result<()> {
466    match record_type {
467        DnsRecordType::A => {
468            let arr: Vec<Value> = records
469                .iter()
470                .map(|r| match r {
471                    DnsRecord::A(ip) => json!({"ipv4Address": ip.to_string()}),
472                    _ => unreachable!(),
473                })
474                .collect();
475            props.insert("ARecords".to_string(), Value::Array(arr));
476        }
477        DnsRecordType::AAAA => {
478            let arr: Vec<Value> = records
479                .iter()
480                .map(|r| match r {
481                    DnsRecord::AAAA(ip) => json!({"ipv6Address": ip.to_string()}),
482                    _ => unreachable!(),
483                })
484                .collect();
485            props.insert("AAAARecords".to_string(), Value::Array(arr));
486        }
487        DnsRecordType::CNAME => {
488            if let Some(DnsRecord::CNAME(target)) = records.first() {
489                props.insert(
490                    "CNAMERecord".to_string(),
491                    json!({"cname": target.trim_end_matches('.')}),
492                );
493            }
494        }
495        DnsRecordType::NS => {
496            let arr: Vec<Value> = records
497                .iter()
498                .map(|r| match r {
499                    DnsRecord::NS(target) => {
500                        json!({"nsdname": target.trim_end_matches('.')})
501                    }
502                    _ => unreachable!(),
503                })
504                .collect();
505            props.insert("NSRecords".to_string(), Value::Array(arr));
506        }
507        DnsRecordType::MX => {
508            let arr: Vec<Value> = records
509                .iter()
510                .map(|r| match r {
511                    DnsRecord::MX(mx) => json!({
512                        "preference": mx.priority,
513                        "exchange": mx.exchange.trim_end_matches('.'),
514                    }),
515                    _ => unreachable!(),
516                })
517                .collect();
518            props.insert("MXRecords".to_string(), Value::Array(arr));
519        }
520        DnsRecordType::TXT => {
521            let arr: Vec<Value> = records
522                .iter()
523                .map(|r| match r {
524                    DnsRecord::TXT(text) => json!({"value": txt_chunks(text.clone())}),
525                    _ => unreachable!(),
526                })
527                .collect();
528            props.insert("TXTRecords".to_string(), Value::Array(arr));
529        }
530        DnsRecordType::SRV => {
531            let arr: Vec<Value> = records
532                .iter()
533                .map(|r| match r {
534                    DnsRecord::SRV(srv) => json!({
535                        "priority": srv.priority,
536                        "weight": srv.weight,
537                        "port": srv.port,
538                        "target": srv.target.trim_end_matches('.'),
539                    }),
540                    _ => unreachable!(),
541                })
542                .collect();
543            props.insert("SRVRecords".to_string(), Value::Array(arr));
544        }
545        DnsRecordType::CAA => {
546            let arr: Vec<Value> = records
547                .iter()
548                .map(|r| match r {
549                    DnsRecord::CAA(caa) => {
550                        let (flags, tag, value) = caa.clone().decompose();
551                        json!({"flags": flags, "tag": tag, "value": value})
552                    }
553                    _ => unreachable!(),
554                })
555                .collect();
556            props.insert("caaRecords".to_string(), Value::Array(arr));
557        }
558        DnsRecordType::TLSA => {
559            return Err(Error::Unsupported(
560                "TLSA records are not supported by Azure DNS".to_string(),
561            ));
562        }
563    }
564    Ok(())
565}
566
567fn parse_rrset_records(value: &Value) -> Result<Vec<DnsRecord>> {
568    let props = match value.get("properties") {
569        Some(p) => p,
570        None => return Ok(Vec::new()),
571    };
572
573    let mut out = Vec::new();
574
575    if let Some(arr) = props.get("ARecords").and_then(Value::as_array) {
576        for entry in arr {
577            if let Some(addr) = entry.get("ipv4Address").and_then(Value::as_str)
578                && let Ok(ip) = Ipv4Addr::from_str(addr)
579            {
580                out.push(DnsRecord::A(ip));
581            }
582        }
583    }
584    if let Some(arr) = props.get("AAAARecords").and_then(Value::as_array) {
585        for entry in arr {
586            if let Some(addr) = entry.get("ipv6Address").and_then(Value::as_str)
587                && let Ok(ip) = Ipv6Addr::from_str(addr)
588            {
589                out.push(DnsRecord::AAAA(ip));
590            }
591        }
592    }
593    if let Some(obj) = props.get("CNAMERecord")
594        && let Some(target) = obj.get("cname").and_then(Value::as_str)
595    {
596        out.push(DnsRecord::CNAME(target.to_string()));
597    }
598    if let Some(arr) = props.get("NSRecords").and_then(Value::as_array) {
599        for entry in arr {
600            if let Some(target) = entry.get("nsdname").and_then(Value::as_str) {
601                out.push(DnsRecord::NS(target.to_string()));
602            }
603        }
604    }
605    if let Some(arr) = props.get("MXRecords").and_then(Value::as_array) {
606        for entry in arr {
607            let priority = entry.get("preference").and_then(Value::as_u64).unwrap_or(0) as u16;
608            if let Some(exchange) = entry.get("exchange").and_then(Value::as_str) {
609                out.push(DnsRecord::MX(MXRecord {
610                    priority,
611                    exchange: exchange.to_string(),
612                }));
613            }
614        }
615    }
616    if let Some(arr) = props.get("TXTRecords").and_then(Value::as_array) {
617        for entry in arr {
618            if let Some(values) = entry.get("value").and_then(Value::as_array) {
619                let joined: String = values
620                    .iter()
621                    .filter_map(Value::as_str)
622                    .collect::<Vec<_>>()
623                    .concat();
624                out.push(DnsRecord::TXT(joined));
625            }
626        }
627    }
628    if let Some(arr) = props.get("SRVRecords").and_then(Value::as_array) {
629        for entry in arr {
630            let priority = entry.get("priority").and_then(Value::as_u64).unwrap_or(0) as u16;
631            let weight = entry.get("weight").and_then(Value::as_u64).unwrap_or(0) as u16;
632            let port = entry.get("port").and_then(Value::as_u64).unwrap_or(0) as u16;
633            if let Some(target) = entry.get("target").and_then(Value::as_str) {
634                out.push(DnsRecord::SRV(SRVRecord {
635                    priority,
636                    weight,
637                    port,
638                    target: target.to_string(),
639                }));
640            }
641        }
642    }
643    if let Some(arr) = props.get("caaRecords").and_then(Value::as_array) {
644        for entry in arr {
645            let flags = entry.get("flags").and_then(Value::as_u64).unwrap_or(0) as u8;
646            let tag = entry.get("tag").and_then(Value::as_str).unwrap_or("");
647            let value = entry
648                .get("value")
649                .and_then(Value::as_str)
650                .unwrap_or("")
651                .to_string();
652            let issuer_critical = flags & 0x80 != 0;
653            let caa = match tag.to_ascii_lowercase().as_str() {
654                "issue" => CAARecord::Issue {
655                    issuer_critical,
656                    name: if value.is_empty() { None } else { Some(value) },
657                    options: Vec::<KeyValue>::new(),
658                },
659                "issuewild" => CAARecord::IssueWild {
660                    issuer_critical,
661                    name: if value.is_empty() { None } else { Some(value) },
662                    options: Vec::<KeyValue>::new(),
663                },
664                "iodef" => CAARecord::Iodef {
665                    issuer_critical,
666                    url: value,
667                },
668                _ => continue,
669            };
670            out.push(DnsRecord::CAA(caa));
671        }
672    }
673
674    Ok(out)
675}
676
677#[derive(Deserialize)]
678struct AzureTokenResponse {
679    access_token: String,
680    #[serde(default)]
681    expires_in: Option<u64>,
682}