use std::path::{Path, PathBuf};
use std::time::Duration;
use anyhow::{Context, Result};
use async_nats::jetstream;
use base64::Engine as _;
use kanade_shared::kv::OBJECT_AGENT_RELEASES;
use kanade_shared::wire::EffectiveConfig;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::io::AsyncWriteExt;
use tokio::sync::watch;
use tracing::{error, info, warn};
#[derive(Serialize, Deserialize, Debug, Clone)]
struct LastSwap {
target: String,
running_before: String,
}
pub async fn run(
client: async_nats::Client,
pc_id: String,
running_version: String,
mut cfg_rx: watch::Receiver<EffectiveConfig>,
tracker: crate::staleness::Tracker,
) {
let js = jetstream::new(client.clone());
let store = crate::nats_retry::wait_for_object_store(
&js,
&client,
&tracker,
OBJECT_AGENT_RELEASES,
"self_update",
)
.await;
let last_swap = read_last_swap();
if let Some(prev) = &last_swap {
info!(?prev, "recovered last_swap.json from prior cycle");
if prev.target == running_version && prev.running_before != running_version {
emit_update_event(&pc_id, &prev.running_before, &running_version);
clear_last_swap();
}
}
let (mut current_target, jitter) = {
let cfg = cfg_rx.borrow();
(
cfg.target_version.clone(),
cfg.target_version_jitter_duration(),
)
};
let mut loop_blocked_target: Option<String> = None;
if let Some(target) = current_target.as_deref()
&& target != running_version
{
if is_quarantined(target) {
warn!(
target,
"self-update: target is quarantined (it crash-looped on a prior boot and was \
rolled back). Refusing to re-deploy it — this is what stops a bad rollout from \
looping rollout↔rollback. Republish a fixed binary under a new version, or clear \
the quarantine.",
);
} else if is_loop(&last_swap, target, &running_version) {
loop_blocked_target = Some(target.to_string());
warn!(
target,
running = %running_version,
"self-update LOOP detected — previous swap to this target produced the same running_version. \
Refusing to swap again. The binary under this label has a label/version mismatch; \
republish it or clear target_version (`kanade config unset target_version`)."
);
} else {
sleep_jitter(jitter).await;
if let Err(e) = attempt_swap(&store, target, &running_version).await {
warn!(error = %e, target, "initial self-update fetch failed");
}
}
} else if last_swap.is_some() {
clear_last_swap();
}
loop {
if cfg_rx.changed().await.is_err() {
return;
}
let (new_target, jitter) = {
let cfg = cfg_rx.borrow();
(
cfg.target_version.clone(),
cfg.target_version_jitter_duration(),
)
};
if new_target == current_target {
continue;
}
current_target = new_target.clone();
if loop_blocked_target.is_some() && loop_blocked_target.as_deref() != new_target.as_deref()
{
info!("target_version changed; clearing loop block");
loop_blocked_target = None;
clear_last_swap();
}
if let Some(target) = new_target.as_deref()
&& target != running_version
{
if loop_blocked_target.as_deref() == Some(target) {
warn!(target, "still loop-blocked on this target; ignoring");
continue;
}
if is_quarantined(target) {
warn!(
target,
"self-update: target is quarantined (crash-looped on a prior boot); refusing \
to re-deploy. Republish a fixed version or clear the quarantine.",
);
continue;
}
sleep_jitter(jitter).await;
if let Err(e) = attempt_swap(&store, target, &running_version).await {
warn!(error = %e, target, "self-update fetch failed");
}
}
}
}
fn emit_update_event(pc_id: &str, from: &str, to: &str) {
let event = kanade_shared::wire::ObsEvent {
pc_id: pc_id.to_string(),
at: chrono::Utc::now(),
kind: "agent_update".to_string(),
source: "agent:self_update".to_string(),
event_record_id: Some(format!("self_update_{}", uuid::Uuid::new_v4().as_simple())),
payload: serde_json::json!({ "from": from, "to": to }),
};
let dir = kanade_shared::default_paths::data_dir().join("obs-outbox");
let res = crate::obs_outbox::ensure_outbox_dir(&dir)
.and_then(|()| crate::obs_outbox::enqueue(&dir, &event).map(|_| ()));
match res {
Ok(()) => info!(from, to, "queued agent_update obs event"),
Err(e) => warn!(error = %e, from, to, "failed to queue agent_update obs event"),
}
}
fn is_loop(last: &Option<LastSwap>, target: &str, running: &str) -> bool {
last.as_ref()
.map(|p| p.target == target && p.running_before == running)
.unwrap_or(false)
}
fn is_quarantined(target: &str) -> bool {
use kanade_shared::boot_sentinel::BootSentinel;
let Ok(exe) = std::env::current_exe() else {
return false;
};
BootSentinel::new(
&kanade_shared::default_paths::data_dir(),
exe,
env!("CARGO_PKG_VERSION"),
)
.is_quarantined(target)
}
fn last_swap_path() -> Option<PathBuf> {
use kanade_shared::default_paths;
Some(default_paths::data_dir().join("last_swap.json"))
}
fn read_last_swap() -> Option<LastSwap> {
let path = last_swap_path()?;
let bytes = std::fs::read(&path).ok()?;
serde_json::from_slice(&bytes).ok()
}
fn write_last_swap(target: &str, running_before: &str) {
let Some(path) = last_swap_path() else {
return;
};
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let payload = LastSwap {
target: target.to_string(),
running_before: running_before.to_string(),
};
match serde_json::to_vec(&payload) {
Ok(b) => {
if let Err(e) = std::fs::write(&path, b) {
warn!(error = %e, ?path, "write last_swap.json");
}
}
Err(e) => warn!(error = %e, "encode last_swap.json"),
}
}
fn clear_last_swap() {
if let Some(path) = last_swap_path() {
let _ = std::fs::remove_file(path);
}
}
async fn attempt_swap(
store: &jetstream::object_store::ObjectStore,
target: &str,
running: &str,
) -> Result<()> {
const ATTEMPTS: u32 = 3;
let mut delay = Duration::from_secs(15);
let mut last_err = None;
for attempt in 1..=ATTEMPTS {
match maybe_download(store, target, running).await {
Ok(()) => return Ok(()),
Err(e) => {
warn!(
attempt,
max_attempts = ATTEMPTS,
target,
error = ?e,
"self-update download attempt failed",
);
last_err = Some(e);
if attempt < ATTEMPTS {
tokio::time::sleep(delay).await;
delay *= 3;
}
}
}
}
Err(last_err.expect("at least one attempt ran"))
}
async fn sleep_jitter(max: Duration) {
if max.is_zero() {
return;
}
let secs = max.as_secs();
let pick = if secs == 0 {
0
} else {
use rand::RngExt;
rand::rng().random_range(0..=secs)
};
info!(
jitter_max_secs = secs,
sleep_secs = pick,
"self-update jitter — pausing before download"
);
tokio::time::sleep(Duration::from_secs(pick)).await;
}
async fn maybe_download(
store: &jetstream::object_store::ObjectStore,
target: &str,
running: &str,
) -> Result<()> {
if target == running {
info!(target, "target_version matches running — no self-update");
return Ok(());
}
info!(
target,
running, "target_version drift — downloading new binary"
);
let mut object = store
.get(target)
.await
.with_context(|| format!("object store get '{target}'"))?;
let staging = staging_path(target)?;
if let Some(parent) = staging.parent() {
tokio::fs::create_dir_all(parent).await.ok();
}
let mut file = tokio::fs::File::create(&staging)
.await
.with_context(|| format!("create {staging:?}"))?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
let mut total: u64 = 0;
loop {
let n = tokio::io::AsyncReadExt::read(&mut object, &mut buf)
.await
.context("read object chunk")?;
if n == 0 {
break;
}
file.write_all(&buf[..n])
.await
.context("write staged exe")?;
hasher.update(&buf[..n]);
total += n as u64;
}
file.sync_all().await.context("sync staged exe")?;
drop(file);
let digest = hasher.finalize();
if let Some(expected) = object.info.digest.as_deref() {
if !digest_matches(expected, digest.as_slice()) {
let _ = tokio::fs::remove_file(&staging).await;
let actual = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
anyhow::bail!(
"staged binary digest mismatch for '{target}': object store records {expected}, downloaded bytes hash to SHA-256={actual} — discarding staged file"
);
}
} else {
warn!(
target,
"object store entry carries no digest; proceeding without verification"
);
}
info!(
target,
path = ?staging,
bytes = total,
sha256 = %hex(&digest),
"staged new agent binary (digest verified) — beginning atomic swap",
);
swap_and_restart(&staging, target, running).await?;
Ok(())
}
async fn swap_and_restart(staged: &Path, target_version: &str, running: &str) -> Result<()> {
let current = std::env::current_exe().context("current_exe")?;
let exe_dir = current
.parent()
.context("current_exe has no parent directory")?;
let exe_name = current
.file_name()
.and_then(|s| s.to_str())
.context("current_exe has no UTF-8 file name")?
.to_string();
let new_path = exe_dir.join(format!("{exe_name}.new"));
let old_path = exe_dir.join(format!("{exe_name}.old"));
let _ = tokio::fs::remove_file(&new_path).await;
let _ = tokio::fs::remove_file(&old_path).await;
tokio::fs::copy(staged, &new_path)
.await
.with_context(|| format!("copy {staged:?} -> {new_path:?}"))?;
tokio::fs::rename(¤t, &old_path)
.await
.with_context(|| format!("rename {current:?} -> {old_path:?}"))?;
if let Err(e) = tokio::fs::rename(&new_path, ¤t).await {
match tokio::fs::rename(&old_path, ¤t).await {
Ok(()) => warn!(
error = %e,
"second rename failed; rolled the original exe back into place",
),
Err(restore_err) => error!(
error = %e,
restore_error = %restore_err,
exe = ?current,
backup = ?old_path,
"second rename failed AND rollback failed — service binary path is empty; \
manual repair required (rename the .old file back)",
),
}
return Err(e).with_context(|| format!("rename {new_path:?} -> {current:?}"));
}
write_last_swap(target_version, running);
{
use kanade_shared::boot_sentinel::BootSentinel;
let sentinel = BootSentinel::new(
&kanade_shared::default_paths::data_dir(),
current.clone(),
running,
);
if let Err(e) = sentinel.arm_for_swap(&old_path, target_version) {
warn!(
error = %e, target = target_version,
"boot sentinel: arm_for_swap failed — crash-loop rollback disabled for this swap",
);
}
}
info!(
target = target_version,
replaced = ?current,
backup = ?old_path,
"swap complete — exiting (code 64); SCM failure-actions take over",
);
tokio::time::sleep(std::time::Duration::from_millis(250)).await;
std::process::exit(64);
}
fn staging_path(version: &str) -> Result<PathBuf> {
use kanade_shared::default_paths;
let exe = std::env::current_exe().context("current_exe")?;
let stem = exe
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("kanade-agent")
.to_string();
Ok(default_paths::data_dir()
.join("staging")
.join(format!("{stem}.{version}.staged")))
}
fn hex(bytes: &[u8]) -> String {
use std::fmt::Write;
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(out, "{b:02x}");
}
out
}
fn digest_matches(expected: &str, actual: &[u8]) -> bool {
use base64::engine::general_purpose::{STANDARD_NO_PAD, URL_SAFE_NO_PAD};
expected
.strip_prefix("SHA-256=")
.or_else(|| expected.strip_prefix("sha-256="))
.and_then(|b64| {
let payload = b64.trim_end_matches('=');
URL_SAFE_NO_PAD
.decode(payload)
.or_else(|_| STANDARD_NO_PAD.decode(payload))
.ok()
})
.as_deref()
== Some(actual)
}
#[cfg(test)]
mod tests {
use super::*;
use base64::engine::general_purpose::{STANDARD, URL_SAFE, URL_SAFE_NO_PAD};
const DIGEST: [u8; 32] = [
0x21, 0x3e, 0x9b, 0xbd, 0xfc, 0x8e, 0x5c, 0x44, 0x6d, 0x51, 0x44, 0x24, 0xd0, 0xfe, 0xd3,
0x98, 0x63, 0x24, 0xd7, 0xa0, 0xaa, 0x9e, 0x9a, 0x0c, 0xf8, 0x68, 0x71, 0x91, 0x1a, 0xc4,
0xd2, 0x1f,
];
#[test]
fn matches_padded_url_safe_digest() {
let recorded = format!("SHA-256={}", URL_SAFE.encode(DIGEST));
assert!(recorded.ends_with('='), "fixture must carry padding");
assert!(digest_matches(&recorded, &DIGEST));
}
#[test]
fn matches_unpadded_and_standard_alphabet() {
assert!(digest_matches(
&format!("SHA-256={}", URL_SAFE_NO_PAD.encode(DIGEST)),
&DIGEST
));
assert!(digest_matches(
&format!("SHA-256={}", STANDARD.encode(DIGEST)),
&DIGEST
));
}
#[test]
fn prefix_is_case_insensitive() {
assert!(digest_matches(
&format!("sha-256={}", URL_SAFE.encode(DIGEST)),
&DIGEST
));
}
#[test]
fn rejects_wrong_bytes_missing_prefix_and_garbage() {
let mut other = DIGEST;
other[0] ^= 0xff;
assert!(!digest_matches(
&format!("SHA-256={}", URL_SAFE.encode(DIGEST)),
&other
));
assert!(!digest_matches(&URL_SAFE.encode(DIGEST), &DIGEST));
assert!(!digest_matches("SHA-256=not*valid*base64", &DIGEST));
}
}