omni_dev/utils/
ai_scratch.rs1use std::env;
4use std::path::{Path, PathBuf};
5
6use anyhow::{Context, Result};
7
8pub fn get_ai_scratch_dir() -> Result<PathBuf> {
10 if let Ok(ai_scratch) = env::var("AI_SCRATCH") {
12 if let Some(git_root_path) = ai_scratch.strip_prefix("git-root:") {
13 let git_root = find_git_root()?;
15 Ok(git_root.join(git_root_path))
16 } else {
17 Ok(PathBuf::from(ai_scratch))
19 }
20 } else {
21 let tmpdir = env::var("TMPDIR").unwrap_or_else(|_| "/tmp".to_string());
23 Ok(PathBuf::from(tmpdir))
24 }
25}
26
27fn find_git_root() -> Result<PathBuf> {
29 let current_dir = env::current_dir().context("Failed to get current directory")?;
30 find_git_root_from_path(¤t_dir)
31}
32
33fn find_git_root_from_path(start_path: &Path) -> Result<PathBuf> {
35 let mut current = start_path;
36
37 loop {
38 let git_dir = current.join(".git");
39 if git_dir.exists() {
40 return Ok(current.to_path_buf());
41 }
42
43 match current.parent() {
44 Some(parent) => current = parent,
45 None => {
46 return Err(anyhow::anyhow!(
47 "No git repository found in current directory or any parent directory"
48 ))
49 }
50 }
51 }
52}
53
54#[cfg(test)]
55#[allow(clippy::unwrap_used, clippy::expect_used)]
56mod tests {
57 use super::*;
58 use std::env;
59 use tempfile::TempDir;
60
61 use std::sync::Mutex;
62 use std::sync::OnceLock;
63
64 static ENV_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
66
67 struct EnvGuard {
69 _lock: std::sync::MutexGuard<'static, ()>,
70 vars: Vec<(String, Option<String>)>,
71 }
72
73 impl EnvGuard {
74 fn new() -> Self {
75 let lock = ENV_TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
76 Self {
77 _lock: lock,
78 vars: Vec::new(),
79 }
80 }
81
82 fn set(&mut self, key: &str, value: &str) {
83 let original = env::var(key).ok();
84 self.vars.push((key.to_string(), original));
85 env::set_var(key, value);
86 }
87
88 fn remove(&mut self, key: &str) {
89 let original = env::var(key).ok();
90 self.vars.push((key.to_string(), original));
91 env::remove_var(key);
92 }
93 }
94
95 impl Drop for EnvGuard {
96 fn drop(&mut self) {
97 for (key, original_value) in self.vars.drain(..).rev() {
99 match original_value {
100 Some(value) => env::set_var(&key, value),
101 None => env::remove_var(&key),
102 }
103 }
104 }
105 }
106
107 #[test]
108 fn get_ai_scratch_dir_with_direct_path() {
109 let mut guard = EnvGuard::new();
110 guard.set("AI_SCRATCH", "/custom/scratch/path");
111
112 let result = get_ai_scratch_dir().unwrap();
113 assert_eq!(result, PathBuf::from("/custom/scratch/path"));
114 }
115
116 #[test]
117 fn get_ai_scratch_dir_fallback_to_tmpdir() {
118 let mut guard = EnvGuard::new();
119 guard.remove("AI_SCRATCH");
120 guard.set("TMPDIR", "/custom/tmp");
121
122 let result = get_ai_scratch_dir().unwrap();
123 assert_eq!(result, PathBuf::from("/custom/tmp"));
124 }
125
126 #[test]
127 fn find_git_root_from_path() {
128 let _guard = EnvGuard::new(); let temp_dir = {
131 std::fs::create_dir_all("tmp").ok();
132 TempDir::new_in("tmp").unwrap()
133 };
134 let git_dir = temp_dir.path().join(".git");
135 std::fs::create_dir(&git_dir).unwrap();
136
137 let sub_dir = temp_dir.path().join("subdir").join("deeper");
138 std::fs::create_dir_all(&sub_dir).unwrap();
139
140 let result = super::find_git_root_from_path(&sub_dir).unwrap();
141 assert_eq!(result, temp_dir.path());
142 }
143}