use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use std::{env, fs};
use serde::{Deserialize, Serialize};
const ENTRY_TTL_SECS: u64 = 300;
const SCRIPT_EXTENSIONS: &[&str] = &[".sh", ".bash", ".py", ".rb", ".pl", ".zsh"];
const SHELL_INTERPRETERS: &[&str] = &["bash", "sh", "zsh", "python", "python3", "ruby", "perl"];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JournalEntry {
pub dangerous: bool,
pub timestamp: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WriteJournal {
pub entries: HashMap<String, JournalEntry>,
}
impl WriteJournal {
#[must_use]
pub fn load() -> Self {
Self::load_from(&default_journal_path())
}
#[must_use]
pub fn load_from(path: &Path) -> Self {
let Ok(content) = fs::read_to_string(path) else {
return Self::default();
};
serde_json::from_str(&content).unwrap_or_default()
}
pub fn save(&self) -> anyhow::Result<()> {
self.save_to(&default_journal_path())
}
pub fn save_to(&self, path: &Path) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let content = serde_json::to_string(self)?;
fs::write(path, content)?;
Ok(())
}
pub fn record(&mut self, file_path: &str, dangerous: bool) {
self.prune();
self.entries.insert(
file_path.to_owned(),
JournalEntry {
dangerous,
timestamp: now_secs(),
},
);
}
#[must_use]
pub fn is_dangerous(&self, file_path: &str) -> bool {
let now = now_secs();
self.entries.get(file_path).is_some_and(|e| {
e.dangerous && now.saturating_sub(e.timestamp) < ENTRY_TTL_SECS
})
}
pub fn prune(&mut self) {
let now = now_secs();
self.entries
.retain(|_, e| now.saturating_sub(e.timestamp) < ENTRY_TTL_SECS);
}
}
#[must_use]
pub fn extract_executed_paths(command: &str) -> Vec<String> {
let words: Vec<&str> = command.split_whitespace().collect();
let mut paths = Vec::new();
for (i, word) in words.iter().enumerate() {
let basename = Path::new(word)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(word);
if SHELL_INTERPRETERS.contains(&basename) {
if let Some(script) = words[i + 1..].iter().find(|w| !w.starts_with('-')) {
paths.push((*script).to_owned());
}
}
if is_path_like(word) && has_script_extension(word) {
paths.push((*word).to_owned());
}
}
paths.sort();
paths.dedup();
paths
}
fn is_path_like(word: &str) -> bool {
word.starts_with('/')
|| word.starts_with("./")
|| word.starts_with("~/")
}
fn has_script_extension(word: &str) -> bool {
SCRIPT_EXTENSIONS.iter().any(|ext| word.ends_with(ext))
}
fn default_journal_path() -> PathBuf {
if let Ok(runtime) = env::var("XDG_RUNTIME_DIR") {
PathBuf::from(runtime)
.join("guardrail")
.join("write-journal.json")
} else if let Ok(tmpdir) = env::var("TMPDIR") {
PathBuf::from(tmpdir).join("guardrail-journal.json")
} else {
let user = env::var("USER").unwrap_or_else(|_| "unknown".into());
PathBuf::from(format!("/tmp/guardrail-journal-{user}.json"))
}
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn journal_record_and_check() {
let mut journal = WriteJournal::default();
journal.record("/tmp/evil.sh", true);
assert!(journal.is_dangerous("/tmp/evil.sh"));
assert!(!journal.is_dangerous("/tmp/safe.sh"));
}
#[test]
fn journal_safe_write_not_dangerous() {
let mut journal = WriteJournal::default();
journal.record("/tmp/safe.sh", false);
assert!(!journal.is_dangerous("/tmp/safe.sh"));
}
#[test]
fn journal_empty_not_dangerous() {
let journal = WriteJournal::default();
assert!(!journal.is_dangerous("/tmp/anything.sh"));
}
#[test]
fn journal_overwrite_replaces_entry() {
let mut journal = WriteJournal::default();
journal.record("/tmp/file.sh", true);
assert!(journal.is_dangerous("/tmp/file.sh"));
journal.record("/tmp/file.sh", false);
assert!(!journal.is_dangerous("/tmp/file.sh"));
}
#[test]
fn journal_prune_removes_expired() {
let mut journal = WriteJournal::default();
journal.entries.insert(
"/tmp/old.sh".to_owned(),
JournalEntry {
dangerous: true,
timestamp: 1000, },
);
journal.entries.insert(
"/tmp/new.sh".to_owned(),
JournalEntry {
dangerous: true,
timestamp: now_secs(),
},
);
journal.prune();
assert!(!journal.entries.contains_key("/tmp/old.sh"));
assert!(journal.entries.contains_key("/tmp/new.sh"));
}
#[test]
fn journal_expired_entry_not_dangerous() {
let mut journal = WriteJournal::default();
journal.entries.insert(
"/tmp/expired.sh".to_owned(),
JournalEntry {
dangerous: true,
timestamp: 1000,
},
);
assert!(!journal.is_dangerous("/tmp/expired.sh"));
}
#[test]
fn journal_disk_round_trip() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test-journal.json");
let mut journal = WriteJournal::default();
journal.record("/tmp/evil.sh", true);
journal.record("/tmp/safe.sh", false);
journal.save_to(&path).unwrap();
let loaded = WriteJournal::load_from(&path);
assert!(loaded.is_dangerous("/tmp/evil.sh"));
assert!(!loaded.is_dangerous("/tmp/safe.sh"));
assert_eq!(loaded.entries.len(), 2);
}
#[test]
fn journal_load_missing_file_returns_empty() {
let journal = WriteJournal::load_from(Path::new("/nonexistent/journal.json"));
assert!(journal.entries.is_empty());
}
#[test]
fn journal_load_corrupt_file_returns_empty() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("corrupt.json");
fs::write(&path, "not valid json {{{").unwrap();
let journal = WriteJournal::load_from(&path);
assert!(journal.entries.is_empty());
}
#[test]
fn extract_paths_bash_script() {
let paths = extract_executed_paths("bash /tmp/evil.sh");
assert!(paths.contains(&"/tmp/evil.sh".to_owned()));
}
#[test]
fn extract_paths_direct_script() {
let paths = extract_executed_paths("/tmp/deploy.sh --prod");
assert!(paths.contains(&"/tmp/deploy.sh".to_owned()));
}
#[test]
fn extract_paths_python() {
let paths = extract_executed_paths("python3 /tmp/script.py");
assert!(paths.contains(&"/tmp/script.py".to_owned()));
}
#[test]
fn extract_paths_no_scripts() {
let paths = extract_executed_paths("ls -la /tmp");
assert!(paths.is_empty());
}
#[test]
fn extract_paths_relative() {
let paths = extract_executed_paths("./deploy.sh");
assert!(paths.contains(&"./deploy.sh".to_owned()));
}
#[test]
fn extract_paths_shell_with_flags() {
let paths = extract_executed_paths("bash -x -e /tmp/test.sh");
assert!(paths.contains(&"/tmp/test.sh".to_owned()));
}
#[test]
fn extract_paths_tilde() {
let paths = extract_executed_paths("~/scripts/deploy.sh --env prod");
assert!(paths.contains(&"~/scripts/deploy.sh".to_owned()));
}
#[test]
fn extract_paths_no_extension() {
let paths = extract_executed_paths("bash /tmp/binary");
assert!(paths.contains(&"/tmp/binary".to_owned()));
}
#[test]
fn extract_paths_deduplicates() {
let paths = extract_executed_paths("bash /tmp/evil.sh");
assert_eq!(
paths.iter().filter(|p| *p == "/tmp/evil.sh").count(),
1,
"expected exactly one /tmp/evil.sh, got: {paths:?}"
);
}
#[test]
fn is_path_like_checks() {
assert!(is_path_like("/usr/bin/foo"));
assert!(is_path_like("./script.sh"));
assert!(is_path_like("~/bin/bar.sh"));
assert!(!is_path_like("plain"));
assert!(!is_path_like("--flag"));
}
#[test]
fn has_script_extension_checks() {
assert!(has_script_extension("foo.sh"));
assert!(has_script_extension("bar.py"));
assert!(has_script_extension("baz.rb"));
assert!(!has_script_extension("binary"));
assert!(!has_script_extension("file.txt"));
}
#[test]
fn extract_paths_empty_command() {
let paths = extract_executed_paths("");
assert!(paths.is_empty());
}
#[test]
fn extract_paths_whitespace_only() {
let paths = extract_executed_paths(" ");
assert!(paths.is_empty());
}
#[test]
fn extract_paths_multiple_interpreters() {
let paths = extract_executed_paths("bash /tmp/a.sh && python3 /tmp/b.py");
assert!(paths.contains(&"/tmp/a.sh".to_owned()));
assert!(paths.contains(&"/tmp/b.py".to_owned()));
}
#[test]
fn extract_paths_ruby_interpreter() {
let paths = extract_executed_paths("ruby /opt/script.rb");
assert!(paths.contains(&"/opt/script.rb".to_owned()));
}
#[test]
fn extract_paths_perl_interpreter() {
let paths = extract_executed_paths("perl /opt/script.pl");
assert!(paths.contains(&"/opt/script.pl".to_owned()));
}
#[test]
fn extract_paths_zsh_interpreter() {
let paths = extract_executed_paths("zsh ./setup.zsh");
assert!(paths.contains(&"./setup.zsh".to_owned()));
}
#[test]
fn extract_paths_sh_interpreter() {
let paths = extract_executed_paths("sh /tmp/run.sh");
assert!(paths.contains(&"/tmp/run.sh".to_owned()));
}
#[test]
fn extract_paths_no_flag_args() {
let paths = extract_executed_paths("python3 -u -B script.py");
assert!(paths.contains(&"script.py".to_owned()));
}
#[test]
fn extract_paths_bare_word_not_path() {
let paths = extract_executed_paths("echo hello world");
assert!(paths.is_empty());
}
#[test]
fn extract_paths_direct_bash_extension() {
let paths = extract_executed_paths("/usr/local/bin/setup.bash");
assert!(paths.contains(&"/usr/local/bin/setup.bash".to_owned()));
}
#[test]
fn extract_paths_multiple_direct_scripts() {
let paths = extract_executed_paths("./a.sh && ./b.py && ./c.rb");
assert!(paths.contains(&"./a.sh".to_owned()));
assert!(paths.contains(&"./b.py".to_owned()));
assert!(paths.contains(&"./c.rb".to_owned()));
}
#[test]
fn record_auto_prunes() {
let mut journal = WriteJournal::default();
journal.entries.insert(
"/tmp/old.sh".to_owned(),
JournalEntry { dangerous: true, timestamp: 1000 },
);
journal.record("/tmp/new.sh", true);
assert!(!journal.entries.contains_key("/tmp/old.sh"));
assert!(journal.entries.contains_key("/tmp/new.sh"));
}
#[test]
fn record_multiple_files() {
let mut journal = WriteJournal::default();
journal.record("/tmp/a.sh", true);
journal.record("/tmp/b.sh", false);
journal.record("/tmp/c.sh", true);
assert_eq!(journal.entries.len(), 3);
assert!(journal.is_dangerous("/tmp/a.sh"));
assert!(!journal.is_dangerous("/tmp/b.sh"));
assert!(journal.is_dangerous("/tmp/c.sh"));
}
#[test]
fn prune_all_expired() {
let mut journal = WriteJournal::default();
journal.entries.insert(
"/tmp/a.sh".to_owned(),
JournalEntry { dangerous: true, timestamp: 100 },
);
journal.entries.insert(
"/tmp/b.sh".to_owned(),
JournalEntry { dangerous: true, timestamp: 200 },
);
journal.prune();
assert!(journal.entries.is_empty());
}
#[test]
fn prune_keeps_fresh() {
let mut journal = WriteJournal::default();
let now = now_secs();
journal.entries.insert(
"/tmp/fresh.sh".to_owned(),
JournalEntry { dangerous: true, timestamp: now },
);
journal.prune();
assert_eq!(journal.entries.len(), 1);
}
#[test]
fn journal_save_creates_parent_dirs() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("deep/nested/journal.json");
let journal = WriteJournal::default();
journal.save_to(&path).unwrap();
assert!(path.exists());
}
#[test]
fn journal_round_trip_preserves_timestamps() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("ts-test.json");
let now = now_secs();
let mut journal = WriteJournal::default();
journal.entries.insert(
"/tmp/ts.sh".to_owned(),
JournalEntry { dangerous: true, timestamp: now },
);
journal.save_to(&path).unwrap();
let loaded = WriteJournal::load_from(&path);
let entry = loaded.entries.get("/tmp/ts.sh").unwrap();
assert_eq!(entry.timestamp, now);
assert!(entry.dangerous);
}
#[test]
fn is_path_like_edge_cases() {
assert!(!is_path_like(""));
assert!(!is_path_like("-"));
assert!(!is_path_like("~notapath"));
assert!(is_path_like("~/"));
assert!(!is_path_like("relative/path"));
}
#[test]
fn has_script_extension_all_types() {
for ext in SCRIPT_EXTENSIONS {
let filename = format!("test{ext}");
assert!(has_script_extension(&filename), "expected {filename} to have script extension");
}
}
#[test]
fn has_script_extension_false_positives() {
assert!(!has_script_extension("file.pyc"));
assert!(!has_script_extension("file.shell"));
assert!(!has_script_extension("file.rs"));
}
#[test]
fn extract_paths_interpreter_at_end_no_script() {
let paths = extract_executed_paths("echo hello && bash");
assert!(paths.is_empty(), "bare interpreter with no script should yield nothing");
}
#[test]
fn extract_paths_full_path_interpreter() {
let paths = extract_executed_paths("/usr/bin/bash /tmp/script.sh");
assert!(paths.contains(&"/tmp/script.sh".to_owned()));
}
#[test]
fn extract_paths_python3_full_path() {
let paths = extract_executed_paths("/usr/local/bin/python3 /opt/app.py");
assert!(paths.contains(&"/opt/app.py".to_owned()));
}
#[test]
fn default_journal_path_is_absolute() {
let path = default_journal_path();
assert!(
path.is_absolute(),
"journal path should be absolute, got: {}",
path.display()
);
}
#[test]
fn default_journal_path_contains_guardrail() {
let path = default_journal_path();
let path_str = path.to_string_lossy();
assert!(
path_str.contains("guardrail"),
"journal path should contain 'guardrail', got: {path_str}"
);
}
#[test]
fn journal_serde_round_trip() {
let mut journal = WriteJournal::default();
journal.record("/tmp/test.sh", true);
let json = serde_json::to_string(&journal).unwrap();
let loaded: WriteJournal = serde_json::from_str(&json).unwrap();
assert!(loaded.is_dangerous("/tmp/test.sh"));
}
#[test]
fn journal_entry_serde() {
let entry = JournalEntry {
dangerous: true,
timestamp: 12345,
};
let json = serde_json::to_string(&entry).unwrap();
let back: JournalEntry = serde_json::from_str(&json).unwrap();
assert!(back.dangerous);
assert_eq!(back.timestamp, 12345);
}
#[test]
fn now_secs_returns_reasonable_value() {
let ts = now_secs();
assert!(ts > 1_700_000_000, "timestamp should be recent, got: {ts}");
}
}