use anyhow::Result;
use sqlitegraph::{
BackendDirection, EdgeSpec, GraphBackend, NeighborQuery, NodeId, NodeSpec, SnapshotId,
};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use crate::graph::schema::CallNode;
use crate::ingest::c::CParser;
use crate::ingest::cpp::CppParser;
use crate::ingest::java::JavaParser;
use crate::ingest::javascript::JavaScriptParser;
use crate::ingest::pool;
use crate::ingest::python::PythonParser;
use crate::ingest::typescript::TypeScriptParser;
use crate::ingest::{detect::Language, detect_language, Parser, SymbolFact, SymbolKind};
use crate::references::CallFact;
pub struct CallOps {
pub backend: Arc<dyn GraphBackend>,
}
impl CallOps {
pub fn delete_calls_in_file(&self, path: &str) -> Result<usize> {
let entity_ids = self.backend.entity_ids()?;
let mut to_delete: Vec<i64> = Vec::new();
for entity_id in entity_ids {
let snapshot = SnapshotId::current();
let node = match self.backend.get_node(snapshot, entity_id) {
Ok(n) => n,
Err(_) => continue,
};
if node.kind != "Call" {
continue;
}
let call_node: CallNode = match serde_json::from_value(node.data) {
Ok(value) => value,
Err(_) => continue,
};
if call_node.file == path {
to_delete.push(entity_id);
}
}
to_delete.sort_unstable();
for id in &to_delete {
self.backend.delete_entity(*id)?;
}
Ok(to_delete.len())
}
pub fn index_calls(
&self,
path: &str,
source: &[u8],
symbol_ids: &HashMap<String, i64>,
) -> Result<usize> {
let path_buf = PathBuf::from(path);
let language = detect_language(&path_buf);
let mut symbol_facts = Vec::new();
let mut current_file_facts = Vec::new();
let mut stable_symbol_ids: HashMap<(String, String), Option<String>> = HashMap::new();
for symbol_id in symbol_ids.values() {
let snapshot = SnapshotId::current();
let node = match self.backend.get_node(snapshot, *symbol_id) {
Ok(value) => value,
Err(_) => continue,
};
if node.kind != "Symbol" {
continue;
}
let symbol_node: Option<crate::graph::schema::SymbolNode> =
serde_json::from_value(node.data.clone()).ok();
let stable_id = symbol_node.as_ref().and_then(|n| n.symbol_id.clone());
let symbol_fact = match self.symbol_fact_from_node(&node) {
Some(value) => value,
None => continue,
};
if let Some(ref name) = symbol_fact.name {
let key = (
symbol_fact.file_path.to_string_lossy().to_string(),
name.clone(),
);
stable_symbol_ids.insert(key, stable_id);
}
if node.file_path.as_deref() == Some(path) {
current_file_facts.push(symbol_fact);
} else {
symbol_facts.push(symbol_fact);
}
}
symbol_facts.extend(current_file_facts);
let calls = match language {
Some(Language::Rust) => pool::with_parser_opt(Language::Rust, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = Parser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::Python) => pool::with_parser_opt(Language::Python, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = PythonParser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::C) => pool::with_parser_opt(Language::C, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = CParser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::Cpp) => pool::with_parser_opt(Language::Cpp, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = CppParser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::Java) => pool::with_parser_opt(Language::Java, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = JavaParser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?,
Some(Language::JavaScript) => {
pool::with_parser_opt(Language::JavaScript, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = JavaScriptParser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?
}
Some(Language::TypeScript) => {
pool::with_parser_opt(Language::TypeScript, |opt_parser| {
let parser = opt_parser
.take()
.expect("Parser pool corruption: parser was None"); let mut wrapper = TypeScriptParser::from_parser(parser);
let result = wrapper.extract_calls(path_buf.clone(), source, &symbol_facts);
*opt_parser = Some(wrapper.parser);
result
})?
}
None => Vec::new(),
};
let call_count = calls.len();
let mut name_to_ids: HashMap<String, Vec<i64>> = HashMap::new();
for (fqn, &id) in symbol_ids.iter() {
let simple_name = fqn.split("::").last().unwrap_or(fqn.as_str());
let simple_name = simple_name.split('.').next_back().unwrap_or(simple_name);
name_to_ids
.entry(simple_name.to_string())
.or_default()
.push(id);
}
for mut call in calls {
let caller_key = (
call.file_path.to_string_lossy().to_string(),
call.caller.clone(),
);
let callee_key = (
call.file_path.to_string_lossy().to_string(),
call.callee.clone(),
);
call.caller_symbol_id = stable_symbol_ids.get(&caller_key).and_then(|id| id.clone());
call.callee_symbol_id = stable_symbol_ids.get(&callee_key).and_then(|id| id.clone());
let callee_symbol_id = symbol_ids.get(&call.callee).or_else(|| {
name_to_ids.get(&call.callee).and_then(|ids| ids.first())
});
let caller_symbol_id = symbol_ids.get(&call.caller);
let call_id = self.insert_call_node(&call)?;
if let Some(&caller_id) = caller_symbol_id {
self.insert_caller_edge(NodeId::from(caller_id), call_id)?;
}
if let Some(&callee_id) = callee_symbol_id {
self.insert_calls_edge(call_id, NodeId::from(callee_id))?;
}
}
Ok(call_count)
}
pub fn calls_from_symbol(&mut self, symbol_id: i64) -> Result<Vec<CallFact>> {
let snapshot = SnapshotId::current();
let neighbor_ids = self.backend.neighbors(
snapshot,
symbol_id,
NeighborQuery {
direction: BackendDirection::Outgoing,
edge_type: Some("CALLER".to_string()),
},
)?;
let mut calls = Vec::new();
for call_node_id in neighbor_ids {
if let Ok(Some(call)) = self.call_fact_from_node(call_node_id) {
calls.push(call);
}
}
Ok(calls)
}
pub fn callers_of_symbol(&mut self, symbol_id: i64) -> Result<Vec<CallFact>> {
let snapshot = SnapshotId::current();
let neighbor_ids = self.backend.neighbors(
snapshot,
symbol_id,
NeighborQuery {
direction: BackendDirection::Incoming,
edge_type: Some("CALLS".to_string()),
},
)?;
let mut calls = Vec::new();
for call_node_id in neighbor_ids {
if let Ok(Some(call)) = self.call_fact_from_node(call_node_id) {
calls.push(call);
}
}
Ok(calls)
}
fn insert_call_node(&self, call: &CallFact) -> Result<NodeId> {
let call_node = CallNode {
file: call.file_path.to_string_lossy().to_string(),
caller: call.caller.clone(),
callee: call.callee.clone(),
caller_symbol_id: call.caller_symbol_id.clone(),
callee_symbol_id: call.callee_symbol_id.clone(),
byte_start: call.byte_start as u64,
byte_end: call.byte_end as u64,
start_line: call.start_line as u64,
start_col: call.start_col as u64,
end_line: call.end_line as u64,
end_col: call.end_col as u64,
};
let node_spec = NodeSpec {
kind: "Call".to_string(),
name: format!("{} calls {}", call.caller, call.callee),
file_path: Some(call.file_path.to_string_lossy().to_string()),
data: serde_json::to_value(call_node)?,
};
let id = self.backend.insert_node(node_spec)?;
Ok(NodeId::from(id))
}
fn insert_calls_edge(&self, call_id: NodeId, callee_id: NodeId) -> Result<()> {
let edge_spec = EdgeSpec {
from: call_id.as_i64(),
to: callee_id.as_i64(),
edge_type: "CALLS".to_string(),
data: serde_json::json!({}),
};
self.backend.insert_edge(edge_spec)?;
Ok(())
}
fn insert_caller_edge(&self, caller_id: NodeId, call_id: NodeId) -> Result<()> {
let edge_spec = EdgeSpec {
from: caller_id.as_i64(),
to: call_id.as_i64(),
edge_type: "CALLER".to_string(),
data: serde_json::json!({}),
};
self.backend.insert_edge(edge_spec)?;
Ok(())
}
fn call_fact_from_node(&self, node_id: i64) -> Result<Option<CallFact>> {
let snapshot = SnapshotId::current();
let node = self.backend.get_node(snapshot, node_id)?;
let call_node: Option<CallNode> = serde_json::from_value(node.data).ok();
let call_node = match call_node {
Some(n) => n,
None => return Ok(None),
};
Ok(Some(CallFact {
file_path: PathBuf::from(&call_node.file),
caller: call_node.caller,
callee: call_node.callee,
caller_symbol_id: call_node.caller_symbol_id,
callee_symbol_id: call_node.callee_symbol_id,
byte_start: call_node.byte_start as usize,
byte_end: call_node.byte_end as usize,
start_line: call_node.start_line as usize,
start_col: call_node.start_col as usize,
end_line: call_node.end_line as usize,
end_col: call_node.end_col as usize,
}))
}
fn symbol_fact_from_node(&self, node: &sqlitegraph::GraphEntity) -> Option<SymbolFact> {
let symbol_node: crate::graph::schema::SymbolNode =
serde_json::from_value(node.data.clone()).ok()?;
let file_path = node.file_path.as_deref()?;
let kind = match symbol_node.kind.as_str() {
"Function" => SymbolKind::Function,
"Method" => SymbolKind::Method,
"Class" => SymbolKind::Class,
"Interface" => SymbolKind::Interface,
"Enum" => SymbolKind::Enum,
"Module" => SymbolKind::Module,
"Union" => SymbolKind::Union,
"Namespace" => SymbolKind::Namespace,
"TypeAlias" => SymbolKind::TypeAlias,
"Unknown" => SymbolKind::Unknown,
_ => SymbolKind::Unknown,
};
let normalized_kind = symbol_node
.kind_normalized
.clone()
.unwrap_or_else(|| kind.normalized_key().to_string());
Some(SymbolFact {
file_path: PathBuf::from(file_path),
kind,
kind_normalized: normalized_kind,
name: symbol_node.name.clone(),
fqn: symbol_node.name,
canonical_fqn: None,
display_fqn: None,
byte_start: symbol_node.byte_start,
byte_end: symbol_node.byte_end,
start_line: symbol_node.start_line,
start_col: symbol_node.start_col,
end_line: symbol_node.end_line,
end_col: symbol_node.end_col,
})
}
}