use anyhow::Result;
use semver::Version;
use std::fs;
use std::path::PathBuf;
use std::time::{Duration, SystemTime};
pub use freenet::transport::{
clear_version_mismatch, get_open_connection_count, has_version_mismatch,
version_mismatch_generation,
};
pub const EXIT_CODE_UPDATE_NEEDED: i32 = 42;
const INITIAL_BACKOFF: Duration = Duration::from_secs(60);
const MAX_BACKOFF: Duration = Duration::from_secs(3600);
const MAX_UPDATE_FAILURES: u32 = 3;
const GITHUB_API_URL: &str = "https://api.github.com/repos/freenet/freenet-core/releases/latest";
#[derive(Debug)]
pub struct UpdateNeededError {
pub new_version: String,
}
impl std::fmt::Display for UpdateNeededError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Update available: version {} is available on GitHub. Exiting for auto-update.",
self.new_version
)
}
}
impl std::error::Error for UpdateNeededError {}
#[derive(Debug, PartialEq)]
pub enum UpdateCheckResult {
Skipped,
UpdateAvailable(String),
}
pub async fn check_if_update_available(current_version: &str) -> UpdateCheckResult {
if !should_attempt_update() {
tracing::debug!(
failures = get_update_failure_count(),
max = MAX_UPDATE_FAILURES,
"Skipping update check - too many previous failures"
);
return UpdateCheckResult::Skipped;
}
let current_backoff = get_current_backoff();
if !should_check_for_update(current_backoff) {
tracing::debug!(
backoff_secs = current_backoff.as_secs(),
"Skipping update check - backoff not elapsed"
);
return UpdateCheckResult::Skipped;
}
record_check_time();
match get_latest_version().await {
Ok(latest) => {
let current = match Version::parse(current_version) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"Failed to parse current version '{}': {}",
current_version,
e
);
increase_backoff();
return UpdateCheckResult::Skipped;
}
};
let latest_ver = match Version::parse(&latest) {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to parse latest version '{}': {}", latest, e);
increase_backoff();
return UpdateCheckResult::Skipped;
}
};
if latest_ver > current {
tracing::info!(
current = %current_version,
latest = %latest,
"Newer version confirmed on GitHub"
);
clear_update_failures();
reset_backoff();
UpdateCheckResult::UpdateAvailable(latest)
} else {
tracing::debug!(
current = %current_version,
latest = %latest,
backoff_secs = current_backoff.as_secs(),
"No newer version on GitHub yet, will retry with increased backoff"
);
increase_backoff();
UpdateCheckResult::Skipped
}
}
Err(e) => {
tracing::warn!(
"Failed to check GitHub for updates: {}. Will retry with increased backoff.",
e
);
increase_backoff();
UpdateCheckResult::Skipped
}
}
}
async fn get_latest_version() -> Result<String> {
let client = reqwest::Client::builder()
.user_agent("freenet-updater")
.timeout(Duration::from_secs(10))
.build()?;
let response = client.get(GITHUB_API_URL).send().await?;
if !response.status().is_success() {
anyhow::bail!("GitHub API returned {}", response.status());
}
#[derive(serde::Deserialize)]
struct Release {
tag_name: String,
}
let release: Release = response.json().await?;
Ok(release.tag_name.trim_start_matches('v').to_string())
}
fn state_dir() -> Option<PathBuf> {
dirs::home_dir().map(|h| h.join(".local/state/freenet"))
}
fn get_last_check_time() -> Option<SystemTime> {
let marker = state_dir()?.join("last_update_check");
fs::metadata(&marker).ok()?.modified().ok()
}
fn record_check_time() {
if let Some(dir) = state_dir() {
let _mkdir = fs::create_dir_all(&dir);
let marker = dir.join("last_update_check");
let _write = fs::write(&marker, "");
}
}
fn get_current_backoff() -> Duration {
let path = state_dir().map(|d| d.join("update_backoff_secs"));
path.and_then(|p| fs::read_to_string(p).ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or(INITIAL_BACKOFF)
}
fn increase_backoff() {
if let Some(dir) = state_dir() {
let _mkdir = fs::create_dir_all(&dir);
let current = get_current_backoff();
let new_backoff = std::cmp::min(current * 2, MAX_BACKOFF);
let _write = fs::write(
dir.join("update_backoff_secs"),
new_backoff.as_secs().to_string(),
);
}
}
pub fn reset_backoff() {
if let Some(dir) = state_dir() {
let _rm = fs::remove_file(dir.join("update_backoff_secs"));
}
}
fn should_check_for_update(backoff: Duration) -> bool {
get_last_check_time()
.and_then(|last| last.elapsed().ok())
.is_none_or(|elapsed| elapsed > backoff)
}
fn get_update_failure_count() -> u32 {
let path = state_dir().map(|d| d.join("update_failures"));
path.and_then(|p| fs::read_to_string(p).ok())
.and_then(|s| s.trim().parse().ok())
.unwrap_or(0)
}
#[allow(dead_code)] pub fn record_update_failure() {
if let Some(dir) = state_dir() {
let _mkdir = fs::create_dir_all(&dir);
let count = get_update_failure_count() + 1;
let _write = fs::write(dir.join("update_failures"), count.to_string());
}
}
pub fn clear_update_failures() {
if let Some(dir) = state_dir() {
let _rm = fs::remove_file(dir.join("update_failures"));
}
}
pub fn should_attempt_update() -> bool {
get_update_failure_count() < MAX_UPDATE_FAILURES
}
pub fn has_reached_max_backoff() -> bool {
get_current_backoff() >= MAX_BACKOFF
}
pub async fn startup_update_check(current_version: &str) -> Option<String> {
startup_update_check_with_fetcher(current_version, get_latest_version).await
}
pub(crate) async fn startup_update_check_with_fetcher<F, Fut>(
current_version: &str,
fetcher: F,
) -> Option<String>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<String>>,
{
let latest = match fetcher().await {
Ok(s) => s,
Err(e) => {
tracing::warn!(
"Startup update check: failed to fetch latest version: {}. \
Continuing with current binary.",
e
);
return None;
}
};
compare_versions_for_startup(current_version, &latest)
}
pub(crate) fn compare_versions_for_startup(current: &str, latest: &str) -> Option<String> {
let current_ver = match Version::parse(current) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"Startup update check: failed to parse current version '{}': {}",
current,
e
);
return None;
}
};
let latest_ver = match Version::parse(latest) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"Startup update check: failed to parse latest version '{}': {}",
latest,
e
);
return None;
}
};
if latest_ver > current_ver {
Some(latest.to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use freenet::transport::{
set_open_connection_count, signal_version_mismatch, version_mismatch_generation,
};
#[test]
fn test_version_mismatch_flag() {
clear_version_mismatch();
assert!(!has_version_mismatch());
signal_version_mismatch();
assert!(has_version_mismatch());
clear_version_mismatch();
assert!(!has_version_mismatch());
}
#[test]
fn test_mismatch_generation_increments() {
let gen_before = version_mismatch_generation();
signal_version_mismatch();
let gen_after = version_mismatch_generation();
assert!(
gen_after > gen_before,
"generation should increment on each signal"
);
signal_version_mismatch();
assert!(version_mismatch_generation() > gen_after);
}
#[test]
fn test_open_connection_count() {
set_open_connection_count(0);
assert_eq!(get_open_connection_count(), 0);
set_open_connection_count(5);
assert_eq!(get_open_connection_count(), 5);
set_open_connection_count(0);
assert_eq!(get_open_connection_count(), 0);
}
#[test]
fn test_update_needed_error_display() {
let err = UpdateNeededError {
new_version: "0.1.74".to_string(),
};
let msg = format!("{}", err);
assert!(msg.contains("0.1.74"));
assert!(msg.contains("auto-update"));
}
#[test]
fn test_compare_versions_newer_available() {
assert_eq!(
compare_versions_for_startup("0.1.74", "0.1.75"),
Some("0.1.75".to_string())
);
assert_eq!(
compare_versions_for_startup("0.1.74", "0.2.0"),
Some("0.2.0".to_string())
);
assert_eq!(
compare_versions_for_startup("0.1.74", "1.0.0"),
Some("1.0.0".to_string())
);
}
#[test]
fn test_compare_versions_already_current() {
assert_eq!(compare_versions_for_startup("0.1.75", "0.1.75"), None);
}
#[test]
fn test_compare_versions_never_downgrades() {
assert_eq!(compare_versions_for_startup("0.2.0", "0.1.99"), None);
assert_eq!(compare_versions_for_startup("1.0.0", "0.9.99"), None);
}
#[test]
fn test_compare_versions_unparseable_fails_open() {
assert_eq!(
compare_versions_for_startup("not-a-version", "0.1.75"),
None
);
assert_eq!(compare_versions_for_startup("0.1.74", "also-garbage"), None);
assert_eq!(compare_versions_for_startup("", "0.1.75"), None);
}
#[test]
fn test_compare_versions_prerelease_semver_semantics() {
assert_eq!(
compare_versions_for_startup("0.1.75-alpha", "0.1.75"),
Some("0.1.75".to_string())
);
assert_eq!(compare_versions_for_startup("0.1.75", "0.1.75-alpha"), None);
}
#[tokio::test]
async fn test_startup_check_fetcher_error_returns_none() {
let result = startup_update_check_with_fetcher("0.1.74", || async {
anyhow::bail!("simulated network failure")
})
.await;
assert_eq!(result, None);
}
#[tokio::test]
async fn test_startup_check_finds_newer_version() {
let result =
startup_update_check_with_fetcher("0.1.74", || async { Ok("0.1.75".to_string()) })
.await;
assert_eq!(result, Some("0.1.75".to_string()));
}
#[tokio::test]
async fn test_startup_check_no_update_when_current() {
let result =
startup_update_check_with_fetcher("0.1.75", || async { Ok("0.1.75".to_string()) })
.await;
assert_eq!(result, None);
}
#[tokio::test]
async fn test_startup_check_refuses_downgrade() {
let result =
startup_update_check_with_fetcher("0.2.0", || async { Ok("0.1.99".to_string()) }).await;
assert_eq!(result, None);
}
#[test]
fn test_backoff_constants() {
assert_eq!(INITIAL_BACKOFF, Duration::from_secs(60));
assert_eq!(MAX_BACKOFF, Duration::from_secs(3600));
let mut backoff = INITIAL_BACKOFF;
for _ in 0..6 {
backoff = std::cmp::min(backoff * 2, MAX_BACKOFF);
}
assert_eq!(backoff, MAX_BACKOFF);
}
}