use std::fs;
use std::fs::OpenOptions;
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
const NOTIFY_EVENT_QUEUE_CAPACITY: usize = 1024;
fn bounded_event_queue<T>() -> (std::sync::mpsc::SyncSender<T>, std::sync::mpsc::Receiver<T>) {
std::sync::mpsc::sync_channel(NOTIFY_EVENT_QUEUE_CAPACITY)
}
pub struct PromptMonitor {
restoration_detected: Arc<AtomicBool>,
stop_signal: Arc<AtomicBool>,
monitor_thread: Option<thread::JoinHandle<()>>,
warnings_tx: std::sync::mpsc::SyncSender<String>,
warnings_rx: std::sync::mpsc::Receiver<String>,
}
impl PromptMonitor {
pub fn new() -> std::io::Result<Self> {
let prompt_path = Path::new("PROMPT.md");
if !prompt_path.exists() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"PROMPT.md does not exist - cannot monitor",
));
}
let (warnings_tx, warnings_rx) = bounded_event_queue();
Ok(Self {
restoration_detected: Arc::new(AtomicBool::new(false)),
stop_signal: Arc::new(AtomicBool::new(false)),
monitor_thread: None,
warnings_tx,
warnings_rx,
})
}
pub fn start(&mut self) -> std::io::Result<()> {
if self.monitor_thread.is_some() {
return Err(std::io::Error::new(
std::io::ErrorKind::AlreadyExists,
"Monitor is already running",
));
}
let restoration_flag = Arc::clone(&self.restoration_detected);
let stop_signal = Arc::clone(&self.stop_signal);
let warnings = self.warnings_tx.clone();
let handle = thread::spawn(move || {
Self::monitor_thread_main(&restoration_flag, &stop_signal, warnings);
});
self.monitor_thread = Some(handle);
Ok(())
}
fn monitor_thread_main(
restoration_detected: &Arc<AtomicBool>,
stop_signal: &Arc<AtomicBool>,
warnings: std::sync::mpsc::SyncSender<String>,
) {
let (tx, rx) = bounded_event_queue();
match setup_directory_watcher(tx) {
Ok(_watcher) => run_watcher_event_loop(&rx, restoration_detected, stop_signal),
Err(e) => {
push_warning(&warnings, watcher_setup_error_message(&e));
Self::polling_monitor(restoration_detected, stop_signal);
}
}
}
fn handle_fs_event(event: ¬ify::Event, restoration_detected: &Arc<AtomicBool>) {
if is_restore_trigger_event(event) && Self::restore_from_backup() {
restoration_detected.store(true, Ordering::Release);
}
}
fn polling_monitor(restoration_detected: &Arc<AtomicBool>, stop_signal: &Arc<AtomicBool>) {
let previous_exists = AtomicBool::new(Path::new("PROMPT.md").exists());
std::iter::from_fn(|| {
if stop_signal.load(Ordering::Relaxed) {
return None;
}
thread::sleep(Duration::from_millis(100));
Some(Path::new("PROMPT.md").exists())
})
.for_each(|current_exists| {
let previous = previous_exists.swap(current_exists, Ordering::AcqRel);
if previous && !current_exists && Self::restore_from_backup() {
restoration_detected.store(true, Ordering::Release);
}
});
}
#[must_use]
pub fn restore_from_backup() -> bool {
let backup_paths = [
Path::new(".agent/PROMPT.md.backup"),
Path::new(".agent/PROMPT.md.backup.1"),
Path::new(".agent/PROMPT.md.backup.2"),
];
let prompt_path = Path::new("PROMPT.md");
backup_paths
.iter()
.filter_map(|backup_path| read_backup_content_secure(backup_path))
.filter(|backup_content| !backup_content.trim().is_empty())
.any(|backup_content| {
restore_prompt_content_atomic(prompt_path, backup_content.as_bytes()).is_ok()
})
}
#[must_use]
pub fn check_and_restore(&self) -> bool {
self.restoration_detected.swap(false, Ordering::AcqRel)
}
#[must_use]
pub fn drain_warnings(&self) -> Vec<String> {
drain_warnings(&self.warnings_rx)
}
#[must_use]
pub fn stop(mut self) -> Vec<String> {
self.stop_signal.store(true, Ordering::Release);
if let Some(handle) = self.monitor_thread.take() {
if let Err(panic_payload) = handle.join() {
push_warning(
&self.warnings_tx,
format!(
"File monitoring thread panicked: {}",
extract_panic_message(panic_payload)
),
);
}
}
self.drain_warnings()
}
}
enum MonitorSetupError {
Create(notify::Error),
Watch(notify::Error),
}
fn watcher_setup_error_message(err: &MonitorSetupError) -> String {
match err {
MonitorSetupError::Create(e) => format!(
"Failed to create file system watcher: {e}. Falling back to periodic polling for PROMPT.md protection."
),
MonitorSetupError::Watch(e) => format!(
"Failed to watch current directory: {e}. Falling back to periodic polling for PROMPT.md protection."
),
}
}
fn run_watcher_event_loop(
rx: &std::sync::mpsc::Receiver<notify::Result<notify::Event>>,
restoration_detected: &Arc<AtomicBool>,
stop_signal: &Arc<AtomicBool>,
) {
std::iter::from_fn(|| {
if stop_signal.load(Ordering::Relaxed) {
return None;
}
Some(rx.recv_timeout(Duration::from_millis(100)))
})
.take_while(|received| {
!matches!(
received,
Err(std::sync::mpsc::RecvTimeoutError::Disconnected)
)
})
.for_each(|received| {
if let Ok(Ok(event)) = received {
PromptMonitor::handle_fs_event(&event, restoration_detected);
std::iter::from_fn(|| rx.try_recv().ok())
.filter_map(Result::ok)
.for_each(|next_event| {
PromptMonitor::handle_fs_event(&next_event, restoration_detected);
});
}
});
}
fn extract_panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
payload
.downcast_ref::<String>()
.cloned()
.or_else(|| payload.downcast_ref::<&str>().map(ToString::to_string))
.or_else(|| payload.downcast_ref::<&String>().map(|s| (*s).clone()))
.unwrap_or_else(|| {
format!(
"<unknown panic type: {}>",
std::any::type_name_of_val(&payload)
)
})
}
fn setup_directory_watcher(
event_sender: std::sync::mpsc::SyncSender<notify::Result<notify::Event>>,
) -> std::result::Result<notify::RecommendedWatcher, MonitorSetupError> {
notify::recommended_watcher(move |res| {
let _ = event_sender.try_send(res);
})
.map_err(MonitorSetupError::Create)
.and_then(|watcher| {
watcher
.with_current_directory_watch()
.map_err(MonitorSetupError::Watch)
})
}
trait WatcherRegistrationExt {
fn with_current_directory_watch(self) -> notify::Result<Self>
where
Self: Sized;
}
impl WatcherRegistrationExt for notify::RecommendedWatcher {
fn with_current_directory_watch(mut self) -> notify::Result<Self> {
use notify::Watcher;
self.watch(Path::new("."), notify::RecursiveMode::NonRecursive)?;
Ok(self)
}
}
fn push_warning(warnings: &std::sync::mpsc::SyncSender<String>, warning: String) {
let _ = warnings.try_send(warning);
}
fn drain_warnings(warnings: &std::sync::mpsc::Receiver<String>) -> Vec<String> {
std::iter::from_fn(|| warnings.try_recv().ok()).collect()
}
fn read_backup_content_secure(path: &Path) -> Option<String> {
#[cfg(unix)]
{
use std::os::unix::fs::{MetadataExt, OpenOptionsExt};
let file = OpenOptions::new()
.read(true)
.custom_flags(libc::O_NOFOLLOW)
.open(path)
.ok()?;
let metadata = file.metadata().ok()?;
if !metadata.is_file() {
return None;
}
if metadata.nlink() != 1 {
return None;
}
std::io::read_to_string(file).ok()
}
#[cfg(not(unix))]
{
let meta = fs::symlink_metadata(path).ok()?;
if meta.file_type().is_symlink() {
return None;
}
if !meta.is_file() {
return None;
}
std::fs::read_to_string(path).ok()
}
}
fn ensure_not_directory(path: &Path) -> std::io::Result<()> {
fs::symlink_metadata(path)
.ok()
.filter(|m| m.is_dir())
.map_or(Ok(()), |_| {
Err(std::io::Error::other("PROMPT.md path is a directory"))
})
}
fn write_and_sync_temp(path: &Path, content: &[u8]) -> std::io::Result<()> {
fs::write(path, content)?;
let _ = OpenOptions::new()
.write(true)
.open(path)
.and_then(|file| file.sync_all());
Ok(())
}
fn make_file_readonly(path: &Path) -> std::io::Result<()> {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(
path,
<fs::Permissions as PermissionsExt>::from_mode(0o444),
)?;
}
#[cfg(windows)]
{
let mut perms = fs::metadata(path)?.permissions();
perms.set_readonly(true);
fs::set_permissions(path, perms)?;
}
Ok(())
}
fn rename_or_cleanup(src: &Path, dest: &Path) -> std::io::Result<()> {
#[cfg(windows)]
if dest.exists() {
let _ = fs::remove_file(dest);
}
fs::rename(src, dest).inspect_err(|_e| {
let _ = fs::remove_file(src);
})
}
fn restore_prompt_content_atomic(prompt_path: &Path, content: &[u8]) -> std::io::Result<()> {
ensure_not_directory(prompt_path)?;
let temp_name = unique_temp_name();
let temp_path = Path::new(&temp_name);
write_and_sync_temp(temp_path, content)?;
make_file_readonly(temp_path)?;
rename_or_cleanup(temp_path, prompt_path)
}
fn unique_temp_name() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let pid = std::process::id();
format!(".prompt_restore_tmp_{pid}_{nanos}")
}
fn is_prompt_md_path(path: &Path) -> bool {
matches!(path.file_name(), Some(name) if name == "PROMPT.md")
}
fn is_restore_trigger_event(event: ¬ify::Event) -> bool {
matches!(event.kind, notify::EventKind::Remove(_))
&& event.paths.iter().any(|path| is_prompt_md_path(path))
}
impl Drop for PromptMonitor {
fn drop(&mut self) {
self.stop_signal.store(true, Ordering::Release);
let _ = self.monitor_thread.take();
}
}
#[cfg(test)]
mod tests {
use super::{drain_warnings, is_restore_trigger_event, push_warning};
use std::path::PathBuf;
fn remove_event(paths: Vec<&str>) -> notify::Event {
paths.into_iter().map(PathBuf::from).fold(
notify::Event::new(notify::EventKind::Remove(notify::event::RemoveKind::Any)),
|event, path| event.add_path(path),
)
}
fn create_event(paths: Vec<&str>) -> notify::Event {
paths.into_iter().map(PathBuf::from).fold(
notify::Event::new(notify::EventKind::Create(notify::event::CreateKind::Any)),
|event, path| event.add_path(path),
)
}
#[test]
fn drain_warnings_clears_buffer_after_read() {
let (warnings_tx, warnings_rx) = std::sync::mpsc::sync_channel::<String>(16);
push_warning(&warnings_tx, "first warning".to_string());
push_warning(&warnings_tx, "second warning".to_string());
let first_drain = drain_warnings(&warnings_rx);
assert_eq!(first_drain.len(), 2);
let second_drain = drain_warnings(&warnings_rx);
assert!(
second_drain.is_empty(),
"warnings should be cleared after drain"
);
}
#[test]
fn restore_trigger_event_requires_remove_kind_and_prompt_path() {
let remove_prompt = remove_event(vec!["PROMPT.md"]);
assert!(is_restore_trigger_event(&remove_prompt));
let remove_other = remove_event(vec!["README.md"]);
assert!(!is_restore_trigger_event(&remove_other));
let create_prompt = create_event(vec!["PROMPT.md"]);
assert!(!is_restore_trigger_event(&create_prompt));
}
}