use anyhow::{Context, Result, bail};
use semver::Version;
use std::path::{Path, PathBuf};
use tracing::{debug, info, warn};
fn validate_repo_identifier(identifier: &str) -> bool {
if identifier.is_empty() || identifier.len() > 100 {
return false;
}
identifier.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
&& !identifier.starts_with('.')
&& !identifier.starts_with('-')
&& !identifier.ends_with('.')
&& !identifier.ends_with('-')
&& !identifier.contains("..")
&& !identifier.contains("./")
&& !identifier.contains('\\')
}
fn validate_and_sanitize_path(path: &Path, base_dir: &Path) -> Result<PathBuf> {
let path_str = path.to_string_lossy();
if path_str.contains("..")
|| path_str.starts_with('/')
|| path_str.starts_with('\\')
|| path_str.contains('\0')
{
bail!("Path contains unsafe traversal patterns: {path_str}");
}
let canonical_base = base_dir.canonicalize().with_context(|| {
format!("Failed to canonicalize base directory: {}", base_dir.display())
})?;
let full_path = base_dir.join(path);
let canonical_path = match full_path.canonicalize() {
Ok(p) => p,
Err(_) => {
if let Some(parent) = full_path.parent() {
if let Some(filename) = full_path.file_name() {
match parent.canonicalize() {
Ok(canonical_parent) => canonical_parent.join(filename),
Err(_) => {
return validate_path_components(&full_path, &canonical_base);
}
}
} else {
bail!("Invalid path structure: {}", full_path.display());
}
} else {
bail!("Invalid path: {}", full_path.display());
}
}
};
if !canonical_path.starts_with(&canonical_base) {
bail!(
"Path traversal detected: {} is outside base directory {}",
canonical_path.display(),
canonical_base.display()
);
}
Ok(canonical_path)
}
fn validate_path_components(path: &Path, base_dir: &Path) -> Result<PathBuf> {
let mut validated_path = base_dir.to_path_buf();
for component in path.components() {
match component {
std::path::Component::Normal(name) => {
let name_str = name.to_string_lossy();
if name_str.contains('\0') || name_str == "." || name_str == ".." {
bail!("Invalid path component: {name_str}");
}
validated_path.push(name);
}
std::path::Component::CurDir => {
}
std::path::Component::ParentDir => {
bail!("Parent directory traversal not allowed");
}
_ => {
bail!("Absolute path components not allowed in extraction");
}
}
}
Ok(validated_path)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ChecksumPolicy {
Required,
#[default]
WarnOnFailure,
Skip,
}
pub struct SelfUpdater {
repo_owner: String,
repo_name: String,
current_version: String,
force: bool,
checksum_policy: ChecksumPolicy,
}
impl Default for SelfUpdater {
fn default() -> Self {
let repo_owner = "aig787".to_string();
let repo_name = "agpm".to_string();
debug_assert!(validate_repo_identifier(&repo_owner), "Default repo_owner must be valid");
debug_assert!(validate_repo_identifier(&repo_name), "Default repo_name must be valid");
Self {
repo_owner,
repo_name,
current_version: env!("CARGO_PKG_VERSION").to_string(),
force: false,
checksum_policy: ChecksumPolicy::default(),
}
}
}
impl SelfUpdater {
pub fn new() -> Self {
Self::default()
}
pub fn with_repo(repo_owner: &str, repo_name: &str) -> Result<Self> {
if !validate_repo_identifier(repo_owner) {
bail!("Invalid repository owner: {repo_owner}");
}
if !validate_repo_identifier(repo_name) {
bail!("Invalid repository name: {repo_name}");
}
Ok(Self {
repo_owner: repo_owner.to_string(),
repo_name: repo_name.to_string(),
current_version: env!("CARGO_PKG_VERSION").to_string(),
force: false,
checksum_policy: ChecksumPolicy::default(),
})
}
pub const fn force(mut self, force: bool) -> Self {
self.force = force;
self
}
pub const fn checksum_policy(mut self, policy: ChecksumPolicy) -> Self {
self.checksum_policy = policy;
self
}
pub fn current_version(&self) -> &str {
&self.current_version
}
fn build_github_api_url(&self, endpoint: &str) -> String {
debug_assert!(
validate_repo_identifier(&self.repo_owner),
"Repository owner should be validated: {}",
self.repo_owner
);
debug_assert!(
validate_repo_identifier(&self.repo_name),
"Repository name should be validated: {}",
self.repo_name
);
format!("https://api.github.com/repos/{}/{}/{}", self.repo_owner, self.repo_name, endpoint)
}
fn build_github_download_url(&self, version: &str, filename: &str) -> String {
debug_assert!(
validate_repo_identifier(&self.repo_owner),
"Repository owner should be validated: {}",
self.repo_owner
);
debug_assert!(
validate_repo_identifier(&self.repo_name),
"Repository name should be validated: {}",
self.repo_name
);
format!(
"https://github.com/{}/{}/releases/download/v{}/{}",
self.repo_owner, self.repo_name, version, filename
)
}
pub async fn check_for_update(&self) -> Result<Option<String>> {
debug!("Checking for updates from {}/{}", self.repo_owner, self.repo_name);
let url = self.build_github_api_url("releases/latest");
let client = reqwest::Client::new();
let response = client
.get(&url)
.header("User-Agent", "agpm")
.send()
.await
.context("Failed to fetch release information")?;
if !response.status().is_success() {
if response.status() == 404 {
warn!("No releases found");
return Ok(None);
}
bail!("GitHub API error: {}", response.status());
}
let release: serde_json::Value = response.json().await?;
let latest_version = release["tag_name"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Release missing tag_name"))?
.trim_start_matches('v');
debug!("Latest version: {}", latest_version);
let current =
Version::parse(&self.current_version).context("Failed to parse current version")?;
let latest = Version::parse(latest_version).context("Failed to parse latest version")?;
if latest > current {
info!("Update available: {} -> {}", self.current_version, latest_version);
Ok(Some(latest_version.to_string()))
} else {
debug!("Already on latest version");
Ok(None)
}
}
pub async fn update(&self, target_version: Option<&str>) -> Result<bool> {
info!("Starting self-update process");
let target_version = if let Some(v) = target_version {
v.trim_start_matches('v').to_string()
} else {
let url = self.build_github_api_url("releases/latest");
let client = reqwest::Client::new();
let response = client
.get(&url)
.header("User-Agent", "agpm")
.send()
.await
.context("Failed to fetch release information")?;
if !response.status().is_success() {
bail!("Failed to get latest release: HTTP {}", response.status());
}
let release: serde_json::Value = response.json().await?;
release["tag_name"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Release missing tag_name"))?
.trim_start_matches('v')
.to_string()
};
let current = Version::parse(&self.current_version)?;
let target = Version::parse(&target_version)?;
if !self.force && current >= target {
info!("Already on version {} (target: {})", current, target);
return Ok(false);
}
let archive_url = self.get_archive_url(&target_version)?;
info!("Downloading from {}", archive_url);
let temp_dir = tempfile::tempdir()?;
let archive_path = temp_dir.path().join("agpm-archive");
self.download_file(&archive_url, &archive_path).await?;
let extracted_binary = self.extract_binary(&archive_path, temp_dir.path()).await?;
self.replace_binary(&extracted_binary).await?;
info!("Successfully updated to version {}", target_version);
Ok(true)
}
fn get_archive_url(&self, version: &str) -> Result<String> {
let platform = match (std::env::consts::OS, std::env::consts::ARCH) {
("macos", "aarch64") => "aarch64-apple-darwin",
("macos", "x86_64") => "x86_64-apple-darwin",
("linux", "aarch64") => "aarch64-unknown-linux-gnu",
("linux", "x86_64") => "x86_64-unknown-linux-gnu",
("windows", "x86_64") => "x86_64-pc-windows-msvc",
("windows", "aarch64") => "aarch64-pc-windows-msvc",
(os, arch) => bail!("Unsupported platform: {os}-{arch}"),
};
let extension = if std::env::consts::OS == "windows" {
"zip"
} else {
"tar.xz"
};
let filename = format!("agpm-cli-{platform}.{extension}");
Ok(self.build_github_download_url(version, &filename))
}
async fn download_file(&self, url: &str, dest: &std::path::Path) -> Result<()> {
use tokio::io::AsyncWriteExt;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300)) .build()?;
let mut retries = 3;
let mut delay = std::time::Duration::from_secs(1);
loop {
match client.get(url).header("User-Agent", "agpm").send().await {
Ok(response) => {
if !response.status().is_success() {
if retries > 0 && response.status().is_server_error() {
warn!("Server error {}, retrying in {:?}...", response.status(), delay);
tokio::time::sleep(delay).await;
delay *= 2; retries -= 1;
continue;
}
bail!("Failed to download: HTTP {}", response.status());
}
if let Some(content_length) = response.content_length()
&& content_length > 100 * 1024 * 1024
{
bail!("Archive too large: {content_length} bytes (max 100MB)");
}
let bytes = response.bytes().await?;
let mut file = tokio::fs::File::create(dest).await?;
file.write_all(&bytes).await?;
file.sync_all().await?;
match self.checksum_policy {
ChecksumPolicy::Required => {
if let Some(checksum_url) = self.get_checksum_url(url) {
self.verify_checksum(&checksum_url, dest, &bytes).await?;
} else {
bail!(
"Checksum verification required but no checksum available for URL: {url}"
);
}
}
ChecksumPolicy::WarnOnFailure => {
if let Some(checksum_url) = self.get_checksum_url(url) {
if let Err(e) =
self.verify_checksum(&checksum_url, dest, &bytes).await
{
warn!("Checksum verification failed, but continuing: {}", e);
}
} else {
warn!("No checksum available for verification: {}", url);
}
}
ChecksumPolicy::Skip => {
debug!("Skipping checksum verification as configured");
}
}
return Ok(());
}
Err(e) if retries > 0 => {
warn!("Download failed: {}, retrying in {:?}...", e, delay);
tokio::time::sleep(delay).await;
delay *= 2; retries -= 1;
}
Err(e) => bail!("Failed to download after retries: {e}"),
}
}
}
fn get_checksum_url(&self, url: &str) -> Option<String> {
if url.contains("github.com") && !url.ends_with(".sha256") {
Some(format!("{url}.sha256"))
} else {
None
}
}
async fn verify_checksum(
&self,
checksum_url: &str,
file_path: &std::path::Path,
content: &[u8],
) -> Result<()> {
use sha2::{Digest, Sha256};
let client =
reqwest::Client::builder().timeout(std::time::Duration::from_secs(30)).build()?;
let response = client
.get(checksum_url)
.header("User-Agent", "agpm")
.send()
.await
.context("Failed to download checksum file")?;
if !response.status().is_success() {
bail!("Failed to download checksum: HTTP {}", response.status());
}
let checksum_text =
response.text().await.context("Failed to read checksum file content")?;
let expected_checksum = checksum_text
.split_whitespace()
.next()
.ok_or_else(|| anyhow::anyhow!("Invalid checksum format: empty file"))?;
if expected_checksum.len() != 64
|| !expected_checksum.chars().all(|c| c.is_ascii_hexdigit())
{
bail!("Invalid SHA256 checksum format: {expected_checksum}");
}
let mut hasher = Sha256::new();
hasher.update(content);
let actual_checksum = format!("{:x}", hasher.finalize());
if expected_checksum.to_lowercase() != actual_checksum {
let _ = tokio::fs::remove_file(file_path).await;
bail!(
"Checksum verification failed! Expected: {expected_checksum}, Got: {actual_checksum}. File may be corrupted or tampered with."
);
}
info!("Checksum verified successfully (SHA256: {})", &actual_checksum[..16]);
Ok(())
}
async fn extract_binary(
&self,
archive_path: &std::path::Path,
temp_dir: &std::path::Path,
) -> Result<std::path::PathBuf> {
let binary_name = if std::env::consts::OS == "windows" {
"agpm.exe"
} else {
"agpm"
};
if archive_path.to_string_lossy().ends_with(".zip") {
let archive_data = tokio::fs::read(archive_path).await?;
let cursor = std::io::Cursor::new(archive_data);
let mut archive = zip::ZipArchive::new(cursor)?;
let total_size: u64 = (0..archive.len())
.map(|i| archive.by_index(i).map(|f| f.size()).unwrap_or(0))
.sum();
if total_size > 500 * 1024 * 1024 {
bail!("Archive uncompressed size too large: {total_size} bytes");
}
for i in 0..archive.len() {
let file = archive.by_index(i)?;
let file_name = file.name();
if file_name.ends_with(&binary_name) {
let file_path = Path::new(file_name);
if let Err(e) = validate_and_sanitize_path(file_path, temp_dir) {
warn!("Skipping malicious path {}: {}", file_name, e);
continue;
}
let path_components: Vec<&str> = file_name.split(&['/', '\\'][..]).collect();
if path_components.len() > 3 {
warn!("Binary nested too deep in archive: {}", file_name);
continue;
}
let output_path = temp_dir.join(binary_name);
use std::io::Read;
let mut content = Vec::new();
let size = file
.take(100 * 1024 * 1024) .read_to_end(&mut content)?;
if size >= 100 * 1024 * 1024 {
bail!("Binary file too large in archive");
}
tokio::fs::write(&output_path, content).await?;
return Ok(output_path);
}
}
bail!("Binary not found in archive");
}
let output = tokio::process::Command::new("tar")
.args(["-xf", &archive_path.to_string_lossy(), "-C", &temp_dir.to_string_lossy()])
.output()
.await?;
if !output.status.success() {
bail!("Failed to extract archive: {}", String::from_utf8_lossy(&output.stderr));
}
let mut entries = tokio::fs::read_dir(temp_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
let relative_path = if let Ok(rel) = path.strip_prefix(temp_dir) {
rel
} else {
warn!("Skipping path outside temp directory: {:?}", path);
continue;
};
match validate_and_sanitize_path(relative_path, temp_dir) {
Ok(validated_path) => {
if validated_path != path {
warn!("Path validation mismatch, skipping: {:?}", path);
continue;
}
}
Err(e) => {
warn!("Skipping invalid path {:?}: {}", path, e);
continue;
}
}
if path.is_dir() {
let binary_path = path.join(binary_name);
if let Ok(metadata) = tokio::fs::metadata(&binary_path).await {
let relative_binary_path = match binary_path.strip_prefix(temp_dir) {
Ok(rel) => rel,
Err(_) => continue,
};
match validate_and_sanitize_path(relative_binary_path, temp_dir) {
Ok(_) => {
if metadata.is_file() && metadata.len() < 100 * 1024 * 1024 {
return Ok(binary_path);
}
}
Err(e) => {
warn!("Invalid binary path {:?}: {}", binary_path, e);
continue;
}
}
}
}
if path.file_name() == Some(std::ffi::OsStr::new(binary_name))
&& let Ok(metadata) = tokio::fs::metadata(&path).await
&& metadata.is_file()
&& metadata.len() < 100 * 1024 * 1024
{
return Ok(path);
}
}
bail!("Binary not found after extraction")
}
async fn replace_binary(&self, new_binary: &std::path::Path) -> Result<()> {
let current_exe = std::env::current_exe()?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = tokio::fs::metadata(&new_binary).await?.permissions();
perms.set_mode(0o755);
tokio::fs::set_permissions(&new_binary, perms).await?;
}
let mut retries = 3;
while retries > 0 {
match tokio::fs::rename(&new_binary, ¤t_exe).await {
Ok(()) => return Ok(()),
Err(e) if retries > 1 => {
warn!("Failed to replace binary, retrying: {}", e);
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
retries -= 1;
}
Err(e) => bail!("Failed to replace binary: {e}"),
}
}
Ok(())
}
pub async fn update_to_latest(&self) -> Result<bool> {
self.update(None).await
}
pub async fn update_to_version(&self, version: &str) -> Result<bool> {
self.update(Some(version)).await
}
}