use crate::aws::Error;
use std::path::Path;
use tokio::{
fs::File,
io::AsyncWriteExt,
process::Command,
time::{sleep, Duration},
};
use tracing::{info, warn};
pub const MAX_SSH_ATTEMPTS: usize = 30;
pub const MAX_POLL_ATTEMPTS: usize = 30;
pub const RETRY_INTERVAL: Duration = Duration::from_secs(15);
pub const DEPLOYER_PROTOCOL: &str = "tcp";
pub const DEPLOYER_MIN_PORT: i32 = 0;
pub const DEPLOYER_MAX_PORT: i32 = 65535;
pub async fn get_public_ip() -> Result<String, Error> {
let result = reqwest::get("https://ipv4.icanhazip.com")
.await?
.text()
.await?
.trim()
.to_string();
Ok(result)
}
pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), Error> {
for _ in 0..MAX_SSH_ATTEMPTS {
let output = Command::new("ssh")
.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
.arg("-o")
.arg("ServerAliveInterval=600")
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}"))
.arg(command)
.output()
.await?;
if output.status.success() {
return Ok(());
}
warn!(ip, stderr = ?String::from_utf8_lossy(&output.stderr), stdout = ?String::from_utf8_lossy(&output.stdout), "SSH command failed");
sleep(RETRY_INTERVAL).await;
}
Err(Error::SshFailed)
}
pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
for _ in 0..MAX_POLL_ATTEMPTS {
let output = Command::new("ssh")
.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
.arg("-o")
.arg("ServerAliveInterval=600")
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}"))
.arg(format!("systemctl is-active {service}"))
.output()
.await?;
let parsed = String::from_utf8_lossy(&output.stdout);
let parsed = parsed.trim();
if parsed == "active" {
return Ok(());
}
if service == "binary" && parsed == "failed" {
warn!(service, "service failed to start (check logs and update)");
return Ok(());
}
warn!(status = parsed, service, "service not yet active");
sleep(RETRY_INTERVAL).await;
}
Err(Error::ServiceTimeout(ip.to_string(), service.to_string()))
}
pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
for _ in 0..MAX_POLL_ATTEMPTS {
let output = Command::new("ssh")
.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
.arg("-o")
.arg("ServerAliveInterval=600")
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}"))
.arg(format!("systemctl is-active {service}"))
.output()
.await?;
let parsed = String::from_utf8_lossy(&output.stdout);
let parsed = parsed.trim();
if parsed == "inactive" {
return Ok(());
}
if service == "binary" && parsed == "failed" {
warn!(service, "service was never active");
return Ok(());
}
warn!(status = parsed, service, "service not yet inactive");
sleep(RETRY_INTERVAL).await;
}
Err(Error::ServiceTimeout(ip.to_string(), service.to_string()))
}
pub async fn scp_download(
key_file: &str,
ip: &str,
remote_path: &str,
local_path: &str,
) -> Result<(), Error> {
for _ in 0..MAX_SSH_ATTEMPTS {
let output = Command::new("scp")
.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
.arg("-o")
.arg("ServerAliveInterval=600")
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}:{remote_path}"))
.arg(local_path)
.output()
.await?;
if output.status.success() {
return Ok(());
}
warn!(error = ?String::from_utf8_lossy(&output.stderr), "SCP failed");
sleep(RETRY_INTERVAL).await;
}
Err(Error::SshFailed)
}
pub fn exact_cidr(ip: &str) -> String {
format!("{ip}/32")
}
pub const MAX_DOWNLOAD_ATTEMPTS: usize = 10;
pub async fn download_file(url: &str, dest: &Path) -> Result<(), Error> {
for attempt in 1..=MAX_DOWNLOAD_ATTEMPTS {
match download_file_once(url, dest).await {
Ok(()) => {
info!(url = url, dest = ?dest, "downloaded file");
return Ok(());
}
Err(e) => {
warn!(
url = url,
attempt = attempt,
error = ?e,
"download attempt failed"
);
if attempt < MAX_DOWNLOAD_ATTEMPTS {
sleep(RETRY_INTERVAL).await;
}
}
}
}
Err(Error::DownloadFailed(url.to_string()))
}
async fn download_file_once(url: &str, dest: &Path) -> Result<(), Error> {
let response = reqwest::get(url).await?;
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let bytes = response.bytes().await?;
if let Some(parent) = dest.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let mut file = File::create(dest).await?;
file.write_all(&bytes).await?;
file.flush().await?;
Ok(())
}