Skip to main content

mars_agents/platform/
cache.rs

1//! Cache directory naming and root resolution.
2//!
3//! Generates filesystem-safe cache keys from URLs and external identifiers.
4
5use sha2::{Digest, Sha256};
6use std::path::PathBuf;
7
8use crate::error::MarsError;
9
10/// Characters invalid in Windows path components.
11const INVALID_CHARS: &[char] = &['/', '\\', ':', '<', '>', '"', '|', '?', '*'];
12
13/// Windows reserved device names (case-insensitive).
14const RESERVED_NAMES: &[&str] = &[
15    "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8",
16    "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
17];
18
19/// Generate a filesystem-safe single path component from an external identifier.
20///
21/// Rules (applied on all platforms for cross-platform determinism):
22/// - Replace `/`, `\`, `:`, `<`, `>`, `"`, `|`, `?`, `*`, ASCII control chars, NUL with `_`
23/// - Avoid trailing space or dot
24/// - Avoid Windows reserved device names (CON, PRN, AUX, NUL, COM1-9, LPT1-9)
25/// - Truncate to 200 bytes
26pub fn safe_component(raw: &str) -> String {
27    let mut result = String::with_capacity(raw.len());
28
29    for c in raw.chars() {
30        if INVALID_CHARS.contains(&c) || c.is_ascii_control() || c == '\0' {
31            result.push('_');
32        } else {
33            result.push(c);
34        }
35    }
36
37    // Avoid trailing space or dot
38    while result.ends_with(' ') || result.ends_with('.') {
39        result.pop();
40    }
41
42    // Handle reserved names by appending underscore
43    let upper = result.to_ascii_uppercase();
44    for reserved in RESERVED_NAMES {
45        if upper == *reserved || upper.starts_with(&format!("{reserved}.")) {
46            result.push('_');
47            break;
48        }
49    }
50
51    // Truncate to 200 bytes (UTF-8 aware)
52    if result.len() > 200 {
53        let mut end = 200;
54        while end > 0 && !result.is_char_boundary(end) {
55            end -= 1;
56        }
57        result.truncate(end);
58    }
59
60    // Empty result becomes underscore
61    if result.is_empty() {
62        result.push('_');
63    }
64
65    result
66}
67
68/// Generate a safe component with a hash suffix to prevent collisions.
69///
70/// Returns: `{safe_component(raw, prefix_chars=60)}_{hex8(sha256(raw))}`
71pub fn safe_component_with_hash(raw: &str) -> String {
72    let prefix = safe_component(raw);
73
74    // Truncate prefix to 60 chars for readable portion
75    let prefix_truncated = if prefix.len() > 60 {
76        let mut end = 60;
77        while end > 0 && !prefix.is_char_boundary(end) {
78            end -= 1;
79        }
80        &prefix[..end]
81    } else {
82        &prefix
83    };
84
85    // Compute SHA-256 hash and take first 8 hex chars
86    let mut hasher = Sha256::new();
87    hasher.update(raw.as_bytes());
88    let hash = hasher.finalize();
89    let hash_hex: String = hash.iter().take(4).map(|b| format!("{b:02x}")).collect();
90
91    format!("{prefix_truncated}_{hash_hex}")
92}
93
94/// Generate a cache directory component for a git clone URL.
95pub fn git_cache_component(url: &str) -> Result<String, MarsError> {
96    Ok(safe_component_with_hash(
97        &crate::source::canonical::canonicalize_git_url(url),
98    ))
99}
100
101/// Generate a cache directory component for an archive URL + SHA.
102pub fn archive_cache_component(url: &str, sha: &str) -> Result<String, MarsError> {
103    let combined = format!("{url}@{sha}");
104    Ok(safe_component_with_hash(&combined))
105}
106
107/// Resolve the global cache root directory.
108///
109/// Resolution order:
110/// 1. `MARS_CACHE_DIR` env var
111/// 2. OS cache directory + `mars/cache`
112/// 3. `{cwd}/.mars/cache` fallback
113pub fn global_cache_root() -> Result<PathBuf, MarsError> {
114    if let Some(cache_dir) = std::env::var_os("MARS_CACHE_DIR") {
115        return Ok(PathBuf::from(cache_dir));
116    }
117
118    if let Some(cache_dir) = dirs::cache_dir() {
119        return Ok(cache_dir.join("mars").join("cache"));
120    }
121
122    Ok(std::env::current_dir()
123        .unwrap_or_else(|_| PathBuf::from("."))
124        .join(".mars")
125        .join("cache"))
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use serial_test::serial;
132    use std::ffi::OsString;
133    use std::path::Path;
134
135    #[allow(unused_unsafe)]
136    fn env_set(key: &str, value: &std::path::Path) {
137        unsafe {
138            std::env::set_var(key, value);
139        }
140    }
141
142    #[allow(unused_unsafe)]
143    fn env_remove(key: &str) {
144        unsafe {
145            std::env::remove_var(key);
146        }
147    }
148
149    struct EnvVarGuard {
150        key: String,
151        prev: Option<OsString>,
152    }
153
154    impl EnvVarGuard {
155        fn set_path(key: &str, value: &std::path::Path) -> Self {
156            let prev = std::env::var_os(key);
157            env_set(key, value);
158            Self {
159                key: key.to_string(),
160                prev,
161            }
162        }
163
164        fn remove(key: &str) -> Self {
165            let prev = std::env::var_os(key);
166            env_remove(key);
167            Self {
168                key: key.to_string(),
169                prev,
170            }
171        }
172    }
173
174    impl Drop for EnvVarGuard {
175        fn drop(&mut self) {
176            if let Some(prev) = &self.prev {
177                #[allow(unused_unsafe)]
178                unsafe {
179                    std::env::set_var(&self.key, prev);
180                }
181            } else {
182                env_remove(&self.key);
183            }
184        }
185    }
186
187    #[test]
188    fn safe_component_replaces_invalid_chars() {
189        assert_eq!(safe_component("a/b\\c:d"), "a_b_c_d");
190        assert_eq!(safe_component("file<>name"), "file__name");
191        assert_eq!(safe_component("test|file"), "test_file");
192    }
193
194    #[test]
195    fn safe_component_handles_trailing_space_dot() {
196        assert_eq!(safe_component("test "), "test");
197        assert_eq!(safe_component("test."), "test");
198        assert_eq!(safe_component("test. "), "test");
199    }
200
201    #[test]
202    fn safe_component_handles_reserved_names() {
203        assert_eq!(safe_component("CON"), "CON_");
204        assert_eq!(safe_component("con"), "con_");
205        assert_eq!(safe_component("NUL"), "NUL_");
206        assert_eq!(safe_component("COM1"), "COM1_");
207        assert_eq!(safe_component("lpt9"), "lpt9_");
208    }
209
210    #[test]
211    fn safe_component_handles_empty() {
212        assert_eq!(safe_component(""), "_");
213        assert_eq!(safe_component("..."), "_");
214    }
215
216    #[test]
217    fn safe_component_with_hash_prevents_collisions() {
218        // These would collide without hash.
219        let a = safe_component_with_hash("a:b");
220        let b = safe_component_with_hash("a/b");
221        let c = safe_component_with_hash("a_b");
222
223        // All different due to hash suffix.
224        assert_ne!(a, b);
225        assert_ne!(b, c);
226        assert_ne!(a, c);
227    }
228
229    #[test]
230    fn git_cache_component_no_colon() {
231        // Explicit port URL must not have colon in result.
232        let result = git_cache_component("git://gitlab.localtest.me:19424/group/pkg.git").unwrap();
233        assert!(
234            !result.contains(':'),
235            "cache component should not contain colon: {result}"
236        );
237    }
238
239    #[test]
240    fn git_cache_component_various_urls() {
241        // All should produce valid components without colons.
242        let urls = [
243            "https://github.com/foo/bar",
244            "git@github.com:foo/bar.git",
245            "ssh://git@github.com/foo/bar",
246            "git://host:1234/path.git",
247        ];
248        for url in urls {
249            let result = git_cache_component(url).unwrap();
250            assert!(
251                !result.contains(':'),
252                "URL {url} produced component with colon: {result}"
253            );
254        }
255    }
256
257    #[test]
258    fn archive_cache_component_no_colon() {
259        let result = archive_cache_component("https://host:8080/archive.tar.gz", "abc123").unwrap();
260        assert!(
261            !result.contains(':'),
262            "archive component should not contain colon: {result}"
263        );
264    }
265
266    #[test]
267    #[serial]
268    fn global_cache_root_respects_env_var() {
269        let temp = std::env::temp_dir().join("mars-test-cache");
270        let _guard = EnvVarGuard::set_path("MARS_CACHE_DIR", &temp);
271
272        let root = global_cache_root().unwrap();
273        assert_eq!(root, temp);
274    }
275
276    #[test]
277    #[serial]
278    fn global_cache_root_uses_os_cache_when_no_env() {
279        let _guard = EnvVarGuard::remove("MARS_CACHE_DIR");
280
281        let root = global_cache_root().unwrap();
282
283        if let Some(cache_dir) = dirs::cache_dir() {
284            assert_eq!(root, cache_dir.join("mars").join("cache"));
285        } else {
286            assert!(
287                root.ends_with(Path::new(".mars").join("cache")),
288                "fallback root should end with .mars/cache: {root:?}"
289            );
290        }
291    }
292}