mod schema;
use anyhow::{Context, Result};
use rusqlite::{params, Connection, OptionalExtension};
use std::{
fs,
io::ErrorKind,
path::{Path, PathBuf},
};
use crate::types::{
Edge, EdgeKind, FileRecord, IndexStats, Language, Node, NodeKind, UnresolvedReference,
Visibility,
};
const CONNECTION_PRAGMAS: &str = "PRAGMA foreign_keys = ON; \
PRAGMA journal_mode = WAL; \
PRAGMA synchronous = NORMAL; \
PRAGMA cache_size = -64000;";
pub struct Database {
conn: Connection,
path: Option<PathBuf>,
}
#[derive(Debug, Clone)]
pub struct EdgeEndpoint {
pub source_file: String,
pub target_file: String,
pub kind: EdgeKind,
pub detail: Option<String>,
}
impl Database {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let conn = Connection::open(&path)?;
Self::from_connection(conn, Some(path))
}
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
Self::from_connection(conn, None)
}
fn from_connection(conn: Connection, path: Option<PathBuf>) -> Result<Self> {
conn.execute_batch(CONNECTION_PRAGMAS)?;
let db = Self { conn, path };
db.initialize()?;
Ok(db)
}
pub fn path(&self) -> Option<&Path> {
self.path.as_deref()
}
pub fn close(self) -> Result<()> {
match self.conn.close() {
Ok(()) => Ok(()),
Err((_conn, err)) => Err(err.into()),
}
}
fn initialize(&self) -> Result<()> {
self.conn.execute_batch(schema::SCHEMA)?;
for stmt in schema::MIGRATIONS {
if let Err(e) = self.conn.execute(stmt, []) {
let msg = e.to_string();
if !msg.contains("duplicate column name") {
return Err(e.into());
}
}
}
Ok(())
}
pub fn insert_or_update_file(&self, file: &FileRecord) -> Result<()> {
self.conn.execute(
r#"
INSERT INTO files (path, content_hash, language, size, modified_at, indexed_at, node_count)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
ON CONFLICT(path) DO UPDATE SET
content_hash = excluded.content_hash,
language = excluded.language,
size = excluded.size,
modified_at = excluded.modified_at,
indexed_at = excluded.indexed_at,
node_count = excluded.node_count
"#,
params![
file.path,
file.content_hash,
file.language.as_str(),
file.size as i64,
file.modified_at,
file.indexed_at,
file.node_count as i64,
],
)?;
Ok(())
}
pub fn get_file(&self, path: &str) -> Result<Option<FileRecord>> {
let result = self
.conn
.query_row(
"SELECT path, content_hash, language, size, modified_at, indexed_at, node_count FROM files WHERE path = ?1",
params![path],
|row| {
Ok(FileRecord {
path: row.get(0)?,
content_hash: row.get(1)?,
language: Language::from_extension(row.get::<_, String>(2)?.as_str()),
size: row.get::<_, i64>(3)? as u64,
modified_at: row.get(4)?,
indexed_at: row.get(5)?,
node_count: row.get::<_, i64>(6)? as u32,
})
},
)
.optional()?;
Ok(result)
}
pub fn needs_reindex(&self, path: &str, content_hash: &str) -> Result<bool> {
match self.get_file(path)? {
Some(file) => Ok(file.content_hash != content_hash),
None => Ok(true),
}
}
pub fn delete_file(&self, path: &str) -> Result<()> {
self.conn.execute(
"DELETE FROM edges WHERE source_id IN (SELECT id FROM nodes WHERE file_path = ?1)",
params![path],
)?;
self.conn.execute(
"DELETE FROM edges WHERE target_id IN (SELECT id FROM nodes WHERE file_path = ?1)",
params![path],
)?;
self.conn.execute(
"INSERT INTO nodes_fts(nodes_fts, rowid, name, qualified_name) SELECT 'delete', id, name, qualified_name FROM nodes WHERE file_path = ?1",
params![path],
)?;
self.conn.execute(
"DELETE FROM nodes_semantic_fts WHERE rowid IN (SELECT id FROM nodes WHERE file_path = ?1)",
params![path],
)?;
self.conn
.execute("DELETE FROM nodes WHERE file_path = ?1", params![path])?;
self.conn.execute(
"DELETE FROM unresolved_refs WHERE file_path = ?1",
params![path],
)?;
self.conn
.execute("DELETE FROM files WHERE path = ?1", params![path])?;
Ok(())
}
pub fn insert_node(&self, node: &Node) -> Result<i64> {
self.conn.execute(
r#"
INSERT INTO nodes (
kind, name, qualified_name, file_path, start_line, end_line,
start_column, end_column, signature, visibility, docstring,
is_async, is_static, is_exported, is_test, is_generated, language
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17)
"#,
params![
node.kind.as_str(),
node.name,
node.qualified_name,
node.file_path,
node.start_line as i64,
node.end_line as i64,
node.start_column as i64,
node.end_column as i64,
node.signature,
node.visibility.as_str(),
node.docstring,
node.is_async,
node.is_static,
node.is_exported,
node.is_test,
node.is_generated,
node.language.as_str(),
],
)?;
let rowid = self.conn.last_insert_rowid();
self.conn.execute(
"INSERT INTO nodes_fts(rowid, name, qualified_name) VALUES (?1, ?2, ?3)",
params![rowid, node.name, node.qualified_name],
)?;
let tokens = build_semantic_tokens(node);
self.conn.execute(
"INSERT INTO nodes_semantic_fts(rowid, tokens) VALUES (?1, ?2)",
params![rowid, tokens],
)?;
Ok(rowid)
}
pub fn get_node(&self, id: i64) -> Result<Option<Node>> {
let result = self
.conn
.query_row("SELECT * FROM nodes WHERE id = ?1", params![id], |row| {
Self::row_to_node(row)
})
.optional()?;
Ok(result)
}
pub fn search_nodes(
&self,
query: &str,
kind: Option<NodeKind>,
limit: u32,
) -> Result<Vec<Node>> {
let use_fts = query.len() >= 2;
if use_fts {
let fts_query = format!("\"{}\"*", query.to_lowercase());
let sql = if kind.is_some() {
r#"
SELECT n.* FROM nodes n
INNER JOIN nodes_fts fts ON n.id = fts.rowid
WHERE nodes_fts MATCH ?1 AND n.kind = ?2
ORDER BY LENGTH(n.name), n.name
LIMIT ?3
"#
} else {
r#"
SELECT n.* FROM nodes n
INNER JOIN nodes_fts fts ON n.id = fts.rowid
WHERE nodes_fts MATCH ?1
ORDER BY LENGTH(n.name), n.name
LIMIT ?2
"#
};
let result = (|| -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(sql)?;
let mut nodes = Vec::new();
if let Some(k) = kind {
let rows = stmt.query_map(
params![fts_query, k.as_str(), limit as i64],
Self::row_to_node,
)?;
for row in rows {
nodes.push(row?);
}
} else {
let rows =
stmt.query_map(params![fts_query, limit as i64], Self::row_to_node)?;
for row in rows {
nodes.push(row?);
}
}
Ok(nodes)
})();
if let Ok(nodes) = result {
return Ok(nodes);
}
}
let pattern = format!("{}%", query.to_lowercase());
let sql = if kind.is_some() {
r#"
SELECT * FROM nodes
WHERE LOWER(name) LIKE ?1 AND kind = ?2
ORDER BY LENGTH(name), name
LIMIT ?3
"#
} else {
r#"
SELECT * FROM nodes
WHERE LOWER(name) LIKE ?1
ORDER BY LENGTH(name), name
LIMIT ?2
"#
};
let mut stmt = self.conn.prepare(sql)?;
let mut nodes = Vec::new();
if let Some(k) = kind {
let rows = stmt.query_map(params![pattern, k.as_str(), limit as i64], |row| {
Self::row_to_node(row)
})?;
for row in rows {
nodes.push(row?);
}
} else {
let rows = stmt.query_map(params![pattern, limit as i64], Self::row_to_node)?;
for row in rows {
nodes.push(row?);
}
}
Ok(nodes)
}
pub fn get_nodes_by_file(&self, file_path: &str) -> Result<Vec<Node>> {
let mut stmt = self
.conn
.prepare("SELECT * FROM nodes WHERE file_path = ?1 ORDER BY start_line")?;
let rows = stmt.query_map(params![file_path], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn get_edge_endpoints(&self) -> Result<Vec<EdgeEndpoint>> {
let mut stmt = self.conn.prepare(
"SELECT s.file_path, t.file_path, e.kind, e.detail \
FROM edges e \
JOIN nodes s ON e.source_id = s.id \
JOIN nodes t ON e.target_id = t.id",
)?;
let rows = stmt.query_map([], |row| {
Ok(EdgeEndpoint {
source_file: row.get(0)?,
target_file: row.get(1)?,
kind: EdgeKind::parse(&row.get::<_, String>(2)?).unwrap_or(EdgeKind::References),
detail: row.get(3)?,
})
})?;
let mut out = Vec::new();
for row in rows {
out.push(row?);
}
Ok(out)
}
pub fn get_nodes_by_kind(&self, kind: NodeKind) -> Result<Vec<Node>> {
let mut stmt = self
.conn
.prepare("SELECT * FROM nodes WHERE kind = ?1 ORDER BY name")?;
let rows = stmt.query_map(params![kind.as_str()], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn get_struct_fields(&self, struct_name: &str) -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(
"SELECT t.* FROM nodes t \
JOIN edges e ON e.target_id = t.id AND e.kind = 'contains' \
JOIN nodes s ON e.source_id = s.id \
WHERE s.name = ?1 AND s.kind IN ('struct','class','interface','trait','protocol') \
AND t.kind IN ('field','property')",
)?;
let rows = stmt.query_map(params![struct_name], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn get_dispatch_sites(&self, enum_name: &str) -> Result<Vec<(String, String)>> {
let mut stmt = self.conn.prepare(
"SELECT DISTINCT e.file_path, m.name \
FROM edges e \
JOIN nodes m ON e.target_id = m.id AND m.kind = 'enum_member' \
JOIN edges c ON c.target_id = m.id AND c.kind = 'contains' \
JOIN nodes en ON c.source_id = en.id AND en.kind = 'enum' \
WHERE e.kind = 'references' AND en.name = ?1 AND e.file_path IS NOT NULL",
)?;
let rows = stmt.query_map(params![enum_name], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
})?;
let mut out = Vec::new();
for row in rows {
out.push(row?);
}
Ok(out)
}
pub fn find_node_by_name(&self, name: &str) -> Result<Option<Node>> {
let result = self
.conn
.query_row(
"SELECT * FROM nodes WHERE name = ?1 LIMIT 1",
params![name],
Self::row_to_node,
)
.optional()?;
Ok(result)
}
fn row_to_node(row: &rusqlite::Row) -> rusqlite::Result<Node> {
Ok(Node {
id: row.get("id")?,
kind: NodeKind::parse(&row.get::<_, String>("kind")?).unwrap_or(NodeKind::Function),
name: row.get("name")?,
qualified_name: row.get("qualified_name")?,
file_path: row.get("file_path")?,
start_line: row.get::<_, i64>("start_line")? as u32,
end_line: row.get::<_, i64>("end_line")? as u32,
start_column: row.get::<_, i64>("start_column")? as u32,
end_column: row.get::<_, i64>("end_column")? as u32,
signature: row.get("signature")?,
visibility: Visibility::parse(&row.get::<_, String>("visibility").unwrap_or_default()),
docstring: row.get("docstring")?,
is_async: row.get("is_async")?,
is_static: row.get("is_static")?,
is_exported: row.get("is_exported")?,
is_test: row.get("is_test").unwrap_or(false),
is_generated: row.get("is_generated").unwrap_or(false),
language: Language::parse(&row.get::<_, String>("language").unwrap_or_default()),
})
}
pub fn insert_edge(&self, edge: &Edge) -> Result<i64> {
self.conn.execute(
r#"
INSERT INTO edges (source_id, target_id, kind, file_path, line, column, detail)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"#,
params![
edge.source_id,
edge.target_id,
edge.kind.as_str(),
edge.file_path,
edge.line.map(|l| l as i64),
edge.column.map(|c| c as i64),
edge.detail,
],
)?;
Ok(self.conn.last_insert_rowid())
}
pub fn get_callers(&self, node_id: i64, limit: u32) -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(
r#"
SELECT n.* FROM nodes n
INNER JOIN edges e ON e.source_id = n.id
WHERE e.target_id = ?1 AND e.kind = 'calls'
LIMIT ?2
"#,
)?;
let rows = stmt.query_map(params![node_id, limit as i64], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn get_callees(&self, node_id: i64, limit: u32) -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(
r#"
SELECT n.* FROM nodes n
INNER JOIN edges e ON e.target_id = n.id
WHERE e.source_id = ?1 AND e.kind = 'calls'
LIMIT ?2
"#,
)?;
let rows = stmt.query_map(params![node_id, limit as i64], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn get_outgoing_edges(&self, node_id: i64) -> Result<Vec<Edge>> {
let mut stmt = self
.conn
.prepare("SELECT * FROM edges WHERE source_id = ?1")?;
let rows = stmt.query_map(params![node_id], Self::row_to_edge)?;
let mut edges = Vec::new();
for row in rows {
edges.push(row?);
}
Ok(edges)
}
pub fn get_incoming_edges(&self, node_id: i64) -> Result<Vec<Edge>> {
let mut stmt = self
.conn
.prepare("SELECT * FROM edges WHERE target_id = ?1")?;
let rows = stmt.query_map(params![node_id], Self::row_to_edge)?;
let mut edges = Vec::new();
for row in rows {
edges.push(row?);
}
Ok(edges)
}
fn row_to_edge(row: &rusqlite::Row) -> rusqlite::Result<Edge> {
Ok(Edge {
id: row.get(0)?,
source_id: row.get(1)?,
target_id: row.get(2)?,
kind: EdgeKind::parse(&row.get::<_, String>(3)?).unwrap_or(EdgeKind::References),
file_path: row.get(4)?,
line: row.get::<_, Option<i64>>(5)?.map(|l| l as u32),
column: row.get::<_, Option<i64>>(6)?.map(|c| c as u32),
detail: row.get(7)?,
})
}
pub fn insert_unresolved_ref(&self, uref: &UnresolvedReference) -> Result<()> {
self.conn.execute(
r#"
INSERT INTO unresolved_refs (source_node_id, reference_name, kind, file_path, line, column, detail)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"#,
params![
uref.source_node_id,
uref.reference_name,
uref.kind.as_str(),
uref.file_path,
uref.line as i64,
uref.column as i64,
uref.detail,
],
)?;
Ok(())
}
pub fn get_unresolved_refs(&self) -> Result<Vec<UnresolvedReference>> {
let mut stmt = self.conn.prepare("SELECT * FROM unresolved_refs")?;
let rows = stmt.query_map([], |row| {
Ok(UnresolvedReference {
source_node_id: row.get(1)?,
reference_name: row.get(2)?,
kind: EdgeKind::parse(&row.get::<_, String>(3)?).unwrap_or(EdgeKind::Calls),
file_path: row.get(4)?,
line: row.get::<_, i64>(5)? as u32,
column: row.get::<_, i64>(6)? as u32,
detail: row.get(7)?,
})
})?;
let mut refs = Vec::new();
for row in rows {
refs.push(row?);
}
Ok(refs)
}
pub fn resolve_references(&self) -> Result<u32> {
let refs = self.get_unresolved_refs()?;
let mut resolved = 0;
for uref in refs {
if let Some(target) =
self.find_target_preferring_file(&uref.reference_name, &uref.file_path)?
{
let edge = Edge::new(uref.source_node_id, target.id, uref.kind)
.at(uref.file_path.clone(), uref.line, uref.column)
.detail(uref.detail.clone());
self.insert_edge(&edge)?;
resolved += 1;
if uref.kind == EdgeKind::Calls {
let source_is_test = self
.conn
.query_row(
"SELECT is_test FROM nodes WHERE id = ?1",
params![uref.source_node_id],
|r| r.get::<_, bool>(0),
)
.unwrap_or(false);
if source_is_test {
let test_edge = Edge::new(uref.source_node_id, target.id, EdgeKind::Tests)
.at(uref.file_path.clone(), uref.line, uref.column);
self.insert_edge(&test_edge)?;
}
}
}
}
self.conn.execute("DELETE FROM unresolved_refs", [])?;
Ok(resolved)
}
pub fn resolve_references_for_files(&self, files: &[String]) -> Result<u32> {
if files.is_empty() {
return Ok(0);
}
let mut resolved = 0;
for file_path in files {
let mut stmt = self.conn.prepare(
"SELECT source_node_id, reference_name, kind, file_path, line, column, detail \
FROM unresolved_refs WHERE file_path = ?1",
)?;
let refs: Vec<UnresolvedReference> = stmt
.query_map(params![file_path], |row| {
Ok(UnresolvedReference {
source_node_id: row.get(0)?,
reference_name: row.get::<_, String>(1)?,
kind: EdgeKind::parse(&row.get::<_, String>(2)?).unwrap_or(EdgeKind::Calls),
file_path: row.get(3)?,
line: row.get::<_, i64>(4)? as u32,
column: row.get::<_, i64>(5)? as u32,
detail: row.get(6)?,
})
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
for uref in &refs {
if let Some(target) =
self.find_target_preferring_file(&uref.reference_name, &uref.file_path)?
{
let edge = Edge::new(uref.source_node_id, target.id, uref.kind)
.at(uref.file_path.clone(), uref.line, uref.column)
.detail(uref.detail.clone());
self.insert_edge(&edge)?;
resolved += 1;
if uref.kind == EdgeKind::Calls {
let source_is_test = self
.conn
.query_row(
"SELECT is_test FROM nodes WHERE id = ?1",
params![uref.source_node_id],
|r| r.get::<_, bool>(0),
)
.unwrap_or(false);
if source_is_test {
let test_edge =
Edge::new(uref.source_node_id, target.id, EdgeKind::Tests).at(
uref.file_path.clone(),
uref.line,
uref.column,
);
self.insert_edge(&test_edge)?;
}
}
}
}
self.conn.execute(
"DELETE FROM unresolved_refs WHERE file_path = ?1",
params![file_path],
)?;
}
Ok(resolved)
}
pub fn insert_nodes_batch(
&self,
nodes: &mut [Node],
) -> Result<std::collections::HashMap<i64, i64>> {
let id_map = self.insert_nodes_base_batch(nodes)?;
self.insert_node_fts_rows(nodes)?;
self.insert_semantic_fts_rows(nodes)?;
Ok(id_map)
}
pub fn insert_nodes_batch_without_fts(
&self,
nodes: &mut [Node],
) -> Result<std::collections::HashMap<i64, i64>> {
self.insert_nodes_base_batch(nodes)
}
pub fn rebuild_fts_indexes(&self) -> Result<()> {
self.conn
.execute_batch("INSERT INTO nodes_fts(nodes_fts) VALUES('rebuild');")
.context("rebuild_fts_indexes: nodes_fts")?;
self.conn
.execute("DELETE FROM nodes_semantic_fts", [])
.context("rebuild_fts_indexes: nodes_semantic_fts clear")?;
let mut select = self
.conn
.prepare_cached("SELECT * FROM nodes ORDER BY id")?;
let rows = select.query_map([], Self::row_to_node)?;
let mut insert = self
.conn
.prepare_cached("INSERT INTO nodes_semantic_fts(rowid, tokens) VALUES (?1, ?2)")?;
for node in rows {
let node = node?;
insert.execute(params![node.id, build_semantic_tokens(&node)])?;
}
self.optimize_fts()
}
fn insert_nodes_base_batch(
&self,
nodes: &mut [Node],
) -> Result<std::collections::HashMap<i64, i64>> {
let mut id_map = std::collections::HashMap::with_capacity(nodes.len());
let mut stmt = self.conn.prepare_cached(
r#"
INSERT INTO nodes (
kind, name, qualified_name, file_path, start_line, end_line,
start_column, end_column, signature, visibility, docstring,
is_async, is_static, is_exported, is_test, is_generated, language
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17)
"#,
)?;
for node in nodes.iter_mut() {
let old_id = node.id;
stmt.execute(params![
node.kind.as_str(),
node.name,
node.qualified_name,
node.file_path,
node.start_line as i64,
node.end_line as i64,
node.start_column as i64,
node.end_column as i64,
node.signature,
node.visibility.as_str(),
node.docstring,
node.is_async,
node.is_static,
node.is_exported,
node.is_test,
node.is_generated,
node.language.as_str(),
])?;
let new_id = self.conn.last_insert_rowid();
node.id = new_id;
id_map.insert(old_id, new_id);
}
Ok(id_map)
}
fn insert_node_fts_rows(&self, nodes: &[Node]) -> Result<()> {
let mut stmt = self.conn.prepare_cached(
"INSERT INTO nodes_fts(rowid, name, qualified_name) VALUES (?1, ?2, ?3)",
)?;
for node in nodes {
stmt.execute(params![node.id, node.name, node.qualified_name])?;
}
Ok(())
}
fn insert_semantic_fts_rows(&self, nodes: &[Node]) -> Result<()> {
let mut stmt = self
.conn
.prepare_cached("INSERT INTO nodes_semantic_fts(rowid, tokens) VALUES (?1, ?2)")?;
for node in nodes {
stmt.execute(params![node.id, build_semantic_tokens(node)])?;
}
Ok(())
}
pub fn insert_edges_batch(
&self,
edges: &[Edge],
id_map: &std::collections::HashMap<i64, i64>,
) -> Result<u64> {
let mut count: u64 = 0;
let mut stmt = self.conn.prepare_cached(
r#"
INSERT INTO edges (source_id, target_id, kind, file_path, line, column, detail)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"#,
)?;
for edge in edges {
if let (Some(&new_source), Some(&new_target)) =
(id_map.get(&edge.source_id), id_map.get(&edge.target_id))
{
stmt.execute(params![
new_source,
new_target,
edge.kind.as_str(),
edge.file_path,
edge.line.map(|l| l as i64),
edge.column.map(|c| c as i64),
edge.detail,
])?;
count += 1;
}
}
Ok(count)
}
pub fn insert_unresolved_refs_batch(
&self,
refs: &[UnresolvedReference],
id_map: &std::collections::HashMap<i64, i64>,
) -> Result<()> {
let mut stmt = self.conn.prepare_cached(
r#"
INSERT INTO unresolved_refs (source_node_id, reference_name, kind, file_path, line, column, detail)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"#,
)?;
for uref in refs {
if let Some(&new_source) = id_map.get(&uref.source_node_id) {
stmt.execute(params![
new_source,
uref.reference_name,
uref.kind.as_str(),
uref.file_path,
uref.line as i64,
uref.column as i64,
uref.detail,
])?;
}
}
Ok(())
}
pub fn get_stats(&self) -> Result<IndexStats> {
let total_files: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM files", [], |row| row.get(0))?;
let total_nodes: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM nodes", [], |row| row.get(0))?;
let total_edges: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM edges", [], |row| row.get(0))?;
let db_size_bytes: i64 = self
.conn
.query_row(
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()",
[],
|row| row.get(0),
)
.unwrap_or(0);
let mut stmt = self
.conn
.prepare("SELECT language, COUNT(*) FROM nodes GROUP BY language")?;
let lang_rows = stmt.query_map([], |row| {
let lang_str: String = row.get(0)?;
let count: i64 = row.get(1)?;
Ok((Language::parse(&lang_str), count as u64))
})?;
let mut languages = Vec::new();
for row in lang_rows {
languages.push(row?);
}
let mut stmt = self
.conn
.prepare("SELECT kind, COUNT(*) FROM nodes GROUP BY kind")?;
let kind_rows = stmt.query_map([], |row| {
let kind_str: String = row.get(0)?;
let count: i64 = row.get(1)?;
Ok((
NodeKind::parse(&kind_str).unwrap_or(NodeKind::Function),
count as u64,
))
})?;
let mut node_kinds = Vec::new();
for row in kind_rows {
node_kinds.push(row?);
}
Ok(IndexStats {
total_files: total_files as u64,
total_nodes: total_nodes as u64,
total_edges: total_edges as u64,
db_size_bytes: db_size_bytes as u64,
languages,
node_kinds,
})
}
pub fn disable_fts_automerge(&self) -> Result<()> {
use anyhow::Context;
self.conn
.execute(
"INSERT INTO nodes_fts(nodes_fts, rank) VALUES('automerge', 0)",
[],
)
.context("disable_fts_automerge: nodes_fts")?;
self.conn
.execute(
"INSERT INTO nodes_semantic_fts(nodes_semantic_fts, rank) VALUES('automerge', 0)",
[],
)
.context("disable_fts_automerge: nodes_semantic_fts")?;
Ok(())
}
pub fn optimize_fts(&self) -> Result<()> {
use anyhow::Context;
self.conn
.execute_batch("INSERT INTO nodes_fts(nodes_fts) VALUES('optimize');")
.context("optimize_fts: nodes_fts")?;
self.conn
.execute_batch("INSERT INTO nodes_semantic_fts(nodes_semantic_fts) VALUES('optimize');")
.context("optimize_fts: nodes_semantic_fts")?;
Ok(())
}
pub fn begin_transaction(&mut self) -> Result<()> {
use anyhow::Context;
self.conn
.execute("BEGIN TRANSACTION", [])
.context("begin_transaction")?;
Ok(())
}
pub fn commit_transaction(&mut self) -> Result<()> {
self.conn.execute("COMMIT", []).context("commit: COMMIT")?;
Ok(())
}
pub fn checkpoint_wal_truncate(&self) -> Result<()> {
self.conn
.execute_batch("PRAGMA wal_checkpoint(TRUNCATE)")
.context("wal_checkpoint")?;
Ok(())
}
pub fn commit(&mut self) -> Result<()> {
self.commit_transaction()?;
self.checkpoint_wal_truncate()
}
pub fn rollback(&mut self) -> Result<()> {
self.conn.execute("ROLLBACK", []).context("rollback")?;
Ok(())
}
pub fn prepare_for_swap(self) -> Result<PathBuf> {
let path = self
.path
.clone()
.context("prepare_for_swap requires an on-disk database")?;
self.checkpoint_wal_truncate()?;
self.close()?;
cleanup_sqlite_sidecars(&path)?;
Ok(path)
}
pub fn replace_with_shadow<P: AsRef<Path>>(&mut self, shadow_path: P) -> Result<()> {
let live_path = self
.path
.clone()
.context("replace_with_shadow requires an on-disk database")?;
let shadow_path = shadow_path.as_ref().to_path_buf();
self.checkpoint_wal_truncate()?;
let placeholder = Connection::open_in_memory().context("opening placeholder connection")?;
let live_conn = std::mem::replace(&mut self.conn, placeholder);
if let Err((conn, err)) = live_conn.close() {
self.conn = conn;
return Err(err.into());
}
cleanup_sqlite_sidecars(&live_path)?;
if let Err(err) = fs::rename(&shadow_path, &live_path) {
self.reopen_from_path(&live_path)?;
return Err(err.into());
}
cleanup_sqlite_sidecars(&shadow_path)?;
self.reopen_from_path(&live_path)
}
pub fn cleanup_on_disk_path<P: AsRef<Path>>(path: P) -> Result<()> {
let path = path.as_ref();
remove_file_if_exists(path)?;
cleanup_sqlite_sidecars(path)
}
fn reopen_from_path(&mut self, path: &Path) -> Result<()> {
let reopened = Self::open(path)?;
self.conn = reopened.conn;
self.path = reopened.path;
Ok(())
}
pub fn get_hierarchy(&self, symbol: &str) -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(
"SELECT n.* FROM nodes n
INNER JOIN edges e ON e.source_id = n.id
INNER JOIN nodes target ON e.target_id = target.id
WHERE e.kind = 'contains' AND target.name = ?
UNION
SELECT n.* FROM nodes n
INNER JOIN edges e ON e.target_id = n.id
INNER JOIN nodes source ON e.source_id = source.id
WHERE e.kind = 'contains' AND source.name = ?",
)?;
let rows = stmt.query_map(params![symbol, symbol], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn find_call_path(&self, from: &str, to: &str) -> Result<Vec<Vec<Node>>> {
let source = self.find_node_by_name(from)?;
let target = self.find_node_by_name(to)?;
match (source, target) {
(Some(src), Some(tgt)) => {
let mut paths = Vec::new();
let mut visited = std::collections::HashSet::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back((src.id, vec![src.clone()]));
while let Some((current_id, path)) = queue.pop_front() {
if current_id == tgt.id {
paths.push(path);
if paths.len() >= 5 {
break;
}
continue;
}
if path.len() > 10 || visited.contains(¤t_id) {
continue;
}
visited.insert(current_id);
let callees = self.get_callees(current_id, 100)?;
for callee in callees {
let mut new_path = path.clone();
new_path.push(callee.clone());
queue.push_back((callee.id, new_path));
}
}
Ok(paths)
}
_ => Ok(Vec::new()),
}
}
pub fn find_unused_symbols(&self) -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(
"SELECT n.* FROM nodes n
WHERE n.kind IN ('function', 'method', 'class', 'struct', 'interface')
AND n.is_test = 0
AND n.is_generated = 0
AND n.id NOT IN (SELECT DISTINCT target_id FROM edges WHERE kind IN ('calls', 'references', 'instantiates', 'tests'))
ORDER BY n.file_path, n.start_line",
)?;
let rows = stmt.query_map([], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn find_implementations(&self, symbol: &str) -> Result<Vec<Node>> {
let mut stmt = self.conn.prepare(
"SELECT n.* FROM nodes n
INNER JOIN edges e ON e.source_id = n.id
INNER JOIN nodes target ON e.target_id = target.id
WHERE e.kind IN ('implements', 'extends') AND target.name = ?",
)?;
let rows = stmt.query_map([symbol], Self::row_to_node)?;
let mut nodes = Vec::new();
for row in rows {
nodes.push(row?);
}
Ok(nodes)
}
pub fn get_diff_impact(
&self,
file_path: &str,
start_line: u32,
end_line: u32,
) -> Result<Vec<Node>> {
let mut affected = Vec::new();
let mut stmt = self.conn.prepare(
"SELECT * FROM nodes
WHERE file_path = ?
AND ((start_line <= ? AND end_line >= ?)
OR (start_line >= ? AND start_line <= ?))",
)?;
let rows = stmt.query_map(
params![file_path, end_line, start_line, start_line, end_line],
Self::row_to_node,
)?;
for row in rows {
affected.push(row?);
}
let mut impacted = affected.clone();
for node in &affected {
let callers = self.get_callers(node.id, 100)?;
for caller in callers {
if !impacted.iter().any(|n| n.id == caller.id) {
impacted.push(caller);
}
}
}
Ok(impacted)
}
pub fn semantic_search(&self, query: &str, limit: u32) -> Result<Vec<Node>> {
let normalized = normalize_query_for_fts(query);
if normalized.is_empty() {
return Ok(Vec::new());
}
let mut stmt = self.conn.prepare(
"SELECT n.* FROM nodes n
INNER JOIN nodes_semantic_fts s ON s.rowid = n.id
WHERE nodes_semantic_fts MATCH ?1
ORDER BY bm25(nodes_semantic_fts)
LIMIT ?2",
)?;
let rows = stmt.query_map(params![normalized, limit as i64], Self::row_to_node);
let mut nodes = Vec::new();
match rows {
Ok(iter) => {
for row in iter {
nodes.push(row?);
}
}
Err(_) => return Ok(Vec::new()),
}
Ok(nodes)
}
pub fn find_target_preferring_file(
&self,
name: &str,
source_file: &str,
) -> Result<Option<Node>> {
let local = self
.conn
.query_row(
"SELECT * FROM nodes WHERE name = ?1 AND file_path = ?2 LIMIT 1",
params![name, source_file],
Self::row_to_node,
)
.optional()?;
if local.is_some() {
return Ok(local);
}
self.find_node_by_name(name)
}
}
fn build_semantic_tokens(node: &Node) -> String {
let mut out = String::new();
push_split_tokens(&mut out, &node.name);
if let Some(qn) = &node.qualified_name {
out.push(' ');
push_split_tokens(&mut out, qn);
}
if let Some(doc) = &node.docstring {
out.push(' ');
out.push_str(doc);
}
out.to_lowercase()
}
fn push_split_tokens(out: &mut String, s: &str) {
out.push(' ');
out.push_str(s);
out.push(' ');
let mut current = String::new();
let mut prev_lower = false;
for ch in s.chars() {
if ch.is_ascii_uppercase() && prev_lower && !current.is_empty() {
out.push_str(¤t);
out.push(' ');
current.clear();
}
if ch.is_alphanumeric() {
current.push(ch);
prev_lower = ch.is_ascii_lowercase() || ch.is_ascii_digit();
} else {
if !current.is_empty() {
out.push_str(¤t);
out.push(' ');
current.clear();
}
prev_lower = false;
}
}
if !current.is_empty() {
out.push_str(¤t);
out.push(' ');
}
}
fn normalize_query_for_fts(query: &str) -> String {
let cleaned: String = query
.chars()
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
.collect();
cleaned
.split_whitespace()
.map(|t| t.to_lowercase())
.collect::<Vec<_>>()
.join(" ")
}
fn sqlite_sidecar_paths(path: &Path) -> [PathBuf; 2] {
[
PathBuf::from(format!("{}-wal", path.display())),
PathBuf::from(format!("{}-shm", path.display())),
]
}
fn cleanup_sqlite_sidecars(path: &Path) -> Result<()> {
for sidecar in sqlite_sidecar_paths(path) {
remove_file_if_exists(&sidecar)?;
}
Ok(())
}
fn remove_file_if_exists(path: &Path) -> Result<()> {
match fs::remove_file(path) {
Ok(()) => Ok(()),
Err(err) if err.kind() == ErrorKind::NotFound => Ok(()),
Err(err) => Err(err.into()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn create_test_node(name: &str, kind: NodeKind, file_path: &str) -> Node {
Node {
id: 0,
kind,
name: name.to_string(),
qualified_name: Some(format!("test::{}", name)),
file_path: file_path.to_string(),
start_line: 1,
end_line: 10,
start_column: 0,
end_column: 1,
signature: Some(format!("fn {}()", name)),
visibility: Visibility::Public,
docstring: None,
is_async: false,
is_static: false,
is_exported: true,
is_test: false,
is_generated: false,
language: Language::Rust,
}
}
fn mk_file(path: &str) -> FileRecord {
FileRecord {
path: path.to_string(),
content_hash: "abc123".to_string(),
language: Language::Rust,
size: 1000,
modified_at: 1234567890,
indexed_at: 1234567890,
node_count: 5,
}
}
#[test]
fn test_in_memory_database_creation() {
let db = Database::in_memory();
assert!(db.is_ok());
}
#[test]
fn test_open_tracks_on_disk_path() {
let dir = tempdir().unwrap();
let path = dir.path().join("tracked.db");
let db = Database::open(&path).unwrap();
assert_eq!(db.path(), Some(path.as_path()));
}
#[test]
fn test_database_stats_empty() {
let db = Database::in_memory().unwrap();
let stats = db.get_stats().unwrap();
assert_eq!(stats.total_files, 0);
assert_eq!(stats.total_nodes, 0);
assert_eq!(stats.total_edges, 0);
}
#[test]
fn test_upsert_and_get_file() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let retrieved = db.get_file("test.rs").unwrap();
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.path, "test.rs");
assert_eq!(retrieved.content_hash, "abc123");
assert_eq!(retrieved.node_count, 5);
}
#[test]
fn test_file_upsert_updates_existing() {
let db = Database::in_memory().unwrap();
let mut file = mk_file("src/lib.rs");
db.insert_or_update_file(&file).unwrap();
file.content_hash = "updated_hash".to_string();
file.node_count = 10;
db.insert_or_update_file(&file).unwrap();
let retrieved = db.get_file("src/lib.rs").unwrap().unwrap();
assert_eq!(retrieved.content_hash, "updated_hash");
assert_eq!(retrieved.node_count, 10);
}
#[test]
fn test_get_nonexistent_file() {
let db = Database::in_memory().unwrap();
let result = db.get_file("nonexistent.rs").unwrap();
assert!(result.is_none());
}
#[test]
fn test_needs_reindex_new_file() {
let db = Database::in_memory().unwrap();
let needs = db.needs_reindex("new_file.rs", "somehash").unwrap();
assert!(needs);
}
#[test]
fn test_needs_reindex_unchanged_file() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let needs = db.needs_reindex("test.rs", "abc123").unwrap();
assert!(!needs);
}
#[test]
fn test_needs_reindex_changed_file() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let needs = db.needs_reindex("test.rs", "different_hash").unwrap();
assert!(needs);
}
#[test]
fn test_insert_and_get_node() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let node = create_test_node("my_function", NodeKind::Function, "test.rs");
let id = db.insert_node(&node).unwrap();
let retrieved = db.get_node(id).unwrap();
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.name, "my_function");
assert_eq!(retrieved.kind, NodeKind::Function);
}
#[test]
fn test_get_nonexistent_node() {
let db = Database::in_memory().unwrap();
let result = db.get_node(999).unwrap();
assert!(result.is_none());
}
#[test]
fn test_search_nodes() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.insert_node(&create_test_node(
"process_data",
NodeKind::Function,
"test.rs",
))
.unwrap();
db.insert_node(&create_test_node(
"process_input",
NodeKind::Function,
"test.rs",
))
.unwrap();
db.insert_node(&create_test_node(
"handle_error",
NodeKind::Function,
"test.rs",
))
.unwrap();
let results = db.search_nodes("process", None, 10).unwrap();
assert_eq!(results.len(), 2);
let results = db.search_nodes("handle", None, 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "handle_error");
}
#[test]
fn test_search_nodes_with_kind_filter() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.insert_node(&create_test_node("MyClass", NodeKind::Class, "test.rs"))
.unwrap();
db.insert_node(&create_test_node(
"my_function",
NodeKind::Function,
"test.rs",
))
.unwrap();
let results = db.search_nodes("my", Some(NodeKind::Function), 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].kind, NodeKind::Function);
}
#[test]
fn test_search_nodes_case_insensitive() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.insert_node(&create_test_node(
"MyFunction",
NodeKind::Function,
"test.rs",
))
.unwrap();
let results = db.search_nodes("myfunction", None, 10).unwrap();
assert_eq!(results.len(), 1);
let results = db.search_nodes("MYFUNCTION", None, 10).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_find_node_by_name() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.insert_node(&create_test_node(
"unique_name",
NodeKind::Function,
"test.rs",
))
.unwrap();
let result = db.find_node_by_name("unique_name").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().name, "unique_name");
let result = db.find_node_by_name("nonexistent").unwrap();
assert!(result.is_none());
}
#[test]
fn test_insert_edge() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let id1 = db
.insert_node(&create_test_node("caller", NodeKind::Function, "test.rs"))
.unwrap();
let id2 = db
.insert_node(&create_test_node("callee", NodeKind::Function, "test.rs"))
.unwrap();
let edge = Edge {
id: 0,
source_id: id1,
target_id: id2,
kind: EdgeKind::Calls,
file_path: Some("test.rs".to_string()),
line: Some(5),
column: Some(10),
detail: None,
};
let edge_id = db.insert_edge(&edge).unwrap();
assert!(edge_id > 0);
}
#[test]
fn test_get_callers_and_callees() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let caller_id = db
.insert_node(&create_test_node("caller", NodeKind::Function, "test.rs"))
.unwrap();
let callee_id = db
.insert_node(&create_test_node("callee", NodeKind::Function, "test.rs"))
.unwrap();
let edge = Edge {
id: 0,
source_id: caller_id,
target_id: callee_id,
kind: EdgeKind::Calls,
file_path: None,
line: None,
column: None,
detail: None,
};
db.insert_edge(&edge).unwrap();
let callers = db.get_callers(callee_id, 10).unwrap();
assert_eq!(callers.len(), 1);
assert_eq!(callers[0].name, "caller");
let callees = db.get_callees(caller_id, 10).unwrap();
assert_eq!(callees.len(), 1);
assert_eq!(callees[0].name, "callee");
}
#[test]
fn test_get_outgoing_and_incoming_edges() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let id1 = db
.insert_node(&create_test_node("node1", NodeKind::Function, "test.rs"))
.unwrap();
let id2 = db
.insert_node(&create_test_node("node2", NodeKind::Function, "test.rs"))
.unwrap();
let edge = Edge {
id: 0,
source_id: id1,
target_id: id2,
kind: EdgeKind::Calls,
file_path: None,
line: None,
column: None,
detail: None,
};
db.insert_edge(&edge).unwrap();
let outgoing = db.get_outgoing_edges(id1).unwrap();
assert_eq!(outgoing.len(), 1);
assert_eq!(outgoing[0].target_id, id2);
let incoming = db.get_incoming_edges(id2).unwrap();
assert_eq!(incoming.len(), 1);
assert_eq!(incoming[0].source_id, id1);
}
#[test]
fn test_unresolved_refs() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let node_id = db
.insert_node(&create_test_node("caller", NodeKind::Function, "test.rs"))
.unwrap();
let uref = UnresolvedReference {
source_node_id: node_id,
reference_name: "some_function".to_string(),
kind: EdgeKind::Calls,
file_path: "src/lib.rs".to_string(),
line: 5,
column: 10,
detail: None,
};
db.insert_unresolved_ref(&uref).unwrap();
let refs = db.get_unresolved_refs().unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].reference_name, "some_function");
}
#[test]
fn test_resolve_references() {
let db = Database::in_memory().unwrap();
let file1 = mk_file("test.rs");
db.insert_or_update_file(&file1).unwrap();
let caller_id = db
.insert_node(&create_test_node("caller", NodeKind::Function, "test.rs"))
.unwrap();
let _callee_id = db
.insert_node(&create_test_node(
"target_func",
NodeKind::Function,
"test.rs",
))
.unwrap();
let uref = UnresolvedReference {
source_node_id: caller_id,
reference_name: "target_func".to_string(),
kind: EdgeKind::Calls,
file_path: "test.rs".to_string(),
line: 5,
column: 10,
detail: None,
};
db.insert_unresolved_ref(&uref).unwrap();
let resolved = db.resolve_references().unwrap();
assert_eq!(resolved, 1);
let outgoing = db.get_outgoing_edges(caller_id).unwrap();
assert_eq!(outgoing.len(), 1);
let refs = db.get_unresolved_refs().unwrap();
assert!(refs.is_empty());
}
#[test]
fn test_resolve_references_for_files() {
let db = Database::in_memory().unwrap();
let file1 = mk_file("src/a.rs");
let file2 = mk_file("src/b.rs");
db.insert_or_update_file(&file1).unwrap();
db.insert_or_update_file(&file2).unwrap();
let _target_id = db
.insert_node(&create_test_node(
"target_func",
NodeKind::Function,
"src/b.rs",
))
.unwrap();
let caller_a = db
.insert_node(&create_test_node(
"caller_a",
NodeKind::Function,
"src/a.rs",
))
.unwrap();
let caller_b = db
.insert_node(&create_test_node(
"caller_b",
NodeKind::Function,
"src/b.rs",
))
.unwrap();
db.insert_unresolved_ref(&UnresolvedReference {
source_node_id: caller_a,
reference_name: "target_func".to_string(),
kind: EdgeKind::Calls,
file_path: "src/a.rs".to_string(),
line: 10,
column: 5,
detail: None,
})
.unwrap();
db.insert_unresolved_ref(&UnresolvedReference {
source_node_id: caller_b,
reference_name: "target_func".to_string(),
kind: EdgeKind::Calls,
file_path: "src/b.rs".to_string(),
line: 20,
column: 5,
detail: None,
})
.unwrap();
let resolved = db
.resolve_references_for_files(&["src/a.rs".to_string()])
.unwrap();
assert_eq!(resolved, 1);
let outgoing_a = db.get_outgoing_edges(caller_a).unwrap();
assert_eq!(outgoing_a.len(), 1);
let outgoing_b = db.get_outgoing_edges(caller_b).unwrap();
assert_eq!(outgoing_b.len(), 0);
let refs = db.get_unresolved_refs().unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].file_path, "src/b.rs");
}
#[test]
fn test_resolve_references_for_files_empty() {
let db = Database::in_memory().unwrap();
let resolved = db.resolve_references_for_files(&[]).unwrap();
assert_eq!(resolved, 0);
}
#[test]
fn test_stats() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.insert_node(&create_test_node("func1", NodeKind::Function, "test.rs"))
.unwrap();
db.insert_node(&create_test_node("func2", NodeKind::Function, "test.rs"))
.unwrap();
db.insert_node(&create_test_node("MyClass", NodeKind::Class, "test.rs"))
.unwrap();
let stats = db.get_stats().unwrap();
assert_eq!(stats.total_files, 1);
assert_eq!(stats.total_nodes, 3);
assert_eq!(stats.total_edges, 0);
}
#[test]
fn test_delete_file() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let id1 = db
.insert_node(&create_test_node("func1", NodeKind::Function, "test.rs"))
.unwrap();
let id2 = db
.insert_node(&create_test_node("func2", NodeKind::Function, "test.rs"))
.unwrap();
let edge = Edge {
id: 0,
source_id: id1,
target_id: id2,
kind: EdgeKind::Calls,
file_path: Some("test.rs".to_string()),
line: None,
column: None,
detail: None,
};
db.insert_edge(&edge).unwrap();
db.delete_file("test.rs").unwrap();
assert!(db.get_file("test.rs").unwrap().is_none());
assert!(db.get_node(id1).unwrap().is_none());
assert!(db.get_node(id2).unwrap().is_none());
let stats = db.get_stats().unwrap();
assert_eq!(stats.total_files, 0);
assert_eq!(stats.total_nodes, 0);
assert_eq!(stats.total_edges, 0);
}
#[test]
fn test_transaction_commit() {
let mut db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.begin_transaction().unwrap();
db.insert_node(&create_test_node("func1", NodeKind::Function, "test.rs"))
.unwrap();
db.commit().unwrap();
let stats = db.get_stats().unwrap();
assert_eq!(stats.total_nodes, 1);
}
#[test]
fn test_transaction_rollback() {
let mut db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
db.begin_transaction().unwrap();
db.insert_node(&create_test_node("func1", NodeKind::Function, "test.rs"))
.unwrap();
db.rollback().unwrap();
let stats = db.get_stats().unwrap();
assert_eq!(stats.total_nodes, 0);
}
#[test]
fn test_prepare_for_swap_cleans_sidecars() {
let dir = tempdir().unwrap();
let path = dir.path().join("shadow.db");
let db = Database::open(&path).unwrap();
let file = mk_file("shadow.rs");
db.insert_or_update_file(&file).unwrap();
db.insert_node(&create_test_node(
"shadow_fn",
NodeKind::Function,
"shadow.rs",
))
.unwrap();
let prepared_path = db.prepare_for_swap().unwrap();
assert_eq!(prepared_path, path);
assert!(prepared_path.exists());
assert!(!PathBuf::from(format!("{}-wal", prepared_path.display())).exists());
assert!(!PathBuf::from(format!("{}-shm", prepared_path.display())).exists());
}
#[test]
fn test_replace_with_shadow_reopens_new_contents() {
let dir = tempdir().unwrap();
let live_path = dir.path().join("live.db");
let shadow_path = dir.path().join("shadow.db");
let mut live = Database::open(&live_path).unwrap();
let old_file = mk_file("old.rs");
live.insert_or_update_file(&old_file).unwrap();
live.insert_node(&create_test_node("old_fn", NodeKind::Function, "old.rs"))
.unwrap();
let shadow = Database::open(&shadow_path).unwrap();
let new_file = mk_file("new.rs");
shadow.insert_or_update_file(&new_file).unwrap();
shadow
.insert_node(&create_test_node("new_fn", NodeKind::Function, "new.rs"))
.unwrap();
let prepared_shadow = shadow.prepare_for_swap().unwrap();
live.replace_with_shadow(&prepared_shadow).unwrap();
assert_eq!(live.path(), Some(live_path.as_path()));
assert!(live.find_node_by_name("old_fn").unwrap().is_none());
assert!(live.find_node_by_name("new_fn").unwrap().is_some());
assert!(live.get_file("old.rs").unwrap().is_none());
assert!(live.get_file("new.rs").unwrap().is_some());
assert!(!prepared_shadow.exists());
}
#[test]
fn test_cleanup_on_disk_path_removes_db_and_sidecars() {
let dir = tempdir().unwrap();
let path = dir.path().join("cleanup.db");
std::fs::write(&path, b"db").unwrap();
std::fs::write(format!("{}-wal", path.display()), b"wal").unwrap();
std::fs::write(format!("{}-shm", path.display()), b"shm").unwrap();
Database::cleanup_on_disk_path(&path).unwrap();
assert!(!path.exists());
assert!(!PathBuf::from(format!("{}-wal", path.display())).exists());
assert!(!PathBuf::from(format!("{}-shm", path.display())).exists());
}
#[test]
fn test_rebuild_fts_indexes_restores_search() {
let db = Database::in_memory().unwrap();
let file = mk_file("auth.rs");
db.insert_or_update_file(&file).unwrap();
let mut node = create_test_node("validateToken", NodeKind::Function, "auth.rs");
node.docstring = Some("Validate JWT bearer token".to_string());
db.insert_nodes_batch_without_fts(&mut [node]).unwrap();
assert!(db.semantic_search("jwt bearer", 10).unwrap().is_empty());
db.disable_fts_automerge().unwrap();
db.rebuild_fts_indexes().unwrap();
let hits = db.semantic_search("jwt bearer", 10).unwrap();
assert!(hits.iter().any(|n| n.name == "validateToken"));
}
#[test]
fn test_get_hierarchy() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let class_id = db
.insert_node(&create_test_node("MyClass", NodeKind::Class, "test.rs"))
.unwrap();
let method_id = db
.insert_node(&create_test_node("my_method", NodeKind::Method, "test.rs"))
.unwrap();
let edge = Edge {
id: 0,
source_id: class_id,
target_id: method_id,
kind: EdgeKind::Contains,
file_path: None,
line: None,
column: None,
detail: None,
};
db.insert_edge(&edge).unwrap();
let hierarchy = db.get_hierarchy("my_method").unwrap();
assert_eq!(hierarchy.len(), 1);
assert_eq!(hierarchy[0].name, "MyClass");
let hierarchy = db.get_hierarchy("MyClass").unwrap();
assert_eq!(hierarchy.len(), 1);
assert_eq!(hierarchy[0].name, "my_method");
}
#[test]
fn test_find_call_path() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let a_id = db
.insert_node(&create_test_node("a", NodeKind::Function, "test.rs"))
.unwrap();
let b_id = db
.insert_node(&create_test_node("b", NodeKind::Function, "test.rs"))
.unwrap();
let c_id = db
.insert_node(&create_test_node("c", NodeKind::Function, "test.rs"))
.unwrap();
db.insert_edge(&Edge {
id: 0,
source_id: a_id,
target_id: b_id,
kind: EdgeKind::Calls,
file_path: None,
line: None,
column: None,
detail: None,
})
.unwrap();
db.insert_edge(&Edge {
id: 0,
source_id: b_id,
target_id: c_id,
kind: EdgeKind::Calls,
file_path: None,
line: None,
column: None,
detail: None,
})
.unwrap();
let paths = db.find_call_path("a", "c").unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].len(), 3);
assert_eq!(paths[0][0].name, "a");
assert_eq!(paths[0][1].name, "b");
assert_eq!(paths[0][2].name, "c");
}
#[test]
fn test_find_unused_symbols() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let used_id = db
.insert_node(&create_test_node(
"used_func",
NodeKind::Function,
"test.rs",
))
.unwrap();
let _unused_id = db
.insert_node(&create_test_node(
"unused_func",
NodeKind::Function,
"test.rs",
))
.unwrap();
let caller_id = db
.insert_node(&create_test_node("caller", NodeKind::Function, "test.rs"))
.unwrap();
db.insert_edge(&Edge {
id: 0,
source_id: caller_id,
target_id: used_id,
kind: EdgeKind::Calls,
file_path: None,
line: None,
column: None,
detail: None,
})
.unwrap();
let unused = db.find_unused_symbols().unwrap();
assert_eq!(unused.len(), 2); assert!(unused.iter().any(|n| n.name == "unused_func"));
assert!(unused.iter().any(|n| n.name == "caller"));
}
#[test]
fn test_find_implementations() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let interface_id = db
.insert_node(&create_test_node("MyTrait", NodeKind::Interface, "test.rs"))
.unwrap();
let impl1_id = db
.insert_node(&create_test_node("Impl1", NodeKind::Struct, "test.rs"))
.unwrap();
let impl2_id = db
.insert_node(&create_test_node("Impl2", NodeKind::Struct, "test.rs"))
.unwrap();
db.insert_edge(&Edge {
id: 0,
source_id: impl1_id,
target_id: interface_id,
kind: EdgeKind::Implements,
file_path: None,
line: None,
column: None,
detail: None,
})
.unwrap();
db.insert_edge(&Edge {
id: 0,
source_id: impl2_id,
target_id: interface_id,
kind: EdgeKind::Implements,
file_path: None,
line: None,
column: None,
detail: None,
})
.unwrap();
let impls = db.find_implementations("MyTrait").unwrap();
assert_eq!(impls.len(), 2);
assert!(impls.iter().any(|n| n.name == "Impl1"));
assert!(impls.iter().any(|n| n.name == "Impl2"));
}
#[test]
fn test_get_diff_impact() {
let db = Database::in_memory().unwrap();
let file = mk_file("test.rs");
db.insert_or_update_file(&file).unwrap();
let mut affected_node = create_test_node("affected_func", NodeKind::Function, "test.rs");
affected_node.start_line = 10;
affected_node.end_line = 20;
let affected_id = db.insert_node(&affected_node).unwrap();
let caller_id = db
.insert_node(&create_test_node(
"caller_func",
NodeKind::Function,
"test.rs",
))
.unwrap();
db.insert_edge(&Edge {
id: 0,
source_id: caller_id,
target_id: affected_id,
kind: EdgeKind::Calls,
file_path: None,
line: None,
column: None,
detail: None,
})
.unwrap();
let impacted = db.get_diff_impact("test.rs", 15, 18).unwrap();
assert_eq!(impacted.len(), 2); assert!(impacted.iter().any(|n| n.name == "affected_func"));
assert!(impacted.iter().any(|n| n.name == "caller_func"));
}
}
#[cfg(test)]
mod language_tests {
use super::*;
use crate::types::FileRecord;
#[test]
fn test_language_roundtrip() {
let db = Database::in_memory().unwrap();
let file = FileRecord {
path: "test.rs".to_string(),
content_hash: "abc123".to_string(),
language: Language::Rust,
size: 100,
modified_at: 0,
indexed_at: 0,
node_count: 1,
};
db.insert_or_update_file(&file).unwrap();
let node = Node {
id: 0,
kind: NodeKind::Function,
name: "test_func".to_string(),
qualified_name: None,
file_path: "test.rs".to_string(),
start_line: 1,
end_line: 10,
start_column: 0,
end_column: 0,
signature: Some("fn test_func()".to_string()),
visibility: Visibility::Private,
docstring: None,
is_async: false,
is_static: false,
is_exported: false,
is_test: false,
is_generated: false,
language: Language::Rust,
};
db.insert_node(&node).unwrap();
let retrieved = db.find_node_by_name("test_func").unwrap().unwrap();
assert_eq!(
retrieved.language,
Language::Rust,
"Language should be Rust, got {:?}",
retrieved.language
);
assert_eq!(
retrieved.visibility,
Visibility::Private,
"Visibility should be Private, got {:?}",
retrieved.visibility
);
}
}
#[cfg(test)]
mod language_tests_2 {
use super::*;
use crate::types::{FileRecord, NodeKind, UnresolvedReference, Visibility};
fn mk_file(path: &str) -> FileRecord {
FileRecord {
path: path.to_string(),
content_hash: "h".to_string(),
language: Language::Rust,
size: 0,
modified_at: 0,
indexed_at: 0,
node_count: 0,
}
}
fn mk_node(name: &str, file_path: &str) -> Node {
Node {
id: 0,
kind: NodeKind::Function,
name: name.to_string(),
qualified_name: None,
file_path: file_path.to_string(),
start_line: 1,
end_line: 1,
start_column: 0,
end_column: 0,
signature: None,
visibility: Visibility::Public,
docstring: None,
is_async: false,
is_static: false,
is_exported: false,
is_test: false,
is_generated: false,
language: Language::Rust,
}
}
#[test]
fn test_semantic_search_by_docstring() {
let db = Database::in_memory().unwrap();
let file = mk_file("auth.rs");
db.insert_or_update_file(&file).unwrap();
let mut node = mk_node("validate_token", "auth.rs");
node.docstring = Some("Verifies a JWT bearer token against the signing key".to_string());
db.insert_node(&node).unwrap();
let other = mk_node("calculate_total", "auth.rs");
db.insert_node(&other).unwrap();
let hits = db.semantic_search("jwt bearer", 10).unwrap();
assert!(
hits.iter().any(|n| n.name == "validate_token"),
"semantic search should find by docstring"
);
}
#[test]
fn test_semantic_search_by_camel_case_split() {
let db = Database::in_memory().unwrap();
let file = mk_file("svc.rs");
db.insert_or_update_file(&file).unwrap();
let node = mk_node("renderUserDashboard", "svc.rs");
db.insert_node(&node).unwrap();
let hits = db.semantic_search("dashboard", 10).unwrap();
assert!(hits.iter().any(|n| n.name == "renderUserDashboard"));
}
#[test]
fn test_resolve_prefers_same_file() {
let db = Database::in_memory().unwrap();
let f1 = mk_file("a.rs");
let f2 = mk_file("b.rs");
db.insert_or_update_file(&f1).unwrap();
db.insert_or_update_file(&f2).unwrap();
let caller_id = db.insert_node(&mk_node("caller", "a.rs")).unwrap();
let local_id = db.insert_node(&mk_node("helper", "a.rs")).unwrap();
let _foreign_id = db.insert_node(&mk_node("helper", "b.rs")).unwrap();
db.insert_unresolved_ref(&UnresolvedReference {
source_node_id: caller_id,
reference_name: "helper".to_string(),
kind: EdgeKind::Calls,
file_path: "a.rs".to_string(),
line: 1,
column: 0,
detail: None,
})
.unwrap();
db.resolve_references().unwrap();
let outgoing = db.get_outgoing_edges(caller_id).unwrap();
assert_eq!(outgoing.len(), 1);
assert_eq!(outgoing[0].target_id, local_id);
}
#[test]
fn test_unused_excludes_tests_and_generated() {
let db = Database::in_memory().unwrap();
let file = mk_file("x.rs");
db.insert_or_update_file(&file).unwrap();
let mut t = mk_node("test_thing", "x.rs");
t.is_test = true;
db.insert_node(&t).unwrap();
let mut g = mk_node("generated_thing", "x.rs");
g.is_generated = true;
db.insert_node(&g).unwrap();
let unused = db.find_unused_symbols().unwrap();
assert!(unused.iter().all(|n| n.name != "test_thing"));
assert!(unused.iter().all(|n| n.name != "generated_thing"));
}
}