Skip to main content

aft/
symbol_cache_disk.rs

1use std::fs::{self, File};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use crate::parser::SymbolCache;
8use crate::search_index::{cache_relative_path, validate_cached_relative_path};
9use crate::symbols::Symbol;
10use crate::{slog_info, slog_warn};
11
12const MAGIC: &[u8; 8] = b"AFTSYM1\0";
13const VERSION: u32 = 2;
14const MAX_ENTRIES: usize = 2_000_000;
15const MAX_PATH_BYTES: usize = 16 * 1024;
16const MAX_SYMBOL_BYTES: usize = 16 * 1024 * 1024;
17static TMP_COUNTER: AtomicU64 = AtomicU64::new(0);
18
19pub struct SymbolCacheLock {
20    path: PathBuf,
21}
22
23impl SymbolCacheLock {
24    pub fn acquire(storage_dir: &Path, project_key: &str) -> std::io::Result<Self> {
25        let dir = storage_dir.join("symbols").join(project_key);
26        fs::create_dir_all(&dir)?;
27        let path = dir.join("symbols.lock");
28        for _ in 0..200 {
29            match fs::OpenOptions::new()
30                .write(true)
31                .create_new(true)
32                .open(&path)
33            {
34                Ok(mut file) => {
35                    let _ = writeln!(file, "{}", std::process::id());
36                    let _ = file.sync_all();
37                    return Ok(Self { path });
38                }
39                Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => {
40                    std::thread::sleep(Duration::from_millis(10));
41                }
42                Err(error) => return Err(error),
43            }
44        }
45        Err(std::io::Error::other(
46            "timed out acquiring symbol cache lock",
47        ))
48    }
49}
50
51impl Drop for SymbolCacheLock {
52    fn drop(&mut self) {
53        let _ = fs::remove_file(&self.path);
54    }
55}
56
57#[derive(Debug, Clone)]
58pub struct DiskSymbolCache {
59    pub(crate) entries: Vec<DiskSymbolEntry>,
60}
61
62#[derive(Debug, Clone)]
63pub(crate) struct DiskSymbolEntry {
64    pub(crate) relative_path: PathBuf,
65    pub(crate) mtime: SystemTime,
66    pub(crate) size: u64,
67    pub(crate) content_hash: blake3::Hash,
68    pub(crate) symbols: Vec<Symbol>,
69}
70
71impl DiskSymbolCache {
72    pub fn len(&self) -> usize {
73        self.entries.len()
74    }
75
76    pub fn is_empty(&self) -> bool {
77        self.entries.is_empty()
78    }
79}
80
81pub(crate) fn cache_path(storage_dir: &Path, project_key: &str) -> PathBuf {
82    storage_dir
83        .join("symbols")
84        .join(project_key)
85        .join("symbols.bin")
86}
87
88pub fn read_from_disk(storage_dir: &Path, project_key: &str) -> Option<DiskSymbolCache> {
89    let data_path = cache_path(storage_dir, project_key);
90    if !data_path.exists() {
91        return None;
92    }
93
94    match read_cache_file(&data_path) {
95        Ok(cache) => Some(cache),
96        Err(error) => {
97            slog_warn!(
98                "corrupt symbol cache at {}: {}, rebuilding",
99                data_path.display(),
100                error
101            );
102            None
103        }
104    }
105}
106
107pub fn write_to_disk(
108    cache: &SymbolCache,
109    storage_dir: &Path,
110    project_key: &str,
111) -> std::io::Result<()> {
112    if cache.len() == 0 {
113        slog_info!("skipping symbol cache persistence (0 entries)");
114        return Ok(());
115    }
116
117    let project_root = cache.project_root().ok_or_else(|| {
118        std::io::Error::other("symbol cache project root is not set; cannot persist relative paths")
119    })?;
120
121    let _cache_lock = SymbolCacheLock::acquire(storage_dir, project_key)?;
122    let dir = storage_dir.join("symbols").join(project_key);
123    fs::create_dir_all(&dir)?;
124
125    let data_path = dir.join("symbols.bin");
126    let tmp_path = dir.join(format!(
127        "symbols.bin.tmp.{}.{}.{}",
128        std::process::id(),
129        SystemTime::now()
130            .duration_since(UNIX_EPOCH)
131            .unwrap_or(Duration::ZERO)
132            .as_nanos(),
133        TMP_COUNTER.fetch_add(1, Ordering::Relaxed)
134    ));
135    let write_result = write_cache_file(cache, &project_root, &tmp_path).and_then(|()| {
136        fs::rename(&tmp_path, &data_path)?;
137        if let Ok(dir_file) = File::open(&dir) {
138            let _ = dir_file.sync_all();
139        }
140        Ok(())
141    });
142
143    if write_result.is_err() {
144        let _ = fs::remove_file(&tmp_path);
145    }
146
147    write_result
148}
149
150fn read_cache_file(path: &Path) -> Result<DiskSymbolCache, String> {
151    let mut reader = BufReader::new(File::open(path).map_err(|error| error.to_string())?);
152
153    let mut magic = [0u8; 8];
154    reader
155        .read_exact(&mut magic)
156        .map_err(|error| format!("failed to read symbol cache magic: {error}"))?;
157    if &magic != MAGIC {
158        return Err("invalid symbol cache magic".to_string());
159    }
160
161    let version = read_u32(&mut reader)?;
162    if version != VERSION {
163        return Err(format!(
164            "unsupported symbol cache version: {version} (expected {VERSION})"
165        ));
166    }
167
168    let root_len = read_u32(&mut reader)? as usize;
169    let entry_count = read_u32(&mut reader)? as usize;
170    if root_len > MAX_PATH_BYTES {
171        return Err(format!("project root path too large: {root_len} bytes"));
172    }
173    if entry_count > MAX_ENTRIES {
174        return Err(format!("too many symbol cache entries: {entry_count}"));
175    }
176
177    let _project_root = PathBuf::from(read_string_with_len(&mut reader, root_len)?);
178    let mut entries = Vec::with_capacity(entry_count);
179
180    for _ in 0..entry_count {
181        let path_len = read_u32(&mut reader)? as usize;
182        if path_len > MAX_PATH_BYTES {
183            return Err(format!("cached path too large: {path_len} bytes"));
184        }
185        let relative_path = validate_cached_relative_path(&PathBuf::from(read_string_with_len(
186            &mut reader,
187            path_len,
188        )?))
189        .ok_or_else(|| "cached symbol path escapes project root".to_string())?;
190        let mtime_secs = read_i64(&mut reader)?;
191        let mtime_nanos = read_u32(&mut reader)?;
192        let size = read_u64(&mut reader)?;
193        let mut hash_bytes = [0u8; 32];
194        reader
195            .read_exact(&mut hash_bytes)
196            .map_err(|error| format!("failed to read symbol content hash: {error}"))?;
197        let content_hash = blake3::Hash::from_bytes(hash_bytes);
198        let symbol_bytes_len = read_u32(&mut reader)? as usize;
199        if symbol_bytes_len > MAX_SYMBOL_BYTES {
200            return Err(format!(
201                "cached symbol payload too large: {symbol_bytes_len} bytes"
202            ));
203        }
204
205        let mut symbol_bytes = vec![0u8; symbol_bytes_len];
206        reader
207            .read_exact(&mut symbol_bytes)
208            .map_err(|error| format!("failed to read symbol payload: {error}"))?;
209        let symbols: Vec<Symbol> = serde_json::from_slice(&symbol_bytes)
210            .map_err(|error| format!("failed to decode cached symbols: {error}"))?;
211
212        entries.push(DiskSymbolEntry {
213            relative_path,
214            mtime: system_time_from_parts(mtime_secs, mtime_nanos)?,
215            size,
216            content_hash,
217            symbols,
218        });
219    }
220
221    Ok(DiskSymbolCache { entries })
222}
223
224fn write_cache_file(
225    cache: &SymbolCache,
226    project_root: &Path,
227    tmp_path: &Path,
228) -> std::io::Result<()> {
229    let mut writer = BufWriter::new(File::create(tmp_path)?);
230    let entries = cache
231        .disk_entries()
232        .into_iter()
233        .map(|(path, mtime, size, content_hash, symbols)| {
234            cache_relative_path(project_root, path)
235                .map(|relative_path| (relative_path, mtime, size, content_hash, symbols))
236        })
237        .collect::<Option<Vec<_>>>()
238        .ok_or_else(|| std::io::Error::other("refusing to cache path outside project root"))?;
239    let root = project_root.to_string_lossy();
240    let root_len = u32::try_from(root.len())
241        .map_err(|_| std::io::Error::other("project root too large to cache"))?;
242    let entry_count = u32::try_from(entries.len())
243        .map_err(|_| std::io::Error::other("too many symbol cache entries"))?;
244
245    writer.write_all(MAGIC)?;
246    write_u32(&mut writer, VERSION)?;
247    write_u32(&mut writer, root_len)?;
248    write_u32(&mut writer, entry_count)?;
249    writer.write_all(root.as_bytes())?;
250
251    for (relative_path, mtime, size, content_hash, symbols) in entries {
252        let path_bytes = relative_path.to_string_lossy();
253        let path_len = u32::try_from(path_bytes.len())
254            .map_err(|_| std::io::Error::other("cached path too large"))?;
255        let (secs, nanos) = system_time_parts(mtime);
256        let symbol_bytes = serde_json::to_vec(symbols).map_err(|error| {
257            std::io::Error::other(format!("symbol serialization failed: {error}"))
258        })?;
259        let symbol_len = u32::try_from(symbol_bytes.len())
260            .map_err(|_| std::io::Error::other("cached symbol payload too large"))?;
261
262        write_u32(&mut writer, path_len)?;
263        writer.write_all(path_bytes.as_bytes())?;
264        write_i64(&mut writer, secs)?;
265        write_u32(&mut writer, nanos)?;
266        write_u64(&mut writer, size)?;
267        writer.write_all(content_hash.as_bytes())?;
268        write_u32(&mut writer, symbol_len)?;
269        writer.write_all(&symbol_bytes)?;
270    }
271
272    writer.flush()?;
273    writer.get_ref().sync_all()?;
274    Ok(())
275}
276
277fn system_time_parts(time: SystemTime) -> (i64, u32) {
278    match time.duration_since(UNIX_EPOCH) {
279        Ok(duration) => (
280            i64::try_from(duration.as_secs()).unwrap_or(i64::MAX),
281            duration.subsec_nanos(),
282        ),
283        Err(error) => {
284            let duration = error.duration();
285            let nanos = duration.subsec_nanos();
286            if nanos == 0 {
287                (-(duration.as_secs() as i64), 0)
288            } else {
289                (-(duration.as_secs() as i64) - 1, 1_000_000_000 - nanos)
290            }
291        }
292    }
293}
294
295fn system_time_from_parts(secs: i64, nanos: u32) -> Result<SystemTime, String> {
296    if nanos >= 1_000_000_000 {
297        return Err(format!(
298            "invalid symbol cache mtime nanos: {nanos} >= 1_000_000_000"
299        ));
300    }
301
302    if secs >= 0 {
303        let duration = Duration::new(secs as u64, nanos);
304        UNIX_EPOCH
305            .checked_add(duration)
306            .ok_or_else(|| format!("symbol cache mtime overflows SystemTime: {secs}.{nanos}"))
307    } else {
308        let whole = Duration::new(secs.unsigned_abs(), 0);
309        let base = UNIX_EPOCH.checked_sub(whole).ok_or_else(|| {
310            format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
311        })?;
312        base.checked_add(Duration::new(0, nanos)).ok_or_else(|| {
313            format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
314        })
315    }
316}
317
318fn read_string_with_len<R: Read>(reader: &mut R, len: usize) -> Result<String, String> {
319    let mut bytes = vec![0u8; len];
320    reader
321        .read_exact(&mut bytes)
322        .map_err(|error| format!("failed to read string: {error}"))?;
323    String::from_utf8(bytes).map_err(|error| format!("invalid utf-8 string: {error}"))
324}
325
326fn read_u32<R: Read>(reader: &mut R) -> Result<u32, String> {
327    let mut bytes = [0u8; 4];
328    reader
329        .read_exact(&mut bytes)
330        .map_err(|error| format!("failed to read u32: {error}"))?;
331    Ok(u32::from_le_bytes(bytes))
332}
333
334fn read_i64<R: Read>(reader: &mut R) -> Result<i64, String> {
335    let mut bytes = [0u8; 8];
336    reader
337        .read_exact(&mut bytes)
338        .map_err(|error| format!("failed to read i64: {error}"))?;
339    Ok(i64::from_le_bytes(bytes))
340}
341
342fn read_u64<R: Read>(reader: &mut R) -> Result<u64, String> {
343    let mut bytes = [0u8; 8];
344    reader
345        .read_exact(&mut bytes)
346        .map_err(|error| format!("failed to read u64: {error}"))?;
347    Ok(u64::from_le_bytes(bytes))
348}
349
350fn write_u32<W: Write>(writer: &mut W, value: u32) -> std::io::Result<()> {
351    writer.write_all(&value.to_le_bytes())
352}
353
354fn write_i64<W: Write>(writer: &mut W, value: i64) -> std::io::Result<()> {
355    writer.write_all(&value.to_le_bytes())
356}
357
358fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
359    writer.write_all(&value.to_le_bytes())
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::symbols::{Range, SymbolKind};
366
367    fn test_symbol(name: &str) -> Symbol {
368        Symbol {
369            name: name.to_string(),
370            kind: SymbolKind::Function,
371            range: Range {
372                start_line: 0,
373                start_col: 0,
374                end_line: 0,
375                end_col: 1,
376            },
377            signature: None,
378            scope_chain: Vec::new(),
379            exported: false,
380            parent: None,
381        }
382    }
383
384    fn test_cache(project: &Path, file_name: &str) -> SymbolCache {
385        let file = project.join(file_name);
386        fs::write(&file, format!("fn {file_name}() {{}}\n")).expect("write file");
387        let metadata = fs::metadata(&file).expect("metadata");
388        let content_hash = blake3::hash(&fs::read(&file).expect("read file"));
389        let mut cache = SymbolCache::new();
390        cache.set_project_root(project.to_path_buf());
391        cache.insert(
392            file,
393            metadata.modified().expect("mtime"),
394            metadata.len(),
395            content_hash,
396            vec![test_symbol(file_name)],
397        );
398        cache
399    }
400
401    #[test]
402    fn concurrent_symbol_cache_writes_do_not_share_temp_file() {
403        let dir = tempfile::tempdir().expect("create temp dir");
404        let project = dir.path().join("project");
405        fs::create_dir_all(&project).expect("create project");
406        let storage = dir.path().join("storage");
407
408        let cache_a = test_cache(&project, "a");
409        let cache_b = test_cache(&project, "b");
410        let storage_a = storage.clone();
411        let writer_a = std::thread::spawn(move || {
412            write_to_disk(&cache_a, &storage_a, "unit-project").expect("write a");
413        });
414        let storage_b = storage.clone();
415        let writer_b = std::thread::spawn(move || {
416            write_to_disk(&cache_b, &storage_b, "unit-project").expect("write b");
417        });
418
419        writer_a.join().expect("writer a");
420        writer_b.join().expect("writer b");
421
422        let loaded = read_from_disk(&storage, "unit-project").expect("load symbol cache");
423        assert_eq!(loaded.len(), 1);
424        assert!(fs::read_dir(storage.join("symbols").join("unit-project"))
425            .expect("read symbol cache dir")
426            .all(|entry| !entry
427                .expect("cache entry")
428                .file_name()
429                .to_string_lossy()
430                .contains(".tmp.")));
431    }
432
433    #[test]
434    fn symbol_cache_rejects_paths_outside_project_root_on_write() {
435        let dir = tempfile::tempdir().expect("create temp dir");
436        let project = dir.path().join("project");
437        fs::create_dir_all(&project).expect("create project");
438        let outside = dir.path().join("outside.rs");
439        fs::write(&outside, "fn outside() {}\n").expect("write outside");
440        let metadata = fs::metadata(&outside).expect("metadata");
441
442        let mut cache = SymbolCache::new();
443        cache.set_project_root(project);
444        cache.insert(
445            outside.clone(),
446            metadata.modified().expect("mtime"),
447            metadata.len(),
448            blake3::hash(&fs::read(&outside).expect("read outside")),
449            vec![test_symbol("outside")],
450        );
451
452        let error = write_to_disk(&cache, dir.path(), "escape-project").expect_err("reject escape");
453        assert!(error.to_string().contains("outside project root"));
454    }
455}