use std::ffi::OsString;
pub struct UpdaterConfig {
pub asset_prefix: &'static str,
pub variant_suffix: &'static str,
pub current_version: &'static str,
pub github_authorization: Option<String>,
pub handle: Option<crate::cli::output::Handle>,
}
pub async fn maybe_auto_update<I>(config: UpdaterConfig, args: I)
where
I: IntoIterator<Item = OsString> + Clone,
{
if let Err(e) = imp::run(&config, args).await {
imp::emit_error(config.handle.as_ref(), config.asset_prefix, &e).await;
}
}
mod imp {
use std::ffi::OsString;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::UpdaterConfig;
use crate::cli::output::{Handle, Output};
use crate::cli::output::notification::{Notification, SkipReason, Updater};
const UPDATE_CHECK_INTERVAL: Duration = Duration::from_secs(2 * 3600);
const METADATA_TIMEOUT: Duration = Duration::from_secs(5);
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(60);
const RELEASES_API: &str =
"https://api.github.com/repos/ObjectiveAI/objectiveai/releases/latest";
const SKIP_ENV_VAR: &str = "OBJECTIVEAI_SKIP_UPDATE";
#[derive(Debug, thiserror::Error)]
pub(super) enum Error {
#[error("could not locate current binary: {0}")]
CurrentExe(std::io::Error),
#[error("write marker: {0}")]
WriteMarker(std::io::Error),
#[error("http: {0}")]
Http(String),
#[error("github returned status {0}")]
BadStatus(reqwest::StatusCode),
#[error("malformed release metadata: {0}")]
BadMetadata(serde_json::Error),
#[error("semver parse: {0}")]
Semver(semver::Error),
#[error("no asset named {0} in latest release")]
NoAsset(String),
#[error("download: {0}")]
Download(std::io::Error),
#[error("swap: {0}")]
Swap(std::io::Error),
#[error("re-exec: {0}")]
ReExec(std::io::Error),
}
#[derive(serde::Deserialize)]
struct Release {
tag_name: String,
assets: Vec<Asset>,
}
#[derive(serde::Deserialize)]
struct Asset {
name: String,
browser_download_url: String,
}
pub(super) async fn run<I>(config: &UpdaterConfig, args: I) -> Result<(), Error>
where
I: IntoIterator<Item = OsString> + Clone,
{
let Some(asset_name) = asset_name(config.asset_prefix, config.variant_suffix) else {
emit_notification(
config.handle.as_ref(),
Updater::Skipped { reason: SkipReason::UnsupportedPlatform },
)
.await;
return Ok(());
};
if std::env::var_os(SKIP_ENV_VAR).is_some() {
emit_notification(
config.handle.as_ref(),
Updater::Skipped { reason: SkipReason::OptedOut },
)
.await;
return Ok(());
}
let current_exe = std::env::current_exe().map_err(Error::CurrentExe)?;
if looks_like_dev_tree(¤t_exe) {
emit_notification(
config.handle.as_ref(),
Updater::Skipped { reason: SkipReason::DevTree },
)
.await;
return Ok(());
}
sweep_stale_old(¤t_exe);
let marker = marker_path(config.asset_prefix)?;
if !check_elapsed(&marker) {
return Ok(());
}
write_marker(&marker)?;
emit_notification(
config.handle.as_ref(),
Updater::Checking {
asset_name: asset_name.clone(),
current_version: config.current_version.to_string(),
},
)
.await;
let client = reqwest::Client::new();
let auth = github_authorization(config.github_authorization.as_deref()).await;
let release: Release = {
let mut req = client
.get(RELEASES_API)
.header("User-Agent", user_agent(config))
.header("Accept", "application/vnd.github+json")
.timeout(METADATA_TIMEOUT);
if let Some(header) = auth.as_deref() {
req = req.header("Authorization", header);
}
let resp = req
.send()
.await
.map_err(|e| Error::Http(e.to_string()))?;
let status = resp.status();
if !status.is_success() {
return Err(Error::BadStatus(status));
}
let body = resp
.bytes()
.await
.map_err(|e| Error::Http(e.to_string()))?;
serde_json::from_slice(&body).map_err(Error::BadMetadata)?
};
let remote_str = release.tag_name.strip_prefix('v').unwrap_or(&release.tag_name);
let remote = semver::Version::parse(remote_str).map_err(Error::Semver)?;
let local = semver::Version::parse(config.current_version).map_err(Error::Semver)?;
if remote <= local {
emit_notification(
config.handle.as_ref(),
Updater::UpToDate {
current_version: config.current_version.to_string(),
remote_version: remote.to_string(),
},
)
.await;
return Ok(());
}
let asset = release
.assets
.iter()
.find(|a| a.name == asset_name)
.ok_or_else(|| Error::NoAsset(asset_name.clone()))?;
emit_notification(
config.handle.as_ref(),
Updater::Found {
current_version: config.current_version.to_string(),
remote_version: remote.to_string(),
asset_name: asset_name.clone(),
url: asset.browser_download_url.clone(),
},
)
.await;
let new_path = staged_path(¤t_exe);
download_to(&client, &asset.browser_download_url, auth.as_deref(), &new_path, config)
.await?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&new_path, std::fs::Permissions::from_mode(0o755))
.map_err(Error::Swap)?;
}
self_replace(¤t_exe, &new_path)?;
emit_notification(
config.handle.as_ref(),
Updater::Installed {
current_version: config.current_version.to_string(),
remote_version: remote.to_string(),
},
)
.await;
re_exec(¤t_exe, args)
}
async fn emit<T: serde::Serialize>(handle: Option<&Handle>, output: &Output<T>) {
match handle {
Some(h) => h.emit(output).await,
None => {
let json = serde_json::to_string(output)
.expect("Output<T> serializes when T: Serialize");
println!("{json}");
if matches!(output, Output::Error(e) if e.fatal) {
eprintln!("{json}");
}
}
}
}
async fn emit_notification(handle: Option<&Handle>, value: Updater) {
let output: Output<Updater> = Output::Notification(Notification { value });
emit(handle, &output).await;
}
pub(super) async fn emit_error(
handle: Option<&Handle>,
asset_prefix: &str,
e: &Error,
) {
let output: Output<serde_json::Value> = Output::Error(crate::cli::output::Error {
level: crate::cli::output::Level::Warn,
fatal: false,
message: serde_json::Value::String(format!("{asset_prefix}: auto-update error: {e}")),
});
emit(handle, &output).await;
}
fn platform_triple() -> Option<(&'static str, &'static str, &'static str)> {
#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
{
Some(("linux", "x86_64", ""))
}
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
{
Some(("linux", "aarch64", ""))
}
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
{
Some(("macos", "x86_64", ""))
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
Some(("macos", "aarch64", ""))
}
#[cfg(all(target_os = "windows", target_arch = "x86_64"))]
{
Some(("windows", "x86_64", ".exe"))
}
#[cfg(not(any(
all(target_os = "linux", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "aarch64"),
all(target_os = "macos", target_arch = "x86_64"),
all(target_os = "macos", target_arch = "aarch64"),
all(target_os = "windows", target_arch = "x86_64"),
)))]
{
None
}
}
pub(super) fn asset_name(prefix: &str, variant_suffix: &str) -> Option<String> {
let (os, arch, ext) = platform_triple()?;
Some(format!("{prefix}-{os}-{arch}{variant_suffix}{ext}"))
}
fn staged_path(current_exe: &Path) -> PathBuf {
let mut p = current_exe.to_path_buf();
let filename = p
.file_name()
.map(|s| s.to_string_lossy().into_owned())
.unwrap_or_else(|| "objectiveai".to_string());
let pid = std::process::id();
p.set_file_name(format!("{filename}.new.{pid}"));
p
}
fn looks_like_dev_tree(current_exe: &Path) -> bool {
current_exe.components().any(|c| {
let s = c.as_os_str();
s == "target"
|| s == "target-objectiveai-mcp-filesystem"
|| s == "target-objectiveai-mcp-proxy"
|| s == "target-objectiveai-viewer"
})
}
fn user_agent(config: &UpdaterConfig) -> String {
format!("{}/{}", config.asset_prefix, config.current_version)
}
fn marker_path(asset_prefix: &str) -> Result<PathBuf, Error> {
let fs_client = fs_client();
Ok(fs_client.base_dir().join(format!("updated-{asset_prefix}.txt")))
}
fn fs_client() -> crate::filesystem::Client {
crate::filesystem::Client::new(None::<String>, None::<String>, None::<String>)
}
async fn github_authorization(caller_supplied: Option<&str>) -> Option<String> {
let raw = match caller_supplied
.map(str::trim)
.filter(|s| !s.is_empty())
{
Some(s) => Some(s.to_string()),
None => {
let client = fs_client();
match client.read_config().await {
Ok(mut config) => config
.api()
.headers()
.get_x_github_authorization()
.map(|s| s.to_string()),
Err(_) => None,
}
}
};
raw.map(|s| {
let trimmed = s.trim();
let raw = trimmed.strip_prefix("Bearer ").unwrap_or(trimmed);
format!("Bearer {raw}")
})
}
fn check_elapsed(marker: &Path) -> bool {
let Ok(contents) = std::fs::read_to_string(marker) else {
return true;
};
let Ok(ts) = contents.trim().parse::<u64>() else {
return true;
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now.saturating_sub(ts) >= UPDATE_CHECK_INTERVAL.as_secs()
}
fn write_marker(marker: &Path) -> Result<(), Error> {
if let Some(parent) = marker.parent() {
std::fs::create_dir_all(parent).map_err(Error::WriteMarker)?;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
std::fs::write(marker, now.to_string()).map_err(Error::WriteMarker)
}
async fn download_to(
client: &reqwest::Client,
url: &str,
auth: Option<&str>,
dst: &Path,
config: &UpdaterConfig,
) -> Result<(), Error> {
use futures::StreamExt as _;
use tokio::io::AsyncWriteExt as _;
let mut req = client
.get(url)
.header("User-Agent", user_agent(config))
.timeout(DOWNLOAD_TIMEOUT);
if let Some(header) = auth {
req = req.header("Authorization", header);
}
let resp = req
.send()
.await
.map_err(|e| Error::Http(e.to_string()))?;
let status = resp.status();
if !status.is_success() {
return Err(Error::BadStatus(status));
}
let mut file = tokio::fs::File::create(dst).await.map_err(Error::Download)?;
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| Error::Http(e.to_string()))?;
file.write_all(&chunk).await.map_err(Error::Download)?;
}
file.flush().await.map_err(Error::Download)?;
Ok(())
}
#[cfg(unix)]
fn self_replace(current: &Path, new: &Path) -> Result<(), Error> {
std::fs::rename(new, current).map_err(Error::Swap)
}
#[cfg(windows)]
fn self_replace(current: &Path, new: &Path) -> Result<(), Error> {
let old = current.with_extension("exe.old");
let _ = std::fs::remove_file(&old);
std::fs::rename(current, &old).map_err(Error::Swap)?;
std::fs::rename(new, current).map_err(|e| {
let _ = std::fs::rename(&old, current);
Error::Swap(e)
})
}
#[cfg(not(any(unix, windows)))]
fn self_replace(_current: &Path, _new: &Path) -> Result<(), Error> {
Err(Error::Swap(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"self-replace not implemented on this platform",
)))
}
fn sweep_stale_old(current: &Path) {
#[cfg(windows)]
{
let old = current.with_extension("exe.old");
let _ = std::fs::remove_file(old);
}
#[cfg(not(windows))]
{
let _ = current;
}
}
fn re_exec<I>(current: &Path, args: I) -> Result<(), Error>
where
I: IntoIterator<Item = OsString>,
{
use std::process::{Command, Stdio};
let mut iter = args.into_iter();
let argv0 = iter.next(); let forwarded: Vec<OsString> = iter.collect();
let mut cmd = Command::new(current);
cmd.args(&forwarded)
.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.env(SKIP_ENV_VAR, "1");
#[cfg(unix)]
if let Some(argv0) = argv0.as_ref() {
use std::os::unix::process::CommandExt as _;
cmd.arg0(argv0);
}
#[cfg(not(unix))]
let _ = argv0;
let status = cmd.status().map_err(Error::ReExec)?;
let code = match status.code() {
Some(c) => c,
None => {
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt as _;
status.signal().map(|s| 128 + s).unwrap_or(1)
}
#[cfg(not(unix))]
{
1
}
}
};
std::process::exit(code);
}
}
#[cfg(all(test, feature = "updater"))]
mod tests {
#[test]
fn asset_name_resolves_for_current_target_and_prefix() {
#[cfg(any(
all(target_os = "linux", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "aarch64"),
all(target_os = "macos", target_arch = "x86_64"),
all(target_os = "macos", target_arch = "aarch64"),
all(target_os = "windows", target_arch = "x86_64"),
))]
{
for (prefix, variant) in [
("objectiveai", ""),
("objectiveai", "-no-viewer"),
("objectiveai-api", ""),
("objectiveai-mcp", ""),
("objectiveai-viewer", ""),
] {
let name = super::imp::asset_name(prefix, variant);
assert!(name.is_some(), "asset_name({prefix:?}, {variant:?}) is None");
let name = name.unwrap();
assert!(name.starts_with(&format!("{prefix}-")));
if !variant.is_empty() {
let stem = name
.strip_suffix(".exe")
.map(|s| s.to_string())
.unwrap_or_else(|| name.clone());
assert!(
stem.ends_with(variant),
"expected {stem} to end with {variant}"
);
}
#[cfg(target_os = "windows")]
assert!(name.ends_with(".exe"));
}
}
}
#[test]
fn version_ordering() {
fn needs_update(remote: &str, local: &str) -> bool {
let r = remote.strip_prefix('v').unwrap_or(remote);
semver::Version::parse(r).unwrap() > semver::Version::parse(local).unwrap()
}
assert!(needs_update("v2.0.1", "2.0.0"));
assert!(needs_update("2.0.1", "2.0.0"));
assert!(!needs_update("v2.0.0", "2.0.0"));
assert!(!needs_update("v2.0.0", "2.1.0"));
assert!(needs_update("v3.0.0", "2.99.99"));
}
#[test]
fn updater_notifications_serialize_with_event_tag() {
use crate::cli::output::notification::{SkipReason, Updater};
let cases: &[(Updater, &str)] = &[
(
Updater::Checking {
asset_name: "objectiveai-linux-x86_64".into(),
current_version: "2.0.7".into(),
},
"checking",
),
(
Updater::UpToDate {
current_version: "2.0.7".into(),
remote_version: "2.0.7".into(),
},
"up_to_date",
),
(
Updater::Found {
current_version: "2.0.7".into(),
remote_version: "2.0.8".into(),
asset_name: "objectiveai-linux-x86_64".into(),
url: "https://example.invalid/asset".into(),
},
"found",
),
(
Updater::Installed {
current_version: "2.0.7".into(),
remote_version: "2.0.8".into(),
},
"installed",
),
(
Updater::Skipped {
reason: SkipReason::UnsupportedPlatform,
},
"skipped",
),
];
for (value, expected_tag) in cases {
let json = serde_json::to_value(value).unwrap();
assert_eq!(json["event"], serde_json::json!(expected_tag));
}
}
#[test]
fn updater_skipped_variant_serialization() {
use crate::cli::output::notification::{SkipReason, Updater};
let cases: &[(SkipReason, &str)] = &[
(SkipReason::UnsupportedPlatform, "unsupported_platform"),
(SkipReason::OptedOut, "opted_out"),
(SkipReason::DevTree, "dev_tree"),
];
for (reason, wire) in cases {
let value = Updater::Skipped { reason: reason.clone() };
let json = serde_json::to_value(&value).unwrap();
assert_eq!(json["event"], serde_json::json!("skipped"));
assert_eq!(json["reason"], serde_json::json!(wire));
}
}
}