use anyhow::Result;
use sqlitegraph::{
BackendDirection, EdgeSpec, GraphBackend, NeighborQuery, NodeId, NodeSpec, SnapshotId,
};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use crate::graph::schema::ReferenceNode;
use crate::ingest::c::CParser;
use crate::ingest::cpp::CppParser;
use crate::ingest::java::JavaParser;
use crate::ingest::javascript::JavaScriptParser;
use crate::ingest::pool;
use crate::ingest::python::PythonParser;
use crate::ingest::typescript::TypeScriptParser;
use crate::ingest::{detect::Language, detect_language, Parser};
use crate::references::ReferenceFact;
pub struct ReferenceOps {
pub backend: Arc<dyn GraphBackend>,
}
impl ReferenceOps {
pub fn delete_references_in_file(&self, path: &str) -> Result<usize> {
let entity_ids = self.backend.entity_ids()?;
let snapshot = SnapshotId::current();
let mut to_delete: Vec<i64> = Vec::new();
for entity_id in entity_ids {
let node = match self.backend.get_node(snapshot, entity_id) {
Ok(n) => n,
Err(_) => continue,
};
if node.kind != "Reference" {
continue;
}
let reference_node: ReferenceNode = match serde_json::from_value(node.data) {
Ok(value) => value,
Err(_) => continue,
};
if reference_node.file == path {
to_delete.push(entity_id);
}
}
to_delete.sort_unstable();
for id in &to_delete {
self.backend.delete_entity(*id)?;
}
Ok(to_delete.len())
}
#[expect(dead_code)] pub fn index_references(
&self,
path: &str,
source: &[u8],
symbol_fqn_to_id: &HashMap<String, i64>,
) -> Result<usize> {
let path_buf = PathBuf::from(path);
let language = detect_language(&path_buf);
let mut all_symbol_facts: Vec<crate::ingest::SymbolFact> = Vec::new();
let entity_ids = match self.backend.entity_ids() {
Ok(ids) => ids,
Err(_) => return Ok(0), };
let snapshot = SnapshotId::current();
for entity_id in entity_ids {
if let Ok(node) = self.backend.get_node(snapshot, entity_id) {
if node.kind == "Symbol" {
if let Ok(symbol_node) = serde_json::from_value::<
crate::graph::schema::SymbolNode,
>(node.data.clone())
{
if let Some(name) = &symbol_node.name {
let file_path_str = node.file_path.as_deref().unwrap_or("");
if std::str::from_utf8(file_path_str.as_bytes()).is_ok() {
let fqn = symbol_node
.fqn
.clone()
.or(symbol_node.name.clone())
.unwrap_or_default();
all_symbol_facts.push(crate::ingest::SymbolFact {
file_path: PathBuf::from(file_path_str),
kind: match symbol_node.kind_normalized.as_deref() {
Some("fn") => crate::ingest::SymbolKind::Function,
Some("method") => crate::ingest::SymbolKind::Method,
Some("struct") => crate::ingest::SymbolKind::Class,
Some("enum") => crate::ingest::SymbolKind::Enum,
Some("trait") => crate::ingest::SymbolKind::Interface,
Some("mod") => crate::ingest::SymbolKind::Module,
_ => crate::ingest::SymbolKind::Unknown,
},
kind_normalized: symbol_node
.kind_normalized
.clone()
.unwrap_or(symbol_node.kind.clone()),
name: Some(name.clone()),
fqn: if fqn.is_empty() { None } else { Some(fqn) },
canonical_fqn: None,
display_fqn: None,
byte_start: symbol_node.byte_start,
byte_end: symbol_node.byte_end,
start_line: symbol_node.start_line,
start_col: symbol_node.start_col,
end_line: symbol_node.end_line,
end_col: symbol_node.end_col,
});
}
}
}
}
}
}
let references = match language {
Some(Language::Rust) => {
let mut parser = Parser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
Some(Language::Python) => {
let mut parser = PythonParser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
Some(Language::C) => {
let mut parser = CParser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
Some(Language::Cpp) => {
let mut parser = CppParser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
Some(Language::Java) => {
let mut parser = JavaParser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
Some(Language::JavaScript) => {
let mut parser = JavaScriptParser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
Some(Language::TypeScript) => {
let mut parser = TypeScriptParser::new()?;
parser.extract_references(path_buf.clone(), source, &all_symbol_facts)
}
None => Vec::new(),
};
for reference in &references {
if let Some(&target_symbol_id) = symbol_fqn_to_id.get(&reference.referenced_symbol) {
let reference_id = self.insert_reference_node(reference)?;
self.insert_references_edge(
reference_id,
NodeId::from(target_symbol_id),
reference,
)?;
}
}
Ok(references.len())
}
pub fn index_references_with_symbol_id(
&self,
path: &str,
source: &[u8],
symbol_id_to_id: &HashMap<String, i64>,
fqn_to_id: &HashMap<String, i64>,
) -> Result<usize> {
let path_buf = PathBuf::from(path);
let language = detect_language(&path_buf);
let mut all_symbol_facts: Vec<crate::ingest::SymbolFact> = Vec::new();
let entity_ids = match self.backend.entity_ids() {
Ok(ids) => ids,
Err(_) => return Ok(0), };
let snapshot = SnapshotId::current();
for entity_id in entity_ids {
if let Ok(node) = self.backend.get_node(snapshot, entity_id) {
if node.kind == "Symbol" {
if let Ok(symbol_node) = serde_json::from_value::<
crate::graph::schema::SymbolNode,
>(node.data.clone())
{
if let Some(name) = &symbol_node.name {
let file_path_str = node.file_path.as_deref().unwrap_or("");
if std::str::from_utf8(file_path_str.as_bytes()).is_ok() {
let fqn = symbol_node
.fqn
.clone()
.or(symbol_node.name.clone())
.unwrap_or_default();
all_symbol_facts.push(crate::ingest::SymbolFact {
file_path: PathBuf::from(file_path_str),
kind: match symbol_node.kind_normalized.as_deref() {
Some("fn") => crate::ingest::SymbolKind::Function,
Some("method") => crate::ingest::SymbolKind::Method,
Some("struct") => crate::ingest::SymbolKind::Class,
Some("enum") => crate::ingest::SymbolKind::Enum,
Some("trait") => crate::ingest::SymbolKind::Interface,
Some("mod") => crate::ingest::SymbolKind::Module,
_ => crate::ingest::SymbolKind::Unknown,
},
kind_normalized: symbol_node
.kind_normalized
.clone()
.unwrap_or(symbol_node.kind.clone()),
name: Some(name.clone()),
fqn: if fqn.is_empty() { None } else { Some(fqn) },
canonical_fqn: symbol_node.canonical_fqn.clone(),
display_fqn: symbol_node.display_fqn.clone(),
byte_start: symbol_node.byte_start,
byte_end: symbol_node.byte_end,
start_line: symbol_node.start_line,
start_col: symbol_node.start_col,
end_line: symbol_node.end_line,
end_col: symbol_node.end_col,
});
}
}
}
}
}
}
let references = match language {
Some(Language::Rust) => pool::with_parser_opt(Language::Rust, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = Parser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::Python) => pool::with_parser_opt(Language::Python, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = PythonParser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::C) => pool::with_parser_opt(Language::C, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = CParser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::Cpp) => pool::with_parser_opt(Language::Cpp, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = CppParser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::Java) => pool::with_parser_opt(Language::Java, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = JavaParser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::JavaScript) => {
pool::with_parser_opt(Language::JavaScript, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = JavaScriptParser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?
}
Some(Language::TypeScript) => {
pool::with_parser_opt(Language::TypeScript, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = TypeScriptParser::from_parser(parser);
let result =
wrapper.extract_references(path_buf.clone(), source, &all_symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?
}
None => Vec::new(),
};
for reference in &references {
let target_symbol_id = symbol_id_to_id
.get(&reference.referenced_symbol)
.or_else(|| fqn_to_id.get(&reference.referenced_symbol));
if let Some(&target_symbol_id) = target_symbol_id {
let reference_id = self.insert_reference_node(reference)?;
self.insert_references_edge(
reference_id,
NodeId::from(target_symbol_id),
reference,
)?;
}
}
Ok(references.len())
}
pub fn references_to_symbol(&mut self, symbol_id: i64) -> Result<Vec<ReferenceFact>> {
let snapshot = SnapshotId::current();
let neighbor_ids = self.backend.neighbors(
snapshot,
symbol_id,
NeighborQuery {
direction: BackendDirection::Incoming,
edge_type: Some("REFERENCES".to_string()),
},
)?;
let mut references = Vec::new();
for reference_node_id in neighbor_ids {
if let Ok(Some(reference)) = self.reference_fact_from_node(reference_node_id) {
references.push(reference);
}
}
Ok(references)
}
fn insert_reference_node(&self, reference: &ReferenceFact) -> Result<NodeId> {
let reference_node = ReferenceNode {
file: reference.file_path.to_string_lossy().to_string(),
byte_start: reference.byte_start as u64,
byte_end: reference.byte_end as u64,
start_line: reference.start_line as u64,
start_col: reference.start_col as u64,
end_line: reference.end_line as u64,
end_col: reference.end_col as u64,
};
let node_spec = NodeSpec {
kind: "Reference".to_string(),
name: format!("ref to {}", reference.referenced_symbol),
file_path: Some(reference.file_path.to_string_lossy().to_string()),
data: serde_json::to_value(reference_node)?,
};
let id = self.backend.insert_node(node_spec)?;
Ok(NodeId::from(id))
}
fn insert_references_edge(
&self,
reference_id: NodeId,
symbol_id: NodeId,
reference: &ReferenceFact,
) -> Result<()> {
let edge_spec = EdgeSpec {
from: reference_id.as_i64(),
to: symbol_id.as_i64(),
edge_type: "REFERENCES".to_string(),
data: serde_json::json!({
"byte_start": reference.byte_start,
"byte_end": reference.byte_end,
"start_line": reference.start_line,
"start_col": reference.start_col,
"end_line": reference.end_line,
"end_col": reference.end_col,
}),
};
self.backend.insert_edge(edge_spec)?;
Ok(())
}
fn reference_fact_from_node(&self, node_id: i64) -> Result<Option<ReferenceFact>> {
let snapshot = SnapshotId::current();
let node = self.backend.get_node(snapshot, node_id)?;
let reference_node: Option<ReferenceNode> = serde_json::from_value(node.data).ok();
let reference_node = match reference_node {
Some(n) => n,
None => return Ok(None),
};
let referenced_symbol = node.name.strip_prefix("ref to ").unwrap_or("").to_string();
Ok(Some(ReferenceFact {
file_path: PathBuf::from(&reference_node.file),
referenced_symbol,
byte_start: reference_node.byte_start as usize,
byte_end: reference_node.byte_end as usize,
start_line: reference_node.start_line as usize,
start_col: reference_node.start_col as usize,
end_line: reference_node.end_line as usize,
end_col: reference_node.end_col as usize,
}))
}
}
#[cfg(test)]
mod tests {
use crate::graph::schema::SymbolNode;
use sqlitegraph::{GraphBackend, SnapshotId};
use std::collections::HashMap;
#[test]
fn test_index_references_with_symbol_id_uses_fqn_fallback() {
let temp_dir = tempfile::TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let mut graph = crate::CodeGraph::open(&db_path).unwrap();
let test_file = temp_dir.path().join("test.rs");
std::fs::write(
&test_file,
r#"
fn foo() {}
fn bar() {
foo();
}
"#,
)
.unwrap();
let path_str = test_file.to_string_lossy().to_string();
let source = std::fs::read(&test_file).unwrap();
graph.index_file(&path_str, &source).unwrap();
let mut symbol_id_to_id: HashMap<String, i64> = HashMap::new();
let mut fqn_to_id: HashMap<String, i64> = HashMap::new();
let entity_ids = graph.files.backend.entity_ids().unwrap();
let snapshot = SnapshotId::current();
for entity_id in entity_ids {
if let Ok(node) = graph.files.backend.get_node(snapshot, entity_id) {
if node.kind == "Symbol" {
if let Ok(symbol_node) = serde_json::from_value::<SymbolNode>(node.data) {
if let Some(symbol_id) = symbol_node.symbol_id {
symbol_id_to_id.insert(symbol_id, entity_id);
}
let fqn = symbol_node.fqn.or(symbol_node.name).unwrap_or_default();
if !fqn.is_empty() {
fqn_to_id.insert(fqn, entity_id);
}
}
}
}
}
let count = graph
.references
.index_references_with_symbol_id(&path_str, &source, &symbol_id_to_id, &fqn_to_id)
.unwrap();
assert!(count > 0, "Expected at least one reference to be indexed");
}
}