use crate::error::SurgeonError;
use crate::language::SupportedLanguage;
use crate::parser::AstParser;
use crate::vue_zones::{parse_vue_multizone, MultiZoneTree};
use lru::LruCache;
use pathfinder_common::types::VersionHash;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::SystemTime;
use tokio::sync::OnceCell;
use tracing::instrument;
use tree_sitter::Tree;
#[inline]
fn io_err(e: std::io::Error, path: &Path) -> SurgeonError {
if e.kind() == std::io::ErrorKind::NotFound {
SurgeonError::FileNotFound(path.to_path_buf())
} else {
SurgeonError::Io(e)
}
}
type InFlightParse = Arc<OnceCell<Result<(Tree, Arc<[u8]>), SurgeonError>>>;
type InFlightVueParse = Arc<OnceCell<Result<(MultiZoneTree, VersionHash), SurgeonError>>>;
type InFlightParseMap = HashMap<PathBuf, InFlightParse>;
type InFlightVueParseMap = HashMap<PathBuf, InFlightVueParse>;
#[derive(Debug)]
pub struct CacheEntry {
pub tree: Tree,
pub source: Arc<[u8]>,
pub content_hash: VersionHash,
pub lang: SupportedLanguage,
pub mtime: SystemTime,
}
#[derive(Debug)]
pub struct MultiZoneEntry {
pub multi: MultiZoneTree,
pub content_hash: VersionHash,
pub mtime: SystemTime,
}
#[derive(Debug)]
pub struct AstCache {
entries: Mutex<LruCache<PathBuf, CacheEntry>>,
vue_entries: Mutex<LruCache<PathBuf, MultiZoneEntry>>,
in_flight: Mutex<InFlightParseMap>,
vue_in_flight: Mutex<InFlightVueParseMap>,
}
impl AstCache {
#[must_use]
#[allow(clippy::missing_panics_doc, clippy::unwrap_used)]
pub fn new(max_entries: usize) -> Self {
let cap = NonZeroUsize::new(max_entries.max(1)).unwrap();
Self {
entries: Mutex::new(LruCache::new(cap)),
vue_entries: Mutex::new(LruCache::new(cap)),
in_flight: Mutex::new(HashMap::new()),
vue_in_flight: Mutex::new(HashMap::new()),
}
}
#[instrument(skip(self), fields(cache_hit = false))]
pub async fn get_or_parse(
&self,
path: &Path,
lang: SupportedLanguage,
) -> Result<(Tree, Arc<[u8]>), SurgeonError> {
let meta = tokio::fs::metadata(path)
.await
.map_err(|e| io_err(e, path))?;
let current_mtime = meta.modified().unwrap_or(SystemTime::UNIX_EPOCH);
{
let mut lock = self.entries.lock().map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "Lock poisoned".into(),
})?;
if let Some(entry) = lock.get(path) {
if entry.mtime == current_mtime && entry.lang == lang {
tracing::Span::current().record("cache_hit", true);
return Ok((entry.tree.clone(), entry.source.clone()));
}
}
}
let cell = {
let mut in_flight = self
.in_flight
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "In-flight lock poisoned".into(),
})?;
in_flight
.entry(path.to_path_buf())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
};
let result = cell
.get_or_init(|| async {
let content = tokio::fs::read(path).await.map_err(|e| io_err(e, path))?;
let current_hash = VersionHash::compute(&content);
let content_arc: Arc<[u8]> = Arc::from(content);
let parse_input = lang.preprocess_source(&content_arc);
let tree = AstParser::parse_source(path, lang, &parse_input)?;
self.entries
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "Lock poisoned".into(),
})?
.put(
path.to_path_buf(),
CacheEntry {
tree: tree.clone(),
source: content_arc.clone(),
content_hash: current_hash,
lang,
mtime: current_mtime,
},
);
Ok::<_, SurgeonError>((tree, content_arc.clone()))
})
.await;
{
let mut in_flight = self
.in_flight
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "In-flight lock poisoned".into(),
})?;
in_flight.remove(path);
}
match result.as_ref() {
Ok((tree, source)) => Ok((tree.clone(), source.clone())),
Err(e) => Err(SurgeonError::ParseError {
path: path.to_path_buf(),
reason: format!("Parse failed: {e}"),
}),
}
}
#[instrument(skip(self), fields(cache_hit = false))]
pub async fn get_or_parse_vue(
&self,
path: &Path,
) -> Result<(MultiZoneTree, VersionHash), SurgeonError> {
let meta = tokio::fs::metadata(path)
.await
.map_err(|e| io_err(e, path))?;
let current_mtime = meta.modified().unwrap_or(SystemTime::UNIX_EPOCH);
{
let mut lock = self
.vue_entries
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "Vue cache lock poisoned".into(),
})?;
if let Some(entry) = lock.get(path) {
if entry.mtime == current_mtime {
tracing::Span::current().record("cache_hit", true);
let multi = MultiZoneTree {
script_tree: entry.multi.script_tree.clone(),
template_tree: entry.multi.template_tree.clone(),
style_tree: entry.multi.style_tree.clone(),
zones: entry.multi.zones.clone(),
source: entry.multi.source.clone(),
degraded: entry.multi.degraded,
};
return Ok((multi, entry.content_hash.clone()));
}
}
}
let cell = {
let mut vue_in_flight =
self.vue_in_flight
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "Vue in-flight lock poisoned".into(),
})?;
vue_in_flight
.entry(path.to_path_buf())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
};
let result = cell
.get_or_init(|| async {
let content = tokio::fs::read(path).await.map_err(|e| io_err(e, path))?;
let content_hash = VersionHash::compute(&content);
let multi =
parse_vue_multizone(&content).map_err(|e| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: format!("Vue multi-zone parse failed: {e}"),
})?;
let cached_multi = MultiZoneTree {
script_tree: multi.script_tree.clone(),
template_tree: multi.template_tree.clone(),
style_tree: multi.style_tree.clone(),
zones: multi.zones.clone(),
source: multi.source.clone(),
degraded: multi.degraded,
};
self.vue_entries
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "Vue cache lock poisoned".into(),
})?
.put(
path.to_path_buf(),
MultiZoneEntry {
multi: cached_multi,
content_hash: content_hash.clone(),
mtime: current_mtime,
},
);
Ok::<_, SurgeonError>((multi, content_hash))
})
.await;
{
let mut vue_in_flight =
self.vue_in_flight
.lock()
.map_err(|_| SurgeonError::ParseError {
path: path.to_path_buf(),
reason: "Vue in-flight lock poisoned".into(),
})?;
vue_in_flight.remove(path);
}
match result.as_ref() {
Ok((multi, hash)) => Ok((multi.clone(), hash.clone())),
Err(e) => Err(SurgeonError::ParseError {
path: path.to_path_buf(),
reason: format!("Parse failed: {e}"),
}),
}
}
pub fn invalidate(&self, path: &Path) {
if let Ok(mut lock) = self.entries.lock() {
lock.pop(path);
}
if let Ok(mut lock) = self.vue_entries.lock() {
lock.pop(path);
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::{tempdir, NamedTempFile};
#[tokio::test]
async fn test_cache_hits_and_misses() {
let cache = AstCache::new(2);
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "package main\nfunc A() {{}}").unwrap();
let path = file.path().to_path_buf();
let (tree1, src1) = cache
.get_or_parse(&path, SupportedLanguage::Go)
.await
.unwrap();
assert_eq!(src1.len(), 25);
let (tree2, src2) = cache
.get_or_parse(&path, SupportedLanguage::Go)
.await
.unwrap();
assert_eq!(src2.len(), 25);
assert_eq!(
tree1.root_node().child_count(),
tree2.root_node().child_count()
);
{
let lock = cache.entries.lock().unwrap();
let entry = lock.peek(&path).unwrap();
let meta = std::fs::metadata(&path).unwrap();
assert_eq!(
entry.mtime,
meta.modified().unwrap_or(SystemTime::UNIX_EPOCH)
);
}
std::thread::sleep(std::time::Duration::from_millis(10));
writeln!(file, "func B() {{}}").unwrap();
let (_tree3, src3) = cache
.get_or_parse(&path, SupportedLanguage::Go)
.await
.unwrap();
assert!(src3.len() > 25); }
#[tokio::test]
async fn test_cache_eviction_lru() {
let cache = AstCache::new(2);
let mut f1 = NamedTempFile::new().unwrap();
writeln!(f1, "func A() {{}}").unwrap();
let mut f2 = NamedTempFile::new().unwrap();
writeln!(f2, "func B() {{}}").unwrap();
let mut f3 = NamedTempFile::new().unwrap();
writeln!(f3, "func C() {{}}").unwrap();
cache
.get_or_parse(f1.path(), SupportedLanguage::Go)
.await
.unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
cache
.get_or_parse(f2.path(), SupportedLanguage::Go)
.await
.unwrap();
{
let lock = cache.entries.lock().unwrap();
assert_eq!(lock.len(), 2);
assert!(lock.contains(f1.path()));
assert!(lock.contains(f2.path()));
}
std::thread::sleep(std::time::Duration::from_millis(10));
cache
.get_or_parse(f1.path(), SupportedLanguage::Go)
.await
.unwrap();
cache
.get_or_parse(f3.path(), SupportedLanguage::Go)
.await
.unwrap();
{
let lock = cache.entries.lock().unwrap();
assert_eq!(lock.len(), 2);
assert!(lock.contains(f1.path()));
assert!(!lock.contains(f2.path())); assert!(lock.contains(f3.path()));
}
}
#[tokio::test]
async fn test_cache_invalidation() {
let cache = AstCache::new(2);
let mut f1 = NamedTempFile::new().unwrap();
writeln!(f1, "func A() {{}}").unwrap();
cache
.get_or_parse(f1.path(), SupportedLanguage::Go)
.await
.unwrap();
assert_eq!(cache.entries.lock().unwrap().len(), 1);
cache.invalidate(f1.path());
assert_eq!(cache.entries.lock().unwrap().len(), 0);
}
#[tokio::test]
async fn test_vue_cache_hits_and_misses() {
let cache = AstCache::new(2);
let sfc = b"<template>\n<div>Hello</div>\n</template>\n<script setup lang=\"ts\">\nconst x = 1\n</script>\n";
let mut file = NamedTempFile::new().unwrap();
file.write_all(sfc).unwrap();
let (multi1, hash1) = cache.get_or_parse_vue(file.path()).await.unwrap();
assert!(multi1.script_tree.is_some());
assert!(multi1.template_tree.is_some());
assert!(!multi1.degraded);
let (_multi2, hash2) = cache.get_or_parse_vue(file.path()).await.unwrap();
assert_eq!(hash1, hash2, "hash must be stable across cache hits");
{
let lock = cache.vue_entries.lock().unwrap();
assert_eq!(lock.len(), 1, "exactly one Vue entry cached");
}
}
#[tokio::test]
async fn test_vue_cache_invalidation_clears_both_caches() {
let cache = AstCache::new(2);
let mut go_file = NamedTempFile::new().unwrap();
writeln!(go_file, "func A() {{}}").unwrap();
cache
.get_or_parse(go_file.path(), SupportedLanguage::Go)
.await
.unwrap();
let sfc = b"<template><div/></template>\n";
let mut vue_file = NamedTempFile::new().unwrap();
vue_file.write_all(sfc).unwrap();
cache.get_or_parse_vue(vue_file.path()).await.unwrap();
assert_eq!(cache.entries.lock().unwrap().len(), 1);
assert_eq!(cache.vue_entries.lock().unwrap().len(), 1);
cache.invalidate(vue_file.path());
assert_eq!(
cache.vue_entries.lock().unwrap().len(),
0,
"Vue entry cleared"
);
assert_eq!(cache.entries.lock().unwrap().len(), 1, "Go entry untouched");
}
#[tokio::test]
async fn test_singleflight_prevents_redundant_parsing() {
let cache = AstCache::new(10);
let mut file = NamedTempFile::new().unwrap();
let content = "package main\n".repeat(1000); file.write_all(content.as_bytes()).unwrap();
let path = Arc::new(file.path().to_path_buf());
let cache = Arc::new(cache);
let handles = (0..5).map(|_| {
let cache = Arc::clone(&cache);
let path = Arc::clone(&path);
tokio::spawn(async move { cache.get_or_parse(&path, SupportedLanguage::Go).await })
});
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap().unwrap())
.collect();
for result in &results[1..] {
assert_eq!(
results[0].0.root_node().child_count(),
result.0.root_node().child_count()
);
assert_eq!(results[0].1.len(), result.1.len());
}
assert_eq!(cache.entries.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn test_singleflight_vue() {
let cache = AstCache::new(10);
let sfc = b"<template><div/></template>\n<script>const x = 1;</script>\n";
let mut file = NamedTempFile::new().unwrap();
file.write_all(sfc).unwrap();
let path = Arc::new(file.path().to_path_buf());
let cache = Arc::new(cache);
let handles = (0..3).map(|_| {
let cache = Arc::clone(&cache);
let path = Arc::clone(&path);
tokio::spawn(async move { cache.get_or_parse_vue(&path).await })
});
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap().unwrap())
.collect();
for result in &results[1..] {
assert_eq!(results[0].1, result.1);
}
assert_eq!(cache.vue_entries.lock().unwrap().len(), 1);
}
#[tokio::test]
async fn test_get_or_parse_missing_file_returns_file_not_found() {
let cache = AstCache::new(2);
let temp_dir = tempdir().unwrap();
let missing = temp_dir
.path()
.join("this_file_does_not_exist_pathfinder.go");
let err = cache
.get_or_parse(&missing, SupportedLanguage::Go)
.await
.unwrap_err();
assert!(
matches!(err, SurgeonError::FileNotFound(_)),
"expected FileNotFound, got: {err:?}"
);
}
#[tokio::test]
async fn test_get_or_parse_vue_missing_file_returns_file_not_found() {
let cache = AstCache::new(2);
let dir = tempdir().unwrap();
let missing = dir.path().join("this_file_does_not_exist_pathfinder.vue");
let err = cache.get_or_parse_vue(&missing).await.unwrap_err();
assert!(
matches!(err, SurgeonError::FileNotFound(_)),
"expected FileNotFound, got: {err:?}"
);
}
}