#![allow(clippy::await_holding_lock)]
use super::Store;
use crate::splade::SparseVector;
use crate::store::StoreError;
use sqlx::Row;
pub(crate) async fn bump_splade_generation_tx(
tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
) -> Result<(), StoreError> {
let current: Option<(String,)> =
sqlx::query_as("SELECT value FROM metadata WHERE key = 'splade_generation'")
.fetch_optional(&mut **tx)
.await?;
let next: u64 = match current {
Some((ref s,)) => match s.parse::<u64>() {
Ok(n) => n.saturating_add(1),
Err(e) => {
tracing::warn!(
raw = %s,
error = %e,
"splade_generation metadata is not a valid u64, resetting to 1",
);
1
}
},
None => 1,
};
sqlx::query(
"INSERT INTO metadata (key, value) VALUES ('splade_generation', ?1) \
ON CONFLICT(key) DO UPDATE SET value = excluded.value",
)
.bind(next.to_string())
.execute(&mut **tx)
.await?;
Ok(())
}
impl Store {
pub fn upsert_sparse_vectors(
&self,
vectors: &[(String, SparseVector)],
) -> Result<usize, StoreError> {
let _span = tracing::info_span!("upsert_sparse_vectors", count = vectors.len()).entered();
if vectors.is_empty() {
return Ok(0);
}
self.rt.block_on(async {
let (_guard, mut tx) = self.begin_write().await?;
let mut total = 0usize;
tracing::debug!("Dropping idx_sparse_token before bulk insert");
sqlx::query("DROP INDEX IF EXISTS idx_sparse_token")
.execute(&mut *tx)
.await?;
use crate::store::helpers::sql::max_rows_per_statement;
const DELETE_BATCH: usize = max_rows_per_statement(1);
let chunk_ids: Vec<&str> = vectors.iter().map(|(id, _)| id.as_str()).collect();
for batch in chunk_ids.chunks(DELETE_BATCH) {
let mut qb: sqlx::QueryBuilder<sqlx::Sqlite> =
sqlx::QueryBuilder::new("DELETE FROM sparse_vectors WHERE chunk_id IN (");
let mut sep = qb.separated(", ");
for id in batch {
sep.push_bind(*id);
}
sep.push_unseparated(")");
qb.build().execute(&mut *tx).await?;
}
const ROWS_PER_INSERT: usize = max_rows_per_statement(3);
let mut pending: Vec<(&str, u32, f32)> = Vec::with_capacity(ROWS_PER_INSERT);
for (chunk_id, sparse) in vectors {
for &(token_id, weight) in sparse {
pending.push((chunk_id.as_str(), token_id, weight));
if pending.len() >= ROWS_PER_INSERT {
let mut qb: sqlx::QueryBuilder<sqlx::Sqlite> = sqlx::QueryBuilder::new(
"INSERT INTO sparse_vectors (chunk_id, token_id, weight)",
);
qb.push_values(pending.iter(), |mut b, &(cid, tid, w)| {
b.push_bind(cid).push_bind(tid as i64).push_bind(w);
});
qb.build().execute(&mut *tx).await?;
total += pending.len();
pending.clear();
}
}
}
if !pending.is_empty() {
let mut qb: sqlx::QueryBuilder<sqlx::Sqlite> = sqlx::QueryBuilder::new(
"INSERT INTO sparse_vectors (chunk_id, token_id, weight)",
);
qb.push_values(pending.iter(), |mut b, &(cid, tid, w)| {
b.push_bind(cid).push_bind(tid as i64).push_bind(w);
});
qb.build().execute(&mut *tx).await?;
total += pending.len();
}
tracing::debug!("Recreating idx_sparse_token after bulk insert");
sqlx::query("CREATE INDEX IF NOT EXISTS idx_sparse_token ON sparse_vectors(token_id)")
.execute(&mut *tx)
.await?;
bump_splade_generation_tx(&mut tx).await?;
tx.commit().await?;
tracing::info!(
entries = total,
chunks = vectors.len(),
"Sparse vectors upserted"
);
Ok(total)
})
}
pub fn load_all_sparse_vectors(&self) -> Result<Vec<(String, SparseVector)>, StoreError> {
let _span = tracing::info_span!("load_all_sparse_vectors").entered();
self.rt.block_on(async {
let rows: Vec<_> = sqlx::query(
"SELECT chunk_id, token_id, weight FROM sparse_vectors ORDER BY chunk_id",
)
.fetch_all(&self.pool)
.await?;
let mut result: Vec<(String, SparseVector)> = Vec::new();
let mut current_id: Option<String> = None;
let mut current_vec: SparseVector = Vec::new();
for row in &rows {
let chunk_id: String = row.get("chunk_id");
let token_id: i64 = row.get("token_id");
let weight: f64 = row.get("weight");
if current_id.as_ref() != Some(&chunk_id) {
if let Some(id) = current_id.take() {
result.push((id, std::mem::take(&mut current_vec)));
}
current_id = Some(chunk_id);
}
if token_id < 0 || token_id > u32::MAX as i64 {
tracing::warn!(token_id, chunk_id = %current_id.as_deref().unwrap_or("?"), "Invalid token_id, skipping");
continue;
}
current_vec.push((token_id as u32, weight as f32));
}
if let Some(id) = current_id {
result.push((id, current_vec));
}
tracing::info!(
chunks = result.len(),
total_entries = rows.len(),
"Sparse vectors loaded"
);
Ok(result)
})
}
pub fn chunk_splade_texts(&self) -> Result<Vec<(String, String)>, StoreError> {
self.chunk_splade_texts_query("SELECT id, name, signature, doc FROM chunks")
}
pub fn chunk_splade_texts_missing(&self) -> Result<Vec<(String, String)>, StoreError> {
self.chunk_splade_texts_query(
"SELECT c.id, c.name, c.signature, c.doc FROM chunks c \
WHERE c.id NOT IN (SELECT DISTINCT chunk_id FROM sparse_vectors)",
)
}
fn chunk_splade_texts_query(&self, sql: &str) -> Result<Vec<(String, String)>, StoreError> {
let _span = tracing::info_span!("chunk_splade_texts").entered();
self.rt.block_on(async {
let rows: Vec<_> = sqlx::query(sql).fetch_all(&self.pool).await?;
let result: Vec<(String, String)> = rows
.iter()
.map(|row| {
let id: String = row.get("id");
let name: String = row.get("name");
let sig: String = row.get("signature");
let doc: Option<String> = row.get("doc");
let text = match doc {
Some(d) if !d.is_empty() => format!("{} {} {}", name, sig, d),
_ => format!("{} {}", name, sig),
};
(id, text)
})
.collect();
tracing::info!(chunks = result.len(), "Loaded chunk texts for SPLADE");
Ok(result)
})
}
pub fn prune_orphan_sparse_vectors(&self) -> Result<usize, StoreError> {
let _span = tracing::debug_span!("prune_orphan_sparse_vectors").entered();
self.rt.block_on(async {
let (_guard, mut tx) = self.begin_write().await?;
let result = sqlx::query(
"DELETE FROM sparse_vectors WHERE chunk_id NOT IN \
(SELECT DISTINCT id FROM chunks)",
)
.execute(&mut *tx)
.await?;
let affected = result.rows_affected();
if affected > 0 {
bump_splade_generation_tx(&mut tx).await?;
tracing::warn!(
rows = affected,
"prune_orphan_sparse_vectors deleted rows that should have been caught by \
the v19 FK cascade — either FK enforcement is disabled or this database was \
manipulated directly. Investigate."
);
}
tx.commit().await?;
Ok(affected as usize)
})
}
pub fn splade_generation(&self) -> Result<u64, StoreError> {
let _span = tracing::debug_span!("splade_generation").entered();
self.rt.block_on(async {
let row: Option<(String,)> =
sqlx::query_as("SELECT value FROM metadata WHERE key = 'splade_generation'")
.fetch_optional(&self.pool)
.await?;
Ok(match row {
Some((s,)) => match s.parse::<u64>() {
Ok(n) => n,
Err(e) => {
tracing::warn!(
raw = %s,
error = %e,
"splade_generation metadata is not a valid u64, treating as 0 — \
next SPLADE load will rebuild from SQLite"
);
0
}
},
None => 0,
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn insert_test_chunk(store: &Store, id: &str) {
store.rt.block_on(async {
let embedding = crate::embedder::Embedding::new(vec![0.0f32; crate::EMBEDDING_DIM]);
let embedding_bytes =
crate::store::helpers::embedding_to_bytes(&embedding, crate::EMBEDDING_DIM)
.unwrap();
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO chunks (id, origin, source_type, language, chunk_type, name,
signature, content, content_hash, doc, line_start, line_end, embedding,
source_mtime, created_at, updated_at)
VALUES (?1, ?1, 'file', 'rust', 'function', ?1,
'', '', '', NULL, 1, 10, ?2, 0, ?3, ?3)",
)
.bind(id)
.bind(&embedding_bytes)
.bind(&now)
.execute(&store.pool)
.await
.unwrap();
});
}
fn setup_store() -> (Store, tempfile::TempDir) {
let dir = tempfile::TempDir::new().unwrap();
let db_path = dir.path().join("test.db");
let store = Store::open(&db_path).unwrap();
store.init(&crate::store::ModelInfo::default()).unwrap();
(store, dir)
}
fn insert_chunk_with_fields(
store: &Store,
id: &str,
name: &str,
signature: &str,
doc: Option<&str>,
) {
store.rt.block_on(async {
let embedding = crate::embedder::Embedding::new(vec![0.0f32; crate::EMBEDDING_DIM]);
let embedding_bytes =
crate::store::helpers::embedding_to_bytes(&embedding, crate::EMBEDDING_DIM)
.unwrap();
let now = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO chunks (id, origin, source_type, language, chunk_type, name,
signature, content, content_hash, doc, line_start, line_end, embedding,
source_mtime, created_at, updated_at)
VALUES (?1, ?1, 'file', 'rust', 'function', ?2,
?3, '', '', ?4, 1, 10, ?5, 0, ?6, ?6)",
)
.bind(id)
.bind(name)
.bind(signature)
.bind(doc)
.bind(&embedding_bytes)
.bind(&now)
.execute(&store.pool)
.await
.unwrap();
});
}
#[test]
fn test_sparse_roundtrip() {
let (store, _dir) = setup_store();
insert_test_chunk(&store, "chunk_a");
insert_test_chunk(&store, "chunk_b");
let vectors = vec![
(
"chunk_a".to_string(),
vec![(1u32, 0.5f32), (2, 0.3), (3, 0.8)],
),
("chunk_b".to_string(), vec![(1, 0.7), (4, 0.6)]),
];
let entries = store.upsert_sparse_vectors(&vectors).unwrap();
assert_eq!(entries, 5);
let loaded = store.load_all_sparse_vectors().unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].0, "chunk_a");
assert_eq!(loaded[0].1.len(), 3);
assert_eq!(loaded[1].0, "chunk_b");
assert_eq!(loaded[1].1.len(), 2);
}
#[test]
fn test_sparse_upsert_replaces() {
let (store, _dir) = setup_store();
insert_test_chunk(&store, "chunk_a");
let v1 = vec![("chunk_a".to_string(), vec![(1u32, 0.5f32)])];
store.upsert_sparse_vectors(&v1).unwrap();
let v2 = vec![("chunk_a".to_string(), vec![(2u32, 0.9f32), (3, 0.1)])];
store.upsert_sparse_vectors(&v2).unwrap();
let loaded = store.load_all_sparse_vectors().unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded[0].1.len(), 2); assert_eq!(loaded[0].1[0].0, 2); }
#[test]
fn test_sparse_empty() {
let (store, _dir) = setup_store();
let loaded = store.load_all_sparse_vectors().unwrap();
assert!(loaded.is_empty());
}
#[test]
fn test_splade_generation_starts_at_zero_and_is_monotonic() {
let (store, _dir) = setup_store();
insert_test_chunk(&store, "c1");
assert_eq!(store.splade_generation().unwrap(), 0);
store
.upsert_sparse_vectors(&[("c1".to_string(), vec![(1u32, 0.5f32)])])
.unwrap();
let after_upsert = store.splade_generation().unwrap();
assert!(after_upsert >= 1, "upsert must bump generation");
store
.upsert_sparse_vectors(&[("c1".to_string(), vec![(2u32, 0.5f32)])])
.unwrap();
let after_second = store.splade_generation().unwrap();
assert!(
after_second > after_upsert,
"second upsert must bump generation strictly"
);
}
#[test]
fn test_prune_orphan_no_rows_does_not_bump_generation() {
let (store, _dir) = setup_store();
insert_test_chunk(&store, "c1");
store
.upsert_sparse_vectors(&[("c1".to_string(), vec![(1u32, 0.5f32)])])
.unwrap();
let before = store.splade_generation().unwrap();
let pruned = store.prune_orphan_sparse_vectors().unwrap();
assert_eq!(pruned, 0, "v19 FK cascade should leave zero orphans");
let after = store.splade_generation().unwrap();
assert_eq!(
before, after,
"zero-orphan prune must NOT bump the generation"
);
}
#[test]
fn test_fk_cascade_removes_sparse_rows_on_chunk_delete() {
let (store, _dir) = setup_store();
insert_test_chunk(&store, "c1");
insert_test_chunk(&store, "c2");
store
.upsert_sparse_vectors(&[
("c1".to_string(), vec![(1u32, 0.5f32), (2, 0.3)]),
("c2".to_string(), vec![(3u32, 0.7f32)]),
])
.unwrap();
store.rt.block_on(async {
sqlx::query("DELETE FROM chunks WHERE id = 'c1'")
.execute(&store.pool)
.await
.unwrap();
});
let loaded = store.load_all_sparse_vectors().unwrap();
assert_eq!(
loaded.len(),
1,
"cascade should have dropped c1's sparse rows"
);
assert_eq!(loaded[0].0, "c2");
}
#[test]
fn test_chunk_splade_texts_with_doc() {
let (store, _dir) = setup_store();
insert_chunk_with_fields(&store, "c1", "foo", "fn foo()", Some("does things"));
let texts = store.chunk_splade_texts().unwrap();
assert_eq!(texts.len(), 1);
assert_eq!(texts[0].1, "foo fn foo() does things");
}
#[test]
fn test_chunk_splade_texts_no_doc() {
let (store, _dir) = setup_store();
insert_chunk_with_fields(&store, "c1", "foo", "fn foo()", None);
let texts = store.chunk_splade_texts().unwrap();
assert_eq!(texts.len(), 1);
assert_eq!(texts[0].1, "foo fn foo()");
}
#[test]
fn test_chunk_splade_texts_empty_doc_treated_as_missing() {
let (store, _dir) = setup_store();
insert_chunk_with_fields(&store, "c1", "foo", "fn foo()", Some(""));
let texts = store.chunk_splade_texts().unwrap();
assert_eq!(texts.len(), 1);
assert_eq!(
texts[0].1, "foo fn foo()",
"empty doc should not be appended"
);
}
#[test]
fn test_chunk_splade_texts_missing_skips_encoded() {
let (store, _dir) = setup_store();
insert_chunk_with_fields(&store, "c1", "alpha", "fn alpha()", None);
insert_chunk_with_fields(&store, "c2", "beta", "fn beta()", None);
insert_chunk_with_fields(&store, "c3", "gamma", "fn gamma()", None);
store
.upsert_sparse_vectors(&[
("c1".to_string(), vec![(1u32, 0.5f32)]),
("c3".to_string(), vec![(2u32, 0.7f32)]),
])
.unwrap();
let missing = store.chunk_splade_texts_missing().unwrap();
assert_eq!(missing.len(), 1, "only the un-encoded chunk should appear");
assert_eq!(missing[0].0, "c2");
assert_eq!(missing[0].1, "beta fn beta()");
}
#[test]
fn test_load_all_sparse_vectors_groups_multiple_tokens() {
let (store, _dir) = setup_store();
insert_test_chunk(&store, "c1");
insert_test_chunk(&store, "c2");
insert_test_chunk(&store, "c3");
store
.upsert_sparse_vectors(&[
(
"c1".to_string(),
vec![(10u32, 0.1f32), (11, 0.2), (12, 0.3)],
),
(
"c2".to_string(),
vec![(20u32, 0.4f32), (21, 0.5), (22, 0.6), (23, 0.7)],
),
(
"c3".to_string(),
vec![(30u32, 0.8f32), (31, 0.9), (32, 1.0)],
),
])
.unwrap();
let loaded = store.load_all_sparse_vectors().unwrap();
assert_eq!(loaded.len(), 3);
let c1 = loaded.iter().find(|(id, _)| id == "c1").unwrap();
let c2 = loaded.iter().find(|(id, _)| id == "c2").unwrap();
let c3 = loaded.iter().find(|(id, _)| id == "c3").unwrap();
assert_eq!(c1.1.len(), 3, "c1 should have 3 tokens");
assert_eq!(c2.1.len(), 4, "c2 should have 4 tokens");
assert_eq!(c3.1.len(), 3, "c3 should have 3 tokens");
}
}