use std::collections::HashMap;
use std::path::PathBuf;
use sqlx::Row;
use crate::embedder::Embedding;
use crate::nl::normalize_for_fts;
use crate::parser::{ChunkType, Language};
use crate::store::helpers::{
bytes_to_embedding, clamp_line_number, ChunkIdentity, ChunkRow, ChunkSummary, IndexStats,
StoreError,
};
use crate::store::Store;
impl Store {
pub fn chunk_count(&self) -> Result<u64, StoreError> {
let _span = tracing::debug_span!("chunk_count").entered();
self.rt.block_on(async {
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM chunks")
.fetch_one(&self.pool)
.await?;
Ok(row.0 as u64)
})
}
pub fn stats(&self) -> Result<IndexStats, StoreError> {
let _span = tracing::debug_span!("stats").entered();
self.rt.block_on(async {
let (total_chunks, total_files): (i64, i64) = sqlx::query_as(
"SELECT
(SELECT COUNT(*) FROM chunks),
(SELECT COUNT(DISTINCT origin) FROM chunks)",
)
.fetch_one(&self.pool)
.await?;
let lang_rows: Vec<(String, i64)> =
sqlx::query_as("SELECT language, COUNT(*) FROM chunks GROUP BY language")
.fetch_all(&self.pool)
.await?;
let chunks_by_language: HashMap<Language, u64> = lang_rows
.into_iter()
.filter_map(|(lang, count)| {
lang.parse()
.map_err(|_| {
tracing::warn!(
language = %lang,
count,
"Unknown language in database, skipping in stats"
);
})
.ok()
.map(|l| (l, count as u64))
})
.collect();
let type_rows: Vec<(String, i64)> =
sqlx::query_as("SELECT chunk_type, COUNT(*) FROM chunks GROUP BY chunk_type")
.fetch_all(&self.pool)
.await?;
let chunks_by_type: HashMap<ChunkType, u64> = type_rows
.into_iter()
.filter_map(|(ct, count)| {
ct.parse()
.map_err(|_| {
tracing::warn!(
chunk_type = %ct,
count,
"Unknown chunk_type in database, skipping in stats"
);
})
.ok()
.map(|c| (c, count as u64))
})
.collect();
let metadata_rows: Vec<(String, String)> = sqlx::query_as(
"SELECT key, value FROM metadata WHERE key IN ('model_name', 'created_at', 'updated_at', 'schema_version')",
)
.fetch_all(&self.pool)
.await?;
let metadata: HashMap<String, String> = metadata_rows.into_iter().collect();
let model_name = metadata.get("model_name").cloned().unwrap_or_else(|| {
tracing::debug!("metadata key 'model_name' missing, defaulting to empty");
String::new()
});
let created_at = metadata.get("created_at").cloned().unwrap_or_else(|| {
tracing::debug!("metadata key 'created_at' missing, defaulting to empty");
String::new()
});
let updated_at = metadata
.get("updated_at")
.cloned()
.unwrap_or_else(|| created_at.clone());
let schema_version: i32 = metadata
.get("schema_version")
.and_then(|s| {
s.parse().map_err(|e| {
tracing::warn!(raw = %s, error = %e, "Failed to parse schema_version, defaulting to 0");
}).ok()
})
.unwrap_or(0);
Ok(IndexStats {
total_chunks: total_chunks as u64,
total_files: total_files as u64,
chunks_by_language,
chunks_by_type,
index_size_bytes: 0,
created_at,
updated_at,
model_name,
schema_version,
})
})
}
pub fn get_chunks_by_origin(&self, origin: &str) -> Result<Vec<ChunkSummary>, StoreError> {
let _span = tracing::debug_span!("get_chunks_by_origin", origin = %origin).entered();
self.rt.block_on(async {
let rows: Vec<_> = sqlx::query(
"SELECT id, origin, language, chunk_type, name, signature, content, doc,
line_start, line_end, parent_id, parent_type_name
FROM chunks WHERE origin = ?1
ORDER BY line_start",
)
.bind(origin)
.fetch_all(&self.pool)
.await?;
Ok(rows
.iter()
.map(|r| ChunkSummary::from(ChunkRow::from_row(r)))
.collect())
})
}
pub fn get_chunks_by_origins_batch(
&self,
origins: &[&str],
) -> Result<HashMap<String, Vec<ChunkSummary>>, StoreError> {
let _span =
tracing::debug_span!("get_chunks_by_origins_batch", count = origins.len()).entered();
if origins.is_empty() {
return Ok(HashMap::new());
}
self.rt.block_on(async {
let mut result: HashMap<String, Vec<ChunkSummary>> = HashMap::new();
const BATCH_SIZE: usize = 500;
for batch in origins.chunks(BATCH_SIZE) {
let placeholders = crate::store::helpers::make_placeholders(batch.len());
let sql = format!(
"SELECT id, origin, language, chunk_type, name, signature, content, doc,
line_start, line_end, parent_id, parent_type_name
FROM chunks WHERE origin IN ({})
ORDER BY origin, line_start",
placeholders
);
let mut query = sqlx::query(&sql);
for origin in batch {
query = query.bind(*origin);
}
let rows: Vec<_> = query.fetch_all(&self.pool).await?;
for row in &rows {
let chunk = ChunkSummary::from(ChunkRow::from_row(row));
let origin_key: String = row.get("origin");
result.entry(origin_key).or_default().push(chunk);
}
}
Ok(result)
})
}
pub fn get_chunks_by_names_batch(
&self,
names: &[&str],
) -> Result<HashMap<String, Vec<ChunkSummary>>, StoreError> {
let _span =
tracing::debug_span!("get_chunks_by_names_batch", count = names.len()).entered();
if names.is_empty() {
return Ok(HashMap::new());
}
self.rt.block_on(async {
let mut result: HashMap<String, Vec<ChunkSummary>> = HashMap::new();
const BATCH_SIZE: usize = 500;
for batch in names.chunks(BATCH_SIZE) {
let placeholders = crate::store::helpers::make_placeholders(batch.len());
let sql = format!(
"SELECT id, origin, language, chunk_type, name, signature, content, doc,
line_start, line_end, parent_id, parent_type_name
FROM chunks WHERE name IN ({})
ORDER BY origin, line_start",
placeholders
);
let rows: Vec<_> = {
let mut q = sqlx::query(&sql);
for name in batch {
q = q.bind(*name);
}
q.fetch_all(&self.pool).await?
};
for row in &rows {
let chunk = ChunkSummary::from(ChunkRow::from_row(row));
result.entry(chunk.name.clone()).or_default().push(chunk);
}
}
Ok(result)
})
}
pub fn get_chunk_with_embedding(
&self,
id: &str,
) -> Result<Option<(ChunkSummary, Embedding)>, StoreError> {
let _span = tracing::debug_span!("get_chunk_with_embedding", id = %id).entered();
let dim = self.dim;
self.rt.block_on(async {
let results = self
.fetch_chunks_with_embeddings_by_ids_async(&[id])
.await?;
Ok(results.into_iter().next().and_then(|(row, bytes)| {
match bytes_to_embedding(&bytes, dim) {
Ok(emb) => Some((ChunkSummary::from(row), Embedding::new(emb))),
Err(e) => {
tracing::warn!(chunk_id = %row.id, error = %e, "Corrupt embedding for chunk, skipping");
None
}
}
}))
})
}
pub fn get_chunks_by_ids(
&self,
ids: &[&str],
) -> Result<HashMap<String, ChunkSummary>, StoreError> {
let _span = tracing::debug_span!("get_chunks_by_ids", count = ids.len()).entered();
self.rt.block_on(async {
let rows = self.fetch_chunks_by_ids_async(ids).await?;
Ok(rows
.into_iter()
.map(|(id, row)| (id, ChunkSummary::from(row)))
.collect())
})
}
pub fn get_embeddings_by_ids(
&self,
ids: &[&str],
) -> Result<HashMap<String, Embedding>, StoreError> {
let _span = tracing::debug_span!("get_embeddings_by_ids", count = ids.len()).entered();
if ids.is_empty() {
return Ok(HashMap::new());
}
const BATCH_SIZE: usize = 500;
let dim = self.dim;
let mut result = HashMap::new();
self.rt.block_on(async {
for batch in ids.chunks(BATCH_SIZE) {
let placeholders = crate::store::helpers::make_placeholders(batch.len());
let sql = format!(
"SELECT id, embedding FROM chunks WHERE id IN ({})",
placeholders
);
let rows: Vec<_> = {
let mut q = sqlx::query(&sql);
for id in batch {
q = q.bind(*id);
}
q.fetch_all(&self.pool).await?
};
for row in rows {
let id: String = row.get(0);
let bytes: Vec<u8> = row.get(1);
match bytes_to_embedding(&bytes, dim) {
Ok(emb) => {
result.insert(id, Embedding::new(emb));
}
Err(e) => {
tracing::trace!(chunk_id = %id, error = %e, "Skipping embedding");
}
}
}
}
Ok(result)
})
}
pub fn search_by_names_batch(
&self,
names: &[&str],
limit_per_name: usize,
) -> Result<HashMap<String, Vec<crate::store::SearchResult>>, StoreError> {
let _span =
tracing::info_span!("search_by_names_batch", count = names.len(), limit_per_name)
.entered();
if names.is_empty() {
return Ok(HashMap::new());
}
self.rt.block_on(async {
let mut result: HashMap<String, Vec<crate::store::SearchResult>> = HashMap::new();
let normalized_names: Vec<(&str, String)> = names
.iter()
.map(|n| (*n, crate::store::sanitize_fts_query(&normalize_for_fts(n))))
.filter(|(_, norm)| !norm.is_empty())
.collect();
const BATCH_SIZE: usize = 20;
for batch in normalized_names.chunks(BATCH_SIZE) {
let fts_terms: Vec<String> = batch
.iter()
.filter_map(|(_, norm)| {
debug_assert!(
!norm.contains('"'),
"sanitized query must not contain double quotes"
);
if norm.contains('"') {
return None;
}
Some(format!("name:\"{}\" OR name:\"{}\"*", norm, norm))
})
.collect();
let combined_fts = fts_terms.join(" OR ");
let total_limit = limit_per_name * batch.len();
let light_rows: Vec<_> = sqlx::query(
"SELECT c.id, c.name
FROM chunks c
JOIN chunks_fts f ON c.id = f.id
WHERE chunks_fts MATCH ?1
ORDER BY bm25(chunks_fts, 10.0, 1.0, 1.0, 1.0)
LIMIT ?2",
)
.bind(&combined_fts)
.bind(total_limit as i64)
.fetch_all(&self.pool)
.await?;
let mut matched: Vec<(String, String, f32)> = Vec::new();
let mut ids_to_fetch: Vec<String> = Vec::new();
for row in &light_rows {
let id: String = row.get("id");
let chunk_name: String = row.get("name");
for (original_name, _normalized) in batch {
let score = crate::store::score_name_match(&chunk_name, original_name);
if score > 0.0 {
let entry = result.entry(original_name.to_string()).or_default();
if entry.len() < limit_per_name {
ids_to_fetch.push(id.clone());
matched.push((id.clone(), original_name.to_string(), score));
}
break;
}
}
}
if ids_to_fetch.is_empty() {
continue;
}
let id_refs: Vec<&str> = ids_to_fetch.iter().map(|s| s.as_str()).collect();
let full_chunks = self.fetch_chunks_by_ids_async(&id_refs).await?;
for (id, query_name, score) in matched {
if let Some(chunk_row) = full_chunks.get(&id) {
let entry = result.entry(query_name).or_default();
if entry.len() < limit_per_name {
entry.push(crate::store::SearchResult {
chunk: ChunkSummary::from(chunk_row.clone()),
score,
});
}
}
}
}
Ok(result)
})
}
pub fn all_chunk_identities(&self) -> Result<Vec<ChunkIdentity>, StoreError> {
let _span = tracing::debug_span!("all_chunk_identities").entered();
self.all_chunk_identities_filtered(None)
}
pub fn chunk_type_language_map(
&self,
) -> Result<std::sync::Arc<crate::store::ChunkTypeMap>, StoreError> {
if let Some(cached) = self.chunk_type_map_cache.get() {
return Ok(std::sync::Arc::clone(cached));
}
let _span = tracing::debug_span!("chunk_type_language_map").entered();
let map = self.rt.block_on(async {
let rows: Vec<_> = sqlx::query("SELECT id, chunk_type, language FROM chunks")
.fetch_all(&self.pool)
.await?;
let mut map = HashMap::with_capacity(rows.len());
for row in &rows {
let id: String = row.get("id");
let ct: String = row.get("chunk_type");
let lang: String = row.get("language");
match (ct.parse(), lang.parse()) {
(Ok(chunk_type), Ok(language)) => {
map.insert(id, (chunk_type, language));
}
(ct_result, lang_result) => {
tracing::warn!(
chunk_id = %id,
chunk_type = %ct,
language = %lang,
ct_err = ?ct_result.err(),
lang_err = ?lang_result.err(),
"Skipping chunk with unparseable chunk_type or language"
);
}
}
}
Ok::<_, StoreError>(map)
})?;
let arc = std::sync::Arc::new(map);
let _ = self.chunk_type_map_cache.set(std::sync::Arc::clone(&arc));
Ok(arc)
}
pub fn chunks_paged(
&self,
after_rowid: i64,
limit: usize,
) -> Result<(Vec<ChunkSummary>, i64), StoreError> {
let _span = tracing::debug_span!("chunks_paged", after_rowid, limit).entered();
self.rt.block_on(async {
let rows: Vec<_> = sqlx::query(
"SELECT rowid, id, origin, language, chunk_type, name, signature, content, doc, \
line_start, line_end, content_hash, window_idx, parent_id, parent_type_name \
FROM chunks WHERE rowid > ?1 ORDER BY rowid ASC LIMIT ?2",
)
.bind(after_rowid)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
let mut max_rowid = after_rowid;
let chunks: Vec<ChunkSummary> = rows
.iter()
.map(|row| {
let rowid: i64 = row.get("rowid");
if rowid > max_rowid {
max_rowid = rowid;
}
ChunkSummary::from(ChunkRow::from_row(row))
})
.collect();
Ok((chunks, max_rowid))
})
}
pub fn all_chunk_identities_filtered(
&self,
language: Option<&str>,
) -> Result<Vec<ChunkIdentity>, StoreError> {
let _span =
tracing::debug_span!("all_chunk_identities_filtered", language = ?language).entered();
self.rt.block_on(async {
let rows: Vec<_> = if let Some(lang) = language {
sqlx::query(
"SELECT id, origin, name, chunk_type, language, line_start, parent_id, window_idx FROM chunks WHERE language = ?1",
)
.bind(lang)
.fetch_all(&self.pool)
.await?
} else {
sqlx::query(
"SELECT id, origin, name, chunk_type, language, line_start, parent_id, window_idx FROM chunks",
)
.fetch_all(&self.pool)
.await?
};
Ok(rows
.iter()
.map(|row| ChunkIdentity {
id: row.get("id"),
file: PathBuf::from(row.get::<String, _>("origin")),
name: row.get("name"),
chunk_type: {
let raw: String = row.get("chunk_type");
raw.parse().unwrap_or_else(|_| {
tracing::warn!(raw = %raw, "Unknown chunk_type in DB, defaulting to Function");
ChunkType::Function
})
},
line_start: clamp_line_number(row.get::<i64, _>("line_start")),
language: {
let raw: String = row.get("language");
raw.parse().unwrap_or_else(|_| {
tracing::warn!(raw = %raw, "Unknown language in DB, defaulting to Rust");
Language::Rust
})
},
parent_id: row.get("parent_id"),
window_idx: row
.get::<Option<i64>, _>("window_idx")
.map(|i| i.clamp(0, u32::MAX as i64) as u32),
})
.collect())
})
}
}
#[cfg(test)]
mod tests {
use super::super::test_utils::make_chunk;
use crate::parser::Language;
use crate::test_helpers::{mock_embedding, setup_store};
#[test]
fn test_all_chunk_identities_filtered_by_language() {
let (store, _dir) = setup_store();
let mut rust_chunk = make_chunk("rs_fn", "src/lib.rs");
rust_chunk.language = Language::Rust;
let mut py_chunk = make_chunk("py_fn", "src/main.py");
py_chunk.language = Language::Python;
py_chunk.id = format!("src/main.py:1:{}", &py_chunk.content_hash[..8]);
let emb = mock_embedding(1.0);
store
.upsert_chunks_batch(
&[(rust_chunk, emb.clone()), (py_chunk, emb.clone())],
Some(100),
)
.unwrap();
let identities = store.all_chunk_identities_filtered(Some("rust")).unwrap();
assert_eq!(identities.len(), 1);
assert_eq!(identities[0].language, Language::Rust);
let identities = store.all_chunk_identities_filtered(Some("python")).unwrap();
assert_eq!(identities.len(), 1);
assert_eq!(identities[0].language, Language::Python);
let identities = store.all_chunk_identities_filtered(None).unwrap();
assert_eq!(identities.len(), 2);
}
#[test]
fn test_get_chunks_by_origin_sorted_by_line() {
let (store, _dir) = setup_store();
let mut c1 = make_chunk("fn_late", "src/lib.rs");
c1.line_start = 50;
c1.line_end = 60;
let mut c2 = make_chunk("fn_early", "src/lib.rs");
c2.line_start = 1;
c2.line_end = 10;
c2.id = format!("src/lib.rs:1:{}", &c2.content_hash[..8]);
let emb = mock_embedding(1.0);
store
.upsert_chunks_batch(&[(c1, emb.clone()), (c2, emb.clone())], Some(100))
.unwrap();
let chunks = store.get_chunks_by_origin("src/lib.rs").unwrap();
assert_eq!(chunks.len(), 2);
assert!(
chunks[0].line_start <= chunks[1].line_start,
"Chunks should be sorted by line_start"
);
}
#[test]
fn test_get_chunks_by_origin_empty() {
let (store, _dir) = setup_store();
let chunks = store.get_chunks_by_origin("nonexistent.rs").unwrap();
assert!(chunks.is_empty());
}
#[test]
fn test_chunks_paged_empty() {
let (store, _dir) = setup_store();
let (chunks, max_rowid) = store.chunks_paged(0, 10).unwrap();
assert!(chunks.is_empty());
assert_eq!(max_rowid, 0);
}
#[test]
fn test_chunks_paged_single_page() {
let (store, _dir) = setup_store();
let pairs: Vec<_> = (0..3)
.map(|i| {
let c = make_chunk(&format!("fn_{}", i), &format!("src/{}.rs", i));
(c, mock_embedding(i as f32))
})
.collect();
store.upsert_chunks_batch(&pairs, Some(100)).unwrap();
let (chunks, max_rowid) = store.chunks_paged(0, 10).unwrap();
assert_eq!(chunks.len(), 3);
assert!(max_rowid > 0);
}
#[test]
fn test_chunks_paged_multi_page() {
let (store, _dir) = setup_store();
let pairs: Vec<_> = (0..5)
.map(|i| {
let c = make_chunk(&format!("fn_{}", i), &format!("src/{}.rs", i));
(c, mock_embedding(i as f32))
})
.collect();
store.upsert_chunks_batch(&pairs, Some(100)).unwrap();
let (page1, cursor1) = store.chunks_paged(0, 2).unwrap();
assert_eq!(page1.len(), 2);
assert!(cursor1 > 0);
let (page2, cursor2) = store.chunks_paged(cursor1, 2).unwrap();
assert_eq!(page2.len(), 2);
assert!(cursor2 > cursor1);
let (page3, _cursor3) = store.chunks_paged(cursor2, 2).unwrap();
assert_eq!(page3.len(), 1);
assert_eq!(page1.len() + page2.len() + page3.len(), 5);
}
#[test]
fn test_chunks_paged_exact_boundary() {
let (store, _dir) = setup_store();
let pairs: Vec<_> = (0..4)
.map(|i| {
let c = make_chunk(&format!("fn_{}", i), &format!("src/{}.rs", i));
(c, mock_embedding(i as f32))
})
.collect();
store.upsert_chunks_batch(&pairs, Some(100)).unwrap();
let (page1, cursor1) = store.chunks_paged(0, 4).unwrap();
assert_eq!(page1.len(), 4);
let (page2, cursor2) = store.chunks_paged(cursor1, 4).unwrap();
assert!(page2.is_empty());
assert_eq!(cursor2, cursor1);
}
}