netlify_ddns/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(warnings)]
3
4pub mod netlify;
5
6use futures::future::FutureExt;
7use futures::{executor, future};
8
9use anyhow::{Context, Result};
10use tracing::{debug, info};
11
12use netlify::DnsRecord;
13
14#[derive(clap::ValueEnum, Clone, Debug)]
15pub enum IpType {
16    Ipv4,
17    Ipv6,
18}
19
20#[derive(Debug, clap::Parser)]
21#[command(
22    author,
23    version,
24    about,
25    long_about = None
26)]
27pub struct Args {
28    /// The full domain for the DNS record
29    #[arg(short, long)]
30    pub domain: String,
31
32    #[arg(short, long, default_value = "www")]
33    /// The subdomain segment for the DNS record
34    pub subdomain: String,
35
36    /// The TTL value in seconds to set with the record
37    #[arg(long, default_value = "3600")]
38    pub ttl: u32,
39
40    /// Whether an IPv6 "AAAA" or an IPv4 "A" record should be updated
41    #[arg(short, long, value_enum, ignore_case = true, default_value = "ipv4")]
42    pub ip_type: IpType,
43
44    /// Your Netlify personal access token
45    #[arg(short, long, env = "NETLIFY_TOKEN")]
46    pub token: String,
47}
48
49async fn query_ident_me(ip_type: &IpType) -> Result<String> {
50    #[cfg(test)]
51    let resp = match ip_type {
52        IpType::Ipv4 => ureq::get(&mockito::server_url()).call()?,
53        IpType::Ipv6 => ureq::get(&mockito::server_url()).call()?,
54    };
55    #[cfg(not(test))]
56    let resp = match ip_type {
57        IpType::Ipv4 => ureq::get("https://v4.ident.me/").call()?,
58        IpType::Ipv6 => ureq::get("https://v6.ident.me/").call()?,
59    };
60
61    let body = resp
62        .into_string()
63        .context("Failed to convert ident.me response into string.")?;
64    Ok(body)
65}
66
67async fn query_ipify_org(ip_type: &IpType) -> Result<String> {
68    #[cfg(test)]
69    let resp = match ip_type {
70        IpType::Ipv4 => ureq::get(&mockito::server_url()).call()?,
71        IpType::Ipv6 => ureq::get(&mockito::server_url()).call()?,
72    };
73    #[cfg(not(test))]
74    let resp = match ip_type {
75        IpType::Ipv4 => ureq::get("https://api.ipify.org/").call()?,
76        IpType::Ipv6 => ureq::get("https://api6.ipify.org/").call()?,
77    };
78
79    let body = resp
80        .into_string()
81        .context("Failed to convert ident.me response into string.")?;
82    Ok(body)
83}
84
85// Get the host machine's external IP address by querying multiple services and
86// taking the first response.
87async fn get_external_ip(ip_type: &IpType) -> Result<String> {
88    debug!("Querying third-party services for external IP...");
89
90    let third_parties = vec![
91        query_ident_me(ip_type).boxed(),
92        query_ipify_org(ip_type).boxed(),
93    ];
94
95    // Select the first succesful future, or the last failure.
96    let (ip, _) = future::select_ok(third_parties.into_iter())
97        .await
98        .context("All queries for external IP failed.")?;
99
100    info!("Found External IP: {}", ip);
101    Ok(ip)
102}
103
104fn get_conflicts(
105    dns_records: Vec<DnsRecord>,
106    args: &Args,
107    rec: &DnsRecord,
108) -> (Vec<DnsRecord>, Vec<DnsRecord>) {
109    let target_hostname = format!(
110        "{}{}{}",
111        &args.subdomain,
112        if args.subdomain.is_empty() { "" } else { "." },
113        &args.domain
114    );
115    dns_records
116        .into_iter()
117        .filter(|r| match args.ip_type {
118            IpType::Ipv4 => r.dns_type == "A",
119            IpType::Ipv6 => r.dns_type == "AAAA",
120        })
121        .filter(|r| r.hostname == target_hostname)
122        .partition(|r| r.hostname == target_hostname && r.value == rec.value && r.ttl == rec.ttl)
123}
124
125pub fn run(args: Args) -> Result<()> {
126    let ip = executor::block_on(get_external_ip(&args.ip_type))?;
127
128    let rec = DnsRecord {
129        hostname: args.subdomain.to_string(),
130        dns_type: match args.ip_type {
131            IpType::Ipv4 => "A".to_string(),
132            IpType::Ipv6 => "AAAA".to_string(),
133        },
134        ttl: Some(args.ttl),
135        value: ip,
136        id: None,
137    };
138
139    // Update the DNS record if it exists, otherwise add.
140    let dns_records = netlify::get_dns_records(&args.domain, &args.token)
141        .context("Unable to fetch DNS records.")?;
142
143    // Match on subdomain
144    let (exact, conflicts) = get_conflicts(dns_records, &args, &rec);
145
146    // Clear existing records for this subdomain, if any
147    for r in conflicts {
148        info!("Clearing conflicting DNS records for this subdomain.");
149        netlify::delete_dns_record(&args.domain, &args.token, r)
150            .context("Unable to delete DNS records.")?;
151    }
152
153    // Add new record
154    if exact.is_empty() {
155        info!("Adding the DNS record.");
156        let rec = netlify::add_dns_record(&args.domain, &args.token, rec)
157            .context("Unable to add the DNS record.")?;
158        info!("{:#?}", rec);
159    }
160
161    Ok(())
162}
163
164#[cfg(test)]
165mod test {
166    use super::*;
167    use mockito::mock;
168
169    #[test]
170    fn test_get_external_ip() {
171        let _m = mock("GET", "/")
172            .with_status(200)
173            .with_header("content-type", "text/plain")
174            .with_body("104.132.34.103")
175            .create();
176        let ip = executor::block_on(get_external_ip(&IpType::Ipv4)).unwrap();
177        assert_eq!("104.132.34.103", &ip);
178
179        let _m = mock("GET", "/")
180            .with_status(200)
181            .with_header("content-type", "text/plain")
182            .with_body("2620:0:1003:fd00:95e9:369a:53cd:f035")
183            .create();
184
185        let ip = executor::block_on(get_external_ip(&IpType::Ipv6)).unwrap();
186        assert_eq!("2620:0:1003:fd00:95e9:369a:53cd:f035", &ip);
187    }
188
189    #[test]
190    fn test_get_external_ip_404() {
191        let _m = mock("GET", "/")
192            .with_status(404)
193            .with_header("content-type", "text/plain")
194            .with_body("Not found")
195            .create();
196
197        if executor::block_on(get_external_ip(&IpType::Ipv6)).is_ok() {
198            panic!("Should've gotten an error.");
199        }
200    }
201
202    #[test]
203    fn test_conflicts() {
204        let dns_records = vec![
205            // Basic subdomain, exact and non-exact
206            DnsRecord {
207                hostname: "sub.helloworld.com".to_string(),
208                dns_type: "A".to_string(),
209                ttl: Some(3600),
210                value: "1.2.3.4".to_string(),
211                id: Some("abc123".to_string()),
212            },
213            DnsRecord {
214                hostname: "sub.helloworld.com".to_string(),
215                dns_type: "A".to_string(),
216                ttl: Some(3600),
217                id: Some("abc123".to_string()),
218                value: "9.9.9.9".to_string(),
219            },
220            // Glob subdomain, exact and non-exact
221            DnsRecord {
222                hostname: "*.sub.helloworld.com".to_string(),
223                dns_type: "A".to_string(),
224                ttl: Some(3600),
225                id: Some("abc123".to_string()),
226                value: "1.2.3.4".to_string(),
227            },
228            DnsRecord {
229                hostname: "*.sub.helloworld.com".to_string(),
230                dns_type: "A".to_string(),
231                ttl: Some(3600),
232                id: Some("abc123".to_string()),
233                value: "9.9.9.9".to_string(),
234            },
235            // Empty subdomain, exact and non-exact
236            DnsRecord {
237                hostname: "helloworld.com".to_string(),
238                dns_type: "A".to_string(),
239                ttl: Some(3600),
240                id: Some("abc123".to_string()),
241                value: "1.2.3.4".to_string(),
242            },
243            DnsRecord {
244                hostname: "helloworld.com".to_string(),
245                dns_type: "A".to_string(),
246                ttl: Some(3600),
247                id: Some("abc123".to_string()),
248                value: "9.9.9.9".to_string(),
249            },
250        ];
251
252        let (glob_exact, glob_conflicts) = get_conflicts(
253            dns_records.clone(),
254            &Args {
255                domain: "helloworld.com".to_string(),
256                subdomain: "*.sub".to_string(),
257                ttl: 3600,
258                ip_type: IpType::Ipv4,
259                token: "123".to_string(),
260            },
261            &DnsRecord {
262                hostname: "*.sub".to_string(),
263                dns_type: "A".to_string(),
264                ttl: Some(3600),
265                id: None,
266                value: "1.2.3.4".to_string(),
267            },
268        );
269        assert_eq!(glob_conflicts.len(), 1);
270        assert_eq!(glob_exact.len(), 1);
271
272        let (sub_exact, sub_conflicts) = get_conflicts(
273            dns_records.clone(),
274            &Args {
275                domain: "helloworld.com".to_string(),
276                subdomain: "sub".to_string(),
277                ttl: 3600,
278                ip_type: IpType::Ipv4,
279                token: "123".to_string(),
280            },
281            &DnsRecord {
282                hostname: "sub".to_string(),
283                dns_type: "A".to_string(),
284                ttl: Some(3600),
285                id: None,
286                value: "1.2.3.4".to_string(),
287            },
288        );
289        assert_eq!(sub_conflicts.len(), 1);
290        assert_eq!(sub_exact.len(), 1);
291
292        let (empty_exact, empty_conflicts) = get_conflicts(
293            dns_records,
294            &Args {
295                domain: "helloworld.com".to_string(),
296                subdomain: "".to_string(),
297                ttl: 3600,
298                ip_type: IpType::Ipv4,
299                token: "123".to_string(),
300            },
301            &DnsRecord {
302                hostname: "".to_string(),
303                dns_type: "A".to_string(),
304                ttl: Some(3600),
305                id: None,
306                value: "1.2.3.4".to_string(),
307            },
308        );
309        assert_eq!(empty_conflicts.len(), 1);
310        assert_eq!(empty_exact.len(), 1);
311
312        // Test that TTL is also included in the computation
313        let (glob_exact, glob_conflicts) = get_conflicts(
314            vec![DnsRecord {
315                hostname: "sub.helloworld.com".to_string(),
316                dns_type: "A".to_string(),
317                ttl: Some(3600),
318                value: "1.2.3.4".to_string(),
319                id: Some("abc123".to_string()),
320            }],
321            &Args {
322                domain: "helloworld.com".to_string(),
323                subdomain: "sub".to_string(),
324                ttl: 10,
325                ip_type: IpType::Ipv4,
326                token: "123".to_string(),
327            },
328            &DnsRecord {
329                hostname: "sub".to_string(),
330                dns_type: "A".to_string(),
331                ttl: Some(10),
332                id: None,
333                value: "1.2.3.4".to_string(),
334            },
335        );
336        assert_eq!(glob_conflicts.len(), 1);
337        assert_eq!(glob_exact.len(), 0);
338    }
339}