Skip to main content

aft/
symbol_cache_disk.rs

1use std::fs::{self, File};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use crate::fs_lock;
9use crate::parser::SymbolCache;
10use crate::search_index::{cache_relative_path, validate_cached_relative_path};
11use crate::symbols::Symbol;
12use crate::{slog_info, slog_warn};
13
14const MAGIC: &[u8; 8] = b"AFTSYM1\0";
15const VERSION: u32 = 2;
16const MAX_ENTRIES: usize = 2_000_000;
17const MAX_PATH_BYTES: usize = 16 * 1024;
18const MAX_SYMBOL_BYTES: usize = 16 * 1024 * 1024;
19static TMP_COUNTER: AtomicU64 = AtomicU64::new(0);
20static SYMBOL_LOCK_ACQUIRE_MUTEX: Mutex<()> = Mutex::new(());
21
22pub struct SymbolCacheLock {
23    _guard: fs_lock::LockGuard,
24}
25
26impl SymbolCacheLock {
27    pub fn acquire(storage_dir: &Path, project_key: &str) -> std::io::Result<Self> {
28        let dir = storage_dir.join("symbols").join(project_key);
29        fs::create_dir_all(&dir)?;
30        let path = dir.join("symbols.lock");
31        let _acquire_guard = SYMBOL_LOCK_ACQUIRE_MUTEX
32            .lock()
33            .map_err(|_| std::io::Error::other("symbol cache lock acquisition mutex poisoned"))?;
34        fs_lock::try_acquire(&path, Duration::from_secs(2))
35            .map(|guard| Self { _guard: guard })
36            .map_err(|error| match error {
37                fs_lock::AcquireError::Timeout => {
38                    std::io::Error::other("timed out acquiring symbol cache lock")
39                }
40                fs_lock::AcquireError::Io(error) => error,
41            })
42    }
43}
44
45#[derive(Debug, Clone)]
46pub struct DiskSymbolCache {
47    pub(crate) entries: Vec<DiskSymbolEntry>,
48}
49
50#[derive(Debug, Clone)]
51pub(crate) struct DiskSymbolEntry {
52    pub(crate) relative_path: PathBuf,
53    pub(crate) mtime: SystemTime,
54    pub(crate) size: u64,
55    pub(crate) content_hash: blake3::Hash,
56    pub(crate) symbols: Vec<Symbol>,
57}
58
59impl DiskSymbolCache {
60    pub fn len(&self) -> usize {
61        self.entries.len()
62    }
63
64    pub fn is_empty(&self) -> bool {
65        self.entries.is_empty()
66    }
67}
68
69pub(crate) fn cache_path(storage_dir: &Path, project_key: &str) -> PathBuf {
70    storage_dir
71        .join("symbols")
72        .join(project_key)
73        .join("symbols.bin")
74}
75
76pub fn read_from_disk(storage_dir: &Path, project_key: &str) -> Option<DiskSymbolCache> {
77    let data_path = cache_path(storage_dir, project_key);
78    if !data_path.exists() {
79        return None;
80    }
81
82    match read_cache_file(&data_path) {
83        Ok(cache) => Some(cache),
84        Err(error) => {
85            slog_warn!(
86                "corrupt symbol cache at {}: {}, rebuilding",
87                data_path.display(),
88                error
89            );
90            None
91        }
92    }
93}
94
95pub fn write_to_disk(
96    cache: &SymbolCache,
97    storage_dir: &Path,
98    project_key: &str,
99) -> std::io::Result<()> {
100    if cache.len() == 0 {
101        slog_info!("skipping symbol cache persistence (0 entries)");
102        return Ok(());
103    }
104
105    let project_root = cache.project_root().ok_or_else(|| {
106        std::io::Error::other("symbol cache project root is not set; cannot persist relative paths")
107    })?;
108
109    let _cache_lock = SymbolCacheLock::acquire(storage_dir, project_key)?;
110    let dir = storage_dir.join("symbols").join(project_key);
111    fs::create_dir_all(&dir)?;
112
113    let data_path = dir.join("symbols.bin");
114    let tmp_path = dir.join(format!(
115        "symbols.bin.tmp.{}.{}.{}",
116        std::process::id(),
117        SystemTime::now()
118            .duration_since(UNIX_EPOCH)
119            .unwrap_or(Duration::ZERO)
120            .as_nanos(),
121        TMP_COUNTER.fetch_add(1, Ordering::Relaxed)
122    ));
123    let write_result = write_cache_file(cache, &project_root, &tmp_path).and_then(|()| {
124        fs::rename(&tmp_path, &data_path)?;
125        if let Ok(dir_file) = File::open(&dir) {
126            let _ = dir_file.sync_all();
127        }
128        Ok(())
129    });
130
131    if write_result.is_err() {
132        let _ = fs::remove_file(&tmp_path);
133    }
134
135    write_result
136}
137
138fn read_cache_file(path: &Path) -> Result<DiskSymbolCache, String> {
139    let mut reader = BufReader::new(File::open(path).map_err(|error| error.to_string())?);
140
141    let mut magic = [0u8; 8];
142    reader
143        .read_exact(&mut magic)
144        .map_err(|error| format!("failed to read symbol cache magic: {error}"))?;
145    if &magic != MAGIC {
146        return Err("invalid symbol cache magic".to_string());
147    }
148
149    let version = read_u32(&mut reader)?;
150    if version != VERSION {
151        return Err(format!(
152            "unsupported symbol cache version: {version} (expected {VERSION})"
153        ));
154    }
155
156    let root_len = read_u32(&mut reader)? as usize;
157    let entry_count = read_u32(&mut reader)? as usize;
158    if root_len > MAX_PATH_BYTES {
159        return Err(format!("project root path too large: {root_len} bytes"));
160    }
161    if entry_count > MAX_ENTRIES {
162        return Err(format!("too many symbol cache entries: {entry_count}"));
163    }
164
165    let _project_root = PathBuf::from(read_string_with_len(&mut reader, root_len)?);
166    let mut entries = Vec::with_capacity(entry_count);
167
168    for _ in 0..entry_count {
169        let path_len = read_u32(&mut reader)? as usize;
170        if path_len > MAX_PATH_BYTES {
171            return Err(format!("cached path too large: {path_len} bytes"));
172        }
173        let relative_path = validate_cached_relative_path(&PathBuf::from(read_string_with_len(
174            &mut reader,
175            path_len,
176        )?))
177        .ok_or_else(|| "cached symbol path escapes project root".to_string())?;
178        let mtime_secs = read_i64(&mut reader)?;
179        let mtime_nanos = read_u32(&mut reader)?;
180        let size = read_u64(&mut reader)?;
181        let mut hash_bytes = [0u8; 32];
182        reader
183            .read_exact(&mut hash_bytes)
184            .map_err(|error| format!("failed to read symbol content hash: {error}"))?;
185        let content_hash = blake3::Hash::from_bytes(hash_bytes);
186        let symbol_bytes_len = read_u32(&mut reader)? as usize;
187        if symbol_bytes_len > MAX_SYMBOL_BYTES {
188            return Err(format!(
189                "cached symbol payload too large: {symbol_bytes_len} bytes"
190            ));
191        }
192
193        let mut symbol_bytes = vec![0u8; symbol_bytes_len];
194        reader
195            .read_exact(&mut symbol_bytes)
196            .map_err(|error| format!("failed to read symbol payload: {error}"))?;
197        let symbols: Vec<Symbol> = serde_json::from_slice(&symbol_bytes)
198            .map_err(|error| format!("failed to decode cached symbols: {error}"))?;
199
200        entries.push(DiskSymbolEntry {
201            relative_path,
202            mtime: system_time_from_parts(mtime_secs, mtime_nanos)?,
203            size,
204            content_hash,
205            symbols,
206        });
207    }
208
209    Ok(DiskSymbolCache { entries })
210}
211
212fn write_cache_file(
213    cache: &SymbolCache,
214    project_root: &Path,
215    tmp_path: &Path,
216) -> std::io::Result<()> {
217    let mut writer = BufWriter::new(File::create(tmp_path)?);
218    let entries = cache
219        .disk_entries()
220        .into_iter()
221        .map(|(path, mtime, size, content_hash, symbols)| {
222            cache_relative_path(project_root, path)
223                .map(|relative_path| (relative_path, mtime, size, content_hash, symbols))
224        })
225        .collect::<Option<Vec<_>>>()
226        .ok_or_else(|| std::io::Error::other("refusing to cache path outside project root"))?;
227    let root = project_root.to_string_lossy();
228    let root_len = u32::try_from(root.len())
229        .map_err(|_| std::io::Error::other("project root too large to cache"))?;
230    let entry_count = u32::try_from(entries.len())
231        .map_err(|_| std::io::Error::other("too many symbol cache entries"))?;
232
233    writer.write_all(MAGIC)?;
234    write_u32(&mut writer, VERSION)?;
235    write_u32(&mut writer, root_len)?;
236    write_u32(&mut writer, entry_count)?;
237    writer.write_all(root.as_bytes())?;
238
239    for (relative_path, mtime, size, content_hash, symbols) in entries {
240        let path_bytes = relative_path.to_string_lossy();
241        let path_len = u32::try_from(path_bytes.len())
242            .map_err(|_| std::io::Error::other("cached path too large"))?;
243        let (secs, nanos) = system_time_parts(mtime);
244        let symbol_bytes = serde_json::to_vec(symbols).map_err(|error| {
245            std::io::Error::other(format!("symbol serialization failed: {error}"))
246        })?;
247        let symbol_len = u32::try_from(symbol_bytes.len())
248            .map_err(|_| std::io::Error::other("cached symbol payload too large"))?;
249
250        write_u32(&mut writer, path_len)?;
251        writer.write_all(path_bytes.as_bytes())?;
252        write_i64(&mut writer, secs)?;
253        write_u32(&mut writer, nanos)?;
254        write_u64(&mut writer, size)?;
255        writer.write_all(content_hash.as_bytes())?;
256        write_u32(&mut writer, symbol_len)?;
257        writer.write_all(&symbol_bytes)?;
258    }
259
260    writer.flush()?;
261    writer.get_ref().sync_all()?;
262    Ok(())
263}
264
265fn system_time_parts(time: SystemTime) -> (i64, u32) {
266    match time.duration_since(UNIX_EPOCH) {
267        Ok(duration) => (
268            i64::try_from(duration.as_secs()).unwrap_or(i64::MAX),
269            duration.subsec_nanos(),
270        ),
271        Err(error) => {
272            let duration = error.duration();
273            let nanos = duration.subsec_nanos();
274            if nanos == 0 {
275                (-(duration.as_secs() as i64), 0)
276            } else {
277                (-(duration.as_secs() as i64) - 1, 1_000_000_000 - nanos)
278            }
279        }
280    }
281}
282
283fn system_time_from_parts(secs: i64, nanos: u32) -> Result<SystemTime, String> {
284    if nanos >= 1_000_000_000 {
285        return Err(format!(
286            "invalid symbol cache mtime nanos: {nanos} >= 1_000_000_000"
287        ));
288    }
289
290    if secs >= 0 {
291        let duration = Duration::new(secs as u64, nanos);
292        UNIX_EPOCH
293            .checked_add(duration)
294            .ok_or_else(|| format!("symbol cache mtime overflows SystemTime: {secs}.{nanos}"))
295    } else {
296        let whole = Duration::new(secs.unsigned_abs(), 0);
297        let base = UNIX_EPOCH.checked_sub(whole).ok_or_else(|| {
298            format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
299        })?;
300        base.checked_add(Duration::new(0, nanos)).ok_or_else(|| {
301            format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
302        })
303    }
304}
305
306fn read_string_with_len<R: Read>(reader: &mut R, len: usize) -> Result<String, String> {
307    let mut bytes = vec![0u8; len];
308    reader
309        .read_exact(&mut bytes)
310        .map_err(|error| format!("failed to read string: {error}"))?;
311    String::from_utf8(bytes).map_err(|error| format!("invalid utf-8 string: {error}"))
312}
313
314fn read_u32<R: Read>(reader: &mut R) -> Result<u32, String> {
315    let mut bytes = [0u8; 4];
316    reader
317        .read_exact(&mut bytes)
318        .map_err(|error| format!("failed to read u32: {error}"))?;
319    Ok(u32::from_le_bytes(bytes))
320}
321
322fn read_i64<R: Read>(reader: &mut R) -> Result<i64, String> {
323    let mut bytes = [0u8; 8];
324    reader
325        .read_exact(&mut bytes)
326        .map_err(|error| format!("failed to read i64: {error}"))?;
327    Ok(i64::from_le_bytes(bytes))
328}
329
330fn read_u64<R: Read>(reader: &mut R) -> Result<u64, String> {
331    let mut bytes = [0u8; 8];
332    reader
333        .read_exact(&mut bytes)
334        .map_err(|error| format!("failed to read u64: {error}"))?;
335    Ok(u64::from_le_bytes(bytes))
336}
337
338fn write_u32<W: Write>(writer: &mut W, value: u32) -> std::io::Result<()> {
339    writer.write_all(&value.to_le_bytes())
340}
341
342fn write_i64<W: Write>(writer: &mut W, value: i64) -> std::io::Result<()> {
343    writer.write_all(&value.to_le_bytes())
344}
345
346fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
347    writer.write_all(&value.to_le_bytes())
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::symbols::{Range, SymbolKind};
354
355    fn test_symbol(name: &str) -> Symbol {
356        Symbol {
357            name: name.to_string(),
358            kind: SymbolKind::Function,
359            range: Range {
360                start_line: 0,
361                start_col: 0,
362                end_line: 0,
363                end_col: 1,
364            },
365            signature: None,
366            scope_chain: Vec::new(),
367            exported: false,
368            parent: None,
369        }
370    }
371
372    fn test_cache(project: &Path, file_name: &str) -> SymbolCache {
373        let file = project.join(file_name);
374        fs::write(&file, format!("fn {file_name}() {{}}\n")).expect("write file");
375        let metadata = fs::metadata(&file).expect("metadata");
376        let content_hash = blake3::hash(&fs::read(&file).expect("read file"));
377        let mut cache = SymbolCache::new();
378        cache.set_project_root(project.to_path_buf());
379        cache.insert(
380            file,
381            metadata.modified().expect("mtime"),
382            metadata.len(),
383            content_hash,
384            vec![test_symbol(file_name)],
385        );
386        cache
387    }
388
389    #[test]
390    fn concurrent_symbol_cache_writes_do_not_share_temp_file() {
391        let dir = tempfile::tempdir().expect("create temp dir");
392        let project = dir.path().join("project");
393        fs::create_dir_all(&project).expect("create project");
394        let storage = dir.path().join("storage");
395
396        let cache_a = test_cache(&project, "a");
397        let cache_b = test_cache(&project, "b");
398        let storage_a = storage.clone();
399        let writer_a = std::thread::spawn(move || {
400            write_to_disk(&cache_a, &storage_a, "unit-project").expect("write a");
401        });
402        let storage_b = storage.clone();
403        let writer_b = std::thread::spawn(move || {
404            write_to_disk(&cache_b, &storage_b, "unit-project").expect("write b");
405        });
406
407        writer_a.join().expect("writer a");
408        writer_b.join().expect("writer b");
409
410        let loaded = read_from_disk(&storage, "unit-project").expect("load symbol cache");
411        assert_eq!(loaded.len(), 1);
412        assert!(fs::read_dir(storage.join("symbols").join("unit-project"))
413            .expect("read symbol cache dir")
414            .all(|entry| !entry
415                .expect("cache entry")
416                .file_name()
417                .to_string_lossy()
418                .contains(".tmp.")));
419    }
420
421    #[test]
422    fn symbol_cache_rejects_paths_outside_project_root_on_write() {
423        let dir = tempfile::tempdir().expect("create temp dir");
424        let project = dir.path().join("project");
425        fs::create_dir_all(&project).expect("create project");
426        let outside = dir.path().join("outside.rs");
427        fs::write(&outside, "fn outside() {}\n").expect("write outside");
428        let metadata = fs::metadata(&outside).expect("metadata");
429
430        let mut cache = SymbolCache::new();
431        cache.set_project_root(project);
432        cache.insert(
433            outside.clone(),
434            metadata.modified().expect("mtime"),
435            metadata.len(),
436            blake3::hash(&fs::read(&outside).expect("read outside")),
437            vec![test_symbol("outside")],
438        );
439
440        let error = write_to_disk(&cache, dir.path(), "escape-project").expect_err("reject escape");
441        assert!(error.to_string().contains("outside project root"));
442    }
443}