use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::{EmbeddedDatabase, Error, Result, Tuple, Value};
use super::embed::{Embedder, NoopEmbedder};
use super::parse::{self, Language};
use super::resolver::{resolve_in_file, Resolution};
use super::symbols::{extract, Symbol};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupportedLanguage {
Rust,
Python,
TypeScript,
Tsx,
JavaScript,
Go,
Markdown,
Sql,
}
impl SupportedLanguage {
pub fn as_str(self) -> &'static str {
Language::from(self).as_str()
}
pub fn all() -> &'static [SupportedLanguage] {
&[
SupportedLanguage::Rust,
SupportedLanguage::Python,
SupportedLanguage::TypeScript,
SupportedLanguage::Tsx,
SupportedLanguage::JavaScript,
SupportedLanguage::Go,
SupportedLanguage::Markdown,
SupportedLanguage::Sql,
]
}
}
impl From<SupportedLanguage> for Language {
fn from(s: SupportedLanguage) -> Self {
match s {
SupportedLanguage::Rust => Language::Rust,
SupportedLanguage::Python => Language::Python,
SupportedLanguage::TypeScript => Language::TypeScript,
SupportedLanguage::Tsx => Language::Tsx,
SupportedLanguage::JavaScript => Language::JavaScript,
SupportedLanguage::Go => Language::Go,
SupportedLanguage::Markdown => Language::Markdown,
SupportedLanguage::Sql => Language::Sql,
}
}
}
impl From<Language> for SupportedLanguage {
fn from(l: Language) -> Self {
match l {
Language::Rust => SupportedLanguage::Rust,
Language::Python => SupportedLanguage::Python,
Language::TypeScript => SupportedLanguage::TypeScript,
Language::Tsx => SupportedLanguage::Tsx,
Language::JavaScript => SupportedLanguage::JavaScript,
Language::Go => SupportedLanguage::Go,
Language::Markdown => SupportedLanguage::Markdown,
Language::Sql => SupportedLanguage::Sql,
}
}
}
#[derive(Debug, Clone)]
pub struct CodeIndexOptions {
pub source_table: String,
pub embed_bodies: bool,
pub embed_endpoint: Option<String>,
pub embed_bearer: Option<String>,
pub force_reparse: bool,
pub parallelism: Option<usize>,
pub chunk_size: Option<usize>,
}
impl CodeIndexOptions {
pub fn for_table(name: impl Into<String>) -> Self {
Self {
source_table: name.into(),
embed_bodies: false,
embed_endpoint: None,
embed_bearer: None,
force_reparse: false,
parallelism: None,
chunk_size: None,
}
}
pub(crate) fn resolved_parallelism(&self) -> usize {
if let Some(n) = self.parallelism {
return n.max(1);
}
let cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
cores.min(8).max(1)
}
}
#[derive(Debug, Clone, Default)]
pub struct CodeIndexStats {
pub files_seen: u64,
pub files_parsed: u64,
pub files_skipped: u64,
pub files_unchanged: u64,
pub symbols_written: u64,
pub refs_written: u64,
pub embed_calls: u64,
pub languages_seen: Vec<String>,
pub parse_elapsed_ms: u64,
pub write_elapsed_ms: u64,
pub parse_workers: u32,
pub chunks_processed: u32,
}
static EXTENSION_INSTALLED: AtomicBool = AtomicBool::new(false);
pub fn mark_extension_installed() {
EXTENSION_INSTALLED.store(true, Ordering::Relaxed);
}
pub fn is_extension_installed() -> bool {
EXTENSION_INSTALLED.load(Ordering::Relaxed)
}
#[derive(Debug, Clone)]
pub struct AstIndexMeta {
pub index_name: String,
pub table: String,
pub content_col: String,
pub lang_col: Option<String>,
pub embed_endpoint: Option<String>,
pub embed_bearer: Option<String>,
pub embed_bodies: bool,
pub auto_reparse: bool,
pub resolve_cross_file: bool,
pub paused: bool,
}
static AST_INDEXES: std::sync::OnceLock<
std::sync::RwLock<HashMap<String, AstIndexMeta>>,
> = std::sync::OnceLock::new();
fn ast_registry() -> &'static std::sync::RwLock<HashMap<String, AstIndexMeta>> {
AST_INDEXES.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
pub fn register_ast_index(meta: AstIndexMeta) {
let mut reg = ast_registry().write().unwrap_or_else(|p| p.into_inner());
reg.insert(meta.index_name.clone(), meta);
}
pub fn get_ast_index(name: &str) -> Option<AstIndexMeta> {
ast_registry()
.read()
.unwrap_or_else(|p| p.into_inner())
.get(name)
.cloned()
}
pub fn ast_indexes_for_table(table_name: &str) -> Vec<AstIndexMeta> {
ast_registry()
.read()
.unwrap_or_else(|p| p.into_inner())
.values()
.filter(|m| m.table == table_name && !m.paused)
.cloned()
.collect()
}
pub fn set_ast_index_paused(name: &str, paused: bool) -> bool {
let mut reg = ast_registry().write().unwrap_or_else(|p| p.into_inner());
match reg.get_mut(name) {
Some(m) => {
m.paused = paused;
true
}
None => false,
}
}
pub fn code_index(db: &EmbeddedDatabase, opts: CodeIndexOptions) -> Result<CodeIndexStats> {
let embedder: Box<dyn Embedder> = match opts.embed_endpoint.as_deref() {
Some(url) if opts.embed_bodies => {
let mut h = super::embed::HttpEmbedder::new(url);
if let Some(tok) = &opts.embed_bearer {
h = h.with_bearer(tok.clone());
}
Box::new(h)
}
_ => Box::new(NoopEmbedder),
};
code_index_with_embedder(db, opts, embedder)
}
pub fn code_index_with_embedder(
db: &EmbeddedDatabase,
opts: CodeIndexOptions,
embedder: Box<dyn Embedder>,
) -> Result<CodeIndexStats> {
bootstrap_tables(db)?;
mark_extension_installed();
let files = fetch_source_files(db, &opts.source_table)?;
let mut existing_sha = fetch_file_sha_map(db)?;
let mut stats = CodeIndexStats::default();
let mut lang_set: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
let truncated = opts.force_reparse && !existing_sha.is_empty();
if truncated {
for tbl in [
"_hdb_code_symbol_refs",
"_hdb_code_symbols",
"_hdb_code_ast_nodes",
"_hdb_code_files",
] {
let _ = db.execute(&format!("TRUNCATE {tbl}"));
}
existing_sha.clear();
tracing::debug!("force_reparse + populated KB: truncated _hdb_code_* tables");
}
let mut to_parse: Vec<(SourceFile, String, ParseExtractPair)> = Vec::new();
for file in files.into_iter() {
stats.files_seen += 1;
let resolution = match resolve_parser_and_extractor(&file.lang) {
Some(r) => r,
None => {
stats.files_skipped += 1;
continue;
}
};
lang_set.insert(file.lang.to_ascii_lowercase());
let sha = sha256_hex(&file.content);
let unchanged = existing_sha
.get(&file.path)
.map(|s| s == &sha)
.unwrap_or(false);
if unchanged && !opts.force_reparse {
stats.files_unchanged += 1;
continue;
}
to_parse.push((file, sha, resolution));
}
let touched = !to_parse.is_empty();
let workers = opts.resolved_parallelism();
stats.parse_workers = workers as u32;
let chunks: Vec<Vec<(SourceFile, String, ParseExtractPair)>> = match opts.chunk_size {
None => {
if to_parse.is_empty() {
Vec::new()
} else {
vec![to_parse]
}
}
Some(n) => {
let n = n.max(1);
let mut out = Vec::with_capacity(to_parse.len().div_ceil(n));
let mut iter = to_parse.into_iter();
loop {
let chunk: Vec<_> = (&mut iter).take(n).collect();
if chunk.is_empty() {
break;
}
out.push(chunk);
}
out
}
};
let pool = if chunks.is_empty() {
None
} else {
Some(
rayon::ThreadPoolBuilder::new()
.num_threads(workers)
.thread_name(|i| format!("hdb-code-index-{i}"))
.build()
.map_err(|e| Error::query_execution(format!(
"failed to build code-index thread pool ({e})"
)))?,
)
};
let manage_txn = !db.in_transaction();
let mut processed_paths: std::collections::HashSet<String> =
std::collections::HashSet::new();
for chunk in chunks {
let parse_started = std::time::Instant::now();
let parsed: Vec<Result<ParsedFile>> = if let Some(pool) = &pool {
use rayon::prelude::*;
pool.install(|| {
chunk
.into_par_iter()
.map(parse_extract_one)
.collect::<Vec<_>>()
})
} else {
chunk.into_iter().map(parse_extract_one).collect::<Vec<_>>()
};
stats.parse_elapsed_ms += parse_started.elapsed().as_millis() as u64;
stats.chunks_processed += 1;
let write_started = std::time::Instant::now();
if manage_txn {
db.begin()?;
}
let chunk_result = drain_chunk(
db,
&opts,
embedder.as_ref(),
&mut stats,
parsed,
truncated,
&mut processed_paths,
);
if manage_txn {
match chunk_result {
Ok(()) => db.commit()?,
Err(e) => {
let _ = db.rollback();
return Err(e);
}
}
} else {
chunk_result?;
}
stats.write_elapsed_ms += write_started.elapsed().as_millis() as u64;
}
stats.languages_seen = lang_set.into_iter().collect();
if touched {
let cross_started = std::time::Instant::now();
if manage_txn {
db.begin()?;
}
let cross_result = cross_file_resolve(db, &mut stats);
if manage_txn {
match cross_result {
Ok(()) => db.commit()?,
Err(e) => {
let _ = db.rollback();
return Err(e);
}
}
} else {
cross_result?;
}
stats.write_elapsed_ms += cross_started.elapsed().as_millis() as u64;
}
Ok(stats)
}
fn drain_chunk(
db: &EmbeddedDatabase,
opts: &CodeIndexOptions,
embedder: &dyn Embedder,
stats: &mut CodeIndexStats,
parsed: Vec<Result<ParsedFile>>,
truncated_this_call: bool,
processed_paths: &mut std::collections::HashSet<String>,
) -> Result<()> {
let parsed: Vec<ParsedFile> = parsed.into_iter().collect::<Result<Vec<_>>>()?;
if parsed.is_empty() {
return Ok(());
}
let mut prepared: Vec<(i64, ParsedFile)> = Vec::with_capacity(parsed.len());
for p in parsed {
let first_time = processed_paths.insert(p.file.path.clone());
let skip_delete_stale = truncated_this_call && first_time;
let file_id = upsert_file(db, &opts.source_table, &p.file, &p.sha)?;
if !skip_delete_stale {
delete_stale_symbols_and_refs(db, file_id)?;
}
prepared.push((file_id, p));
}
let total_syms: usize = prepared.iter().map(|(_, p)| p.symbols.len()).sum();
let mut all_symbols: Vec<&Symbol> = Vec::with_capacity(total_syms);
let mut sym_ranges: Vec<(usize, usize)> = Vec::with_capacity(prepared.len());
let mut sym_owner_file_ids: Vec<i64> = Vec::with_capacity(total_syms);
for (file_id, p) in &prepared {
let start = all_symbols.len();
all_symbols.extend(p.symbols.iter());
sym_ranges.push((start, p.symbols.len()));
sym_owner_file_ids.extend(std::iter::repeat(*file_id).take(p.symbols.len()));
}
let symbol_ids: Vec<i64> = if all_symbols.is_empty() {
Vec::new()
} else {
bulk_insert_symbols_batched(db, embedder, stats, &sym_owner_file_ids, &all_symbols)?
};
let total_refs: usize = prepared.iter().map(|(_, p)| p.resolved.len()).sum();
if total_refs > 0 {
let written = bulk_insert_refs_batched(db, &prepared, &symbol_ids, &sym_ranges)?;
stats.refs_written += written;
}
for (i, (_, p)) in prepared.iter().enumerate() {
let _ = i; stats.files_parsed += 1;
stats.symbols_written += p.symbols.len() as u64;
}
Ok(())
}
fn delete_stale_symbols_and_refs(db: &EmbeddedDatabase, file_id: i64) -> Result<()> {
let sym_rows = db.query(
&format!("SELECT node_id FROM _hdb_code_symbols WHERE file_id = {file_id}"),
&[],
)?;
let stale_ids: Vec<i64> = sym_rows
.iter()
.filter_map(|r| match r.values.first() {
Some(Value::Int4(n)) => Some(*n as i64),
Some(Value::Int8(n)) => Some(*n),
_ => None,
})
.collect();
if !stale_ids.is_empty() {
let csv = stale_ids
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(",");
db.execute(&format!(
"UPDATE _hdb_code_symbol_refs \
SET to_symbol = NULL, resolution = 'unresolved' \
WHERE to_symbol IN ({csv})"
))?;
}
db.execute(&format!(
"DELETE FROM _hdb_code_symbol_refs WHERE file_id = {file_id}"
))?;
db.execute(&format!(
"DELETE FROM _hdb_code_symbols WHERE file_id = {file_id}"
))?;
Ok(())
}
fn bulk_insert_symbols_batched(
db: &EmbeddedDatabase,
embedder: &dyn Embedder,
stats: &mut CodeIndexStats,
file_id_per_symbol: &[i64],
symbols: &[&Symbol],
) -> Result<Vec<i64>> {
debug_assert_eq!(file_id_per_symbol.len(), symbols.len());
let mut vectors: Vec<Option<Vec<f32>>> = Vec::with_capacity(symbols.len());
for sym in symbols {
let v = if !sym.signature.is_empty() {
embedder.embed(&sym.signature)?
} else {
None
};
if v.is_some() {
stats.embed_calls += 1;
}
vectors.push(v);
}
let any_vec = vectors.iter().any(Option::is_some);
if any_vec {
let dim = vectors
.iter()
.find_map(|v| v.as_ref().map(|v| v.len()))
.unwrap_or(0);
if dim == 0 {
return Err(Error::query_execution(
"embedder returned a zero-length vector",
));
}
ensure_body_vec_column(db, dim)?;
for v in &vectors {
if let Some(vec) = v {
if vec.len() != dim {
return Err(Error::query_execution(format!(
"embedder dimension mismatch: expected {dim}, got {}",
vec.len()
)));
}
}
}
}
let schema = db.storage.catalog().get_table_schema("_hdb_code_symbols")?;
let n_cols = schema.columns.len();
let expected_min_cols = if any_vec { 13 } else { 12 };
if n_cols < expected_min_cols {
return Err(Error::query_execution(format!(
"_hdb_code_symbols schema has {} cols, fast path expects ≥ {}",
n_cols, expected_min_cols
)));
}
let mut tuples: Vec<Tuple> = Vec::with_capacity(symbols.len());
for (idx, sym) in symbols.iter().enumerate() {
let file_id = file_id_per_symbol[idx];
let mut values: Vec<Value> = Vec::with_capacity(n_cols);
values.push(Value::Null); values.push(Value::Int8(file_id));
values.push(Value::String(sym.name.clone()));
values.push(Value::String(sym.qualified.clone()));
values.push(Value::String(sym.kind.as_str().to_string()));
values.push(Value::String(sym.signature.clone()));
values.push(Value::String(sym.visibility.as_str().to_string()));
values.push(Value::Int4(sym.line_start as i32));
values.push(Value::Int4(sym.line_end as i32));
values.push(Value::Int4(sym.byte_start as i32));
values.push(Value::Int4(sym.byte_end as i32));
values.push(Value::Null); if any_vec && n_cols >= 13 {
let v = vectors
.get(idx)
.and_then(|v| v.clone())
.map(Value::Vector)
.unwrap_or(Value::Null);
values.push(v);
}
while values.len() < n_cols {
values.push(Value::Null);
}
tuples.push(Tuple::new(values));
}
let row_ids = db.bulk_insert_tuples("_hdb_code_symbols", tuples)?;
Ok(row_ids.into_iter().map(|id| id as i64).collect())
}
fn bulk_insert_refs_batched(
db: &EmbeddedDatabase,
prepared: &[(i64, ParsedFile)],
symbol_ids: &[i64],
sym_ranges: &[(usize, usize)],
) -> Result<u64> {
let schema = db.storage.catalog().get_table_schema("_hdb_code_symbol_refs")?;
let n_cols = schema.columns.len();
if n_cols < 8 {
return Err(Error::query_execution(format!(
"_hdb_code_symbol_refs schema has {} cols, fast path expects ≥ 8",
n_cols
)));
}
let total: usize = prepared.iter().map(|(_, p)| p.resolved.len()).sum();
let mut tuples: Vec<Tuple> = Vec::with_capacity(total);
for (i, (file_id, p)) in prepared.iter().enumerate() {
let (start, count) = sym_ranges[i];
let file_symbol_ids = &symbol_ids[start..start + count];
for r in &p.resolved {
let from_id = file_symbol_ids.get(r.from_idx).copied().ok_or_else(|| {
Error::query_execution(format!(
"resolver produced invalid from_idx {} for file_id {}",
r.from_idx, file_id
))
})?;
let to_val = match r.to_idx {
Some(idx) => file_symbol_ids
.get(idx)
.map(|id| Value::Int8(*id))
.unwrap_or(Value::Null),
None => Value::Null,
};
let res = match r.resolution {
Resolution::Exact => "exact",
Resolution::Heuristic => "heuristic",
Resolution::Unresolved => "unresolved",
};
let mut values: Vec<Value> = Vec::with_capacity(n_cols);
values.push(Value::Null); values.push(Value::Int8(*file_id));
values.push(Value::Int8(from_id));
values.push(to_val);
values.push(Value::String(r.to_name.clone()));
values.push(Value::String(r.kind_str.to_string()));
values.push(Value::Int4(r.line as i32));
values.push(Value::String(res.to_string()));
while values.len() < n_cols {
values.push(Value::Null);
}
tuples.push(Tuple::new(values));
}
}
let written = tuples.len() as u64;
if !tuples.is_empty() {
db.bulk_insert_tuples("_hdb_code_symbol_refs", tuples)?;
}
Ok(written)
}
fn write_one_parsed(
db: &EmbeddedDatabase,
opts: &CodeIndexOptions,
embedder: &dyn Embedder,
stats: &mut CodeIndexStats,
parsed: ParsedFile,
skip_delete_stale: bool,
) -> Result<()> {
let ParsedFile { file, sha, symbols, resolved } = parsed;
let file_id = upsert_file(db, &opts.source_table, &file, &sha)?;
if !skip_delete_stale {
let sym_rows = db.query(
&format!("SELECT node_id FROM _hdb_code_symbols WHERE file_id = {file_id}"),
&[],
)?;
let stale_ids: Vec<i64> = sym_rows
.iter()
.filter_map(|r| match r.values.first() {
Some(Value::Int4(n)) => Some(*n as i64),
Some(Value::Int8(n)) => Some(*n),
_ => None,
})
.collect();
if !stale_ids.is_empty() {
let csv = stale_ids
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(",");
db.execute(&format!(
"UPDATE _hdb_code_symbol_refs \
SET to_symbol = NULL, resolution = 'unresolved' \
WHERE to_symbol IN ({csv})"
))?;
}
db.execute(&format!(
"DELETE FROM _hdb_code_symbol_refs WHERE file_id = {file_id}"
))?;
db.execute(&format!(
"DELETE FROM _hdb_code_symbols WHERE file_id = {file_id}"
))?;
}
let symbol_ids = insert_symbols(db, file_id, &symbols, embedder, stats)?;
let refs_written = insert_refs(db, file_id, &symbol_ids, &resolved)?;
stats.files_parsed += 1;
stats.symbols_written += symbols.len() as u64;
stats.refs_written += refs_written;
Ok(())
}
struct ParsedFile {
file: SourceFile,
sha: String,
symbols: Vec<Symbol>,
resolved: Vec<super::resolver::ResolvedRef>,
}
fn parse_extract_one(
item: (SourceFile, String, ParseExtractPair),
) -> Result<ParsedFile> {
let (file, sha, resolution) = item;
let tree = match &resolution {
ParseExtractPair::Static(lang) => parse::parse(*lang, &file.content)?,
ParseExtractPair::Dynamic { .. } => {
parse::parse_by_name(&file.lang, &file.content)?
}
};
let (symbols, refs) = match &resolution {
ParseExtractPair::Static(lang) => extract(*lang, &file.content, &tree),
ParseExtractPair::Dynamic { extractor } => extractor.extract(&file.content, &tree),
};
let mut resolved = resolve_in_file(&symbols, &refs);
let bodies: Vec<super::resolver::FunctionBody<'_>> = symbols
.iter()
.filter_map(|s| {
let lo = (s.byte_start as usize).min(file.content.len());
let hi = (s.byte_end as usize).min(file.content.len());
if hi <= lo {
return None;
}
Some(super::resolver::FunctionBody {
line_start: s.line_start,
line_end: s.line_end,
body_text: &file.content[lo..hi],
})
})
.collect();
super::resolver::rebind_via_local_types(&mut resolved, &bodies);
super::resolver::rebind_via_imports(&mut resolved);
Ok(ParsedFile { file, sha, symbols, resolved })
}
fn sha256_hex(s: &str) -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(s.as_bytes());
hex::encode(h.finalize())
}
enum ParseExtractPair {
Static(Language),
Dynamic {
extractor: std::sync::Arc<dyn super::symbols::SymbolExtractor>,
},
}
fn resolve_parser_and_extractor(lang: &str) -> Option<ParseExtractPair> {
if let Some(builtin) = Language::from_lang_str(lang) {
return Some(ParseExtractPair::Static(builtin));
}
let canonical = lang.trim().to_ascii_lowercase();
if super::parse::registered_grammars()
.iter()
.any(|g| g == &canonical)
{
if let Some(extractor) = super::symbols::registered_extractor(&canonical) {
return Some(ParseExtractPair::Dynamic { extractor });
}
}
None
}
fn ensure_body_vec_column(db: &EmbeddedDatabase, dim: usize) -> Result<()> {
db.execute(&format!(
"ALTER TABLE _hdb_code_symbols ADD COLUMN IF NOT EXISTS body_vec VECTOR({dim})"
))?;
Ok(())
}
fn fetch_file_sha_map(db: &EmbeddedDatabase) -> Result<HashMap<String, String>> {
let probe = db.query("SELECT 1 FROM _hdb_code_files LIMIT 1", &[]);
if probe.is_err() {
return Ok(HashMap::new());
}
let rows = db.query("SELECT path, sha256 FROM _hdb_code_files", &[])?;
let mut out = HashMap::with_capacity(rows.len());
for row in rows {
let path = match row.values.first() {
Some(Value::String(s)) => s.clone(),
_ => continue,
};
let sha = match row.values.get(1) {
Some(Value::String(s)) => s.clone(),
_ => continue,
};
out.insert(path, sha);
}
Ok(out)
}
fn cross_file_resolve(db: &EmbeddedDatabase, stats: &mut CodeIndexStats) -> Result<()> {
let rows = db.query(
"SELECT name, node_id FROM _hdb_code_symbols ORDER BY name, node_id",
&[],
)?;
let mut first: std::collections::HashMap<String, (i64, u32)> = std::collections::HashMap::new();
for row in rows {
let name = match row.values.first() {
Some(Value::String(s)) => s.clone(),
_ => continue,
};
let id = match row.values.get(1) {
Some(Value::Int4(n)) => *n as i64,
Some(Value::Int8(n)) => *n,
_ => continue,
};
let entry = first.entry(name).or_insert((id, 0));
entry.1 += 1;
}
let unresolved = db.query(
"SELECT edge_id, to_name FROM _hdb_code_symbol_refs WHERE resolution = 'unresolved'",
&[],
)?;
let mut rebound = 0u64;
for row in unresolved {
let edge_id = match row.values.first() {
Some(Value::Int4(n)) => *n as i64,
Some(Value::Int8(n)) => *n,
_ => continue,
};
let to_name = match row.values.get(1) {
Some(Value::String(s)) => s.clone(),
_ => continue,
};
let bare = last_segment(&to_name);
if let Some((id, count)) = first.get(bare) {
let res = if *count == 1 { "exact" } else { "heuristic" };
db.execute(&format!(
"UPDATE _hdb_code_symbol_refs \
SET to_symbol = {id}, resolution = '{res}' \
WHERE edge_id = {edge_id}"
))?;
rebound += 1;
}
}
let _ = rebound;
let _ = stats;
Ok(())
}
fn last_segment(name: &str) -> &str {
let bare = name.trim_end_matches(')');
let bare = bare.split('(').next().unwrap_or(bare);
if let Some(idx) = bare.rfind("::") {
return &bare[idx + 2..];
}
if let Some(idx) = bare.rfind('.') {
return &bare[idx + 1..];
}
bare
}
fn bootstrap_tables(db: &EmbeddedDatabase) -> Result<()> {
db.execute(
r#"CREATE TABLE IF NOT EXISTS _hdb_code_files (
node_id BIGSERIAL PRIMARY KEY,
source_table TEXT NOT NULL,
path TEXT NOT NULL,
lang TEXT,
sha256 TEXT,
mtime TIMESTAMP,
summary TEXT,
UNIQUE(source_table, path)
)"#,
)?;
db.execute(
r#"CREATE TABLE IF NOT EXISTS _hdb_code_symbols (
node_id BIGSERIAL PRIMARY KEY,
file_id BIGINT NOT NULL REFERENCES _hdb_code_files(node_id),
name TEXT NOT NULL,
qualified TEXT,
kind TEXT,
signature TEXT,
visibility TEXT,
line_start INTEGER,
line_end INTEGER,
byte_start INTEGER,
byte_end INTEGER,
parent_id BIGINT
)"#,
)?;
db.execute(
r#"CREATE TABLE IF NOT EXISTS _hdb_code_symbol_refs (
edge_id BIGSERIAL PRIMARY KEY,
file_id BIGINT NOT NULL REFERENCES _hdb_code_files(node_id),
from_symbol BIGINT NOT NULL REFERENCES _hdb_code_symbols(node_id),
to_symbol BIGINT REFERENCES _hdb_code_symbols(node_id),
to_name TEXT,
kind TEXT,
line INTEGER,
resolution TEXT
)"#,
)?;
let _ = db.execute(
"CREATE INDEX IF NOT EXISTS idx_hdb_code_symbols_file_id ON _hdb_code_symbols(file_id)",
);
let _ = db.execute(
"CREATE INDEX IF NOT EXISTS idx_hdb_code_symbol_refs_file_id ON _hdb_code_symbol_refs(file_id)",
);
let _ = db.execute(
"CREATE INDEX IF NOT EXISTS idx_hdb_code_symbol_refs_to_symbol ON _hdb_code_symbol_refs(to_symbol)",
);
let _ = db.execute(
"CREATE INDEX IF NOT EXISTS idx_hdb_code_symbol_refs_from_symbol ON _hdb_code_symbol_refs(from_symbol)",
);
Ok(())
}
#[derive(Debug, Clone)]
struct SourceFile {
path: String,
lang: String,
content: String,
sha256: Option<String>,
}
fn fetch_source_files(db: &EmbeddedDatabase, source_table: &str) -> Result<Vec<SourceFile>> {
let rows = db.query(
&format!(r#"SELECT "path", "lang", "content" FROM "{source_table}""#),
&[],
)?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let path = match row.values.first() {
Some(Value::String(s)) => s.clone(),
_ => continue,
};
let lang = match row.values.get(1) {
Some(Value::String(s)) => s.clone(),
_ => String::new(),
};
let content = match row.values.get(2) {
Some(Value::String(s)) => s.clone(),
_ => String::new(),
};
out.push(SourceFile { path, lang, content, sha256: None });
}
Ok(out)
}
fn upsert_file(
db: &EmbeddedDatabase,
source_table: &str,
file: &SourceFile,
sha: &str,
) -> Result<i64> {
let path_val = Value::String(file.path.clone());
let lang_val = Value::String(file.lang.clone());
let sha_val = Value::String(sha.to_string());
let st_val = Value::String(source_table.to_string());
let existing = db.query_params(
"SELECT node_id FROM _hdb_code_files \
WHERE source_table = $1 AND path = $2",
&[st_val.clone(), path_val.clone()],
)?;
if let Some(row) = existing.first() {
if let Some(v) = row.values.first() {
let id = match v {
Value::Int4(n) => *n as i64,
Value::Int8(n) => *n,
other => {
return Err(Error::query_execution(format!(
"unexpected file_id type: {other:?}"
)))
}
};
db.execute_params_returning(
"UPDATE _hdb_code_files SET lang = $1, sha256 = $2 WHERE node_id = $3",
&[lang_val, sha_val, Value::Int8(id)],
)?;
return Ok(id);
}
}
let (_, rows) = db.execute_params_returning(
"INSERT INTO _hdb_code_files (source_table, path, lang, sha256) \
VALUES ($1, $2, $3, $4) RETURNING node_id",
&[st_val, path_val, lang_val, sha_val],
)?;
if let Some(row) = rows.first() {
if let Some(v) = row.values.first() {
return match v {
Value::Int4(n) => Ok(*n as i64),
Value::Int8(n) => Ok(*n),
other => Err(Error::query_execution(format!(
"unexpected file_id type: {other:?}"
))),
};
}
}
Err(Error::query_execution("RETURNING file_id yielded no rows"))
}
fn insert_symbols(
db: &EmbeddedDatabase,
file_id: i64,
symbols: &[Symbol],
embedder: &dyn Embedder,
stats: &mut CodeIndexStats,
) -> Result<Vec<i64>> {
if symbols.is_empty() {
return Ok(Vec::new());
}
let mut vectors: Vec<Option<Vec<f32>>> = Vec::with_capacity(symbols.len());
for sym in symbols {
let v = if !sym.signature.is_empty() {
embedder.embed(&sym.signature)?
} else {
None
};
if v.is_some() {
stats.embed_calls += 1;
}
vectors.push(v);
}
let any_vec = vectors.iter().any(Option::is_some);
if any_vec {
let dim = vectors
.iter()
.find_map(|v| v.as_ref().map(|v| v.len()))
.unwrap_or(0);
if dim == 0 {
return Err(Error::query_execution(
"embedder returned a zero-length vector",
));
}
ensure_body_vec_column(db, dim)?;
for v in &vectors {
if let Some(vec) = v {
if vec.len() != dim {
return Err(Error::query_execution(format!(
"embedder dimension mismatch: expected {dim}, got {}",
vec.len()
)));
}
}
}
}
let schema = db.storage.catalog().get_table_schema("_hdb_code_symbols")?;
let n_cols = schema.columns.len();
let expected_min_cols = if any_vec { 13 } else { 12 };
if n_cols < expected_min_cols {
return Err(Error::query_execution(format!(
"_hdb_code_symbols schema has {} cols, fast path expects ≥ {}",
n_cols, expected_min_cols
)));
}
let mut tuples: Vec<Tuple> = Vec::with_capacity(symbols.len());
for (idx, sym) in symbols.iter().enumerate() {
let mut values: Vec<Value> = Vec::with_capacity(n_cols);
values.push(Value::Null); values.push(Value::Int8(file_id));
values.push(Value::String(sym.name.clone()));
values.push(Value::String(sym.qualified.clone()));
values.push(Value::String(sym.kind.as_str().to_string()));
values.push(Value::String(sym.signature.clone()));
values.push(Value::String(sym.visibility.as_str().to_string()));
values.push(Value::Int4(sym.line_start as i32));
values.push(Value::Int4(sym.line_end as i32));
values.push(Value::Int4(sym.byte_start as i32));
values.push(Value::Int4(sym.byte_end as i32));
values.push(Value::Null); if any_vec && n_cols >= 13 {
let v = vectors
.get(idx)
.and_then(|v| v.clone())
.map(Value::Vector)
.unwrap_or(Value::Null);
values.push(v);
}
while values.len() < n_cols {
values.push(Value::Null);
}
tuples.push(Tuple::new(values));
}
let row_ids = db.bulk_insert_tuples("_hdb_code_symbols", tuples)?;
Ok(row_ids.into_iter().map(|id| id as i64).collect())
}
fn insert_refs(
db: &EmbeddedDatabase,
file_id: i64,
symbol_ids: &[i64],
resolved: &[super::resolver::ResolvedRef],
) -> Result<u64> {
if resolved.is_empty() {
return Ok(0);
}
let schema = db.storage.catalog().get_table_schema("_hdb_code_symbol_refs")?;
let n_cols = schema.columns.len();
if n_cols < 8 {
return Err(Error::query_execution(format!(
"_hdb_code_symbol_refs schema has {} cols, fast path expects ≥ 8",
n_cols
)));
}
let mut tuples: Vec<Tuple> = Vec::with_capacity(resolved.len());
for r in resolved {
let from_id = symbol_ids.get(r.from_idx).copied().ok_or_else(|| {
Error::query_execution(format!(
"resolver produced invalid from_idx {}",
r.from_idx
))
})?;
let to_val = match r.to_idx {
Some(idx) => symbol_ids
.get(idx)
.map(|id| Value::Int8(*id))
.unwrap_or(Value::Null),
None => Value::Null,
};
let res = match r.resolution {
Resolution::Exact => "exact",
Resolution::Heuristic => "heuristic",
Resolution::Unresolved => "unresolved",
};
let mut values: Vec<Value> = Vec::with_capacity(n_cols);
values.push(Value::Null); values.push(Value::Int8(file_id));
values.push(Value::Int8(from_id));
values.push(to_val);
values.push(Value::String(r.to_name.clone()));
values.push(Value::String(r.kind_str.to_string()));
values.push(Value::Int4(r.line as i32));
values.push(Value::String(res.to_string()));
while values.len() < n_cols {
values.push(Value::Null);
}
tuples.push(Tuple::new(values));
}
let written = tuples.len() as u64;
db.bulk_insert_tuples("_hdb_code_symbol_refs", tuples)?;
Ok(written)
}
fn sql_text(s: &str) -> String {
format!("'{}'", s.replace('\'', "''"))
}
pub(super) fn file_path_by_id(db: &EmbeddedDatabase, file_id: i64) -> Result<Option<String>> {
let rows = db.query(
&format!("SELECT path FROM _hdb_code_files WHERE node_id = {file_id}"),
&[],
)?;
Ok(rows.first().and_then(|r| match r.values.first() {
Some(Value::String(s)) => Some(s.clone()),
_ => None,
}))
}
pub(super) fn file_id_for_symbol(
db: &EmbeddedDatabase,
symbol_id: i64,
) -> Result<Option<i64>> {
let rows = db.query(
&format!("SELECT file_id FROM _hdb_code_symbols WHERE node_id = {symbol_id}"),
&[],
)?;
Ok(rows.first().and_then(|r| r.values.first()).and_then(|v| match v {
Value::Int4(n) => Some(*n as i64),
Value::Int8(n) => Some(*n),
_ => None,
}))
}
#[allow(dead_code)]
pub(super) fn qualified_index<'a>(symbols: &'a [Symbol]) -> HashMap<&'a str, usize> {
let mut m = HashMap::new();
for (i, s) in symbols.iter().enumerate() {
m.insert(s.qualified.as_str(), i);
}
m
}