#![allow(clippy::missing_panics_doc, clippy::must_use_candidate)]
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::ffi::{OsStr, OsString};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, OnceLock, PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::io::AsyncWriteExt;
fn env_var_lock(key: &'static str) -> &'static RwLock<()> {
static REGISTRY: OnceLock<Mutex<HashMap<&'static str, &'static RwLock<()>>>> = OnceLock::new();
let registry = REGISTRY.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = registry.lock().unwrap_or_else(PoisonError::into_inner);
map.entry(key)
.or_insert_with(|| Box::leak(Box::new(RwLock::new(()))))
}
thread_local! {
static WRITER_KEYS: RefCell<HashSet<&'static str>> = RefCell::new(HashSet::new());
}
pub(crate) fn env_var_read_lock(key: &'static str) -> Option<RwLockReadGuard<'static, ()>> {
let same_thread_writer = WRITER_KEYS.with(|s| s.borrow().contains(key));
if same_thread_writer {
None
} else {
Some(
env_var_lock(key)
.read()
.unwrap_or_else(PoisonError::into_inner),
)
}
}
pub struct EnvGuard {
key: &'static str,
prior: Option<OsString>,
_lock: RwLockWriteGuard<'static, ()>,
}
impl EnvGuard {
pub fn take(key: &'static str) -> Self {
let lock = env_var_lock(key)
.write()
.unwrap_or_else(PoisonError::into_inner);
WRITER_KEYS.with(|s| s.borrow_mut().insert(key));
let prior = std::env::var_os(key);
Self {
key,
prior,
_lock: lock,
}
}
pub fn set(key: &'static str, value: impl AsRef<OsStr>) -> Self {
let guard = Self::take(key);
guard.set_to(value);
guard
}
pub fn unset(key: &'static str) -> Self {
let guard = Self::take(key);
guard.clear();
guard
}
pub fn set_to(&self, value: impl AsRef<OsStr>) {
unsafe {
std::env::set_var(self.key, value);
}
}
pub fn clear(&self) {
unsafe {
std::env::remove_var(self.key);
}
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
unsafe {
match &self.prior {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
WRITER_KEYS.with(|s| s.borrow_mut().remove(self.key));
}
}
use crate::object_store::ObjectStore;
use crate::protocol::backend;
use crate::protocol::{ProtocolError, run};
use crate::url::RemoteUrl;
pub fn git_available() -> bool {
static AVAIL: OnceLock<bool> = OnceLock::new();
*AVAIL.get_or_init(|| {
std::process::Command::new("git")
.arg("--version")
.output()
.is_ok()
})
}
pub fn git(args: &[&str], cwd: &Path) {
let output = std::process::Command::new("git")
.args(args)
.current_dir(cwd)
.output()
.expect("spawn git");
assert!(
output.status.success(),
"git {args:?} failed: stdout={} stderr={}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
}
pub fn git_capture(args: &[&str], cwd: &Path) -> String {
let output = std::process::Command::new("git")
.args(args)
.current_dir(cwd)
.output()
.expect("spawn git");
assert!(
output.status.success(),
"git {args:?} failed: {}",
String::from_utf8_lossy(&output.stderr),
);
String::from_utf8(output.stdout).expect("git stdout utf-8")
}
pub fn make_seed_repo(n: usize, label: &str) -> (tempfile::TempDir, Vec<String>) {
let dir = tempfile::tempdir().expect("tempdir");
git(&["init", "--quiet", "--initial-branch=main"], dir.path());
git(&["config", "user.email", "test@example.com"], dir.path());
git(&["config", "user.name", "Test"], dir.path());
git(&["config", "commit.gpgsign", "false"], dir.path());
let mut shas = Vec::with_capacity(n);
for i in 0..n {
let body = format!("{label}-{i}\n");
std::fs::write(dir.path().join(format!("f{i}.txt")), body.as_bytes()).unwrap();
git(&["add", "."], dir.path());
git(
&["commit", "--quiet", "-m", "step", "--no-gpg-sign"],
dir.path(),
);
let sha = git_capture(&["rev-parse", "HEAD"], dir.path())
.trim()
.to_owned();
shas.push(sha);
}
(dir, shas)
}
pub async fn drive_in(
remote: RemoteUrl,
store: Arc<dyn ObjectStore>,
script: &str,
repo_dir: PathBuf,
) -> (Vec<u8>, Result<(), ProtocolError>) {
let (client_side, helper_side) = tokio::io::duplex(64 * 1024);
let (helper_in, helper_out) = tokio::io::split(helper_side);
let (mut client_reader, mut client_writer) = tokio::io::split(client_side);
let script_bytes = script.as_bytes().to_owned();
let writer_task = tokio::spawn(async move {
let suppress_broken_pipe = |e: std::io::Error| {
if e.kind() == std::io::ErrorKind::BrokenPipe {
Ok(())
} else {
Err(e)
}
};
client_writer
.write_all(&script_bytes)
.await
.or_else(suppress_broken_pipe)
.unwrap();
client_writer
.shutdown()
.await
.or_else(suppress_broken_pipe)
.unwrap();
});
let reader_task = tokio::spawn(async move {
use tokio::io::AsyncReadExt;
let mut buf = Vec::new();
client_reader.read_to_end(&mut buf).await.unwrap();
buf
});
let engine = backend::validate_format(
remote.kind(),
store.as_ref(),
remote.prefix().unwrap_or_default(),
remote.flags().engine,
)
.await
.expect("validate_format must succeed in tests with valid setup");
let result = run(
remote,
store,
engine,
tokio::io::BufReader::new(helper_in),
helper_out,
None,
repo_dir,
)
.await;
writer_task.await.unwrap();
let output = reader_task.await.unwrap();
(output, result)
}
#[cfg(test)]
mod env_guard_tests {
use super::EnvGuard;
#[test]
fn set_then_drop_restores_unset_prior() {
let key = "GROS_ENV_GUARD_TEST_SET_THEN_UNSET";
unsafe {
std::env::remove_var(key);
}
{
let _g = EnvGuard::set(key, "value");
assert_eq!(std::env::var(key).as_deref(), Ok("value"));
}
assert!(std::env::var_os(key).is_none());
}
#[test]
fn set_then_drop_restores_prior_set_value() {
let key = "GROS_ENV_GUARD_TEST_SET_THEN_RESET";
unsafe {
std::env::set_var(key, "original");
}
{
let _g = EnvGuard::set(key, "override");
assert_eq!(std::env::var(key).as_deref(), Ok("override"));
}
assert_eq!(std::env::var(key).as_deref(), Ok("original"));
unsafe {
std::env::remove_var(key);
}
}
#[test]
fn unset_then_drop_restores_prior_value() {
let key = "GROS_ENV_GUARD_TEST_UNSET_THEN_RESET";
unsafe {
std::env::set_var(key, "original");
}
{
let _g = EnvGuard::unset(key);
assert!(std::env::var_os(key).is_none());
}
assert_eq!(std::env::var(key).as_deref(), Ok("original"));
unsafe {
std::env::remove_var(key);
}
}
#[test]
fn take_then_multi_toggle_restores_original() {
let key = "GROS_ENV_GUARD_TEST_MULTI_TOGGLE";
unsafe {
std::env::set_var(key, "first");
}
{
let g = EnvGuard::take(key);
g.set_to("second");
assert_eq!(std::env::var(key).as_deref(), Ok("second"));
g.set_to("third");
assert_eq!(std::env::var(key).as_deref(), Ok("third"));
g.clear();
assert!(std::env::var_os(key).is_none());
}
assert_eq!(std::env::var(key).as_deref(), Ok("first"));
unsafe {
std::env::remove_var(key);
}
}
#[test]
fn env_var_read_lock_succeeds_when_no_writer_active() {
let key = "GROS_ENV_READ_LOCK_NO_WRITER";
let guard = super::env_var_read_lock(key);
assert!(
guard.is_some(),
"no writer for this key — read must succeed"
);
}
#[test]
fn env_var_read_lock_skips_when_same_thread_holds_writer() {
let key = "GROS_ENV_READ_LOCK_SAME_THREAD_RECURSION";
let _g = EnvGuard::take(key);
let read = super::env_var_read_lock(key);
assert!(
read.is_none(),
"same-thread writer must be detected so reader skips locking",
);
}
#[test]
fn env_var_read_lock_allows_concurrent_readers_across_threads() {
use std::sync::Arc;
let key = "GROS_ENV_READ_LOCK_MULTI_READERS";
let barrier = Arc::new(std::sync::Barrier::new(2));
let barrier_for_thread = barrier.clone();
let reader = std::thread::spawn(move || {
let guard = super::env_var_read_lock(key);
assert!(guard.is_some(), "reader on spawned thread must acquire");
barrier_for_thread.wait();
});
let guard = super::env_var_read_lock(key);
assert!(guard.is_some(), "reader on main thread must acquire");
barrier.wait();
reader.join().expect("reader thread");
}
#[test]
fn cross_thread_reader_blocks_until_writer_drops() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let key = "GROS_ENV_READ_LOCK_CROSS_THREAD";
let guard = EnvGuard::set(key, "during");
let acquired = Arc::new(AtomicBool::new(false));
let acquired_for_thread = acquired.clone();
let barrier = Arc::new(std::sync::Barrier::new(2));
let barrier_for_thread = barrier.clone();
let reader = std::thread::spawn(move || {
barrier_for_thread.wait();
let _r = super::env_var_read_lock(key);
acquired_for_thread.store(true, Ordering::SeqCst);
});
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(20));
assert!(
!acquired.load(Ordering::SeqCst),
"reader on another thread must block while the writer is held",
);
drop(guard);
reader.join().expect("reader thread");
assert!(
acquired.load(Ordering::SeqCst),
"reader must acquire after writer releases",
);
}
#[test]
fn panic_inside_guard_still_restores_prior() {
let key = "GROS_ENV_GUARD_TEST_PANIC_RESTORE";
unsafe {
std::env::set_var(key, "before");
}
let outcome = std::panic::catch_unwind(|| {
let _g = EnvGuard::set(key, "during");
panic!("simulated test failure between set and remove");
});
assert!(outcome.is_err(), "the closure must have panicked");
assert_eq!(
std::env::var(key).as_deref(),
Ok("before"),
"Drop must restore the prior value on unwind",
);
unsafe {
std::env::remove_var(key);
}
}
}