use anyhow::Result;
use semver::Version;
use std::fs;
use std::path::PathBuf;
use std::time::{Duration, SystemTime};
pub use freenet::transport::{
clear_version_mismatch, get_open_connection_count, has_version_mismatch,
version_mismatch_generation,
};
pub const EXIT_CODE_UPDATE_NEEDED: i32 = 42;
const INITIAL_BACKOFF: Duration = Duration::from_secs(60);
const MAX_BACKOFF: Duration = Duration::from_secs(3600);
const MAX_UPDATE_FAILURES: u32 = 3;
const GITHUB_API_URL: &str = "https://api.github.com/repos/freenet/freenet-core/releases/latest";
#[derive(Debug)]
pub struct UpdateNeededError {
pub new_version: String,
}
impl std::fmt::Display for UpdateNeededError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Update available: version {} is available on GitHub. Exiting for auto-update.",
self.new_version
)
}
}
impl std::error::Error for UpdateNeededError {}
#[derive(Debug, PartialEq)]
pub enum UpdateCheckResult {
Skipped,
UpdateAvailable(String),
}
pub async fn check_if_update_available(current_version: &str) -> UpdateCheckResult {
if !should_attempt_update() {
tracing::debug!(
failures = get_update_failure_count(),
max = MAX_UPDATE_FAILURES,
"Skipping update check - too many previous failures"
);
return UpdateCheckResult::Skipped;
}
let current_backoff = get_current_backoff();
if !should_check_for_update(current_backoff) {
tracing::debug!(
backoff_secs = current_backoff.as_secs(),
"Skipping update check - backoff not elapsed"
);
return UpdateCheckResult::Skipped;
}
record_check_time();
match get_latest_version().await {
Ok(latest) => {
let current = match Version::parse(current_version) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"Failed to parse current version '{}': {}",
current_version,
e
);
increase_backoff();
return UpdateCheckResult::Skipped;
}
};
let latest_ver = match Version::parse(&latest) {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to parse latest version '{}': {}", latest, e);
increase_backoff();
return UpdateCheckResult::Skipped;
}
};
if latest_ver > current {
tracing::info!(
current = %current_version,
latest = %latest,
"Newer version confirmed on GitHub"
);
clear_update_failures();
reset_backoff();
UpdateCheckResult::UpdateAvailable(latest)
} else {
tracing::debug!(
current = %current_version,
latest = %latest,
backoff_secs = current_backoff.as_secs(),
"No newer version on GitHub yet, will retry with increased backoff"
);
increase_backoff();
UpdateCheckResult::Skipped
}
}
Err(e) => {
tracing::warn!(
"Failed to check GitHub for updates: {}. Will retry with increased backoff.",
e
);
increase_backoff();
UpdateCheckResult::Skipped
}
}
}
async fn get_latest_version() -> Result<String> {
let client = reqwest::Client::builder()
.user_agent("freenet-updater")
.timeout(Duration::from_secs(10))
.build()?;
let response = client.get(GITHUB_API_URL).send().await?;
if !response.status().is_success() {
anyhow::bail!("GitHub API returned {}", response.status());
}
#[derive(serde::Deserialize)]
struct Release {
tag_name: String,
}
let release: Release = response.json().await?;
Ok(release.tag_name.trim_start_matches('v').to_string())
}
fn state_dir() -> Option<PathBuf> {
dirs::home_dir().map(|h| h.join(".local/state/freenet"))
}
fn get_last_check_time() -> Option<SystemTime> {
let marker = state_dir()?.join("last_update_check");
fs::metadata(&marker).ok()?.modified().ok()
}
fn record_check_time() {
if let Some(dir) = state_dir() {
let _mkdir = fs::create_dir_all(&dir);
let marker = dir.join("last_update_check");
let _write = fs::write(&marker, "");
}
}
fn get_current_backoff() -> Duration {
let path = state_dir().map(|d| d.join("update_backoff_secs"));
path.and_then(|p| fs::read_to_string(p).ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or(INITIAL_BACKOFF)
}
fn increase_backoff() {
if let Some(dir) = state_dir() {
let _mkdir = fs::create_dir_all(&dir);
let current = get_current_backoff();
let new_backoff = std::cmp::min(current * 2, MAX_BACKOFF);
let _write = fs::write(
dir.join("update_backoff_secs"),
new_backoff.as_secs().to_string(),
);
}
}
pub fn reset_backoff() {
if let Some(dir) = state_dir() {
let _rm = fs::remove_file(dir.join("update_backoff_secs"));
}
}
fn should_check_for_update(backoff: Duration) -> bool {
get_last_check_time()
.and_then(|last| last.elapsed().ok())
.is_none_or(|elapsed| elapsed > backoff)
}
fn get_update_failure_count() -> u32 {
let path = state_dir().map(|d| d.join("update_failures"));
path.and_then(|p| fs::read_to_string(p).ok())
.and_then(|s| s.trim().parse().ok())
.unwrap_or(0)
}
#[allow(dead_code)] pub fn record_update_failure() {
if let Some(dir) = state_dir() {
let _mkdir = fs::create_dir_all(&dir);
let count = get_update_failure_count() + 1;
let _write = fs::write(dir.join("update_failures"), count.to_string());
}
}
pub fn clear_update_failures() {
if let Some(dir) = state_dir() {
let _rm = fs::remove_file(dir.join("update_failures"));
}
}
pub fn should_attempt_update() -> bool {
get_update_failure_count() < MAX_UPDATE_FAILURES
}
pub fn has_reached_max_backoff() -> bool {
get_current_backoff() >= MAX_BACKOFF
}
#[cfg(test)]
mod tests {
use super::*;
use freenet::transport::{
set_open_connection_count, signal_version_mismatch, version_mismatch_generation,
};
#[test]
fn test_version_mismatch_flag() {
clear_version_mismatch();
assert!(!has_version_mismatch());
signal_version_mismatch();
assert!(has_version_mismatch());
clear_version_mismatch();
assert!(!has_version_mismatch());
}
#[test]
fn test_mismatch_generation_increments() {
let gen_before = version_mismatch_generation();
signal_version_mismatch();
let gen_after = version_mismatch_generation();
assert!(
gen_after > gen_before,
"generation should increment on each signal"
);
signal_version_mismatch();
assert!(version_mismatch_generation() > gen_after);
}
#[test]
fn test_open_connection_count() {
set_open_connection_count(0);
assert_eq!(get_open_connection_count(), 0);
set_open_connection_count(5);
assert_eq!(get_open_connection_count(), 5);
set_open_connection_count(0);
assert_eq!(get_open_connection_count(), 0);
}
#[test]
fn test_update_needed_error_display() {
let err = UpdateNeededError {
new_version: "0.1.74".to_string(),
};
let msg = format!("{}", err);
assert!(msg.contains("0.1.74"));
assert!(msg.contains("auto-update"));
}
#[test]
fn test_backoff_constants() {
assert_eq!(INITIAL_BACKOFF, Duration::from_secs(60));
assert_eq!(MAX_BACKOFF, Duration::from_secs(3600));
let mut backoff = INITIAL_BACKOFF;
for _ in 0..6 {
backoff = std::cmp::min(backoff * 2, MAX_BACKOFF);
}
assert_eq!(backoff, MAX_BACKOFF);
}
}