#![allow(clippy::too_many_arguments)]
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub mod install_state;
pub mod worker;
pub async fn check_for_update(current_version: &str) -> Option<String> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.ok()?;
let resp = client
.get("https://registry.npmjs.org/@openlatch%2Fprovider/latest")
.header("Accept", "application/json")
.send()
.await
.ok()?;
if !resp.status().is_success() {
return None;
}
let body: serde_json::Value = resp.json().await.ok()?;
let latest = body.get("version")?.as_str()?;
(latest != current_version).then_some(latest.to_string())
}
const TRUSTED_KEYS_FILE: &str = include_str!("../../signing/openlatch-provider.pub");
pub fn trusted_keys() -> Vec<String> {
#[cfg(any(test, feature = "insecure-test-keys"))]
if let Ok(test_keys) = std::env::var("OPENLATCH_PROVIDER_TRUSTED_KEYS") {
if !test_keys.trim().is_empty() {
return test_keys
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
}
}
parse_trusted_keys(TRUSTED_KEYS_FILE)
}
fn parse_trusted_keys(input: &str) -> Vec<String> {
input
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(String::from)
.take(3)
.collect()
}
#[derive(Debug, thiserror::Error)]
pub enum VerifyError {
#[error("malformed minisign artefact: {0}")]
Malformed(String),
#[error("no trusted public key matched the signature")]
NoTrustedKeyMatched,
#[error("io error reading signed binary or signature: {0}")]
Io(#[from] std::io::Error),
}
pub fn verify_with_any_trusted_key(
binary_path: &Path,
sig_path: &Path,
) -> Result<usize, VerifyError> {
use minisign_verify::{PublicKey, Signature};
let sig_text = std::fs::read_to_string(sig_path)?;
let sig = Signature::decode(&sig_text)
.map_err(|e| VerifyError::Malformed(format!("decode signature: {e}")))?;
let content = std::fs::read(binary_path)?;
let keys = trusted_keys();
if keys.is_empty() {
return Err(VerifyError::Malformed(
"no trusted keys configured (signing/openlatch-provider.pub is empty)".into(),
));
}
for (idx, key_b64) in keys.iter().enumerate() {
let pk = PublicKey::from_base64(key_b64)
.map_err(|e| VerifyError::Malformed(format!("trusted key #{idx} is invalid: {e}")))?;
if pk.verify(&content, &sig, true).is_ok() {
return Ok(idx);
}
}
Err(VerifyError::NoTrustedKeyMatched)
}
#[derive(Debug, thiserror::Error)]
pub enum SanityError {
#[error("staging binary failed to execute: {0}")]
ExecFailed(#[from] std::io::Error),
#[error("staging binary exited non-zero ({0})")]
NonZeroExit(i32),
#[error("staging binary's --version output ({stdout:?}) did not contain expected version ({expected:?})")]
VersionMismatch { stdout: String, expected: String },
}
pub fn sanity_check_version(
staging_binary: &Path,
expected_version: &str,
) -> Result<(), SanityError> {
let out = std::process::Command::new(staging_binary)
.arg("--version")
.output()?;
if !out.status.success() {
return Err(SanityError::NonZeroExit(out.status.code().unwrap_or(-1)));
}
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
if !stdout.contains(expected_version) {
return Err(SanityError::VersionMismatch {
stdout,
expected: expected_version.to_string(),
});
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum SwapError {
#[error("could not resolve current executable: {0}")]
CurrentExe(#[source] std::io::Error),
#[error("self-replace of running daemon binary failed: {0}")]
SelfReplace(String),
}
#[derive(Debug, Clone)]
pub struct SwapHandle {
pub current_exe: PathBuf,
pub current_exe_bak: PathBuf,
}
pub fn perform_swap(staging_exe: &Path) -> Result<SwapHandle, SwapError> {
let current_exe = std::env::current_exe().map_err(SwapError::CurrentExe)?;
let current_exe_bak = current_exe.with_extension("bak");
if let Err(e) = std::fs::copy(¤t_exe, ¤t_exe_bak) {
return Err(SwapError::SelfReplace(format!(
"snapshot current_exe to {}: {e}",
current_exe_bak.display()
)));
}
if let Err(e) = self_replace::self_replace(staging_exe) {
let _ = std::fs::remove_file(¤t_exe_bak);
return Err(SwapError::SelfReplace(format!("{e}")));
}
Ok(SwapHandle {
current_exe,
current_exe_bak,
})
}
pub fn restore_from_bak(handle: &SwapHandle) -> std::io::Result<()> {
if handle.current_exe_bak.exists() {
self_replace::self_replace(&handle.current_exe_bak)
.map_err(|e| std::io::Error::other(format!("self_replace: {e}")))?;
let _ = std::fs::remove_file(&handle.current_exe_bak);
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Severity {
#[default]
Normal,
Critical,
}
impl Severity {
fn from_str_opt(s: &str) -> Option<Self> {
match s {
"normal" => Some(Self::Normal),
"critical" => Some(Self::Critical),
_ => None,
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::Normal => "normal",
Self::Critical => "critical",
}
}
}
impl serde::Serialize for Severity {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(self.as_str())
}
}
#[derive(Debug, Clone)]
pub enum CheckResult {
UpToDate {
current: String,
},
Available {
current: String,
latest: String,
severity: Severity,
min_supported: Option<String>,
tarball_url: String,
tarball_integrity: String,
},
Failed {
reason: String,
},
}
pub fn platform_package_name() -> Option<&'static str> {
if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
Some("@openlatch/provider-darwin-arm64")
} else if cfg!(all(target_os = "macos", target_arch = "x86_64")) {
Some("@openlatch/provider-darwin-x64")
} else if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
Some("@openlatch/provider-linux-x64")
} else if cfg!(all(target_os = "linux", target_arch = "aarch64")) {
Some("@openlatch/provider-linux-arm64")
} else if cfg!(all(target_os = "windows", target_arch = "x86_64")) {
Some("@openlatch/provider-win32-x64")
} else {
None
}
}
pub fn version_at_least(current: &str, min: &str) -> bool {
use semver::Version;
match (Version::parse(current), Version::parse(min)) {
(Ok(cur), Ok(m)) => cur >= m,
_ => false,
}
}
pub async fn check(current_version: &str, registry_origin: &str) -> CheckResult {
use semver::Version;
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.use_rustls_tls()
.build()
{
Ok(c) => c,
Err(e) => {
return CheckResult::Failed {
reason: format!("http client build: {e}"),
}
}
};
let registry = registry_origin.trim_end_matches('/');
let meta_url = format!("{registry}/@openlatch%2Fprovider/latest");
let manifest = match http_get_json(&client, &meta_url).await {
Ok(v) => v,
Err(e) => {
return CheckResult::Failed {
reason: format!("manifest fetch: {e}"),
}
}
};
let Some(latest_str) = manifest.get("version").and_then(|v| v.as_str()) else {
return CheckResult::Failed {
reason: "manifest missing version".into(),
};
};
let (Ok(current), Ok(latest)) = (Version::parse(current_version), Version::parse(latest_str))
else {
return CheckResult::Failed {
reason: "version parse failed".into(),
};
};
if latest <= current {
return CheckResult::UpToDate {
current: current_version.to_string(),
};
}
let Some(platform_pkg) = platform_package_name() else {
return CheckResult::Failed {
reason: format!(
"no platform package for target_os={} target_arch={} — cannot auto-update",
std::env::consts::OS,
std::env::consts::ARCH,
),
};
};
let plat_url = format!(
"{registry}/{}/{latest_str}",
platform_pkg.replace('/', "%2F"),
);
let v = match http_get_json(&client, &plat_url).await {
Ok(v) => v,
Err(e) => {
return CheckResult::Failed {
reason: format!(
"platform package {platform_pkg} version {latest_str} not yet on registry: {e}"
),
};
}
};
let Some(tarball_url) = v
.pointer("/dist/tarball")
.and_then(|t| t.as_str())
.map(String::from)
else {
return CheckResult::Failed {
reason: "platform manifest missing dist.tarball".into(),
};
};
let Some(tarball_integrity) = v
.pointer("/dist/integrity")
.and_then(|t| t.as_str())
.map(String::from)
else {
return CheckResult::Failed {
reason: "platform manifest missing dist.integrity".into(),
};
};
if !http_head_ok(&client, &tarball_url).await {
return CheckResult::Failed {
reason: "platform tarball not yet reachable on registry CDN".into(),
};
}
let severity = v
.pointer("/openlatch/severity")
.and_then(|s| s.as_str())
.and_then(Severity::from_str_opt)
.unwrap_or_default();
let min_supported = v
.pointer("/openlatch/min_supported_provider")
.and_then(|s| s.as_str())
.map(String::from);
CheckResult::Available {
current: current_version.to_string(),
latest: latest_str.to_string(),
severity,
min_supported,
tarball_url,
tarball_integrity,
}
}
async fn http_get_json(client: &reqwest::Client, url: &str) -> Result<serde_json::Value, String> {
let resp = client
.get(url)
.header("Accept", "application/json")
.send()
.await
.map_err(|e| format!("send: {e}"))?;
if !resp.status().is_success() {
return Err(format!("HTTP {}", resp.status()));
}
resp.json::<serde_json::Value>()
.await
.map_err(|e| format!("json: {e}"))
}
async fn http_head_ok(client: &reqwest::Client, url: &str) -> bool {
matches!(client.head(url).send().await, Ok(r) if r.status().is_success())
}
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
#[error("tarball download failed: {0}")]
Http(String),
#[error("integrity field is not a recognised SRI hash: {0}")]
IntegrityFormat(String),
#[error("downloaded tarball failed integrity check (SRI mismatch)")]
IntegrityMismatch,
}
pub async fn download_tarball(
url: &str,
sri: &str,
timeout: Duration,
) -> Result<Vec<u8>, DownloadError> {
use sha2::{Digest, Sha512};
use subtle::ConstantTimeEq;
let expected_b64 = sri
.split_ascii_whitespace()
.find_map(|tok| tok.strip_prefix("sha512-"))
.ok_or_else(|| {
DownloadError::IntegrityFormat(format!(
"expected `sha512-...`, got {}",
sri.chars().take(40).collect::<String>()
))
})?;
let expected = base64_decode_lenient(expected_b64).map_err(|e| {
DownloadError::IntegrityFormat(format!("base64 decode of integrity hash failed: {e}"))
})?;
if expected.len() != 64 {
return Err(DownloadError::IntegrityFormat(format!(
"sha512 digest must be 64 bytes; got {}",
expected.len()
)));
}
let client = reqwest::Client::builder()
.timeout(timeout)
.use_rustls_tls()
.build()
.map_err(|e| DownloadError::Http(format!("client build: {e}")))?;
let resp = client
.get(url)
.send()
.await
.map_err(|e| DownloadError::Http(format!("send: {e}")))?;
if !resp.status().is_success() {
return Err(DownloadError::Http(format!("HTTP {}", resp.status())));
}
let bytes = resp
.bytes()
.await
.map_err(|e| DownloadError::Http(format!("read body: {e}")))?
.to_vec();
let mut hasher = Sha512::new();
hasher.update(&bytes);
let actual = hasher.finalize();
if !bool::from(actual.as_slice().ct_eq(&expected)) {
return Err(DownloadError::IntegrityMismatch);
}
Ok(bytes)
}
fn base64_decode_lenient(b64: &str) -> Result<Vec<u8>, String> {
use base64::engine::general_purpose::{STANDARD, STANDARD_NO_PAD};
use base64::Engine;
STANDARD
.decode(b64.as_bytes())
.or_else(|_| STANDARD_NO_PAD.decode(b64.as_bytes()))
.map_err(|e| e.to_string())
}
const ENTRY_SIZE_CAP: u64 = 64 * 1024 * 1024;
#[derive(Debug, thiserror::Error)]
pub enum ExtractError {
#[error("io error during extraction: {0}")]
Io(#[from] std::io::Error),
#[error("tar entry exceeds 64 MB cap")]
EntryTooLarge,
#[error("tar entry has suspicious filename: {0}")]
SuspiciousFilename(String),
#[error("tar contains duplicate entry for: {0}")]
DuplicateEntry(String),
#[error("tar missing required file after extraction: {0}")]
MissingRequired(String),
}
const EXTRACT_ALLOWLIST: &[&str] = &[
"openlatch-provider",
"openlatch-provider.exe",
"openlatch-provider.minisig",
"openlatch-provider.exe.minisig",
];
pub fn extract_to_staging(tarball_bytes: &[u8], staging_dir: &Path) -> Result<(), ExtractError> {
use std::io::Read;
use tar::EntryType;
std::fs::create_dir_all(staging_dir)?;
let gz = flate2::read::GzDecoder::new(tarball_bytes);
let mut archive = tar::Archive::new(gz);
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
for entry in archive.entries()? {
let mut entry = entry?;
if entry.header().entry_type() != EntryType::Regular {
continue;
}
if entry.size() > ENTRY_SIZE_CAP {
return Err(ExtractError::EntryTooLarge);
}
let path_in_tar = entry.path()?;
let Some(basename_os) = path_in_tar.file_name() else {
continue;
};
let Some(basename) = basename_os.to_str() else {
continue;
};
if basename.contains("..") || basename.contains('/') || basename.contains('\\') {
return Err(ExtractError::SuspiciousFilename(basename.to_string()));
}
if !EXTRACT_ALLOWLIST.contains(&basename) {
continue;
}
if !seen.insert(basename.to_string()) {
return Err(ExtractError::DuplicateEntry(basename.to_string()));
}
let dest = staging_dir.join(basename);
let mut out = std::fs::File::create(&dest)?;
let mut buf = [0u8; 64 * 1024];
let mut written: u64 = 0;
loop {
let n = entry.read(&mut buf)?;
if n == 0 {
break;
}
written = written.saturating_add(n as u64);
if written > ENTRY_SIZE_CAP {
return Err(ExtractError::EntryTooLarge);
}
std::io::Write::write_all(&mut out, &buf[..n])?;
}
drop(out);
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&dest, std::fs::Permissions::from_mode(0o755))?;
}
}
let required: &[&str] = if cfg!(windows) {
&["openlatch-provider.exe", "openlatch-provider.exe.minisig"]
} else {
&["openlatch-provider", "openlatch-provider.minisig"]
};
for r in required {
if !staging_dir.join(r).exists() {
return Err(ExtractError::MissingRequired((*r).to_string()));
}
}
Ok(())
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct UpdateSentinel {
pub from: String,
pub to: String,
pub applied_at: String,
}
pub fn sentinel_path() -> PathBuf {
crate::config::provider_dir().join("update-sentinel.json")
}
pub fn write_sentinel(s: &UpdateSentinel) -> std::io::Result<()> {
let path = sentinel_path();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let body = serde_json::to_string_pretty(s).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("serialise sentinel: {e}"),
)
})?;
let tmp = path.with_extension("json.tmp");
std::fs::write(&tmp, body)?;
std::fs::rename(&tmp, &path)?;
Ok(())
}
pub fn read_sentinel() -> Option<UpdateSentinel> {
let path = sentinel_path();
let raw = std::fs::read_to_string(&path).ok()?;
match serde_json::from_str::<UpdateSentinel>(&raw) {
Ok(s) => Some(s),
Err(e) => {
tracing::warn!(target: "update", error = %e, "update sentinel malformed; ignoring");
None
}
}
}
pub fn delete_sentinel() -> std::io::Result<()> {
match std::fs::remove_file(sentinel_path()) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
}
}
pub fn cleanup_bak_files() -> std::io::Result<()> {
fn remove_if_present(p: &Path) -> std::io::Result<()> {
match std::fs::remove_file(p) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
}
}
if let Ok(exe) = std::env::current_exe() {
remove_if_present(&exe.with_extension("bak"))?;
}
let _ = std::fs::remove_file(restart_tracker_path());
Ok(())
}
pub fn restart_into_new_binary() -> Result<(), String> {
let exe = std::env::current_exe().map_err(|e| format!("current_exe: {e}"))?;
let args: Vec<std::ffi::OsString> = std::env::args_os().skip(1).collect();
tracing::info!(target: "update", exe = %exe.display(), "re-executing into new binary");
#[cfg(unix)]
{
use std::ffi::CString;
use std::os::unix::ffi::OsStrExt;
let mut argv: Vec<CString> = Vec::with_capacity(args.len() + 1);
argv.push(
CString::new(exe.as_os_str().as_bytes()).map_err(|e| format!("argv0 cstring: {e}"))?,
);
for a in &args {
argv.push(CString::new(a.as_bytes()).map_err(|e| format!("argv cstring: {e}"))?);
}
let argv_refs: Vec<&std::ffi::CStr> = argv.iter().map(|c| c.as_c_str()).collect();
let cstr_path = CString::new(exe.as_os_str().as_bytes())
.map_err(|e| format!("exe path cstring: {e}"))?;
match nix::unistd::execv(&cstr_path, &argv_refs) {
Ok(_void) => Ok(()),
Err(e) => Err(format!("execv: {e}")),
}
}
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200;
const DETACHED_PROCESS: u32 = 0x0000_0008;
let res = std::process::Command::new(&exe)
.args(&args)
.creation_flags(CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS)
.spawn();
match res {
Ok(_child) => std::process::exit(0),
Err(e) => Err(format!("spawn-detached: {e}")),
}
}
#[cfg(not(any(unix, windows)))]
{
Err("restart_into_new_binary: unsupported platform".into())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApplyMode {
Rpc,
InProcess,
}
impl ApplyMode {
pub fn as_str(self) -> &'static str {
match self {
Self::Rpc => "rpc",
Self::InProcess => "in_process",
}
}
}
#[derive(Debug, Clone)]
pub struct ApplyOptions {
pub current_version: String,
pub registry_origin: String,
pub download_timeout: Duration,
pub force_cargo_install: bool,
pub mode: ApplyMode,
}
impl ApplyOptions {
pub fn for_cli(current_version: impl Into<String>, registry_origin: impl Into<String>) -> Self {
Self {
current_version: current_version.into(),
registry_origin: registry_origin.into(),
download_timeout: Duration::from_secs(60),
force_cargo_install: false,
mode: ApplyMode::InProcess,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ApplyStage {
Check,
Download,
Extract,
Verify,
Sanity,
Swap,
Drain,
Restart,
Healthz,
}
impl ApplyStage {
pub fn as_str(self) -> &'static str {
match self {
Self::Check => "check",
Self::Download => "download",
Self::Extract => "extract",
Self::Verify => "verify",
Self::Sanity => "sanity",
Self::Swap => "swap",
Self::Drain => "drain",
Self::Restart => "restart",
Self::Healthz => "healthz",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub enum UpdateStatusKind {
Idle,
InProgress,
Completed,
Failed,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct UpdateStatusSnapshot {
pub status: UpdateStatusKind,
pub stage: Option<ApplyStage>,
pub from: Option<String>,
pub to: Option<String>,
pub started_at: Option<String>,
pub ended_at: Option<String>,
pub error: Option<String>,
}
impl UpdateStatusSnapshot {
pub fn idle() -> Self {
Self {
status: UpdateStatusKind::Idle,
stage: None,
from: None,
to: None,
started_at: None,
ended_at: None,
error: None,
}
}
}
#[derive(Debug, Clone)]
pub enum ApplyResult {
Applied {
from: String,
to: String,
severity: Severity,
duration_ms: u64,
},
UpToDate {
current: String,
},
RefusedCargoInstall {
suggestion: String,
},
Failed {
stage: ApplyStage,
reason: String,
},
}
pub struct SwapArtefacts {
pub staging_dir: tempfile::TempDir,
pub staging_exe: PathBuf,
pub from: String,
pub to: String,
pub severity: Severity,
}
pub async fn prepare_swap_artefacts(opts: &ApplyOptions) -> Result<SwapArtefacts, ApplyResult> {
use install_state::{detect_install_method, InstallMethod};
if !opts.force_cargo_install && matches!(detect_install_method(), InstallMethod::CargoInstall) {
let suggestion = "Run: cargo install --force --locked openlatch-provider".to_string();
tracing::warn!(target: "update", "refusing to auto-update cargo-install binary");
return Err(ApplyResult::RefusedCargoInstall { suggestion });
}
let check_result = check(&opts.current_version, &opts.registry_origin).await;
let (latest, severity, min_supported, tarball_url, tarball_integrity) = match check_result {
CheckResult::UpToDate { current } => {
return Err(ApplyResult::UpToDate { current });
}
CheckResult::Failed { reason } => {
return Err(ApplyResult::Failed {
stage: ApplyStage::Check,
reason,
});
}
CheckResult::Available {
latest,
severity,
min_supported,
tarball_url,
tarball_integrity,
..
} => (
latest,
severity,
min_supported,
tarball_url,
tarball_integrity,
),
};
if let Some(ref min) = min_supported {
if !version_at_least(&opts.current_version, min) {
return Err(ApplyResult::Failed {
stage: ApplyStage::Check,
reason: format!(
"current version {} is older than min_supported_provider {} for release {} — manual `npm install -g @openlatch/provider@{}` required",
opts.current_version, min, latest, latest
),
});
}
}
tracing::info!(
target: "update",
from = %opts.current_version,
to = %latest,
severity = %severity.as_str(),
"auto-update apply started"
);
let bytes = match download_tarball(&tarball_url, &tarball_integrity, opts.download_timeout)
.await
{
Ok(b) => b,
Err(e) => {
tracing::warn!(target: "update", error = %e, stage = "download", "tarball download failed");
return Err(ApplyResult::Failed {
stage: ApplyStage::Download,
reason: e.to_string(),
});
}
};
let staging_dir = tempfile::tempdir().map_err(|e| ApplyResult::Failed {
stage: ApplyStage::Extract,
reason: format!("create staging tempdir: {e}"),
})?;
if let Err(e) = extract_to_staging(&bytes, staging_dir.path()) {
tracing::warn!(target: "update", error = %e, stage = "extract", "tar extraction failed");
return Err(ApplyResult::Failed {
stage: ApplyStage::Extract,
reason: e.to_string(),
});
}
let (exe_name, exe_sig_name) = if cfg!(windows) {
("openlatch-provider.exe", "openlatch-provider.exe.minisig")
} else {
("openlatch-provider", "openlatch-provider.minisig")
};
let staging_exe = staging_dir.path().join(exe_name);
let staging_exe_sig = staging_dir.path().join(exe_sig_name);
if let Err(e) = verify_with_any_trusted_key(&staging_exe, &staging_exe_sig) {
tracing::warn!(
target: "update",
error = %e,
stage = "verify",
binary = "openlatch-provider",
"signature verification failed"
);
return Err(ApplyResult::Failed {
stage: ApplyStage::Verify,
reason: format!("openlatch-provider verify: {e}"),
});
}
if let Err(e) = sanity_check_version(&staging_exe, &latest) {
tracing::warn!(target: "update", error = %e, stage = "sanity", "sanity check failed");
return Err(ApplyResult::Failed {
stage: ApplyStage::Sanity,
reason: e.to_string(),
});
}
Ok(SwapArtefacts {
staging_dir,
staging_exe,
from: opts.current_version.clone(),
to: latest,
severity,
})
}
pub async fn apply_local(opts: ApplyOptions) -> ApplyResult {
let started = std::time::Instant::now();
let artefacts = match prepare_swap_artefacts(&opts).await {
Ok(a) => a,
Err(short_circuit) => return short_circuit,
};
if let Err(e) = perform_swap(&artefacts.staging_exe) {
return ApplyResult::Failed {
stage: ApplyStage::Swap,
reason: e.to_string(),
};
}
let duration_ms = started.elapsed().as_millis().min(u128::from(u64::MAX)) as u64;
install_state::InstallState::stamp_for_running_binary(&artefacts.to);
tracing::info!(
target: "update",
from = %artefacts.from,
to = %artefacts.to,
duration_ms = duration_ms,
"auto-update apply completed"
);
ApplyResult::Applied {
from: artefacts.from,
to: artefacts.to,
severity: artefacts.severity,
duration_ms,
}
}
pub fn should_apply_now(
severity: Severity,
last_event_at: &AtomicU64,
in_flight: &AtomicU32,
pending_age: Duration,
quiet_window_secs: u64,
max_defer_secs: u64,
) -> bool {
if severity == Severity::Critical {
return true;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let last = last_event_at.load(Ordering::Relaxed);
let idle_secs = now.saturating_sub(last);
let live = in_flight.load(Ordering::Acquire);
if idle_secs >= quiet_window_secs && live == 0 {
return true;
}
if pending_age.as_secs() >= max_defer_secs {
return true;
}
false
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
struct RestartTracker {
starts: Vec<u64>,
}
fn restart_tracker_path() -> PathBuf {
crate::config::provider_dir().join("restart-tracker.json")
}
const RESTART_LOOP_THRESHOLD: usize = 3;
const RESTART_LOOP_WINDOW_SECS: u64 = 60;
const RESTART_TRACKER_CAP: usize = 10;
pub fn should_rollback() -> bool {
if !sentinel_path().exists() {
return false;
}
let exe = match std::env::current_exe() {
Ok(e) => e,
Err(_) => return false,
};
if !exe.with_extension("bak").exists() {
return false;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let path = restart_tracker_path();
let mut tracker: RestartTracker = std::fs::read_to_string(&path)
.ok()
.and_then(|s| serde_json::from_str(&s).ok())
.unwrap_or_default();
tracker
.starts
.retain(|t| now.saturating_sub(*t) <= RESTART_LOOP_WINDOW_SECS);
tracker.starts.push(now);
if tracker.starts.len() > RESTART_TRACKER_CAP {
let drop_n = tracker.starts.len() - RESTART_TRACKER_CAP;
tracker.starts.drain(0..drop_n);
}
if let Ok(body) = serde_json::to_string(&tracker) {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let tmp = path.with_extension("json.tmp");
if std::fs::write(&tmp, body).is_ok() && std::fs::rename(&tmp, &path).is_err() {
let _ = std::fs::remove_file(&tmp);
}
}
tracker.starts.len() >= RESTART_LOOP_THRESHOLD
}
pub fn rollback_from_bak() -> std::io::Result<()> {
let cur = std::env::current_exe()?;
let exe_bak = cur.with_extension("bak");
if exe_bak.exists() {
self_replace::self_replace(&exe_bak)
.map_err(|e| std::io::Error::other(format!("self_replace: {e}")))?;
let _ = std::fs::remove_file(&exe_bak);
}
let _ = delete_sentinel();
let _ = std::fs::remove_file(restart_tracker_path());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_trusted_keys_strips_comments_and_blanks() {
let input = "\
# heading comment
RWQ_FAKE_KEY_ONE
# indented comment
RWQ_FAKE_KEY_TWO
RWQ_FAKE_KEY_THREE
RWQ_EXTRA_KEY_DROPPED
";
let parsed = parse_trusted_keys(input);
assert_eq!(
parsed,
vec![
"RWQ_FAKE_KEY_ONE".to_string(),
"RWQ_FAKE_KEY_TWO".to_string(),
"RWQ_FAKE_KEY_THREE".to_string(),
],
);
}
#[test]
fn parse_trusted_keys_empty_when_only_comments() {
let input = "# comment one\n# comment two\n\n";
assert!(parse_trusted_keys(input).is_empty());
}
#[test]
fn version_at_least_handles_semver_correctly() {
assert!(version_at_least("0.1.10", "0.1.9"));
assert!(version_at_least("0.2.0", "0.1.99"));
assert!(version_at_least("1.0.0", "0.99.99"));
assert!(!version_at_least("0.1.0", "0.2.0"));
assert!(!version_at_least("0.1.9", "0.1.10"));
assert!(version_at_least("0.1.5", "0.1.5"));
}
#[test]
fn version_at_least_returns_false_on_unparsable() {
assert!(!version_at_least("not-a-version", "0.1.0"));
assert!(!version_at_least("0.1.0", "not-a-version"));
}
#[test]
fn severity_from_str_opt_round_trip() {
assert_eq!(Severity::from_str_opt("normal"), Some(Severity::Normal));
assert_eq!(Severity::from_str_opt("critical"), Some(Severity::Critical));
assert_eq!(Severity::from_str_opt("warning"), None);
assert_eq!(Severity::Normal.as_str(), "normal");
assert_eq!(Severity::Critical.as_str(), "critical");
}
#[test]
fn platform_package_name_resolves_for_current_target() {
let name = platform_package_name();
assert!(
name.is_some(),
"every CI target must have a platform package"
);
let n = name.unwrap();
assert!(n.starts_with("@openlatch/provider-"));
}
#[tokio::test]
async fn download_tarball_rejects_non_sha512_integrity() {
let err = download_tarball(
"http://127.0.0.1:1/never",
"sha256-abcd",
Duration::from_secs(1),
)
.await
.expect_err("must reject sha256 SRI");
assert!(
matches!(err, DownloadError::IntegrityFormat(_)),
"got {err:?}"
);
}
#[test]
fn should_apply_now_critical_bypasses_everything() {
let last = AtomicU64::new(0);
let live = AtomicU32::new(5);
assert!(should_apply_now(
Severity::Critical,
&last,
&live,
Duration::from_secs(0),
60,
86_400,
));
}
#[test]
fn should_apply_now_normal_quiet_and_idle_applies() {
let last = AtomicU64::new(0);
let live = AtomicU32::new(0);
assert!(should_apply_now(
Severity::Normal,
&last,
&live,
Duration::from_secs(0),
60,
86_400,
));
}
#[test]
fn should_apply_now_normal_in_flight_defers_within_cap() {
let last = AtomicU64::new(0);
let live = AtomicU32::new(1);
assert!(!should_apply_now(
Severity::Normal,
&last,
&live,
Duration::from_secs(0),
60,
86_400,
));
}
#[test]
fn should_apply_now_normal_recent_activity_defers() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let last = AtomicU64::new(now);
let live = AtomicU32::new(0);
assert!(!should_apply_now(
Severity::Normal,
&last,
&live,
Duration::from_secs(0),
60,
86_400,
));
}
#[test]
fn should_apply_now_normal_hard_cap_overrides_in_flight() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let last = AtomicU64::new(now);
let live = AtomicU32::new(3);
assert!(should_apply_now(
Severity::Normal,
&last,
&live,
Duration::from_secs(86_400),
60,
86_400,
));
}
#[test]
fn restart_tracker_serialises_round_trip() {
let mut t = RestartTracker::default();
t.starts.push(100);
t.starts.push(200);
let json = serde_json::to_string(&t).unwrap();
let back: RestartTracker = serde_json::from_str(&json).unwrap();
assert_eq!(back.starts, vec![100, 200]);
}
}