pub mod embedder;
mod error;
pub use embedder::Embedder;
pub use error::RagError;
use crate::error::Result;
use crate::graph::{get_neighbors, Entity};
use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
use rusqlite::Connection;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RagResult {
pub entity: Entity,
pub vector_score: f64,
pub graph_score: f64,
pub combined_score: f64,
pub context_entities: Vec<Entity>,
}
#[derive(Debug, Clone)]
pub struct RagConfig {
pub vector_weight: f64,
pub graph_weight: f64,
pub top_k_candidates: usize,
pub top_k_rerank: usize,
pub enable_graph_expansion: bool,
pub graph_depth: u32,
pub context_depth: u32,
pub max_context_entities: usize,
pub min_vector_score: f32,
pub min_combined_score: f64,
pub vector_dimension: usize,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
vector_weight: 0.6,
graph_weight: 0.4,
top_k_candidates: 50,
top_k_rerank: 20,
enable_graph_expansion: true,
graph_depth: 1,
context_depth: 2,
max_context_entities: 5,
min_vector_score: 0.0,
min_combined_score: 0.0,
vector_dimension: 384,
}
}
}
pub struct RagEngine {
config: RagConfig,
}
impl RagEngine {
pub fn new(config: RagConfig) -> Self {
Self { config }
}
pub fn search(
&self,
conn: &Connection,
embedder: &dyn Embedder,
query: &str,
k: usize,
) -> Result<Vec<RagResult>> {
let query_vec = embedder.embed(query)?;
let ann_candidates = self.stage1_ann(conn, &query_vec)?;
if ann_candidates.is_empty() {
return Ok(Vec::new());
}
let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
reranked.truncate(self.config.top_k_rerank);
let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
if self.config.enable_graph_expansion {
self.rapo_expand(conn, &query_vec, &mut pool)?;
}
let pool_size = pool.len();
let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
scored.truncate(k);
for result in &mut scored {
let entity_id = result.entity.id.unwrap_or(0);
result.context_entities = self.collect_context(conn, entity_id, &pool)?;
}
Ok(scored)
}
fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
let vector_count: i64 =
conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
if vector_count == 0 {
return Ok(Vec::new());
}
let vectors_checksum: i64 = conn.query_row(
"SELECT COALESCE(SUM(entity_id), 0) FROM kg_vectors",
[],
|r| r.get(0),
)?;
let cached = load_turboquant_cache(conn, vector_count, vectors_checksum)?;
let index = match cached {
Some(idx) => idx,
None => {
let all_vectors = load_all_vectors(conn)?;
let dim = all_vectors[0].1.len();
let config = TurboQuantConfig {
dimension: dim,
bit_width: 3,
seed: 42,
};
let mut idx = TurboQuantIndex::new(config)?;
for (entity_id, vec) in &all_vectors {
idx.add_vector(*entity_id, vec)?;
}
save_turboquant_cache(conn, &idx, vector_count, vectors_checksum)?;
idx
}
};
let k = self.config.top_k_candidates.min(vector_count as usize);
index.search(query_vec, k)
}
fn stage2_rerank(
&self,
conn: &Connection,
query_vec: &[f32],
candidates: Vec<(i64, f32)>,
) -> Result<Vec<(i64, f32)>> {
let store = VectorStore::new();
let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
for (entity_id, approx) in candidates {
if approx < self.config.min_vector_score {
continue;
}
match store.get_vector(conn, entity_id) {
Ok(vec) => {
let exact = cosine_similarity(query_vec, &vec);
if exact >= self.config.min_vector_score {
scored.push((entity_id, exact));
}
}
Err(_) => {
}
}
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
Ok(scored)
}
fn rapo_expand(
&self,
conn: &Connection,
query_vec: &[f32],
pool: &mut HashMap<i64, f32>,
) -> Result<()> {
let store = VectorStore::new();
let seeds: Vec<i64> = pool.keys().copied().collect();
for seed_id in seeds {
let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
Ok(n) => n,
Err(_) => continue,
};
for nbr in neighbours {
let nbr_id = match nbr.entity.id {
Some(id) => id,
None => continue,
};
if pool.contains_key(&nbr_id) {
continue;
}
if let Ok(vec) = store.get_vector(conn, nbr_id) {
let score = cosine_similarity(query_vec, &vec);
if score >= self.config.min_vector_score {
pool.insert(nbr_id, score);
}
}
}
}
Ok(())
}
fn score_and_filter(
&self,
conn: &Connection,
pool: &HashMap<i64, f32>,
pool_size: usize,
) -> Result<Vec<RagResult>> {
let mut results = Vec::new();
for (&entity_id, &v_score) in pool {
let vector_score = v_score as f64;
let graph_score = if pool_size > 1 {
let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
let overlap = neighbours
.iter()
.filter(|n| {
n.entity
.id
.map(|id| pool.contains_key(&id))
.unwrap_or(false)
})
.count();
overlap as f64 / (pool_size - 1) as f64
} else {
0.0
};
let combined_score =
self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
if combined_score < self.config.min_combined_score {
continue;
}
let entity = match crate::graph::get_entity(conn, entity_id) {
Ok(e) => e,
Err(_) => continue,
};
results.push(RagResult {
entity,
vector_score,
graph_score,
combined_score,
context_entities: Vec::new(), });
}
Ok(results)
}
fn collect_context(
&self,
conn: &Connection,
entity_id: i64,
pool: &HashMap<i64, f32>,
) -> Result<Vec<Entity>> {
let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
Ok(n) => n,
Err(_) => return Ok(Vec::new()),
};
let mut in_pool: Vec<Entity> = Vec::new();
let mut not_in_pool: Vec<Entity> = Vec::new();
for nbr in neighbours {
if let Some(id) = nbr.entity.id {
if pool.contains_key(&id) {
in_pool.push(nbr.entity);
} else {
not_in_pool.push(nbr.entity);
}
}
}
in_pool.extend(not_in_pool);
in_pool.truncate(self.config.max_context_entities);
Ok(in_pool)
}
}
fn load_turboquant_cache(
conn: &Connection,
current_count: i64,
current_checksum: i64,
) -> Result<Option<TurboQuantIndex>> {
let mut stmt = conn.prepare(
"SELECT index_blob, vector_count, vectors_checksum \
FROM kg_turboquant_cache WHERE id = 1",
)?;
let result = stmt.query_row([], |row| {
let blob: Vec<u8> = row.get(0)?;
let cached_count: i64 = row.get(1)?;
let cached_checksum: i64 = row.get(2)?;
Ok((blob, cached_count, cached_checksum))
});
match result {
Ok((blob, cached_count, cached_checksum))
if cached_count == current_count && cached_checksum == current_checksum =>
{
let index = TurboQuantIndex::from_bytes(&blob)
.map_err(|e| crate::error::Error::Other(e.to_string()))?;
Ok(Some(index))
}
Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
fn save_turboquant_cache(
conn: &Connection,
index: &TurboQuantIndex,
vector_count: i64,
vectors_checksum: i64,
) -> Result<()> {
let blob = index
.to_bytes()
.map_err(|e| crate::error::Error::Other(e.to_string()))?;
conn.execute(
"INSERT INTO kg_turboquant_cache \
(id, index_blob, vector_count, vectors_checksum) \
VALUES (1, ?1, ?2, ?3) \
ON CONFLICT(id) DO UPDATE SET \
index_blob = excluded.index_blob, \
vector_count = excluded.vector_count, \
vectors_checksum = excluded.vectors_checksum",
rusqlite::params![blob, vector_count, vectors_checksum],
)?;
Ok(())
}
fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
let rows = stmt.query_map([], |row| {
let entity_id: i64 = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
let dim: i64 = row.get(2)?;
let mut vec = Vec::with_capacity(dim as usize);
for chunk in blob.chunks_exact(4) {
vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
}
Ok((entity_id, vec))
})?;
let mut out = Vec::new();
for row in rows {
out.push(row?);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::entity::{insert_entity, Entity};
use crate::graph::relation::{insert_relation, Relation};
use crate::rag::embedder::FixedEmbedder;
use crate::vector::VectorStore;
use rusqlite::Connection;
fn setup(dim: usize) -> (Connection, Vec<i64>) {
let conn = Connection::open_in_memory().unwrap();
crate::schema::create_schema(&conn).unwrap();
let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
let store = VectorStore::new();
let mut v1 = vec![0.0f32; dim];
v1[0] = 1.0;
store.insert_vector(&conn, e1, v1).unwrap();
let mut v2 = vec![0.0f32; dim];
v2[1] = 1.0;
store.insert_vector(&conn, e2, v2).unwrap();
let mut v3 = vec![0.0f32; dim];
v3[0] = 0.8;
v3[1] = 0.6;
store.insert_vector(&conn, e3, v3).unwrap();
insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
(conn, vec![e1, e2, e3])
}
#[test]
fn test_basic_search() {
let dim = 4;
let (conn, ids) = setup(dim);
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 10,
top_k_rerank: 5,
..Default::default()
});
let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
assert!(!results.is_empty(), "should return at least one result");
assert_eq!(results[0].entity.id, Some(ids[0]));
assert!((results[0].vector_score - 1.0).abs() < 1e-5);
}
#[test]
fn test_empty_db() {
let conn = Connection::open_in_memory().unwrap();
crate::schema::create_schema(&conn).unwrap();
let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
let engine = RagEngine::new(RagConfig::default());
let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_graph_expansion() {
let dim = 4;
let conn = Connection::open_in_memory().unwrap();
crate::schema::create_schema(&conn).unwrap();
let store = VectorStore::new();
let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
let mut v1 = vec![0.0f32; dim];
v1[0] = 1.0;
store.insert_vector(&conn, e1, v1).unwrap();
let mut v2 = vec![0.0f32; dim];
v2[1] = 1.0;
store.insert_vector(&conn, e2, v2).unwrap();
insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 1, top_k_rerank: 1,
enable_graph_expansion: true,
..Default::default()
});
let results = engine.search(&conn, &embedder, "q", 5).unwrap();
let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
assert!(ids.contains(&e1));
assert!(ids.contains(&e2), "RAPO should expand to e2");
}
#[test]
fn test_context_attached() {
let dim = 4;
let (conn, ids) = setup(dim);
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
context_depth: 1,
max_context_entities: 3,
..Default::default()
});
let results = engine.search(&conn, &embedder, "q", 3).unwrap();
let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
assert!(e1_result.is_some());
let ctx = &e1_result.unwrap().context_entities;
assert!(!ctx.is_empty(), "e1 should have context neighbours");
}
#[test]
fn test_cache_written_on_first_query() {
let dim = 4;
let (conn, _ids) = setup(dim);
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 10,
top_k_rerank: 5,
..Default::default()
});
engine.search(&conn, &embedder, "q", 2).unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM kg_turboquant_cache WHERE id = 1",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(count, 1, "cache row should be created after first query");
}
#[test]
fn test_cache_hit_on_second_query() {
let dim = 4;
let (conn, _ids) = setup(dim);
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 10,
top_k_rerank: 5,
..Default::default()
});
let r1 = engine.search(&conn, &embedder, "q", 2).unwrap();
let r2 = engine.search(&conn, &embedder, "q", 2).unwrap();
assert_eq!(
r1[0].entity.id, r2[0].entity.id,
"cache hit should return identical results"
);
}
#[test]
fn test_cache_stores_checksum() {
let dim = 4;
let (conn, _ids) = setup(dim);
let query = {
let mut q = vec![0.0f32; dim];
q[0] = 1.0;
q
};
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 10,
top_k_rerank: 5,
..Default::default()
});
engine.search(&conn, &embedder, "q", 2).unwrap();
let (count, checksum): (i64, i64) = conn
.query_row(
"SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
[],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(count, 3);
assert!(checksum > 0, "checksum should reflect entity_id sum");
}
#[test]
fn test_cache_invalidated_on_same_count_different_entity() {
let dim = 4;
let (conn, ids) = setup(dim);
let query = {
let mut q = vec![0.0f32; dim];
q[0] = 1.0;
q
};
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 10,
top_k_rerank: 5,
..Default::default()
});
engine.search(&conn, &embedder, "q", 2).unwrap();
let checksum_before: i64 = conn
.query_row(
"SELECT vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
[],
|r| r.get(0),
)
.unwrap();
conn.execute("DELETE FROM kg_vectors WHERE entity_id = ?1", [ids[2]])
.unwrap();
let e_new = crate::graph::entity::insert_entity(
&conn,
&crate::graph::entity::Entity::new("doc", "Doc Swap"),
)
.unwrap();
let store = VectorStore::new();
let mut v_new = vec![0.0f32; dim];
v_new[3] = 1.0;
store.insert_vector(&conn, e_new, v_new).unwrap();
engine.search(&conn, &embedder, "q", 2).unwrap();
let (count_after, checksum_after): (i64, i64) = conn
.query_row(
"SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
[],
|r| Ok((r.get(0)?, r.get(1)?)),
)
.unwrap();
assert_eq!(count_after, 3, "vector count should still be 3 after swap");
assert_ne!(
checksum_after, checksum_before,
"checksum must change after swapping one vector"
);
}
#[test]
fn test_cache_invalidated_after_new_vector() {
let dim = 4;
let (conn, _ids) = setup(dim);
let mut query = vec![0.0f32; dim];
query[0] = 1.0;
let embedder = FixedEmbedder(query);
let engine = RagEngine::new(RagConfig {
vector_dimension: dim,
top_k_candidates: 10,
top_k_rerank: 5,
..Default::default()
});
engine.search(&conn, &embedder, "q", 2).unwrap();
let cached_count_before: i64 = conn
.query_row(
"SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(cached_count_before, 3);
let e4 = crate::graph::entity::insert_entity(
&conn,
&crate::graph::entity::Entity::new("doc", "Doc D"),
)
.unwrap();
let store = VectorStore::new();
let mut v4 = vec![0.0f32; dim];
v4[2] = 1.0;
store.insert_vector(&conn, e4, v4).unwrap();
engine.search(&conn, &embedder, "q", 2).unwrap();
let cached_count_after: i64 = conn
.query_row(
"SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(
cached_count_after, 4,
"cache should be rebuilt after new vector added"
);
}
}