use rusqlite::{Connection, params};
use solo_core::{Error, Result, VectorIndex};
#[derive(Debug, Clone)]
pub struct ReplayReport {
pub rows_seen: usize,
pub rows_replayed: usize,
pub rows_failed: usize,
}
#[derive(Debug, Clone)]
pub struct DriftReport {
pub hot_episodes: usize,
pub index_len: usize,
pub diff: i64,
}
impl DriftReport {
pub fn is_clean(&self) -> bool {
self.diff == 0
}
}
pub fn replay_pending_index(
conn: &mut Connection,
hnsw: &dyn VectorIndex,
) -> Result<ReplayReport> {
let mut report = ReplayReport {
rows_seen: 0,
rows_replayed: 0,
rows_failed: 0,
};
let rows: Vec<(String, i64, Vec<u8>, i64)> = {
let mut stmt = conn
.prepare(
"SELECT p.memory_id, e.rowid, p.embedding, p.embedding_dim
FROM pending_index p
JOIN episodes e ON e.memory_id = p.memory_id
ORDER BY p.enqueued_at",
)
.map_err(|e| Error::storage(format!("prepare pending_index select: {e}")))?;
let mapped = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, Vec<u8>>(2)?,
row.get::<_, i64>(3)?,
))
})
.map_err(|e| Error::storage(format!("query_map pending_index: {e}")))?;
let mut out = Vec::new();
for r in mapped {
out.push(r.map_err(|e| Error::storage(format!("row decode: {e}")))?);
}
out
};
for (memory_id, rowid, blob, dim) in rows {
report.rows_seen += 1;
let dim = dim as usize;
if blob.len() != dim * 4 {
tracing::warn!(
%memory_id,
blob_len = blob.len(),
expected = dim * 4,
"pending_index row size mismatch (not F32×dim); skipping"
);
report.rows_failed += 1;
continue;
}
let slice: &[f32] = match bytemuck::try_cast_slice::<u8, f32>(&blob) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
%memory_id,
error = %e,
"pending_index blob alignment cast failed; skipping"
);
report.rows_failed += 1;
continue;
}
};
if let Err(e) = hnsw.add(rowid, slice) {
tracing::warn!(%memory_id, error = %e, "hnsw.add during replay failed");
report.rows_failed += 1;
continue;
}
match conn.execute(
"DELETE FROM pending_index WHERE memory_id = ?",
params![memory_id],
) {
Ok(_) => report.rows_replayed += 1,
Err(e) => {
tracing::warn!(%memory_id, error = %e, "drain after replay failed");
report.rows_failed += 1;
}
}
}
tracing::info!(
seen = report.rows_seen,
replayed = report.rows_replayed,
failed = report.rows_failed,
"pending_index replay complete"
);
Ok(report)
}
#[derive(Debug, Clone, Default)]
pub struct RebuildReport {
pub rows_seen: usize,
pub rows_added: usize,
pub rows_skipped: usize,
}
pub fn rebuild_hnsw_from_sql(
conn: &Connection,
hnsw: &dyn VectorIndex,
current_embedder_id: i64,
) -> Result<RebuildReport> {
let mut stmt = conn
.prepare(
"SELECT e.rowid, em.vector, em.dim
FROM episodes e
JOIN embeddings em ON em.memory_id = e.memory_id
WHERE em.embedder_id = ?1
AND e.status = 'active'
ORDER BY e.rowid",
)
.map_err(|e| Error::storage(format!("prepare rebuild_hnsw_from_sql: {e}")))?;
let rows = stmt
.query_map(rusqlite::params![current_embedder_id], |r| {
Ok((
r.get::<_, i64>(0)?,
r.get::<_, Vec<u8>>(1)?,
r.get::<_, i64>(2)?,
))
})
.map_err(|e| Error::storage(format!("query_map rebuild_hnsw_from_sql: {e}")))?;
let mut report = RebuildReport::default();
for row in rows {
report.rows_seen += 1;
let (rowid, blob, dim) = match row {
Ok(r) => r,
Err(e) => {
tracing::warn!(error = %e, "rebuild_hnsw_from_sql: row decode failed; skipping");
report.rows_skipped += 1;
continue;
}
};
let dim = dim as usize;
if blob.len() != dim * 4 {
tracing::warn!(
rowid,
blob_len = blob.len(),
expected = dim * 4,
"rebuild_hnsw_from_sql: f32-vector size mismatch; skipping (run `solo reembed` to overwrite)"
);
report.rows_skipped += 1;
continue;
}
let slice: &[f32] = match bytemuck::try_cast_slice(&blob) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
rowid,
error = %e,
"rebuild_hnsw_from_sql: blob alignment cast failed; skipping"
);
report.rows_skipped += 1;
continue;
}
};
if let Err(e) = hnsw.add(rowid, slice) {
tracing::warn!(rowid, error = %e, "rebuild_hnsw_from_sql: hnsw.add failed; skipping");
report.rows_skipped += 1;
continue;
}
report.rows_added += 1;
}
Ok(report)
}
pub fn detect_drift(conn: &Connection, hnsw: &dyn VectorIndex) -> Result<DriftReport> {
let hot_episodes: i64 = conn
.query_row(
"SELECT COUNT(*) FROM episodes WHERE tier = 'hot' AND status = 'active'",
[],
|r| r.get(0),
)
.map_err(|e| Error::storage(format!("count hot episodes: {e}")))?;
let index_len = hnsw.len();
let diff = hot_episodes - (index_len as i64);
Ok(DriftReport {
hot_episodes: hot_episodes as usize,
index_len,
diff,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::{StubVectorIndex, fixture_episode, open_test_db};
use rusqlite::params;
use solo_core::{Tier, VectorIndex};
fn insert_episode(conn: &Connection, content: &str) -> (String, i64) {
let ep = fixture_episode(content);
let memory_id = ep.memory_id.to_string();
let now_ms = chrono::Utc::now().timestamp_millis();
let tier = match ep.tier {
Tier::Hot => "hot",
Tier::Warm => "warm",
Tier::Cold => "cold",
};
conn.execute(
"INSERT INTO episodes (
memory_id, ts_ms, source_type, source_id, content,
encoding_context_json, provenance_json, confidence,
strength, salience, tier, created_at_ms, updated_at_ms
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
params![
memory_id,
ep.ts_ms,
ep.source_type,
ep.source_id,
ep.content,
"{}",
Option::<String>::None,
ep.confidence.0,
ep.strength,
ep.salience,
tier,
now_ms,
now_ms,
],
)
.unwrap();
let rowid = conn.last_insert_rowid();
(memory_id, rowid)
}
fn enqueue_pending(conn: &Connection, memory_id: &str, dim: usize) {
let zeros = vec![0u8; dim * 4];
conn.execute(
"INSERT INTO pending_index (memory_id, embedding, embedding_dim, enqueued_at)
VALUES (?, ?, ?, ?)",
params![memory_id, &zeros[..], dim as i64, 0i64],
)
.unwrap();
}
#[test]
fn replay_drains_all_rows_and_calls_add() {
let (mut conn, _tmp) = open_test_db();
let (mid_a, rowid_a) = insert_episode(&conn, "a");
let (mid_b, rowid_b) = insert_episode(&conn, "b");
enqueue_pending(&conn, &mid_a, 4);
enqueue_pending(&conn, &mid_b, 4);
let stub = StubVectorIndex::new(4);
let report = replay_pending_index(&mut conn, &stub).unwrap();
assert_eq!(report.rows_seen, 2);
assert_eq!(report.rows_replayed, 2);
assert_eq!(report.rows_failed, 0);
assert_eq!(stub.add_count(), 2);
let n: i64 = conn
.query_row("SELECT COUNT(*) FROM pending_index", [], |r| r.get(0))
.unwrap();
assert_eq!(n, 0);
let entries = stub.entries();
let added_rowids: std::collections::HashSet<i64> =
entries.iter().map(|(r, _)| *r).collect();
let expected: std::collections::HashSet<i64> =
[rowid_a, rowid_b].into_iter().collect();
assert_eq!(added_rowids, expected);
}
#[test]
fn replay_is_idempotent_when_run_twice() {
let (mut conn, _tmp) = open_test_db();
let (mid, _rowid) = insert_episode(&conn, "x");
enqueue_pending(&conn, &mid, 4);
let stub = StubVectorIndex::new(4);
let r1 = replay_pending_index(&mut conn, &stub).unwrap();
assert_eq!(r1.rows_replayed, 1);
let r2 = replay_pending_index(&mut conn, &stub).unwrap();
assert_eq!(r2.rows_seen, 0);
assert_eq!(r2.rows_replayed, 0);
assert_eq!(stub.add_count(), 1, "no extra add on second run");
}
#[test]
fn replay_skips_size_mismatch_rows() {
let (mut conn, _tmp) = open_test_db();
let (mid_good, _) = insert_episode(&conn, "good");
let (mid_bad, _) = insert_episode(&conn, "bad");
enqueue_pending(&conn, &mid_good, 4);
conn.execute(
"INSERT INTO pending_index (memory_id, embedding, embedding_dim, enqueued_at)
VALUES (?, ?, ?, ?)",
params![mid_bad, &vec![0u8; 8][..], 4i64, 0i64],
)
.unwrap();
let stub = StubVectorIndex::new(4);
let report = replay_pending_index(&mut conn, &stub).unwrap();
assert_eq!(report.rows_seen, 2);
assert_eq!(report.rows_replayed, 1);
assert_eq!(report.rows_failed, 1);
let stuck: String = conn
.query_row(
"SELECT memory_id FROM pending_index",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(stuck, mid_bad);
}
#[test]
fn drift_clean_when_index_matches_episodes() {
let (conn, _tmp) = open_test_db();
let _ = insert_episode(&conn, "a");
let _ = insert_episode(&conn, "b");
let stub = StubVectorIndex::new(4);
stub.add(1, &[0.0; 4]).unwrap();
stub.add(2, &[0.0; 4]).unwrap();
let drift = detect_drift(&conn, &stub).unwrap();
assert_eq!(drift.hot_episodes, 2);
assert_eq!(drift.index_len, 2);
assert!(drift.is_clean());
}
#[test]
fn drift_positive_when_index_lags_sql() {
let (conn, _tmp) = open_test_db();
let _ = insert_episode(&conn, "a");
let _ = insert_episode(&conn, "b");
let _ = insert_episode(&conn, "c");
let stub = StubVectorIndex::new(4);
stub.add(1, &[0.0; 4]).unwrap();
let drift = detect_drift(&conn, &stub).unwrap();
assert_eq!(drift.hot_episodes, 3);
assert_eq!(drift.index_len, 1);
assert_eq!(drift.diff, 2);
assert!(!drift.is_clean());
}
}