use std::env;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
pub fn get_ai_scratch_dir() -> Result<PathBuf> {
if let Ok(ai_scratch) = env::var("AI_SCRATCH") {
if let Some(git_root_path) = ai_scratch.strip_prefix("git-root:") {
let git_root = find_git_root()?;
Ok(git_root.join(git_root_path))
} else {
Ok(PathBuf::from(ai_scratch))
}
} else {
let tmpdir = env::var("TMPDIR").unwrap_or_else(|_| "/tmp".to_string());
Ok(PathBuf::from(tmpdir))
}
}
fn find_git_root() -> Result<PathBuf> {
let current_dir = env::current_dir().context("Failed to get current directory")?;
find_git_root_from_path(¤t_dir)
}
fn find_git_root_from_path(start_path: &Path) -> Result<PathBuf> {
let mut current = start_path;
loop {
let git_dir = current.join(".git");
if git_dir.exists() {
return Ok(current.to_path_buf());
}
match current.parent() {
Some(parent) => current = parent,
None => {
return Err(anyhow::anyhow!(
"No git repository found in current directory or any parent directory"
))
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::env;
use tempfile::TempDir;
use std::sync::Mutex;
use std::sync::OnceLock;
static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
struct EnvGuard {
_lock: std::sync::MutexGuard<'static, ()>,
vars: Vec<(String, Option<String>)>,
}
impl EnvGuard {
fn new() -> Self {
let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
Self {
_lock: lock,
vars: Vec::new(),
}
}
fn set(&mut self, key: &str, value: &str) {
let original = env::var(key).ok();
self.vars.push((key.to_string(), original));
env::set_var(key, value);
}
fn remove(&mut self, key: &str) {
let original = env::var(key).ok();
self.vars.push((key.to_string(), original));
env::remove_var(key);
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
for (key, original_value) in self.vars.drain(..).rev() {
match original_value {
Some(value) => env::set_var(&key, value),
None => env::remove_var(&key),
}
}
}
}
#[test]
fn get_ai_scratch_dir_with_direct_path() {
let mut guard = EnvGuard::new();
guard.set("AI_SCRATCH", "/custom/scratch/path");
let result = get_ai_scratch_dir().unwrap();
assert_eq!(result, PathBuf::from("/custom/scratch/path"));
}
#[test]
fn get_ai_scratch_dir_fallback_to_tmpdir() {
let mut guard = EnvGuard::new();
guard.remove("AI_SCRATCH");
guard.set("TMPDIR", "/custom/tmp");
let result = get_ai_scratch_dir().unwrap();
assert_eq!(result, PathBuf::from("/custom/tmp"));
}
#[test]
fn find_git_root_from_path() {
let _guard = EnvGuard::new();
let temp_dir = {
std::fs::create_dir_all("tmp").ok();
TempDir::new_in("tmp").unwrap()
};
let git_dir = temp_dir.path().join(".git");
std::fs::create_dir(&git_dir).unwrap();
let sub_dir = temp_dir.path().join("subdir").join("deeper");
std::fs::create_dir_all(&sub_dir).unwrap();
let result = super::find_git_root_from_path(&sub_dir).unwrap();
assert_eq!(result, temp_dir.path());
}
}