mars_agents/platform/
cache.rs1use sha2::{Digest, Sha256};
6use std::path::PathBuf;
7
8use crate::error::MarsError;
9
10const INVALID_CHARS: &[char] = &['/', '\\', ':', '<', '>', '"', '|', '?', '*'];
12
13const RESERVED_NAMES: &[&str] = &[
15 "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8",
16 "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
17];
18
19pub fn safe_component(raw: &str) -> String {
27 let mut result = String::with_capacity(raw.len());
28
29 for c in raw.chars() {
30 if INVALID_CHARS.contains(&c) || c.is_ascii_control() || c == '\0' {
31 result.push('_');
32 } else {
33 result.push(c);
34 }
35 }
36
37 while result.ends_with(' ') || result.ends_with('.') {
39 result.pop();
40 }
41
42 let upper = result.to_ascii_uppercase();
44 for reserved in RESERVED_NAMES {
45 if upper == *reserved || upper.starts_with(&format!("{reserved}.")) {
46 result.push('_');
47 break;
48 }
49 }
50
51 if result.len() > 200 {
53 let mut end = 200;
54 while end > 0 && !result.is_char_boundary(end) {
55 end -= 1;
56 }
57 result.truncate(end);
58 }
59
60 if result.is_empty() {
62 result.push('_');
63 }
64
65 result
66}
67
68pub fn safe_component_with_hash(raw: &str) -> String {
72 let prefix = safe_component(raw);
73
74 let prefix_truncated = if prefix.len() > 60 {
76 let mut end = 60;
77 while end > 0 && !prefix.is_char_boundary(end) {
78 end -= 1;
79 }
80 &prefix[..end]
81 } else {
82 &prefix
83 };
84
85 let mut hasher = Sha256::new();
87 hasher.update(raw.as_bytes());
88 let hash = hasher.finalize();
89 let hash_hex: String = hash.iter().take(4).map(|b| format!("{b:02x}")).collect();
90
91 format!("{prefix_truncated}_{hash_hex}")
92}
93
94pub fn git_cache_component(url: &str) -> Result<String, MarsError> {
96 Ok(safe_component_with_hash(normalize_git_url(url)))
97}
98
99pub fn archive_cache_component(url: &str, sha: &str) -> Result<String, MarsError> {
101 let combined = format!("{url}@{sha}");
102 Ok(safe_component_with_hash(&combined))
103}
104
105fn normalize_git_url(url: &str) -> &str {
109 let mut s = url;
110
111 for prefix in &["https://", "http://", "ssh://", "git://"] {
113 if let Some(rest) = s.strip_prefix(prefix) {
114 s = rest;
115 break;
116 }
117 }
118
119 if let Some(rest) = s.strip_prefix("git@") {
122 s = rest;
123 }
124
125 if let Some(rest) = s.strip_suffix(".git") {
127 s = rest;
128 }
129
130 s.strip_suffix('/').unwrap_or(s)
132}
133
134pub fn global_cache_root() -> Result<PathBuf, MarsError> {
141 if let Some(cache_dir) = std::env::var_os("MARS_CACHE_DIR") {
142 return Ok(PathBuf::from(cache_dir));
143 }
144
145 if let Some(cache_dir) = dirs::cache_dir() {
146 return Ok(cache_dir.join("mars").join("cache"));
147 }
148
149 Ok(std::env::current_dir()
150 .unwrap_or_else(|_| PathBuf::from("."))
151 .join(".mars")
152 .join("cache"))
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use serial_test::serial;
159 use std::ffi::OsString;
160 use std::path::Path;
161
162 #[allow(unused_unsafe)]
163 fn env_set(key: &str, value: &std::path::Path) {
164 unsafe {
165 std::env::set_var(key, value);
166 }
167 }
168
169 #[allow(unused_unsafe)]
170 fn env_remove(key: &str) {
171 unsafe {
172 std::env::remove_var(key);
173 }
174 }
175
176 struct EnvVarGuard {
177 key: String,
178 prev: Option<OsString>,
179 }
180
181 impl EnvVarGuard {
182 fn set_path(key: &str, value: &std::path::Path) -> Self {
183 let prev = std::env::var_os(key);
184 env_set(key, value);
185 Self {
186 key: key.to_string(),
187 prev,
188 }
189 }
190
191 fn remove(key: &str) -> Self {
192 let prev = std::env::var_os(key);
193 env_remove(key);
194 Self {
195 key: key.to_string(),
196 prev,
197 }
198 }
199 }
200
201 impl Drop for EnvVarGuard {
202 fn drop(&mut self) {
203 if let Some(prev) = &self.prev {
204 #[allow(unused_unsafe)]
205 unsafe {
206 std::env::set_var(&self.key, prev);
207 }
208 } else {
209 env_remove(&self.key);
210 }
211 }
212 }
213
214 #[test]
215 fn safe_component_replaces_invalid_chars() {
216 assert_eq!(safe_component("a/b\\c:d"), "a_b_c_d");
217 assert_eq!(safe_component("file<>name"), "file__name");
218 assert_eq!(safe_component("test|file"), "test_file");
219 }
220
221 #[test]
222 fn safe_component_handles_trailing_space_dot() {
223 assert_eq!(safe_component("test "), "test");
224 assert_eq!(safe_component("test."), "test");
225 assert_eq!(safe_component("test. "), "test");
226 }
227
228 #[test]
229 fn safe_component_handles_reserved_names() {
230 assert_eq!(safe_component("CON"), "CON_");
231 assert_eq!(safe_component("con"), "con_");
232 assert_eq!(safe_component("NUL"), "NUL_");
233 assert_eq!(safe_component("COM1"), "COM1_");
234 assert_eq!(safe_component("lpt9"), "lpt9_");
235 }
236
237 #[test]
238 fn safe_component_handles_empty() {
239 assert_eq!(safe_component(""), "_");
240 assert_eq!(safe_component("..."), "_");
241 }
242
243 #[test]
244 fn safe_component_with_hash_prevents_collisions() {
245 let a = safe_component_with_hash("a:b");
247 let b = safe_component_with_hash("a/b");
248 let c = safe_component_with_hash("a_b");
249
250 assert_ne!(a, b);
252 assert_ne!(b, c);
253 assert_ne!(a, c);
254 }
255
256 #[test]
257 fn git_cache_component_no_colon() {
258 let result = git_cache_component("git://gitlab.localtest.me:19424/group/pkg.git").unwrap();
260 assert!(
261 !result.contains(':'),
262 "cache component should not contain colon: {result}"
263 );
264 }
265
266 #[test]
267 fn git_cache_component_various_urls() {
268 let urls = [
270 "https://github.com/foo/bar",
271 "git@github.com:foo/bar.git",
272 "ssh://git@github.com/foo/bar",
273 "git://host:1234/path.git",
274 ];
275 for url in urls {
276 let result = git_cache_component(url).unwrap();
277 assert!(
278 !result.contains(':'),
279 "URL {url} produced component with colon: {result}"
280 );
281 }
282 }
283
284 #[test]
285 fn archive_cache_component_no_colon() {
286 let result = archive_cache_component("https://host:8080/archive.tar.gz", "abc123").unwrap();
287 assert!(
288 !result.contains(':'),
289 "archive component should not contain colon: {result}"
290 );
291 }
292
293 #[test]
294 #[serial]
295 fn global_cache_root_respects_env_var() {
296 let temp = std::env::temp_dir().join("mars-test-cache");
297 let _guard = EnvVarGuard::set_path("MARS_CACHE_DIR", &temp);
298
299 let root = global_cache_root().unwrap();
300 assert_eq!(root, temp);
301 }
302
303 #[test]
304 #[serial]
305 fn global_cache_root_uses_os_cache_when_no_env() {
306 let _guard = EnvVarGuard::remove("MARS_CACHE_DIR");
307
308 let root = global_cache_root().unwrap();
309
310 if let Some(cache_dir) = dirs::cache_dir() {
311 assert_eq!(root, cache_dir.join("mars").join("cache"));
312 } else {
313 assert!(
314 root.ends_with(Path::new(".mars").join("cache")),
315 "fallback root should end with .mars/cache: {root:?}"
316 );
317 }
318 }
319}