Skip to main content

commonware_deployer/aws/
utils.rs

1//! Utility functions for interacting with EC2 instances
2
3use crate::aws::Error;
4use std::path::Path;
5use tokio::{
6    fs::File,
7    io::AsyncWriteExt,
8    process::Command,
9    time::{sleep, Duration},
10};
11use tracing::{info, warn};
12
13/// Maximum number of SSH connection attempts before failing
14pub const MAX_SSH_ATTEMPTS: usize = 30;
15
16/// Maximum number of polling attempts for service status
17pub const MAX_POLL_ATTEMPTS: usize = 30;
18
19/// Interval between retries
20pub const RETRY_INTERVAL: Duration = Duration::from_secs(15);
21
22/// Protocol for deployer ingress
23pub const DEPLOYER_PROTOCOL: &str = "tcp";
24
25/// Minimum port for deployer ingress
26pub const DEPLOYER_MIN_PORT: i32 = 0;
27
28/// Maximum port for deployer ingress
29pub const DEPLOYER_MAX_PORT: i32 = 65535;
30
31/// Fetch the current machine's public IPv4 address
32pub async fn get_public_ip() -> Result<String, Error> {
33    // icanhazip.com is maintained by Cloudflare as of 6/6/2021 (https://major.io/p/a-new-future-for-icanhazip/)
34    let result = reqwest::get("https://ipv4.icanhazip.com")
35        .await?
36        .text()
37        .await?
38        .trim()
39        .to_string();
40    Ok(result)
41}
42
43/// Executes a command on a remote instance via SSH with retries
44pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), Error> {
45    for _ in 0..MAX_SSH_ATTEMPTS {
46        let output = Command::new("ssh")
47            .arg("-i")
48            .arg(key_file)
49            .arg("-o")
50            .arg("IdentitiesOnly=yes")
51            .arg("-o")
52            .arg("ServerAliveInterval=600")
53            .arg("-o")
54            .arg("StrictHostKeyChecking=no")
55            .arg(format!("ubuntu@{ip}"))
56            .arg(command)
57            .output()
58            .await?;
59        if output.status.success() {
60            return Ok(());
61        }
62        warn!(ip, stderr = ?String::from_utf8_lossy(&output.stderr), stdout = ?String::from_utf8_lossy(&output.stdout), "SSH command failed");
63        sleep(RETRY_INTERVAL).await;
64    }
65    Err(Error::SshFailed)
66}
67
68/// Polls the status of a systemd service on a remote instance until active
69pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
70    for _ in 0..MAX_POLL_ATTEMPTS {
71        let output = Command::new("ssh")
72            .arg("-i")
73            .arg(key_file)
74            .arg("-o")
75            .arg("IdentitiesOnly=yes")
76            .arg("-o")
77            .arg("ServerAliveInterval=600")
78            .arg("-o")
79            .arg("StrictHostKeyChecking=no")
80            .arg(format!("ubuntu@{ip}"))
81            .arg(format!("systemctl is-active {service}"))
82            .output()
83            .await?;
84        let parsed = String::from_utf8_lossy(&output.stdout);
85        let parsed = parsed.trim();
86        if parsed == "active" {
87            return Ok(());
88        }
89        if service == "binary" && parsed == "failed" {
90            warn!(service, "service failed to start (check logs and update)");
91            return Ok(());
92        }
93        warn!(status = parsed, service, "service not yet active");
94        sleep(RETRY_INTERVAL).await;
95    }
96    Err(Error::ServiceTimeout(ip.to_string(), service.to_string()))
97}
98
99/// Polls the status of a systemd service on a remote instance until it becomes inactive
100pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
101    for _ in 0..MAX_POLL_ATTEMPTS {
102        let output = Command::new("ssh")
103            .arg("-i")
104            .arg(key_file)
105            .arg("-o")
106            .arg("IdentitiesOnly=yes")
107            .arg("-o")
108            .arg("ServerAliveInterval=600")
109            .arg("-o")
110            .arg("StrictHostKeyChecking=no")
111            .arg(format!("ubuntu@{ip}"))
112            .arg(format!("systemctl is-active {service}"))
113            .output()
114            .await?;
115        let parsed = String::from_utf8_lossy(&output.stdout);
116        let parsed = parsed.trim();
117        if parsed == "inactive" {
118            return Ok(());
119        }
120        if service == "binary" && parsed == "failed" {
121            warn!(service, "service was never active");
122            return Ok(());
123        }
124        warn!(status = parsed, service, "service not yet inactive");
125        sleep(RETRY_INTERVAL).await;
126    }
127    Err(Error::ServiceTimeout(ip.to_string(), service.to_string()))
128}
129
130/// Downloads a file from a remote instance via SCP with retries
131pub async fn scp_download(
132    key_file: &str,
133    ip: &str,
134    remote_path: &str,
135    local_path: &str,
136) -> Result<(), Error> {
137    for _ in 0..MAX_SSH_ATTEMPTS {
138        let output = Command::new("scp")
139            .arg("-i")
140            .arg(key_file)
141            .arg("-o")
142            .arg("IdentitiesOnly=yes")
143            .arg("-o")
144            .arg("ServerAliveInterval=600")
145            .arg("-o")
146            .arg("StrictHostKeyChecking=no")
147            .arg(format!("ubuntu@{ip}:{remote_path}"))
148            .arg(local_path)
149            .output()
150            .await?;
151        if output.status.success() {
152            return Ok(());
153        }
154        warn!(error = ?String::from_utf8_lossy(&output.stderr), "SCP failed");
155        sleep(RETRY_INTERVAL).await;
156    }
157    Err(Error::SshFailed)
158}
159
160/// Converts an IP address to a CIDR block
161pub fn exact_cidr(ip: &str) -> String {
162    format!("{ip}/32")
163}
164
165/// Maximum number of download attempts before failing
166pub const MAX_DOWNLOAD_ATTEMPTS: usize = 10;
167
168/// Downloads a file from a URL to a local path with retries
169pub async fn download_file(url: &str, dest: &Path) -> Result<(), Error> {
170    for attempt in 1..=MAX_DOWNLOAD_ATTEMPTS {
171        match download_file_once(url, dest).await {
172            Ok(()) => {
173                info!(url = url, dest = ?dest, "downloaded file");
174                return Ok(());
175            }
176            Err(e) => {
177                warn!(
178                    url = url,
179                    attempt = attempt,
180                    error = ?e,
181                    "download attempt failed"
182                );
183                if attempt < MAX_DOWNLOAD_ATTEMPTS {
184                    sleep(RETRY_INTERVAL).await;
185                }
186            }
187        }
188    }
189    Err(Error::DownloadFailed(url.to_string()))
190}
191
192async fn download_file_once(url: &str, dest: &Path) -> Result<(), Error> {
193    let response = reqwest::get(url).await?;
194    if !response.status().is_success() {
195        return Err(Error::DownloadFailed(format!(
196            "HTTP {}: {}",
197            response.status(),
198            url
199        )));
200    }
201
202    let bytes = response.bytes().await?;
203
204    // Create parent directory if it doesn't exist
205    if let Some(parent) = dest.parent() {
206        tokio::fs::create_dir_all(parent).await?;
207    }
208
209    let mut file = File::create(dest).await?;
210    file.write_all(&bytes).await?;
211    file.flush().await?;
212
213    Ok(())
214}