Skip to main content

aft/
symbol_cache_disk.rs

1use std::fs::{self, File};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7
8use crate::fs_lock;
9use crate::parser::SymbolCache;
10use crate::search_index::{cache_relative_path, validate_cached_relative_path};
11use crate::symbols::Symbol;
12use crate::{slog_info, slog_warn};
13
14const MAGIC: &[u8; 8] = b"AFTSYM1\0";
15const FORMAT_VERSION: u32 = 3;
16
17/// Version of the symbol extraction schema stored in the disk cache.
18///
19/// Bump this whenever symbol-extraction logic changes: tree-sitter grammar
20/// upgrades, query updates, extractor behavior, or symbol shape changes. A
21/// mismatch rejects persisted symbols so they are regenerated on next access.
22pub const SCHEMA_VERSION: u32 = 3;
23
24const MAX_ENTRIES: usize = 2_000_000;
25const MAX_PATH_BYTES: usize = 16 * 1024;
26const MAX_SYMBOL_BYTES: usize = 16 * 1024 * 1024;
27static TMP_COUNTER: AtomicU64 = AtomicU64::new(0);
28static SYMBOL_LOCK_ACQUIRE_MUTEX: Mutex<()> = Mutex::new(());
29
30pub struct SymbolCacheLock {
31    _guard: fs_lock::LockGuard,
32}
33
34impl SymbolCacheLock {
35    pub fn acquire(storage_dir: &Path, project_key: &str) -> std::io::Result<Self> {
36        let dir = storage_dir.join("symbols").join(project_key);
37        fs::create_dir_all(&dir)?;
38        let path = dir.join("symbols.lock");
39        let _acquire_guard = SYMBOL_LOCK_ACQUIRE_MUTEX
40            .lock()
41            .map_err(|_| std::io::Error::other("symbol cache lock acquisition mutex poisoned"))?;
42        fs_lock::try_acquire(&path, Duration::from_secs(2))
43            .map(|guard| Self { _guard: guard })
44            .map_err(|error| match error {
45                fs_lock::AcquireError::Timeout => {
46                    std::io::Error::other("timed out acquiring symbol cache lock")
47                }
48                fs_lock::AcquireError::Io(error) => error,
49            })
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct DiskSymbolCache {
55    pub(crate) entries: Vec<DiskSymbolEntry>,
56}
57
58#[derive(Debug, Clone)]
59pub(crate) struct DiskSymbolEntry {
60    pub(crate) relative_path: PathBuf,
61    pub(crate) mtime: SystemTime,
62    pub(crate) size: u64,
63    pub(crate) content_hash: blake3::Hash,
64    pub(crate) symbols: Vec<Symbol>,
65}
66
67impl DiskSymbolCache {
68    pub fn len(&self) -> usize {
69        self.entries.len()
70    }
71
72    pub fn is_empty(&self) -> bool {
73        self.entries.is_empty()
74    }
75}
76
77pub(crate) fn cache_path(storage_dir: &Path, project_key: &str) -> PathBuf {
78    storage_dir
79        .join("symbols")
80        .join(project_key)
81        .join("symbols.bin")
82}
83
84pub fn read_from_disk(storage_dir: &Path, project_key: &str) -> Option<DiskSymbolCache> {
85    let data_path = cache_path(storage_dir, project_key);
86    if !data_path.exists() {
87        return None;
88    }
89
90    match read_cache_file(&data_path) {
91        Ok(cache) => Some(cache),
92        Err(error) => {
93            slog_warn!(
94                "corrupt symbol cache at {}: {}, rebuilding",
95                data_path.display(),
96                error
97            );
98            None
99        }
100    }
101}
102
103pub fn write_to_disk(
104    cache: &SymbolCache,
105    storage_dir: &Path,
106    project_key: &str,
107) -> std::io::Result<()> {
108    if cache.len() == 0 {
109        slog_info!("skipping symbol cache persistence (0 entries)");
110        return Ok(());
111    }
112
113    let project_root = cache.project_root().ok_or_else(|| {
114        std::io::Error::other("symbol cache project root is not set; cannot persist relative paths")
115    })?;
116
117    let _cache_lock = SymbolCacheLock::acquire(storage_dir, project_key)?;
118    let dir = storage_dir.join("symbols").join(project_key);
119    fs::create_dir_all(&dir)?;
120
121    let data_path = dir.join("symbols.bin");
122    let tmp_path = dir.join(format!(
123        "symbols.bin.tmp.{}.{}.{}",
124        std::process::id(),
125        SystemTime::now()
126            .duration_since(UNIX_EPOCH)
127            .unwrap_or(Duration::ZERO)
128            .as_nanos(),
129        TMP_COUNTER.fetch_add(1, Ordering::Relaxed)
130    ));
131    let write_result = write_cache_file(cache, &project_root, &tmp_path).and_then(|()| {
132        fs::rename(&tmp_path, &data_path)?;
133        if let Ok(dir_file) = File::open(&dir) {
134            let _ = dir_file.sync_all();
135        }
136        Ok(())
137    });
138
139    if write_result.is_err() {
140        let _ = fs::remove_file(&tmp_path);
141    }
142
143    write_result
144}
145
146fn read_cache_file(path: &Path) -> Result<DiskSymbolCache, String> {
147    let mut reader = BufReader::new(File::open(path).map_err(|error| error.to_string())?);
148
149    let mut magic = [0u8; 8];
150    reader
151        .read_exact(&mut magic)
152        .map_err(|error| format!("failed to read symbol cache magic: {error}"))?;
153    if &magic != MAGIC {
154        return Err("invalid symbol cache magic".to_string());
155    }
156
157    let format_version = read_u32(&mut reader)?;
158    if format_version != FORMAT_VERSION {
159        return Err(format!(
160            "unsupported symbol cache format version: {format_version} (expected {FORMAT_VERSION})"
161        ));
162    }
163
164    let schema_version = read_u32(&mut reader)?;
165    if schema_version != SCHEMA_VERSION {
166        return Err(format!(
167            "unsupported symbol cache schema version: {schema_version} (expected {SCHEMA_VERSION})"
168        ));
169    }
170
171    let root_len = read_u32(&mut reader)? as usize;
172    let entry_count = read_u32(&mut reader)? as usize;
173    if root_len > MAX_PATH_BYTES {
174        return Err(format!("project root path too large: {root_len} bytes"));
175    }
176    if entry_count > MAX_ENTRIES {
177        return Err(format!("too many symbol cache entries: {entry_count}"));
178    }
179
180    let _project_root = PathBuf::from(read_string_with_len(&mut reader, root_len)?);
181    let mut entries = Vec::with_capacity(entry_count);
182
183    for _ in 0..entry_count {
184        let path_len = read_u32(&mut reader)? as usize;
185        if path_len > MAX_PATH_BYTES {
186            return Err(format!("cached path too large: {path_len} bytes"));
187        }
188        let relative_path = validate_cached_relative_path(&PathBuf::from(read_string_with_len(
189            &mut reader,
190            path_len,
191        )?))
192        .ok_or_else(|| "cached symbol path escapes project root".to_string())?;
193        let mtime_secs = read_i64(&mut reader)?;
194        let mtime_nanos = read_u32(&mut reader)?;
195        let size = read_u64(&mut reader)?;
196        let mut hash_bytes = [0u8; 32];
197        reader
198            .read_exact(&mut hash_bytes)
199            .map_err(|error| format!("failed to read symbol content hash: {error}"))?;
200        let content_hash = blake3::Hash::from_bytes(hash_bytes);
201        let symbol_bytes_len = read_u32(&mut reader)? as usize;
202        if symbol_bytes_len > MAX_SYMBOL_BYTES {
203            return Err(format!(
204                "cached symbol payload too large: {symbol_bytes_len} bytes"
205            ));
206        }
207
208        let mut symbol_bytes = vec![0u8; symbol_bytes_len];
209        reader
210            .read_exact(&mut symbol_bytes)
211            .map_err(|error| format!("failed to read symbol payload: {error}"))?;
212        let symbols: Vec<Symbol> = serde_json::from_slice(&symbol_bytes)
213            .map_err(|error| format!("failed to decode cached symbols: {error}"))?;
214
215        entries.push(DiskSymbolEntry {
216            relative_path,
217            mtime: system_time_from_parts(mtime_secs, mtime_nanos)?,
218            size,
219            content_hash,
220            symbols,
221        });
222    }
223
224    Ok(DiskSymbolCache { entries })
225}
226
227fn write_cache_file(
228    cache: &SymbolCache,
229    project_root: &Path,
230    tmp_path: &Path,
231) -> std::io::Result<()> {
232    let mut writer = BufWriter::new(File::create(tmp_path)?);
233    let entries = cache
234        .disk_entries()
235        .into_iter()
236        .map(|(path, mtime, size, content_hash, symbols)| {
237            cache_relative_path(project_root, path)
238                .map(|relative_path| (relative_path, mtime, size, content_hash, symbols))
239        })
240        .collect::<Option<Vec<_>>>()
241        .ok_or_else(|| std::io::Error::other("refusing to cache path outside project root"))?;
242    let root = project_root.to_string_lossy();
243    let root_len = u32::try_from(root.len())
244        .map_err(|_| std::io::Error::other("project root too large to cache"))?;
245    let entry_count = u32::try_from(entries.len())
246        .map_err(|_| std::io::Error::other("too many symbol cache entries"))?;
247
248    writer.write_all(MAGIC)?;
249    write_u32(&mut writer, FORMAT_VERSION)?;
250    write_u32(&mut writer, SCHEMA_VERSION)?;
251    write_u32(&mut writer, root_len)?;
252    write_u32(&mut writer, entry_count)?;
253    writer.write_all(root.as_bytes())?;
254
255    for (relative_path, mtime, size, content_hash, symbols) in entries {
256        let path_bytes = relative_path.to_string_lossy();
257        let path_len = u32::try_from(path_bytes.len())
258            .map_err(|_| std::io::Error::other("cached path too large"))?;
259        let (secs, nanos) = system_time_parts(mtime);
260        let symbol_bytes = serde_json::to_vec(symbols).map_err(|error| {
261            std::io::Error::other(format!("symbol serialization failed: {error}"))
262        })?;
263        let symbol_len = u32::try_from(symbol_bytes.len())
264            .map_err(|_| std::io::Error::other("cached symbol payload too large"))?;
265
266        write_u32(&mut writer, path_len)?;
267        writer.write_all(path_bytes.as_bytes())?;
268        write_i64(&mut writer, secs)?;
269        write_u32(&mut writer, nanos)?;
270        write_u64(&mut writer, size)?;
271        writer.write_all(content_hash.as_bytes())?;
272        write_u32(&mut writer, symbol_len)?;
273        writer.write_all(&symbol_bytes)?;
274    }
275
276    writer.flush()?;
277    writer.get_ref().sync_all()?;
278    Ok(())
279}
280
281fn system_time_parts(time: SystemTime) -> (i64, u32) {
282    match time.duration_since(UNIX_EPOCH) {
283        Ok(duration) => (
284            i64::try_from(duration.as_secs()).unwrap_or(i64::MAX),
285            duration.subsec_nanos(),
286        ),
287        Err(error) => {
288            let duration = error.duration();
289            let nanos = duration.subsec_nanos();
290            if nanos == 0 {
291                (-(duration.as_secs() as i64), 0)
292            } else {
293                (-(duration.as_secs() as i64) - 1, 1_000_000_000 - nanos)
294            }
295        }
296    }
297}
298
299fn system_time_from_parts(secs: i64, nanos: u32) -> Result<SystemTime, String> {
300    if nanos >= 1_000_000_000 {
301        return Err(format!(
302            "invalid symbol cache mtime nanos: {nanos} >= 1_000_000_000"
303        ));
304    }
305
306    if secs >= 0 {
307        let duration = Duration::new(secs as u64, nanos);
308        UNIX_EPOCH
309            .checked_add(duration)
310            .ok_or_else(|| format!("symbol cache mtime overflows SystemTime: {secs}.{nanos}"))
311    } else {
312        let whole = Duration::new(secs.unsigned_abs(), 0);
313        let base = UNIX_EPOCH.checked_sub(whole).ok_or_else(|| {
314            format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
315        })?;
316        base.checked_add(Duration::new(0, nanos)).ok_or_else(|| {
317            format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
318        })
319    }
320}
321
322fn read_string_with_len<R: Read>(reader: &mut R, len: usize) -> Result<String, String> {
323    let mut bytes = vec![0u8; len];
324    reader
325        .read_exact(&mut bytes)
326        .map_err(|error| format!("failed to read string: {error}"))?;
327    String::from_utf8(bytes).map_err(|error| format!("invalid utf-8 string: {error}"))
328}
329
330fn read_u32<R: Read>(reader: &mut R) -> Result<u32, String> {
331    let mut bytes = [0u8; 4];
332    reader
333        .read_exact(&mut bytes)
334        .map_err(|error| format!("failed to read u32: {error}"))?;
335    Ok(u32::from_le_bytes(bytes))
336}
337
338fn read_i64<R: Read>(reader: &mut R) -> Result<i64, String> {
339    let mut bytes = [0u8; 8];
340    reader
341        .read_exact(&mut bytes)
342        .map_err(|error| format!("failed to read i64: {error}"))?;
343    Ok(i64::from_le_bytes(bytes))
344}
345
346fn read_u64<R: Read>(reader: &mut R) -> Result<u64, String> {
347    let mut bytes = [0u8; 8];
348    reader
349        .read_exact(&mut bytes)
350        .map_err(|error| format!("failed to read u64: {error}"))?;
351    Ok(u64::from_le_bytes(bytes))
352}
353
354fn write_u32<W: Write>(writer: &mut W, value: u32) -> std::io::Result<()> {
355    writer.write_all(&value.to_le_bytes())
356}
357
358fn write_i64<W: Write>(writer: &mut W, value: i64) -> std::io::Result<()> {
359    writer.write_all(&value.to_le_bytes())
360}
361
362fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
363    writer.write_all(&value.to_le_bytes())
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::symbols::{Range, SymbolKind};
370
371    fn test_symbol(name: &str) -> Symbol {
372        Symbol {
373            name: name.to_string(),
374            kind: SymbolKind::Function,
375            range: Range {
376                start_line: 0,
377                start_col: 0,
378                end_line: 0,
379                end_col: 1,
380            },
381            signature: None,
382            scope_chain: Vec::new(),
383            exported: false,
384            parent: None,
385        }
386    }
387
388    fn test_cache(project: &Path, file_name: &str) -> SymbolCache {
389        let file = project.join(file_name);
390        fs::write(&file, format!("fn {file_name}() {{}}\n")).expect("write file");
391        let metadata = fs::metadata(&file).expect("metadata");
392        let content_hash = blake3::hash(&fs::read(&file).expect("read file"));
393        let mut cache = SymbolCache::new();
394        cache.set_project_root(project.to_path_buf());
395        cache.insert(
396            file,
397            metadata.modified().expect("mtime"),
398            metadata.len(),
399            content_hash,
400            vec![test_symbol(file_name)],
401        );
402        cache
403    }
404
405    #[test]
406    fn concurrent_symbol_cache_writes_do_not_share_temp_file() {
407        let dir = tempfile::tempdir().expect("create temp dir");
408        let project = dir.path().join("project");
409        fs::create_dir_all(&project).expect("create project");
410        let storage = dir.path().join("storage");
411
412        let cache_a = test_cache(&project, "a");
413        let cache_b = test_cache(&project, "b");
414        let storage_a = storage.clone();
415        let writer_a = std::thread::spawn(move || {
416            write_to_disk(&cache_a, &storage_a, "unit-project").expect("write a");
417        });
418        let storage_b = storage.clone();
419        let writer_b = std::thread::spawn(move || {
420            write_to_disk(&cache_b, &storage_b, "unit-project").expect("write b");
421        });
422
423        writer_a.join().expect("writer a");
424        writer_b.join().expect("writer b");
425
426        let loaded = read_from_disk(&storage, "unit-project").expect("load symbol cache");
427        assert_eq!(loaded.len(), 1);
428        assert!(fs::read_dir(storage.join("symbols").join("unit-project"))
429            .expect("read symbol cache dir")
430            .all(|entry| !entry
431                .expect("cache entry")
432                .file_name()
433                .to_string_lossy()
434                .contains(".tmp.")));
435    }
436
437    #[test]
438    fn symbol_cache_rejects_mismatched_schema_version() {
439        let storage = tempfile::tempdir().expect("create storage dir");
440        let path = cache_path(storage.path(), "schema-project");
441        fs::create_dir_all(path.parent().expect("cache parent")).expect("create cache dir");
442
443        let mut bytes = Vec::new();
444        bytes.extend_from_slice(MAGIC);
445        bytes.extend_from_slice(&FORMAT_VERSION.to_le_bytes());
446        bytes.extend_from_slice(&SCHEMA_VERSION.wrapping_add(1).to_le_bytes());
447        bytes.extend_from_slice(&0u32.to_le_bytes());
448        bytes.extend_from_slice(&0u32.to_le_bytes());
449        fs::write(&path, bytes).expect("write wrong-schema cache");
450
451        assert!(read_from_disk(storage.path(), "schema-project").is_none());
452    }
453
454    #[test]
455    fn symbol_cache_rejects_paths_outside_project_root_on_write() {
456        let dir = tempfile::tempdir().expect("create temp dir");
457        let project = dir.path().join("project");
458        fs::create_dir_all(&project).expect("create project");
459        let outside = dir.path().join("outside.rs");
460        fs::write(&outside, "fn outside() {}\n").expect("write outside");
461        let metadata = fs::metadata(&outside).expect("metadata");
462
463        let mut cache = SymbolCache::new();
464        cache.set_project_root(project);
465        cache.insert(
466            outside.clone(),
467            metadata.modified().expect("mtime"),
468            metadata.len(),
469            blake3::hash(&fs::read(&outside).expect("read outside")),
470            vec![test_symbol("outside")],
471        );
472
473        let error = write_to_disk(&cache, dir.path(), "escape-project").expect_err("reject escape");
474        assert!(error.to_string().contains("outside project root"));
475    }
476}