Skip to main content

runex_core/
precache.rs

1use std::collections::HashMap;
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::path::Path;
4
5use serde::{Deserialize, Serialize};
6
7use crate::model::Config;
8
9/// Environment variable name for the command existence cache.
10pub const CACHE_ENV_VAR: &str = "RUNEX_CMD_CACHE_V1";
11
12/// Current cache format version.
13const CACHE_VERSION: u32 = 1;
14
15/// Maximum byte length of the raw JSON env var value. Prevents memory/CPU DoS
16/// from a maliciously large `RUNEX_CMD_CACHE_V1`. 256 KiB is generous for any
17/// realistic config (10 000 rules × ~25 bytes/entry ≈ 250 KB).
18const MAX_CACHE_BYTES: usize = 256 * 1024;
19
20/// Maximum number of entries in the `commands` map. Mirrors `MAX_ABBR_RULES`
21/// in config validation — a cache should never have more entries than there
22/// are abbreviation rules.
23const MAX_CACHE_COMMANDS: usize = 10_000;
24
25/// Expected length of a fingerprint hex string (16 hex chars from u64).
26const FINGERPRINT_LEN: usize = 16;
27
28/// Serialized command existence cache.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct CmdCache {
31    pub v: u32,
32    pub fingerprint: String,
33    pub commands: HashMap<String, bool>,
34}
35
36/// Compute a fingerprint from PATH, config mtime, and shell name.
37///
38/// Uses a fast non-cryptographic hash — this is for staleness detection,
39/// not tamper resistance. An attacker who can modify the env var can also
40/// modify PATH itself.
41pub fn compute_fingerprint(path_env: &str, config_mtime: u64, shell: &str) -> String {
42    let mut hasher = DefaultHasher::new();
43    path_env.hash(&mut hasher);
44    config_mtime.hash(&mut hasher);
45    shell.hash(&mut hasher);
46    format!("{:016x}", hasher.finish())
47}
48
49/// Collect all unique command names referenced by `when_command_exists`
50/// across all abbreviation rules in the config.
51pub fn collect_unique_commands(config: &Config) -> Vec<String> {
52    let mut seen = std::collections::HashSet::new();
53    let mut result = Vec::new();
54    for abbr in &config.abbr {
55        if let Some(cmds) = &abbr.when_command_exists {
56            for cmd_list in cmds.all_values() {
57                for cmd in cmd_list {
58                    if seen.insert(cmd.clone()) {
59                        result.push(cmd.clone());
60                    }
61                }
62            }
63        }
64    }
65    result
66}
67
68/// Build a cache by checking each command with the provided closure.
69pub fn build_cache<F>(
70    config: &Config,
71    fingerprint: &str,
72    command_exists: F,
73) -> CmdCache
74where
75    F: Fn(&str) -> bool,
76{
77    let cmds = collect_unique_commands(config);
78    let mut commands = HashMap::new();
79    for cmd in cmds {
80        commands.insert(cmd.clone(), command_exists(&cmd));
81    }
82    CmdCache {
83        v: CACHE_VERSION,
84        fingerprint: fingerprint.to_string(),
85        commands,
86    }
87}
88
89/// Parse a `"cmd1=1,cmd2=0,..."` string into a `HashMap<String, bool>`.
90///
91/// Unknown entries (not `0` or `1`) are treated as `false`.
92pub fn parse_resolved(resolved: &str) -> HashMap<String, bool> {
93    let mut map = HashMap::new();
94    for entry in resolved.split(',') {
95        let entry = entry.trim();
96        if entry.is_empty() {
97            continue;
98        }
99        if let Some((cmd, val)) = entry.split_once('=') {
100            map.insert(cmd.to_string(), val == "1");
101        }
102    }
103    map
104}
105
106/// Build a cache from externally resolved command existence results.
107///
108/// Used when the calling shell (e.g. PowerShell) checks command existence
109/// via `Get-Command` instead of `which::which()`. Only commands referenced
110/// in the config's `when_command_exists` are included.
111pub fn build_cache_from_resolved(
112    config: &Config,
113    fingerprint: &str,
114    resolved_str: &str,
115) -> CmdCache {
116    let resolved = parse_resolved(resolved_str);
117    let cmds = collect_unique_commands(config);
118    let mut commands = HashMap::new();
119    for cmd in cmds {
120        let exists = resolved.get(&cmd).copied().unwrap_or(false);
121        commands.insert(cmd, exists);
122    }
123    CmdCache {
124        v: CACHE_VERSION,
125        fingerprint: fingerprint.to_string(),
126        commands,
127    }
128}
129
130/// Serialize a cache to JSON.
131pub fn cache_to_json(cache: &CmdCache) -> String {
132    serde_json::to_string(cache).unwrap_or_default()
133}
134
135/// Parse a cache from JSON, returning None on any failure.
136///
137/// Rejects inputs that are too large, have too many command entries, use an
138/// unexpected version, or have a malformed fingerprint. This is a
139/// defense-in-depth measure — the cache is untrusted input from an
140/// environment variable.
141pub fn parse_cache(json: &str) -> Option<CmdCache> {
142    if json.len() > MAX_CACHE_BYTES {
143        return None;
144    }
145    let cache: CmdCache = serde_json::from_str(json).ok()?;
146    if cache.v != CACHE_VERSION {
147        return None;
148    }
149    if cache.fingerprint.len() != FINGERPRINT_LEN
150        || !cache.fingerprint.chars().all(|c| c.is_ascii_hexdigit())
151    {
152        return None;
153    }
154    if cache.commands.len() > MAX_CACHE_COMMANDS {
155        return None;
156    }
157    Some(cache)
158}
159
160/// Load and validate a cache from the environment variable.
161///
162/// Returns None if:
163/// - env var is absent or empty
164/// - JSON is malformed
165/// - version != 1
166/// - fingerprint does not match expected
167pub fn load_cache(expected_fingerprint: &str) -> Option<CmdCache> {
168    let json = std::env::var(CACHE_ENV_VAR).ok()?;
169    let cache = parse_cache(&json)?;
170    if cache.fingerprint != expected_fingerprint {
171        return None;
172    }
173    Some(cache)
174}
175
176/// Get the mtime of a config file as seconds since epoch.
177/// Returns 0 if the file doesn't exist or metadata can't be read.
178pub fn config_mtime(path: &Path) -> u64 {
179    std::fs::metadata(path)
180        .and_then(|m| m.modified())
181        .ok()
182        .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
183        .map(|d| d.as_secs())
184        .unwrap_or(0)
185}
186
187/// Generate a shell export statement for the cache.
188pub fn export_statement(shell: &str, cache_json: &str) -> String {
189    match shell {
190        "bash" | "zsh" => {
191            let escaped = cache_json.replace('\'', "'\\''");
192            format!("export {}='{}'", CACHE_ENV_VAR, escaped)
193        }
194        "pwsh" => {
195            let escaped = cache_json.replace('\'', "''");
196            format!("$env:{}='{}'", CACHE_ENV_VAR, escaped)
197        }
198        "nu" => {
199            let escaped = cache_json.replace('\'', "''");
200            format!("$env.{} = '{}'", CACHE_ENV_VAR, escaped)
201        }
202        "clink" => {
203            // cmd.exe set command — no quotes around value for os.execute
204            let escaped = cache_json.replace('"', "\\\"");
205            format!("set {}={}", CACHE_ENV_VAR, escaped)
206        }
207        _ => {
208            let escaped = cache_json.replace('\'', "'\\''");
209            format!("export {}='{}'", CACHE_ENV_VAR, escaped)
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use crate::model::{Abbr, Config, KeybindConfig, PrecacheConfig, PerShellCmds, PerShellString};
218
219    fn test_config(abbrs: Vec<Abbr>) -> Config {
220        Config {
221            version: 1,
222            keybind: KeybindConfig::default(),
223            precache: PrecacheConfig::default(),
224            abbr: abbrs,
225        }
226    }
227
228    fn abbr_when(key: &str, exp: &str, cmds: Vec<&str>) -> Abbr {
229        Abbr {
230            key: key.into(),
231            expand: PerShellString::All(exp.into()),
232            when_command_exists: Some(PerShellCmds::All(
233                cmds.into_iter().map(String::from).collect(),
234            )),
235        }
236    }
237
238    #[test]
239    fn cache_roundtrip() {
240        let config = test_config(vec![
241            abbr_when("ls", "lsd", vec!["lsd"]),
242            abbr_when("7z", "7zip", vec!["7z"]),
243        ]);
244        let fp = compute_fingerprint("/usr/bin:/bin", 1234567890, "bash");
245        let cache = build_cache(&config, &fp, |cmd| cmd == "lsd");
246
247        let json = cache_to_json(&cache);
248        let parsed = parse_cache(&json).expect("should parse");
249
250        assert_eq!(parsed.v, 1);
251        assert_eq!(parsed.fingerprint, fp);
252        assert_eq!(parsed.commands.get("lsd"), Some(&true));
253        assert_eq!(parsed.commands.get("7z"), Some(&false));
254    }
255
256    #[test]
257    fn fingerprint_changes_on_path_change() {
258        let fp1 = compute_fingerprint("/usr/bin:/bin", 100, "bash");
259        let fp2 = compute_fingerprint("/usr/local/bin:/usr/bin:/bin", 100, "bash");
260        assert_ne!(fp1, fp2);
261    }
262
263    #[test]
264    fn fingerprint_changes_on_mtime_change() {
265        let fp1 = compute_fingerprint("/usr/bin", 100, "bash");
266        let fp2 = compute_fingerprint("/usr/bin", 200, "bash");
267        assert_ne!(fp1, fp2);
268    }
269
270    #[test]
271    fn fingerprint_changes_on_shell_change() {
272        let fp1 = compute_fingerprint("/usr/bin", 100, "bash");
273        let fp2 = compute_fingerprint("/usr/bin", 100, "pwsh");
274        assert_ne!(fp1, fp2);
275    }
276
277    #[test]
278    fn parse_invalid_json_returns_none() {
279        assert!(parse_cache("not json").is_none());
280        assert!(parse_cache("").is_none());
281        assert!(parse_cache("{}").is_none());
282    }
283
284    #[test]
285    fn parse_wrong_version_returns_none() {
286        let json = r#"{"v":99,"fingerprint":"0123456789abcdef","commands":{}}"#;
287        assert!(parse_cache(json).is_none());
288    }
289
290    #[test]
291    fn parse_rejects_oversized_json() {
292        // Just over MAX_CACHE_BYTES
293        let huge = format!(
294            r#"{{"v":1,"fingerprint":"0123456789abcdef","commands":{{"{}":true}}}}"#,
295            "a".repeat(MAX_CACHE_BYTES)
296        );
297        assert!(parse_cache(&huge).is_none());
298    }
299
300    #[test]
301    fn parse_rejects_bad_fingerprint_format() {
302        // Too short
303        let json = r#"{"v":1,"fingerprint":"abc","commands":{}}"#;
304        assert!(parse_cache(json).is_none());
305
306        // Right length but non-hex
307        let json = r#"{"v":1,"fingerprint":"zzzzzzzzzzzzzzzz","commands":{}}"#;
308        assert!(parse_cache(json).is_none());
309    }
310
311    #[test]
312    fn parse_rejects_too_many_commands() {
313        let mut cmds = String::from("{");
314        for i in 0..=MAX_CACHE_COMMANDS {
315            if i > 0 { cmds.push(','); }
316            cmds.push_str(&format!(r#""cmd{i}":true"#));
317        }
318        cmds.push('}');
319        let json = format!(r#"{{"v":1,"fingerprint":"0123456789abcdef","commands":{cmds}}}"#);
320        assert!(parse_cache(&json).is_none());
321    }
322
323    #[test]
324    fn collect_unique_commands_deduplicates() {
325        let config = test_config(vec![
326            abbr_when("ls", "lsd", vec!["lsd"]),
327            abbr_when("ll", "lsd -l", vec!["lsd"]), // same command
328            abbr_when("7z", "7zip", vec!["7z"]),
329        ]);
330        let cmds = collect_unique_commands(&config);
331        assert_eq!(cmds, vec!["lsd".to_string(), "7z".to_string()]);
332    }
333
334    #[test]
335    fn collect_unique_commands_empty_config() {
336        let config = test_config(vec![]);
337        assert!(collect_unique_commands(&config).is_empty());
338    }
339
340    #[test]
341    fn export_statement_bash() {
342        let stmt = export_statement("bash", r#"{"v":1}"#);
343        assert!(stmt.starts_with("export RUNEX_CMD_CACHE_V1="));
344        assert!(stmt.contains(r#"{"v":1}"#));
345    }
346
347    #[test]
348    fn export_statement_pwsh() {
349        let stmt = export_statement("pwsh", r#"{"v":1}"#);
350        assert!(stmt.starts_with("$env:RUNEX_CMD_CACHE_V1="));
351    }
352
353    #[test]
354    fn parse_resolved_basic() {
355        let map = parse_resolved("lsd=1,bat=0,git=1");
356        assert_eq!(map.get("lsd"), Some(&true));
357        assert_eq!(map.get("bat"), Some(&false));
358        assert_eq!(map.get("git"), Some(&true));
359    }
360
361    #[test]
362    fn parse_resolved_empty() {
363        let map = parse_resolved("");
364        assert!(map.is_empty());
365    }
366
367    #[test]
368    fn build_cache_from_resolved_uses_config_commands() {
369        let config = test_config(vec![
370            abbr_when("ls", "lsd", vec!["lsd"]),
371            abbr_when("7z", "7zip", vec!["7z"]),
372        ]);
373        let fp = compute_fingerprint("/usr/bin", 100, "pwsh");
374        let cache = build_cache_from_resolved(&config, &fp, "lsd=1,7z=0,extra=1");
375        assert_eq!(cache.commands.get("lsd"), Some(&true));
376        assert_eq!(cache.commands.get("7z"), Some(&false));
377        // "extra" is not in config, so not in cache
378        assert_eq!(cache.commands.get("extra"), None);
379    }
380}