1use std::fs::{self, File};
2use std::io::{BufReader, BufWriter, Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use crate::parser::SymbolCache;
8use crate::search_index::{cache_relative_path, validate_cached_relative_path};
9use crate::symbols::Symbol;
10use crate::{slog_info, slog_warn};
11
12const MAGIC: &[u8; 8] = b"AFTSYM1\0";
13const VERSION: u32 = 2;
14const MAX_ENTRIES: usize = 2_000_000;
15const MAX_PATH_BYTES: usize = 16 * 1024;
16const MAX_SYMBOL_BYTES: usize = 16 * 1024 * 1024;
17static TMP_COUNTER: AtomicU64 = AtomicU64::new(0);
18
19pub struct SymbolCacheLock {
20 path: PathBuf,
21}
22
23impl SymbolCacheLock {
24 pub fn acquire(storage_dir: &Path, project_key: &str) -> std::io::Result<Self> {
25 let dir = storage_dir.join("symbols").join(project_key);
26 fs::create_dir_all(&dir)?;
27 let path = dir.join("symbols.lock");
28 for _ in 0..200 {
29 match fs::OpenOptions::new()
30 .write(true)
31 .create_new(true)
32 .open(&path)
33 {
34 Ok(mut file) => {
35 let _ = writeln!(file, "{}", std::process::id());
36 let _ = file.sync_all();
37 return Ok(Self { path });
38 }
39 Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => {
40 std::thread::sleep(Duration::from_millis(10));
41 }
42 Err(error) => return Err(error),
43 }
44 }
45 Err(std::io::Error::other(
46 "timed out acquiring symbol cache lock",
47 ))
48 }
49}
50
51impl Drop for SymbolCacheLock {
52 fn drop(&mut self) {
53 let _ = fs::remove_file(&self.path);
54 }
55}
56
57#[derive(Debug, Clone)]
58pub struct DiskSymbolCache {
59 pub(crate) entries: Vec<DiskSymbolEntry>,
60}
61
62#[derive(Debug, Clone)]
63pub(crate) struct DiskSymbolEntry {
64 pub(crate) relative_path: PathBuf,
65 pub(crate) mtime: SystemTime,
66 pub(crate) size: u64,
67 pub(crate) content_hash: blake3::Hash,
68 pub(crate) symbols: Vec<Symbol>,
69}
70
71impl DiskSymbolCache {
72 pub fn len(&self) -> usize {
73 self.entries.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.entries.is_empty()
78 }
79}
80
81pub(crate) fn cache_path(storage_dir: &Path, project_key: &str) -> PathBuf {
82 storage_dir
83 .join("symbols")
84 .join(project_key)
85 .join("symbols.bin")
86}
87
88pub fn read_from_disk(storage_dir: &Path, project_key: &str) -> Option<DiskSymbolCache> {
89 let data_path = cache_path(storage_dir, project_key);
90 if !data_path.exists() {
91 return None;
92 }
93
94 match read_cache_file(&data_path) {
95 Ok(cache) => Some(cache),
96 Err(error) => {
97 slog_warn!(
98 "corrupt symbol cache at {}: {}, rebuilding",
99 data_path.display(),
100 error
101 );
102 None
103 }
104 }
105}
106
107pub fn write_to_disk(
108 cache: &SymbolCache,
109 storage_dir: &Path,
110 project_key: &str,
111) -> std::io::Result<()> {
112 if cache.len() == 0 {
113 slog_info!("skipping symbol cache persistence (0 entries)");
114 return Ok(());
115 }
116
117 let project_root = cache.project_root().ok_or_else(|| {
118 std::io::Error::other("symbol cache project root is not set; cannot persist relative paths")
119 })?;
120
121 let _cache_lock = SymbolCacheLock::acquire(storage_dir, project_key)?;
122 let dir = storage_dir.join("symbols").join(project_key);
123 fs::create_dir_all(&dir)?;
124
125 let data_path = dir.join("symbols.bin");
126 let tmp_path = dir.join(format!(
127 "symbols.bin.tmp.{}.{}.{}",
128 std::process::id(),
129 SystemTime::now()
130 .duration_since(UNIX_EPOCH)
131 .unwrap_or(Duration::ZERO)
132 .as_nanos(),
133 TMP_COUNTER.fetch_add(1, Ordering::Relaxed)
134 ));
135 let write_result = write_cache_file(cache, &project_root, &tmp_path).and_then(|()| {
136 fs::rename(&tmp_path, &data_path)?;
137 if let Ok(dir_file) = File::open(&dir) {
138 let _ = dir_file.sync_all();
139 }
140 Ok(())
141 });
142
143 if write_result.is_err() {
144 let _ = fs::remove_file(&tmp_path);
145 }
146
147 write_result
148}
149
150fn read_cache_file(path: &Path) -> Result<DiskSymbolCache, String> {
151 let mut reader = BufReader::new(File::open(path).map_err(|error| error.to_string())?);
152
153 let mut magic = [0u8; 8];
154 reader
155 .read_exact(&mut magic)
156 .map_err(|error| format!("failed to read symbol cache magic: {error}"))?;
157 if &magic != MAGIC {
158 return Err("invalid symbol cache magic".to_string());
159 }
160
161 let version = read_u32(&mut reader)?;
162 if version != VERSION {
163 return Err(format!(
164 "unsupported symbol cache version: {version} (expected {VERSION})"
165 ));
166 }
167
168 let root_len = read_u32(&mut reader)? as usize;
169 let entry_count = read_u32(&mut reader)? as usize;
170 if root_len > MAX_PATH_BYTES {
171 return Err(format!("project root path too large: {root_len} bytes"));
172 }
173 if entry_count > MAX_ENTRIES {
174 return Err(format!("too many symbol cache entries: {entry_count}"));
175 }
176
177 let _project_root = PathBuf::from(read_string_with_len(&mut reader, root_len)?);
178 let mut entries = Vec::with_capacity(entry_count);
179
180 for _ in 0..entry_count {
181 let path_len = read_u32(&mut reader)? as usize;
182 if path_len > MAX_PATH_BYTES {
183 return Err(format!("cached path too large: {path_len} bytes"));
184 }
185 let relative_path = validate_cached_relative_path(&PathBuf::from(read_string_with_len(
186 &mut reader,
187 path_len,
188 )?))
189 .ok_or_else(|| "cached symbol path escapes project root".to_string())?;
190 let mtime_secs = read_i64(&mut reader)?;
191 let mtime_nanos = read_u32(&mut reader)?;
192 let size = read_u64(&mut reader)?;
193 let mut hash_bytes = [0u8; 32];
194 reader
195 .read_exact(&mut hash_bytes)
196 .map_err(|error| format!("failed to read symbol content hash: {error}"))?;
197 let content_hash = blake3::Hash::from_bytes(hash_bytes);
198 let symbol_bytes_len = read_u32(&mut reader)? as usize;
199 if symbol_bytes_len > MAX_SYMBOL_BYTES {
200 return Err(format!(
201 "cached symbol payload too large: {symbol_bytes_len} bytes"
202 ));
203 }
204
205 let mut symbol_bytes = vec![0u8; symbol_bytes_len];
206 reader
207 .read_exact(&mut symbol_bytes)
208 .map_err(|error| format!("failed to read symbol payload: {error}"))?;
209 let symbols: Vec<Symbol> = serde_json::from_slice(&symbol_bytes)
210 .map_err(|error| format!("failed to decode cached symbols: {error}"))?;
211
212 entries.push(DiskSymbolEntry {
213 relative_path,
214 mtime: system_time_from_parts(mtime_secs, mtime_nanos)?,
215 size,
216 content_hash,
217 symbols,
218 });
219 }
220
221 Ok(DiskSymbolCache { entries })
222}
223
224fn write_cache_file(
225 cache: &SymbolCache,
226 project_root: &Path,
227 tmp_path: &Path,
228) -> std::io::Result<()> {
229 let mut writer = BufWriter::new(File::create(tmp_path)?);
230 let entries = cache
231 .disk_entries()
232 .into_iter()
233 .map(|(path, mtime, size, content_hash, symbols)| {
234 cache_relative_path(project_root, path)
235 .map(|relative_path| (relative_path, mtime, size, content_hash, symbols))
236 })
237 .collect::<Option<Vec<_>>>()
238 .ok_or_else(|| std::io::Error::other("refusing to cache path outside project root"))?;
239 let root = project_root.to_string_lossy();
240 let root_len = u32::try_from(root.len())
241 .map_err(|_| std::io::Error::other("project root too large to cache"))?;
242 let entry_count = u32::try_from(entries.len())
243 .map_err(|_| std::io::Error::other("too many symbol cache entries"))?;
244
245 writer.write_all(MAGIC)?;
246 write_u32(&mut writer, VERSION)?;
247 write_u32(&mut writer, root_len)?;
248 write_u32(&mut writer, entry_count)?;
249 writer.write_all(root.as_bytes())?;
250
251 for (relative_path, mtime, size, content_hash, symbols) in entries {
252 let path_bytes = relative_path.to_string_lossy();
253 let path_len = u32::try_from(path_bytes.len())
254 .map_err(|_| std::io::Error::other("cached path too large"))?;
255 let (secs, nanos) = system_time_parts(mtime);
256 let symbol_bytes = serde_json::to_vec(symbols).map_err(|error| {
257 std::io::Error::other(format!("symbol serialization failed: {error}"))
258 })?;
259 let symbol_len = u32::try_from(symbol_bytes.len())
260 .map_err(|_| std::io::Error::other("cached symbol payload too large"))?;
261
262 write_u32(&mut writer, path_len)?;
263 writer.write_all(path_bytes.as_bytes())?;
264 write_i64(&mut writer, secs)?;
265 write_u32(&mut writer, nanos)?;
266 write_u64(&mut writer, size)?;
267 writer.write_all(content_hash.as_bytes())?;
268 write_u32(&mut writer, symbol_len)?;
269 writer.write_all(&symbol_bytes)?;
270 }
271
272 writer.flush()?;
273 writer.get_ref().sync_all()?;
274 Ok(())
275}
276
277fn system_time_parts(time: SystemTime) -> (i64, u32) {
278 match time.duration_since(UNIX_EPOCH) {
279 Ok(duration) => (
280 i64::try_from(duration.as_secs()).unwrap_or(i64::MAX),
281 duration.subsec_nanos(),
282 ),
283 Err(error) => {
284 let duration = error.duration();
285 let nanos = duration.subsec_nanos();
286 if nanos == 0 {
287 (-(duration.as_secs() as i64), 0)
288 } else {
289 (-(duration.as_secs() as i64) - 1, 1_000_000_000 - nanos)
290 }
291 }
292 }
293}
294
295fn system_time_from_parts(secs: i64, nanos: u32) -> Result<SystemTime, String> {
296 if nanos >= 1_000_000_000 {
297 return Err(format!(
298 "invalid symbol cache mtime nanos: {nanos} >= 1_000_000_000"
299 ));
300 }
301
302 if secs >= 0 {
303 let duration = Duration::new(secs as u64, nanos);
304 UNIX_EPOCH
305 .checked_add(duration)
306 .ok_or_else(|| format!("symbol cache mtime overflows SystemTime: {secs}.{nanos}"))
307 } else {
308 let whole = Duration::new(secs.unsigned_abs(), 0);
309 let base = UNIX_EPOCH.checked_sub(whole).ok_or_else(|| {
310 format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
311 })?;
312 base.checked_add(Duration::new(0, nanos)).ok_or_else(|| {
313 format!("symbol cache negative mtime overflows SystemTime: {secs}.{nanos}")
314 })
315 }
316}
317
318fn read_string_with_len<R: Read>(reader: &mut R, len: usize) -> Result<String, String> {
319 let mut bytes = vec![0u8; len];
320 reader
321 .read_exact(&mut bytes)
322 .map_err(|error| format!("failed to read string: {error}"))?;
323 String::from_utf8(bytes).map_err(|error| format!("invalid utf-8 string: {error}"))
324}
325
326fn read_u32<R: Read>(reader: &mut R) -> Result<u32, String> {
327 let mut bytes = [0u8; 4];
328 reader
329 .read_exact(&mut bytes)
330 .map_err(|error| format!("failed to read u32: {error}"))?;
331 Ok(u32::from_le_bytes(bytes))
332}
333
334fn read_i64<R: Read>(reader: &mut R) -> Result<i64, String> {
335 let mut bytes = [0u8; 8];
336 reader
337 .read_exact(&mut bytes)
338 .map_err(|error| format!("failed to read i64: {error}"))?;
339 Ok(i64::from_le_bytes(bytes))
340}
341
342fn read_u64<R: Read>(reader: &mut R) -> Result<u64, String> {
343 let mut bytes = [0u8; 8];
344 reader
345 .read_exact(&mut bytes)
346 .map_err(|error| format!("failed to read u64: {error}"))?;
347 Ok(u64::from_le_bytes(bytes))
348}
349
350fn write_u32<W: Write>(writer: &mut W, value: u32) -> std::io::Result<()> {
351 writer.write_all(&value.to_le_bytes())
352}
353
354fn write_i64<W: Write>(writer: &mut W, value: i64) -> std::io::Result<()> {
355 writer.write_all(&value.to_le_bytes())
356}
357
358fn write_u64<W: Write>(writer: &mut W, value: u64) -> std::io::Result<()> {
359 writer.write_all(&value.to_le_bytes())
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::symbols::{Range, SymbolKind};
366
367 fn test_symbol(name: &str) -> Symbol {
368 Symbol {
369 name: name.to_string(),
370 kind: SymbolKind::Function,
371 range: Range {
372 start_line: 0,
373 start_col: 0,
374 end_line: 0,
375 end_col: 1,
376 },
377 signature: None,
378 scope_chain: Vec::new(),
379 exported: false,
380 parent: None,
381 }
382 }
383
384 fn test_cache(project: &Path, file_name: &str) -> SymbolCache {
385 let file = project.join(file_name);
386 fs::write(&file, format!("fn {file_name}() {{}}\n")).expect("write file");
387 let metadata = fs::metadata(&file).expect("metadata");
388 let content_hash = blake3::hash(&fs::read(&file).expect("read file"));
389 let mut cache = SymbolCache::new();
390 cache.set_project_root(project.to_path_buf());
391 cache.insert(
392 file,
393 metadata.modified().expect("mtime"),
394 metadata.len(),
395 content_hash,
396 vec![test_symbol(file_name)],
397 );
398 cache
399 }
400
401 #[test]
402 fn concurrent_symbol_cache_writes_do_not_share_temp_file() {
403 let dir = tempfile::tempdir().expect("create temp dir");
404 let project = dir.path().join("project");
405 fs::create_dir_all(&project).expect("create project");
406 let storage = dir.path().join("storage");
407
408 let cache_a = test_cache(&project, "a");
409 let cache_b = test_cache(&project, "b");
410 let storage_a = storage.clone();
411 let writer_a = std::thread::spawn(move || {
412 write_to_disk(&cache_a, &storage_a, "unit-project").expect("write a");
413 });
414 let storage_b = storage.clone();
415 let writer_b = std::thread::spawn(move || {
416 write_to_disk(&cache_b, &storage_b, "unit-project").expect("write b");
417 });
418
419 writer_a.join().expect("writer a");
420 writer_b.join().expect("writer b");
421
422 let loaded = read_from_disk(&storage, "unit-project").expect("load symbol cache");
423 assert_eq!(loaded.len(), 1);
424 assert!(fs::read_dir(storage.join("symbols").join("unit-project"))
425 .expect("read symbol cache dir")
426 .all(|entry| !entry
427 .expect("cache entry")
428 .file_name()
429 .to_string_lossy()
430 .contains(".tmp.")));
431 }
432
433 #[test]
434 fn symbol_cache_rejects_paths_outside_project_root_on_write() {
435 let dir = tempfile::tempdir().expect("create temp dir");
436 let project = dir.path().join("project");
437 fs::create_dir_all(&project).expect("create project");
438 let outside = dir.path().join("outside.rs");
439 fs::write(&outside, "fn outside() {}\n").expect("write outside");
440 let metadata = fs::metadata(&outside).expect("metadata");
441
442 let mut cache = SymbolCache::new();
443 cache.set_project_root(project);
444 cache.insert(
445 outside.clone(),
446 metadata.modified().expect("mtime"),
447 metadata.len(),
448 blake3::hash(&fs::read(&outside).expect("read outside")),
449 vec![test_symbol("outside")],
450 );
451
452 let error = write_to_disk(&cache, dir.path(), "escape-project").expect_err("reject escape");
453 assert!(error.to_string().contains("outside project root"));
454 }
455}