Skip to main content

mars_agents/platform/
cache.rs

1//! Cache directory naming and root resolution.
2//!
3//! Generates filesystem-safe cache keys from URLs and external identifiers.
4
5use sha2::{Digest, Sha256};
6use std::path::PathBuf;
7
8use crate::error::MarsError;
9
10/// Characters invalid in Windows path components.
11const INVALID_CHARS: &[char] = &['/', '\\', ':', '<', '>', '"', '|', '?', '*'];
12
13/// Windows reserved device names (case-insensitive).
14const RESERVED_NAMES: &[&str] = &[
15    "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8",
16    "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
17];
18
19/// Generate a filesystem-safe single path component from an external identifier.
20///
21/// Rules (applied on all platforms for cross-platform determinism):
22/// - Replace `/`, `\`, `:`, `<`, `>`, `"`, `|`, `?`, `*`, ASCII control chars, NUL with `_`
23/// - Avoid trailing space or dot
24/// - Avoid Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9)
25/// - Truncate to 200 bytes
26pub fn safe_component(raw: &str) -> String {
27    let mut result = String::with_capacity(raw.len());
28
29    for c in raw.chars() {
30        if INVALID_CHARS.contains(&c) || c.is_ascii_control() || c == '\0' {
31            result.push('_');
32        } else {
33            result.push(c);
34        }
35    }
36
37    // Avoid trailing space or dot
38    while result.ends_with(' ') || result.ends_with('.') {
39        result.pop();
40    }
41
42    // Handle reserved names by appending underscore
43    let upper = result.to_ascii_uppercase();
44    for reserved in RESERVED_NAMES {
45        if upper == *reserved || upper.starts_with(&format!("{reserved}.")) {
46            result.push('_');
47            break;
48        }
49    }
50
51    // Truncate to 200 bytes (UTF-8 aware)
52    if result.len() > 200 {
53        let mut end = 200;
54        while end > 0 && !result.is_char_boundary(end) {
55            end -= 1;
56        }
57        result.truncate(end);
58    }
59
60    // Empty result becomes underscore
61    if result.is_empty() {
62        result.push('_');
63    }
64
65    result
66}
67
68/// Generate a safe component with a hash suffix to prevent collisions.
69///
70/// Returns: `{safe_component(raw, prefix_chars=60)}_{hex8(sha256(raw))}`
71pub fn safe_component_with_hash(raw: &str) -> String {
72    let prefix = safe_component(raw);
73
74    // Truncate prefix to 60 chars for readable portion
75    let prefix_truncated = if prefix.len() > 60 {
76        let mut end = 60;
77        while end > 0 && !prefix.is_char_boundary(end) {
78            end -= 1;
79        }
80        &prefix[..end]
81    } else {
82        &prefix
83    };
84
85    // Compute SHA-256 hash and take first 8 hex chars
86    let mut hasher = Sha256::new();
87    hasher.update(raw.as_bytes());
88    let hash = hasher.finalize();
89    let hash_hex: String = hash.iter().take(4).map(|b| format!("{b:02x}")).collect();
90
91    format!("{prefix_truncated}_{hash_hex}")
92}
93
94/// Generate a cache directory component for a git clone URL.
95pub fn git_cache_component(url: &str) -> Result<String, MarsError> {
96    Ok(safe_component_with_hash(normalize_git_url(url)))
97}
98
99/// Generate a cache directory component for an archive URL + SHA.
100pub fn archive_cache_component(url: &str, sha: &str) -> Result<String, MarsError> {
101    let combined = format!("{url}@{sha}");
102    Ok(safe_component_with_hash(&combined))
103}
104
105/// Normalize a git URL for cache key generation.
106///
107/// Strips protocol prefixes, handles SSH shorthand, strips .git suffix.
108fn normalize_git_url(url: &str) -> &str {
109    let mut s = url;
110
111    // Strip common protocol prefixes
112    for prefix in &["https://", "http://", "ssh://", "git://"] {
113        if let Some(rest) = s.strip_prefix(prefix) {
114            s = rest;
115            break;
116        }
117    }
118
119    // Handle SSH shorthand: git@github.com:foo/bar -> github.com:foo/bar
120    // Keep the colon for now, safe_component will convert it.
121    if let Some(rest) = s.strip_prefix("git@") {
122        s = rest;
123    }
124
125    // Strip trailing .git
126    if let Some(rest) = s.strip_suffix(".git") {
127        s = rest;
128    }
129
130    // Strip trailing slash
131    s.strip_suffix('/').unwrap_or(s)
132}
133
134/// Resolve the global cache root directory.
135///
136/// Resolution order:
137/// 1. `MARS_CACHE_DIR` env var
138/// 2. OS cache directory + `mars/cache`
139/// 3. `{cwd}/.mars/cache` fallback
140pub fn global_cache_root() -> Result<PathBuf, MarsError> {
141    if let Some(cache_dir) = std::env::var_os("MARS_CACHE_DIR") {
142        return Ok(PathBuf::from(cache_dir));
143    }
144
145    if let Some(cache_dir) = dirs::cache_dir() {
146        return Ok(cache_dir.join("mars").join("cache"));
147    }
148
149    Ok(std::env::current_dir()
150        .unwrap_or_else(|_| PathBuf::from("."))
151        .join(".mars")
152        .join("cache"))
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use serial_test::serial;
159    use std::ffi::OsString;
160    use std::path::Path;
161
162    #[allow(unused_unsafe)]
163    fn env_set(key: &str, value: &std::path::Path) {
164        unsafe {
165            std::env::set_var(key, value);
166        }
167    }
168
169    #[allow(unused_unsafe)]
170    fn env_remove(key: &str) {
171        unsafe {
172            std::env::remove_var(key);
173        }
174    }
175
176    struct EnvVarGuard {
177        key: String,
178        prev: Option<OsString>,
179    }
180
181    impl EnvVarGuard {
182        fn set_path(key: &str, value: &std::path::Path) -> Self {
183            let prev = std::env::var_os(key);
184            env_set(key, value);
185            Self {
186                key: key.to_string(),
187                prev,
188            }
189        }
190
191        fn remove(key: &str) -> Self {
192            let prev = std::env::var_os(key);
193            env_remove(key);
194            Self {
195                key: key.to_string(),
196                prev,
197            }
198        }
199    }
200
201    impl Drop for EnvVarGuard {
202        fn drop(&mut self) {
203            if let Some(prev) = &self.prev {
204                #[allow(unused_unsafe)]
205                unsafe {
206                    std::env::set_var(&self.key, prev);
207                }
208            } else {
209                env_remove(&self.key);
210            }
211        }
212    }
213
214    #[test]
215    fn safe_component_replaces_invalid_chars() {
216        assert_eq!(safe_component("a/b\\c:d"), "a_b_c_d");
217        assert_eq!(safe_component("file<>name"), "file__name");
218        assert_eq!(safe_component("test|file"), "test_file");
219    }
220
221    #[test]
222    fn safe_component_handles_trailing_space_dot() {
223        assert_eq!(safe_component("test "), "test");
224        assert_eq!(safe_component("test."), "test");
225        assert_eq!(safe_component("test. "), "test");
226    }
227
228    #[test]
229    fn safe_component_handles_reserved_names() {
230        assert_eq!(safe_component("CON"), "CON_");
231        assert_eq!(safe_component("con"), "con_");
232        assert_eq!(safe_component("NUL"), "NUL_");
233        assert_eq!(safe_component("COM1"), "COM1_");
234        assert_eq!(safe_component("lpt9"), "lpt9_");
235    }
236
237    #[test]
238    fn safe_component_handles_empty() {
239        assert_eq!(safe_component(""), "_");
240        assert_eq!(safe_component("..."), "_");
241    }
242
243    #[test]
244    fn safe_component_with_hash_prevents_collisions() {
245        // These would collide without hash.
246        let a = safe_component_with_hash("a:b");
247        let b = safe_component_with_hash("a/b");
248        let c = safe_component_with_hash("a_b");
249
250        // All different due to hash suffix.
251        assert_ne!(a, b);
252        assert_ne!(b, c);
253        assert_ne!(a, c);
254    }
255
256    #[test]
257    fn git_cache_component_no_colon() {
258        // Explicit port URL must not have colon in result.
259        let result = git_cache_component("git://gitlab.localtest.me:19424/group/pkg.git").unwrap();
260        assert!(
261            !result.contains(':'),
262            "cache component should not contain colon: {result}"
263        );
264    }
265
266    #[test]
267    fn git_cache_component_various_urls() {
268        // All should produce valid components without colons.
269        let urls = [
270            "https://github.com/foo/bar",
271            "git@github.com:foo/bar.git",
272            "ssh://git@github.com/foo/bar",
273            "git://host:1234/path.git",
274        ];
275        for url in urls {
276            let result = git_cache_component(url).unwrap();
277            assert!(
278                !result.contains(':'),
279                "URL {url} produced component with colon: {result}"
280            );
281        }
282    }
283
284    #[test]
285    fn archive_cache_component_no_colon() {
286        let result = archive_cache_component("https://host:8080/archive.tar.gz", "abc123").unwrap();
287        assert!(
288            !result.contains(':'),
289            "archive component should not contain colon: {result}"
290        );
291    }
292
293    #[test]
294    #[serial]
295    fn global_cache_root_respects_env_var() {
296        let temp = std::env::temp_dir().join("mars-test-cache");
297        let _guard = EnvVarGuard::set_path("MARS_CACHE_DIR", &temp);
298
299        let root = global_cache_root().unwrap();
300        assert_eq!(root, temp);
301    }
302
303    #[test]
304    #[serial]
305    fn global_cache_root_uses_os_cache_when_no_env() {
306        let _guard = EnvVarGuard::remove("MARS_CACHE_DIR");
307
308        let root = global_cache_root().unwrap();
309
310        if let Some(cache_dir) = dirs::cache_dir() {
311            assert_eq!(root, cache_dir.join("mars").join("cache"));
312        } else {
313            assert!(
314                root.ends_with(Path::new(".mars").join("cache")),
315                "fallback root should end with .mars/cache: {root:?}"
316            );
317        }
318    }
319}