use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use ignore::WalkBuilder;
use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};
use crate::semantic::language::{Lang, LanguageRegistry};
use super::resolve::resolve_callee;
use super::{CodeGraph, Edge, EdgeKind, SymbolKind, SymbolNode, Visibility};
struct FileParseResult {
symbols: Vec<SymbolNode>,
raw_calls: Vec<RawCall>,
}
struct RawCall {
caller_name: String,
callee_name: String,
line: usize,
}
const INDEXED_EXTENSIONS: &[&str] = &[
"rs", "py", "js", "ts", "tsx", "go", "java", "c", "cpp", "vue",
];
pub struct GraphIndexer {
graph: Arc<RwLock<CodeGraph>>,
project_dir: PathBuf,
parser: Parser,
}
impl GraphIndexer {
pub fn new(graph: Arc<RwLock<CodeGraph>>, project_dir: PathBuf) -> Self {
Self {
graph,
project_dir,
parser: Parser::new(),
}
}
pub async fn index_all(&mut self, cancel: CancellationToken) {
if !should_index(&self.project_dir) {
return;
}
if cancel.is_cancelled() {
return;
}
let project_dir = self.project_dir.clone();
let files = tokio::task::spawn_blocking(move || collect_files_sync(&project_dir))
.await
.unwrap_or_default();
let current_paths: HashSet<PathBuf> = files.iter().map(|(p, _)| p.clone()).collect();
let (deleted, dirty_files) = {
let graph = self.graph.read().await;
let deleted: Vec<PathBuf> = graph
.file_mtimes
.keys()
.filter(|p| !current_paths.contains(*p))
.cloned()
.collect();
let dirty: Vec<(PathBuf, u64)> = files
.into_iter()
.filter(|(path, mtime)| graph.file_mtimes.get(path) != Some(mtime))
.collect();
(deleted, dirty)
};
const CPU_BREATHE_CHUNK: usize = 16;
const CPU_BREATHE_MS: u64 = 5;
let mut all_results: Vec<(PathBuf, u64, FileParseResult)> = Vec::new();
for (i, (path, mtime)) in dirty_files.into_iter().enumerate() {
if cancel.is_cancelled() {
return;
}
if let Some(result) = self.parse_file(&path) {
all_results.push((path, mtime, result));
}
tokio::task::yield_now().await;
if i > 0 && i % CPU_BREATHE_CHUNK == 0 {
tokio::time::sleep(std::time::Duration::from_millis(CPU_BREATHE_MS)).await;
}
}
if deleted.is_empty() && all_results.is_empty() {
return; }
if cancel.is_cancelled() {
return;
}
let mut graph = self.graph.write().await;
for path in &deleted {
graph.remove_file(path);
}
for (path, mtime, result) in &all_results {
graph.remove_file(path);
for sym in &result.symbols {
graph.add_symbol(sym.clone());
}
graph.file_mtimes.insert(path.clone(), *mtime);
}
for (_path, _mtime, result) in &all_results {
for raw_call in &result.raw_calls {
let caller_candidates = graph.find_by_name(&raw_call.caller_name);
let caller_id = caller_candidates.first().map(|s| s.id);
if let Some(caller_id) = caller_id {
let caller_file = graph.node(caller_id).unwrap().file.clone();
if let Some(callee_id) =
resolve_callee(&graph, &raw_call.callee_name, &caller_file, &[])
{
graph.add_edge(
caller_id,
Edge {
to: callee_id,
kind: EdgeKind::Calls,
line: raw_call.line,
},
);
}
}
}
}
}
pub async fn reindex_file(&mut self, path: &Path) {
let mtime = match std::fs::metadata(path) {
Ok(meta) => {
use std::time::UNIX_EPOCH;
meta.modified()
.ok()
.and_then(|t| t.duration_since(UNIX_EPOCH).ok())
.map(|d| d.as_secs())
.unwrap_or(0)
}
Err(_) => {
let mut graph = self.graph.write().await;
graph.remove_file(&path.to_path_buf());
return;
}
};
let result = match self.parse_file(path) {
Some(r) => r,
None => return,
};
let mut graph = self.graph.write().await;
let path_buf = path.to_path_buf();
graph.remove_file(&path_buf);
for sym in &result.symbols {
graph.add_symbol(sym.clone());
}
graph.file_mtimes.insert(path_buf.clone(), mtime);
for raw_call in &result.raw_calls {
let caller_candidates = graph.find_by_name(&raw_call.caller_name);
let caller_id = caller_candidates.first().map(|s| s.id);
if let Some(caller_id) = caller_id {
let caller_file = graph.node(caller_id).unwrap().file.clone();
if let Some(callee_id) =
resolve_callee(&graph, &raw_call.callee_name, &caller_file, &[])
{
graph.add_edge(
caller_id,
Edge {
to: callee_id,
kind: EdgeKind::Calls,
line: raw_call.line,
},
);
}
}
}
}
fn parse_file(&mut self, path: &Path) -> Option<FileParseResult> {
let source = std::fs::read_to_string(path).ok()?;
let lang = LanguageRegistry::detect(path)?;
self.parser.set_language(&lang.grammar()).ok()?;
let tree = self.parser.parse(source.as_bytes(), None)?;
let symbols = self.extract_symbols(path, &source, lang, &tree);
let raw_calls = self.extract_calls(path, &source, lang, &tree, &symbols);
Some(FileParseResult { symbols, raw_calls })
}
fn extract_symbols(
&self,
path: &Path,
source: &str,
lang: Lang,
tree: &tree_sitter::Tree,
) -> Vec<SymbolNode> {
let query_src = lang.symbols_query();
let query = match Query::new(&lang.grammar(), query_src) {
Ok(q) => q,
Err(_) => return Vec::new(),
};
let def_idx = match query.capture_index_for_name("definition") {
Some(i) => i,
None => return Vec::new(),
};
let name_idx = match query.capture_index_for_name("name") {
Some(i) => i,
None => return Vec::new(),
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
let mut symbols = Vec::new();
let mut seen_ranges: HashSet<(usize, usize)> = HashSet::new();
let path_buf = path.to_path_buf();
loop {
matches.advance();
let m = match matches.get() {
Some(m) => m,
None => break,
};
let mut sym_name = None;
let mut def_start_line = 0usize;
let mut def_end_line = 0usize;
let mut def_start_byte = 0usize;
let mut def_end_byte = 0usize;
let mut ts_kind = "";
let mut has_def = false;
for capture in m.captures {
if capture.index == name_idx {
sym_name = Some(
source[capture.node.start_byte()..capture.node.end_byte()].to_string(),
);
}
if capture.index == def_idx {
def_start_byte = capture.node.start_byte();
def_end_byte = capture.node.end_byte();
def_start_line = capture.node.start_position().row + 1; def_end_line = capture.node.end_position().row + 1;
ts_kind = capture.node.kind();
has_def = true;
}
}
if let (Some(name), true) = (sym_name, has_def) {
let range = (def_start_byte, def_end_byte);
if seen_ranges.contains(&range) {
continue;
}
seen_ranges.insert(range);
let id = CodeGraph::make_id(&path_buf, &name, def_start_line);
let kind = classify_symbol_kind(ts_kind);
symbols.push(SymbolNode {
id,
name,
kind,
visibility: Visibility::Unknown,
file: path_buf.clone(),
start_line: def_start_line,
end_line: def_end_line,
signature: None,
});
}
}
symbols
}
fn extract_calls(
&self,
_path: &Path,
source: &str,
lang: Lang,
tree: &tree_sitter::Tree,
symbols: &[SymbolNode],
) -> Vec<RawCall> {
let query_src = match lang.calls_query() {
Some(q) => q,
None => return Vec::new(),
};
let query = match Query::new(&lang.grammar(), query_src) {
Ok(q) => q,
Err(_) => return Vec::new(),
};
let callee_idx = match query.capture_index_for_name("callee") {
Some(i) => i,
None => return Vec::new(),
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
let mut raw_calls = Vec::new();
loop {
matches.advance();
let m = match matches.get() {
Some(m) => m,
None => break,
};
for capture in m.captures {
if capture.index == callee_idx {
let callee_name =
source[capture.node.start_byte()..capture.node.end_byte()].to_string();
let call_line = capture.node.start_position().row + 1;
let caller_name = symbols
.iter()
.filter(|s| {
matches!(s.kind, SymbolKind::Function | SymbolKind::Method)
&& s.start_line <= call_line
&& call_line <= s.end_line
})
.last()
.map(|s| s.name.clone());
if let Some(caller_name) = caller_name {
if caller_name == callee_name {
continue;
}
raw_calls.push(RawCall {
caller_name,
callee_name,
line: call_line,
});
}
}
}
}
raw_calls
}
}
fn classify_symbol_kind(ts_kind: &str) -> SymbolKind {
match ts_kind {
"function_item" | "function_definition" | "function_declaration" | "func_literal" => {
SymbolKind::Function
}
"method_definition" | "method_declaration" => SymbolKind::Method,
"struct_item" | "struct_specifier" => SymbolKind::Struct,
"class_definition" | "class_declaration" | "class_specifier" => SymbolKind::Class,
"trait_item" => SymbolKind::Trait,
"interface_declaration" => SymbolKind::Interface,
"enum_item" | "enum_declaration" | "enum_specifier" => SymbolKind::Enum,
"const_item" | "const_declaration" => SymbolKind::Constant,
"let_declaration" | "variable_declaration" | "static_item" => SymbolKind::Variable,
"mod_item" | "module" => SymbolKind::Module,
"use_declaration" | "import_statement" | "import_declaration" => SymbolKind::Import,
"type_item" | "type_alias_declaration" => SymbolKind::TypeAlias,
"impl_item" => SymbolKind::Other("impl".to_string()),
other => SymbolKind::Other(other.to_string()),
}
}
pub fn should_index(project_dir: &Path) -> bool {
if looks_like_project(project_dir) {
return true;
}
if is_home_or_root(project_dir) {
return false;
}
if is_umbrella_dir(project_dir) {
return false;
}
true
}
fn is_home_or_root(path: &Path) -> bool {
if path == Path::new("/") {
return true;
}
if let Some(home) = crate::tool::real_home_dir() {
if path == home.as_path() {
return true;
}
}
false
}
fn is_umbrella_dir(dir: &Path) -> bool {
let Ok(entries) = std::fs::read_dir(dir) else {
return false;
};
let mut project_children = 0;
for entry in entries.flatten().take(200) {
let p = entry.path();
if p.is_dir() && looks_like_project(&p) {
project_children += 1;
if project_children >= 3 {
return true;
}
}
}
false
}
fn looks_like_project(dir: &Path) -> bool {
const MARKERS: &[&str] = &[
".git",
"Cargo.toml",
"package.json",
"pyproject.toml",
"go.mod",
"pom.xml",
"build.gradle",
"build.gradle.kts",
];
MARKERS.iter().any(|m| dir.join(m).exists())
}
fn collect_files_sync(project_dir: &Path) -> Vec<(PathBuf, u64)> {
let mut files = Vec::new();
let walker = WalkBuilder::new(project_dir)
.hidden(true)
.git_ignore(true)
.build();
for entry in walker {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let path = entry.path();
if !path.is_file() {
continue;
}
let ext = match path.extension().and_then(|e| e.to_str()) {
Some(e) => e,
None => continue,
};
if !INDEXED_EXTENSIONS.contains(&ext) {
continue;
}
let mtime = match entry.metadata() {
Ok(meta) => {
use std::time::UNIX_EPOCH;
meta.modified()
.ok()
.and_then(|t| t.duration_since(UNIX_EPOCH).ok())
.map(|d| d.as_secs())
.unwrap_or(0)
}
Err(_) => 0,
};
files.push((path.to_path_buf(), mtime));
}
files
}
#[cfg(test)]
mod tests {
use super::*;
fn mk(parent: &Path, name: &str, markers: &[&str]) {
let p = parent.join(name);
std::fs::create_dir_all(&p).unwrap();
for m in markers {
std::fs::write(p.join(m), "").unwrap();
}
}
#[test]
fn should_index_accepts_marked_project() {
let tmp = tempfile::TempDir::new().unwrap();
std::fs::write(tmp.path().join("Cargo.toml"), "[package]").unwrap();
assert!(should_index(tmp.path()));
}
#[test]
fn should_index_refuses_umbrella_dir_with_many_child_projects() {
let tmp = tempfile::TempDir::new().unwrap();
mk(tmp.path(), "a", &[".git"]);
mk(tmp.path(), "b", &[".git"]);
mk(tmp.path(), "c", &["package.json"]);
mk(tmp.path(), "d", &["Cargo.toml"]);
assert!(
!should_index(tmp.path()),
"umbrella of 4 projects without own marker must be skipped"
);
}
#[test]
fn should_index_accepts_umbrella_with_real_marker() {
let tmp = tempfile::TempDir::new().unwrap();
std::fs::write(tmp.path().join("Cargo.toml"), "[workspace]").unwrap();
mk(tmp.path(), "a", &[".git"]);
mk(tmp.path(), "b", &[".git"]);
mk(tmp.path(), "c", &[".git"]);
assert!(
should_index(tmp.path()),
"user-placed marker must override umbrella detection"
);
}
#[test]
fn should_index_refuses_umbrella_with_only_atomcode_storage_dir() {
let tmp = tempfile::TempDir::new().unwrap();
std::fs::create_dir_all(tmp.path().join(".atomcode")).unwrap();
std::fs::write(tmp.path().join(".atomcode").join("graph.bin"), b"x").unwrap();
mk(tmp.path(), "a", &[".git"]);
mk(tmp.path(), "b", &[".git"]);
mk(tmp.path(), "c", &[".git"]);
assert!(
!should_index(tmp.path()),
".atomcode dir must not rescue an umbrella from the guard"
);
}
#[test]
fn should_index_accepts_dir_with_fewer_than_3_child_projects() {
let tmp = tempfile::TempDir::new().unwrap();
mk(tmp.path(), "a", &[".git"]);
mk(tmp.path(), "b", &[".git"]);
mk(tmp.path(), "other", &[]); assert!(
should_index(tmp.path()),
"2 child projects < umbrella threshold"
);
}
#[tokio::test]
async fn index_all_bails_on_cancelled_token() {
let tmp = tempfile::TempDir::new().unwrap();
std::fs::write(tmp.path().join(".atomcode"), "").unwrap();
std::fs::write(
tmp.path().join("lib.rs"),
"pub fn foo() {}\npub fn bar() {}\n",
)
.unwrap();
let graph = Arc::new(RwLock::new(super::super::CodeGraph::default()));
let mut indexer = GraphIndexer::new(graph.clone(), tmp.path().to_path_buf());
let cancel = CancellationToken::new();
cancel.cancel();
indexer.index_all(cancel).await;
let g = graph.read().await;
assert!(
g.file_mtimes.is_empty(),
"cancelled indexer must not mutate graph"
);
}
}