use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedMutexGuard};
pub struct CwdRestoreGuard {
prev_cwd: PathBuf,
_guard: OwnedMutexGuard<()>,
}
impl CwdRestoreGuard {
pub fn new(new_cwd: impl Into<PathBuf>, guard: OwnedMutexGuard<()>) -> std::io::Result<Self> {
let prev_cwd = std::env::current_dir()?;
std::env::set_current_dir(new_cwd.into())?;
Ok(Self {
prev_cwd,
_guard: guard,
})
}
pub fn acquire(guard: OwnedMutexGuard<()>) -> std::io::Result<Self> {
let prev_cwd = std::env::current_dir()?;
Ok(Self {
prev_cwd,
_guard: guard,
})
}
}
impl Drop for CwdRestoreGuard {
fn drop(&mut self) {
if let Err(e) = std::env::set_current_dir(&self.prev_cwd) {
tracing::error!(
dir = %self.prev_cwd.display(),
error = %e,
"failed to restore process cwd"
);
}
}
}
pub type CwdLock = Arc<Mutex<()>>;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use serial_test::serial;
use tokio::sync::{Barrier, Mutex};
use super::{CwdLock, CwdRestoreGuard};
#[tokio::test]
#[serial]
async fn drop_restores_cwd() {
let original = std::env::current_dir().unwrap().canonicalize().unwrap();
let tmp = tempfile::tempdir().unwrap();
let lock: CwdLock = Arc::new(Mutex::new(()));
{
let guard = lock.clone().lock_owned().await;
let _g = CwdRestoreGuard::new(tmp.path(), guard).unwrap();
assert_eq!(
std::env::current_dir().unwrap().canonicalize().unwrap(),
tmp.path().canonicalize().unwrap()
);
}
assert_eq!(
std::env::current_dir().unwrap().canonicalize().unwrap(),
original
);
}
#[tokio::test]
#[serial]
async fn acquire_only_restores_cwd() {
let original = std::env::current_dir().unwrap();
let lock: CwdLock = Arc::new(Mutex::new(()));
{
let guard = lock.clone().lock_owned().await;
let _g = CwdRestoreGuard::acquire(guard).unwrap();
}
assert_eq!(
std::env::current_dir().unwrap().canonicalize().unwrap(),
original.canonicalize().unwrap()
);
}
#[tokio::test]
#[serial]
async fn m4_plain_agent_blocks_while_worktree_guard_held() {
let lock: CwdLock = Arc::new(Mutex::new(()));
let barrier = Arc::new(Barrier::new(2));
let (release_tx, release_rx) = tokio::sync::oneshot::channel::<()>();
let (acquired_tx, acquired_rx) = tokio::sync::oneshot::channel::<()>();
let lock1 = Arc::clone(&lock);
let b1 = Arc::clone(&barrier);
let holder = tokio::spawn(async move {
let guard = lock1.clone().lock_owned().await;
let _g = CwdRestoreGuard::acquire(guard).unwrap();
let _ = acquired_tx.send(());
b1.wait().await;
let _ = release_rx.await;
});
acquired_rx.await.unwrap();
let lock2 = Arc::clone(&lock);
let (second_acquired_tx, second_acquired_rx) = tokio::sync::oneshot::channel::<()>();
let waiter_done = Arc::new(tokio::sync::Semaphore::new(0));
let waiter_done2 = Arc::clone(&waiter_done);
let waiter = tokio::spawn(async move {
let guard = lock2.clone().lock_owned().await;
let _ = CwdRestoreGuard::acquire(guard);
let _ = second_acquired_tx.send(());
waiter_done2.add_permits(1);
});
tokio::task::yield_now().await;
assert!(
waiter_done.try_acquire().is_err(),
"second task must not acquire the guard while first holds it"
);
barrier.wait().await;
let _ = release_tx.send(());
holder.await.unwrap();
waiter.await.unwrap();
assert!(
second_acquired_rx.await.is_ok(),
"second task should have acquired the guard after first released it"
);
}
}