use std::path::Path;
use crate::cli::commands::doctor_fix::{FixAction, FixKind, Journal};
use crate::cli::commands::lifecycle;
use crate::cli::output::{OutputConfig, OutputFormat};
use crate::cli::DoctorArgs;
use crate::config;
use crate::error::OlError;
use crate::hooks::jsonc;
use crate::telemetry::{self, Event};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RestoreOutcome {
Restored,
NotReversible,
BackupMissing,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct RestoreResult {
pub action: FixAction,
pub outcome: RestoreOutcome,
}
pub fn run(_args: &DoctorArgs, output: &OutputConfig) -> Result<(), OlError> {
let started = std::time::Instant::now();
let ol_dir = config::openlatch_dir();
crate::cli::header::print(output, &["doctor", "--restore"]);
if output.format == OutputFormat::Human && !output.quiet {
eprintln!();
}
let journal = Journal::load(&ol_dir)?;
if journal.actions.is_empty() {
if output.format == OutputFormat::Json {
output.print_json(&serde_json::json!({
"command": "doctor_restore",
"actions_reversed": 0,
"results": [],
"message": "no actions in last --fix run; nothing to restore",
}));
} else if !output.quiet {
eprintln!("No actions in the last --fix run; nothing to restore.");
}
return Ok(());
}
let port = config::Config::load(None, None, false)
.map(|c| c.port)
.unwrap_or(config::PORT_RANGE_START);
let pid = lifecycle::read_pid_file();
let daemon_was_running = pid.map(lifecycle::is_process_alive).unwrap_or(false);
if daemon_was_running {
stop_daemon(port, &ol_dir);
}
let results = restore_actions(&journal.actions, &ol_dir);
let mut daemon_restart_required = daemon_was_running;
if daemon_was_running {
let token = std::fs::read_to_string(ol_dir.join("daemon.token"))
.map(|s| s.trim().to_string())
.unwrap_or_default();
if !token.is_empty() {
if let Ok(_pid) = lifecycle::spawn_daemon_background(port, &token) {
if !lifecycle::wait_for_health(port, 3) {
daemon_restart_required = false; tracing::warn!(
"doctor --restore: daemon spawned but /health did not return 200 within 3s"
);
}
}
}
}
let restored = results
.iter()
.filter(|r| matches!(r.outcome, RestoreOutcome::Restored))
.count();
let skipped = results.len() - restored;
telemetry::capture_global(Event::doctor_restore_run(
restored,
skipped,
daemon_restart_required,
started.elapsed().as_millis() as u64,
));
print_restore_results(&results, daemon_restart_required, output);
Ok(())
}
pub fn restore_actions(actions: &[FixAction], _ol_dir: &Path) -> Vec<RestoreResult> {
let mut results = Vec::with_capacity(actions.len());
for action in actions.iter().rev() {
let outcome = restore_one(action);
results.push(RestoreResult {
action: action.clone(),
outcome,
});
}
results
}
fn restore_one(action: &FixAction) -> RestoreOutcome {
if !action.reversible {
return RestoreOutcome::NotReversible;
}
let Some(backup) = action.backup.as_ref() else {
return RestoreOutcome::NotReversible;
};
if !backup.exists() {
return RestoreOutcome::BackupMissing;
}
match action.kind {
FixKind::HookReinstall => match restore_hooks_surgical(&action.file, backup) {
Ok(()) => RestoreOutcome::Restored,
Err(e) => RestoreOutcome::Failed(format!("{}: {}", e.code, e.message)),
},
FixKind::ConfigRewrite
| FixKind::TokenRegenerate
| FixKind::AgentIdInsert
| FixKind::TelemetryReset
| FixKind::BinaryCopy => match std::fs::copy(backup, &action.file) {
Ok(_) => RestoreOutcome::Restored,
Err(e) => RestoreOutcome::Failed(e.to_string()),
},
FixKind::DaemonRestart | FixKind::PidStaleRemove | FixKind::SupervisionInstall => {
RestoreOutcome::NotReversible
}
}
}
pub fn restore_hooks_surgical(live_path: &Path, bak_path: &Path) -> Result<(), OlError> {
let live_raw = std::fs::read_to_string(live_path).unwrap_or_else(|_| "{}".to_string());
let bak_raw = std::fs::read_to_string(bak_path).map_err(|e| {
OlError::new(
crate::error::ERR_HOOK_WRITE_FAILED,
format!(
"cannot read backup '{}' for hook restore: {e}",
bak_path.display()
),
)
})?;
let bak_value = jsonc::parse_settings_value(&bak_raw)?;
let owned_entries = extract_owned_hook_entries(&bak_value);
let live_no_owned = jsonc::remove_owned_entries(&live_raw)?;
let final_text = if owned_entries.is_empty() {
live_no_owned
} else {
let (text, _) = jsonc::insert_hook_entries(&live_no_owned, &owned_entries)?;
text
};
std::fs::write(live_path, final_text).map_err(|e| {
OlError::new(
crate::error::ERR_HOOK_WRITE_FAILED,
format!(
"cannot write restored settings.json '{}': {e}",
live_path.display()
),
)
})?;
Ok(())
}
fn extract_owned_hook_entries(settings: &serde_json::Value) -> Vec<(String, serde_json::Value)> {
let Some(hooks_obj) = settings.get("hooks").and_then(|v| v.as_object()) else {
return Vec::new();
};
let mut out = Vec::new();
for (event_type, arr) in hooks_obj {
if let Some(entries) = arr.as_array() {
for entry in entries {
if entry.get("_openlatch").and_then(|v| v.as_bool()) == Some(true) {
out.push((event_type.clone(), entry.clone()));
}
}
}
}
out
}
fn stop_daemon(port: u16, ol_dir: &Path) {
let Some(pid) = lifecycle::read_pid_file() else {
return;
};
let token = std::fs::read_to_string(ol_dir.join("daemon.token"))
.map(|s| s.trim().to_string())
.unwrap_or_default();
if !token.is_empty() {
let _ = lifecycle::send_shutdown_request(port, &token);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5);
while std::time::Instant::now() < deadline && lifecycle::is_process_alive(pid) {
std::thread::sleep(std::time::Duration::from_millis(100));
}
}
if lifecycle::is_process_alive(pid) {
lifecycle::force_kill(pid);
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
while std::time::Instant::now() < deadline && lifecycle::is_process_alive(pid) {
std::thread::sleep(std::time::Duration::from_millis(100));
}
}
let _ = std::fs::remove_file(ol_dir.join("daemon.pid"));
}
fn print_restore_results(
results: &[RestoreResult],
daemon_restart_required: bool,
output: &OutputConfig,
) {
let restored_count = results
.iter()
.filter(|r| matches!(r.outcome, RestoreOutcome::Restored))
.count();
let skipped_count = results
.iter()
.filter(|r| !matches!(r.outcome, RestoreOutcome::Restored))
.count();
if output.format == OutputFormat::Json {
let results_json: Vec<serde_json::Value> = results
.iter()
.map(|r| {
serde_json::json!({
"kind": r.action.kind,
"file": r.action.file.display().to_string(),
"outcome": format!("{:?}", r.outcome),
})
})
.collect();
output.print_json(&serde_json::json!({
"command": "doctor_restore",
"actions_reversed": restored_count,
"actions_skipped": skipped_count,
"daemon_restart_required": daemon_restart_required,
"results": results_json,
}));
return;
}
if output.quiet {
return;
}
for r in results {
let label = match &r.outcome {
RestoreOutcome::Restored => format!("restored {}", r.action.file.display()),
RestoreOutcome::NotReversible => format!(
"skipped {:?} (not reversible — {})",
r.action.kind, r.action.note
),
RestoreOutcome::BackupMissing => {
format!(
"backup missing for {} — cannot restore",
r.action.file.display()
)
}
RestoreOutcome::Failed(msg) => {
format!("restore failed for {}: {msg}", r.action.file.display())
}
};
eprintln!(" • {label}");
}
eprintln!();
eprintln!(
"Restored {restored_count} action{}, skipped {skipped_count}.",
if restored_count == 1 { "" } else { "s" }
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cli::commands::doctor_fix::FixKind;
use chrono::Utc;
use serde_json::json;
use std::path::PathBuf;
use tempfile::TempDir;
fn fixture_action(file: PathBuf, backup: Option<PathBuf>, kind: FixKind) -> FixAction {
FixAction {
ol_code: "OL-TEST".to_string(),
kind,
file,
backup,
reversible: true,
applied_at: Utc::now(),
note: "test".to_string(),
}
}
#[test]
fn test_restore_one_blind_swap_replaces_live_with_backup() {
let tmp = TempDir::new().unwrap();
let live = tmp.path().join("config.toml");
let bak = tmp.path().join("config.toml.bak");
std::fs::write(&live, "after = true").unwrap();
std::fs::write(&bak, "before = true").unwrap();
let action = fixture_action(live.clone(), Some(bak.clone()), FixKind::ConfigRewrite);
let outcome = restore_one(&action);
assert_eq!(outcome, RestoreOutcome::Restored);
assert_eq!(std::fs::read_to_string(&live).unwrap(), "before = true");
}
#[test]
fn test_restore_one_returns_backup_missing_when_bak_absent() {
let tmp = TempDir::new().unwrap();
let live = tmp.path().join("config.toml");
let bak = tmp.path().join("config.toml.bak");
std::fs::write(&live, "live").unwrap();
let action = fixture_action(live, Some(bak), FixKind::ConfigRewrite);
let outcome = restore_one(&action);
assert_eq!(outcome, RestoreOutcome::BackupMissing);
}
#[test]
fn test_restore_one_skips_not_reversible_kinds() {
let action = fixture_action(PathBuf::new(), None, FixKind::DaemonRestart);
let outcome = restore_one(&action);
assert_eq!(outcome, RestoreOutcome::NotReversible);
}
#[test]
fn test_extract_owned_hook_entries_returns_only_marked_entries() {
let v = json!({
"hooks": {
"PreToolUse": [
{ "matcher": "Bash", "hooks": [] },
{ "_openlatch": true, "matcher": "", "hooks": [] }
],
"Stop": [
{ "_openlatch": true, "hooks": [] }
]
}
});
let owned = extract_owned_hook_entries(&v);
assert_eq!(owned.len(), 2);
let event_types: Vec<String> = owned.iter().map(|(k, _)| k.clone()).collect();
assert!(event_types.contains(&"PreToolUse".to_string()));
assert!(event_types.contains(&"Stop".to_string()));
}
#[test]
fn test_restore_hooks_surgical_preserves_user_hooks_and_replaces_owned() {
let tmp = TempDir::new().unwrap();
let live = tmp.path().join("settings.json");
let bak = tmp.path().join("settings.json.bak");
let live_raw = r#"{
"hooks": {
"PreToolUse": [
{ "matcher": "Bash", "hooks": [{ "type": "command", "command": "echo hi" }] },
{ "_openlatch": true, "matcher": "", "hooks": [{ "type": "http", "url": "http://wrong:9999/" }] }
]
}
}"#;
let bak_raw = r#"{
"hooks": {
"PreToolUse": [
{ "_openlatch": true, "matcher": "", "hooks": [{ "type": "http", "url": "http://localhost:7443/" }] }
]
}
}"#;
std::fs::write(&live, live_raw).unwrap();
std::fs::write(&bak, bak_raw).unwrap();
restore_hooks_surgical(&live, &bak).expect("restore must succeed");
let result_raw = std::fs::read_to_string(&live).unwrap();
assert!(
result_raw.contains("echo hi"),
"user hook lost: {result_raw}"
);
assert!(
result_raw.contains("http://localhost:7443/"),
"owned entry not restored: {result_raw}"
);
assert!(
!result_raw.contains("http://wrong:9999/"),
"old entry leaked through: {result_raw}"
);
}
#[test]
fn test_restore_hooks_surgical_drops_owned_when_bak_has_none() {
let tmp = TempDir::new().unwrap();
let live = tmp.path().join("settings.json");
let bak = tmp.path().join("settings.json.bak");
std::fs::write(
&live,
r#"{
"hooks": {
"PreToolUse": [
{ "_openlatch": true, "matcher": "", "hooks": [] }
]
}
}"#,
)
.unwrap();
std::fs::write(&bak, r#"{ "hooks": {} }"#).unwrap();
restore_hooks_surgical(&live, &bak).expect("restore must succeed");
let result_raw = std::fs::read_to_string(&live).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&result_raw).unwrap();
let arr = parsed["hooks"]["PreToolUse"].as_array();
assert!(
arr.map(|a| a.is_empty()).unwrap_or(true),
"owned entry survived removal: {result_raw}"
);
}
#[test]
fn test_restore_actions_iterates_in_reverse_order() {
let tmp = TempDir::new().unwrap();
let actions = vec![
fixture_action(
tmp.path().join("a"),
Some(tmp.path().join("a.bak")),
FixKind::ConfigRewrite,
),
fixture_action(
tmp.path().join("b"),
Some(tmp.path().join("b.bak")),
FixKind::ConfigRewrite,
),
];
std::fs::write(tmp.path().join("a.bak"), "A").unwrap();
std::fs::write(tmp.path().join("b.bak"), "B").unwrap();
std::fs::write(tmp.path().join("a"), "X").unwrap();
std::fs::write(tmp.path().join("b"), "X").unwrap();
let results = restore_actions(&actions, tmp.path());
assert_eq!(results.len(), 2);
assert!(results[0].action.file.ends_with("b"));
assert!(results[1].action.file.ends_with("a"));
assert_eq!(std::fs::read_to_string(tmp.path().join("a")).unwrap(), "A");
assert_eq!(std::fs::read_to_string(tmp.path().join("b")).unwrap(), "B");
}
}