Skip to main content

omni_dev/utils/
ai_scratch.rs

1//! AI scratch directory utilities.
2
3use std::env;
4use std::path::{Path, PathBuf};
5
6use anyhow::{Context, Result};
7
8/// Returns the AI scratch directory path based on environment variables and git root detection.
9pub fn get_ai_scratch_dir() -> Result<PathBuf> {
10    // Check for AI_SCRATCH environment variable first
11    if let Ok(ai_scratch) = env::var("AI_SCRATCH") {
12        if let Some(git_root_path) = ai_scratch.strip_prefix("git-root:") {
13            // Find git root and append the path
14            let git_root = find_git_root()?;
15            Ok(git_root.join(git_root_path))
16        } else {
17            // Use AI_SCRATCH directly
18            Ok(PathBuf::from(ai_scratch))
19        }
20    } else {
21        // Fall back to TMPDIR
22        let tmpdir = env::var("TMPDIR").unwrap_or_else(|_| "/tmp".to_string());
23        Ok(PathBuf::from(tmpdir))
24    }
25}
26
27/// Finds the closest ancestor directory containing a .git directory.
28fn find_git_root() -> Result<PathBuf> {
29    let current_dir = env::current_dir().context("Failed to get current directory")?;
30    find_git_root_from_path(&current_dir)
31}
32
33/// Finds the git root starting from a specific path.
34fn find_git_root_from_path(start_path: &Path) -> Result<PathBuf> {
35    let mut current = start_path;
36
37    loop {
38        let git_dir = current.join(".git");
39        if git_dir.exists() {
40            return Ok(current.to_path_buf());
41        }
42
43        match current.parent() {
44            Some(parent) => current = parent,
45            None => {
46                return Err(anyhow::anyhow!(
47                    "No git repository found in current directory or any parent directory"
48                ))
49            }
50        }
51    }
52}
53
54#[cfg(test)]
55#[allow(clippy::unwrap_used, clippy::expect_used)]
56mod tests {
57    use super::*;
58    use std::env;
59    use tempfile::TempDir;
60
61    use std::sync::Mutex;
62    use std::sync::OnceLock;
63
64    /// Global lock to ensure environment variable tests don't interfere with each other.
65    static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
66
67    /// Manages environment variables in tests to avoid interference.
68    struct EnvGuard {
69        _lock: std::sync::MutexGuard<'static, ()>,
70        vars: Vec<(String, Option<String>)>,
71    }
72
73    impl EnvGuard {
74        fn new() -> Self {
75            let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
76            Self {
77                _lock: lock,
78                vars: Vec::new(),
79            }
80        }
81
82        fn set(&mut self, key: &str, value: &str) {
83            let original = env::var(key).ok();
84            self.vars.push((key.to_string(), original));
85            env::set_var(key, value);
86        }
87
88        fn remove(&mut self, key: &str) {
89            let original = env::var(key).ok();
90            self.vars.push((key.to_string(), original));
91            env::remove_var(key);
92        }
93    }
94
95    impl Drop for EnvGuard {
96        fn drop(&mut self) {
97            // Restore in reverse order
98            for (key, original_value) in self.vars.drain(..).rev() {
99                match original_value {
100                    Some(value) => env::set_var(&key, value),
101                    None => env::remove_var(&key),
102                }
103            }
104        }
105    }
106
107    #[test]
108    fn get_ai_scratch_dir_with_direct_path() {
109        let mut guard = EnvGuard::new();
110        guard.set("AI_SCRATCH", "/custom/scratch/path");
111
112        let result = get_ai_scratch_dir().unwrap();
113        assert_eq!(result, PathBuf::from("/custom/scratch/path"));
114    }
115
116    #[test]
117    fn get_ai_scratch_dir_fallback_to_tmpdir() {
118        let mut guard = EnvGuard::new();
119        guard.remove("AI_SCRATCH");
120        guard.set("TMPDIR", "/custom/tmp");
121
122        let result = get_ai_scratch_dir().unwrap();
123        assert_eq!(result, PathBuf::from("/custom/tmp"));
124    }
125
126    #[test]
127    fn find_git_root_from_path() {
128        let _guard = EnvGuard::new(); // Ensure clean environment
129
130        let temp_dir = {
131            std::fs::create_dir_all("tmp").ok();
132            TempDir::new_in("tmp").unwrap()
133        };
134        let git_dir = temp_dir.path().join(".git");
135        std::fs::create_dir(&git_dir).unwrap();
136
137        let sub_dir = temp_dir.path().join("subdir").join("deeper");
138        std::fs::create_dir_all(&sub_dir).unwrap();
139
140        let result = super::find_git_root_from_path(&sub_dir).unwrap();
141        assert_eq!(result, temp_dir.path());
142    }
143}