pub mod algorithms;
pub mod embed;
pub mod error;
pub mod export;
pub mod extension;
pub mod functions;
pub mod graph;
pub mod migrate;
pub mod rag;
pub mod schema;
pub mod vector;
#[cfg(feature = "async")]
pub mod async_kg;
#[cfg(feature = "async")]
pub use async_kg::embed::AsyncEmbeddingGenerator;
#[cfg(feature = "async")]
pub use async_kg::AsyncKnowledgeGraph;
pub(crate) fn row_get_weight(row: &rusqlite::Row, col: usize) -> rusqlite::Result<f64> {
use rusqlite::types::ValueRef;
match row.get_ref(col)? {
ValueRef::Real(f) => Ok(f),
ValueRef::Integer(i) => Ok(i as f64),
ValueRef::Null => Ok(1.0), ValueRef::Blob(b) if b.len() == 8 => {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(b);
Ok(f64::from_le_bytes(bytes)) }
ValueRef::Blob(b) if b.len() == 4 => {
let mut bytes = [0u8; 4];
bytes.copy_from_slice(b);
Ok(f32::from_le_bytes(bytes) as f64) }
ValueRef::Blob(_) => Ok(1.0), _ => Err(rusqlite::Error::InvalidColumnType(
col,
"weight".into(),
rusqlite::types::Type::Real,
)),
}
}
pub use algorithms::{
analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
PageRankConfig,
};
pub use embed::{
check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
EmbeddingStats,
};
pub use error::{Error, Result};
pub use export::{D3ExportGraph, D3ExportMetadata, D3Link, D3Node, DotConfig};
pub use extension::sqlite3_sqlite_knowledge_graph_init;
pub use functions::register_functions;
pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
pub use graph::{Entity, Neighbor, Relation};
pub use graph::{HigherOrderNeighbor, HigherOrderPath, HigherOrderPathStep, Hyperedge};
pub use migrate::{
build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
};
pub use rag::{embedder::Embedder, embedder::FixedEmbedder, RagConfig, RagEngine, RagResult};
pub use schema::{create_schema, schema_exists};
pub use vector::{cosine_similarity, SearchResult, VectorStore};
pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
use rusqlite::Connection;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResultWithEntity {
pub entity: Entity,
pub similarity: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphContext {
pub root_entity: Entity,
pub neighbors: Vec<Neighbor>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridSearchResult {
pub entity: Entity,
pub similarity: f32,
pub context: Option<GraphContext>,
}
#[derive(Debug)]
pub struct KnowledgeGraph {
conn: Connection,
}
impl KnowledgeGraph {
pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let conn = Connection::open(path)?;
conn.execute("PRAGMA foreign_keys = ON", [])?;
if !schema_exists(&conn)? {
create_schema(&conn)?;
}
register_functions(&conn)?;
Ok(Self { conn })
}
pub fn open_in_memory() -> Result<Self> {
let conn = Connection::open_in_memory()?;
conn.execute("PRAGMA foreign_keys = ON", [])?;
create_schema(&conn)?;
register_functions(&conn)?;
Ok(Self { conn })
}
pub fn connection(&self) -> &Connection {
&self.conn
}
pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
Ok(self.conn.unchecked_transaction()?)
}
pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
graph::insert_entity(&self.conn, entity)
}
pub fn get_entity(&self, id: i64) -> Result<Entity> {
graph::get_entity(&self.conn, id)
}
pub fn list_entities(
&self,
entity_type: Option<&str>,
limit: Option<i64>,
) -> Result<Vec<Entity>> {
graph::list_entities(&self.conn, entity_type, limit)
}
pub fn update_entity(&self, entity: &Entity) -> Result<()> {
graph::update_entity(&self.conn, entity)
}
pub fn delete_entity(&self, id: i64) -> Result<()> {
graph::delete_entity(&self.conn, id)
}
pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
graph::insert_relation(&self.conn, relation)
}
pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
graph::get_neighbors(&self.conn, entity_id, depth)
}
pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
let store = VectorStore::new();
store.insert_vector(&self.conn, entity_id, vector)
}
pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
let store = VectorStore::new();
store.search_vectors(&self.conn, query, k)
}
pub fn create_turboquant_index(
&self,
config: Option<TurboQuantConfig>,
) -> Result<TurboQuantIndex> {
let config = config.unwrap_or_default();
TurboQuantIndex::new(config)
}
pub fn build_turboquant_index(
&self,
config: Option<TurboQuantConfig>,
) -> Result<TurboQuantIndex> {
let dimension = self.get_vector_dimension()?.unwrap_or(384);
let config = config.unwrap_or(TurboQuantConfig {
dimension,
bit_width: 3,
seed: 42,
});
let mut index = TurboQuantIndex::new(config)?;
let vectors = self.load_all_vectors()?;
for (entity_id, vector) in vectors {
index.add_vector(entity_id, &vector)?;
}
Ok(index)
}
fn get_vector_dimension(&self) -> Result<Option<usize>> {
let result = self
.conn
.query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
row.get::<_, i64>(0)
});
match result {
Ok(dim) => Ok(Some(dim as usize)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
let mut stmt = self
.conn
.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
let rows = stmt.query_map([], |row| {
let entity_id: i64 = row.get(0)?;
let vector_blob: Vec<u8> = row.get(1)?;
let dimension: i64 = row.get(2)?;
let mut vector = Vec::with_capacity(dimension as usize);
for chunk in vector_blob.chunks_exact(4) {
let bytes: [u8; 4] = chunk.try_into().unwrap();
vector.push(f32::from_le_bytes(bytes));
}
Ok((entity_id, vector))
})?;
let mut vectors = Vec::new();
for row in rows {
vectors.push(row?);
}
Ok(vectors)
}
pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<i64> {
graph::insert_hyperedge(&self.conn, hyperedge)
}
pub fn get_hyperedge(&self, id: i64) -> Result<Hyperedge> {
graph::get_hyperedge(&self.conn, id)
}
pub fn list_hyperedges(
&self,
hyperedge_type: Option<&str>,
min_arity: Option<usize>,
max_arity: Option<usize>,
limit: Option<i64>,
) -> Result<Vec<Hyperedge>> {
graph::list_hyperedges(&self.conn, hyperedge_type, min_arity, max_arity, limit)
}
pub fn update_hyperedge(&self, hyperedge: &Hyperedge) -> Result<()> {
graph::update_hyperedge(&self.conn, hyperedge)
}
pub fn delete_hyperedge(&self, id: i64) -> Result<()> {
graph::delete_hyperedge(&self.conn, id)
}
pub fn get_higher_order_neighbors(
&self,
entity_id: i64,
min_arity: Option<usize>,
max_arity: Option<usize>,
) -> Result<Vec<HigherOrderNeighbor>> {
graph::get_higher_order_neighbors(&self.conn, entity_id, min_arity, max_arity)
}
pub fn get_entity_hyperedges(&self, entity_id: i64) -> Result<Vec<Hyperedge>> {
graph::get_entity_hyperedges(&self.conn, entity_id)
}
pub fn kg_higher_order_bfs(
&self,
start_id: i64,
max_depth: u32,
min_arity: Option<usize>,
) -> Result<Vec<TraversalNode>> {
graph::higher_order_bfs(&self.conn, start_id, max_depth, min_arity)
}
pub fn kg_higher_order_shortest_path(
&self,
from_id: i64,
to_id: i64,
max_depth: u32,
) -> Result<Option<HigherOrderPath>> {
graph::higher_order_shortest_path(&self.conn, from_id, to_id, max_depth)
}
pub fn kg_hyperedge_degree(&self, entity_id: i64) -> Result<f64> {
graph::hyperedge_degree(&self.conn, entity_id)
}
pub fn kg_hypergraph_entity_pagerank(
&self,
damping: Option<f64>,
max_iter: Option<usize>,
tolerance: Option<f64>,
) -> Result<std::collections::HashMap<i64, f64>> {
graph::hypergraph_entity_pagerank(
&self.conn,
damping.unwrap_or(0.85),
max_iter.unwrap_or(100),
tolerance.unwrap_or(1e-6),
)
}
pub fn kg_semantic_search(
&self,
query_embedding: Vec<f32>,
k: usize,
) -> Result<Vec<SearchResultWithEntity>> {
let results = self.search_vectors(query_embedding, k)?;
let mut entities_with_results = Vec::new();
for result in results {
let entity = self.get_entity(result.entity_id)?;
entities_with_results.push(SearchResultWithEntity {
entity,
similarity: result.similarity,
});
}
Ok(entities_with_results)
}
pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
let root_entity = self.get_entity(entity_id)?;
let neighbors = self.get_neighbors(entity_id, depth)?;
Ok(GraphContext {
root_entity,
neighbors,
})
}
pub fn kg_hybrid_search(
&self,
_query_text: &str,
query_embedding: Vec<f32>,
k: usize,
) -> Result<Vec<HybridSearchResult>> {
let semantic_results = self.kg_semantic_search(query_embedding, k)?;
let mut hybrid_results = Vec::new();
for result in semantic_results.iter() {
let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
let context = self.kg_get_context(entity_id, 1)?;
hybrid_results.push(HybridSearchResult {
entity: result.entity.clone(),
similarity: result.similarity,
context: Some(context),
});
}
Ok(hybrid_results)
}
pub fn kg_similar_entities(
&self,
entity_id: i64,
k: usize,
) -> Result<Vec<SearchResultWithEntity>> {
let store = VectorStore::new();
let query_vec = store.get_vector(&self.conn, entity_id)?;
let results = store.search_vectors(&self.conn, query_vec, k + 1)?;
let mut out = Vec::new();
for r in results {
if r.entity_id == entity_id {
continue;
}
let entity = self.get_entity(r.entity_id)?;
out.push(SearchResultWithEntity {
entity,
similarity: r.similarity,
});
}
out.truncate(k);
Ok(out)
}
pub fn kg_find_related(
&self,
entity_id: i64,
threshold: f64,
) -> Result<Vec<(graph::Entity, f64)>> {
let neighbours = self.get_neighbors(entity_id, 1)?;
let mut results: Vec<(graph::Entity, f64)> = neighbours
.into_iter()
.filter(|n| n.relation.weight >= threshold)
.map(|n| (n.entity, n.relation.weight))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
pub fn kg_bfs_traversal(
&self,
start_id: i64,
direction: Direction,
max_depth: u32,
) -> Result<Vec<TraversalNode>> {
let query = TraversalQuery {
direction,
max_depth,
..Default::default()
};
graph::bfs_traversal(&self.conn, start_id, query)
}
pub fn kg_dfs_traversal(
&self,
start_id: i64,
direction: Direction,
max_depth: u32,
) -> Result<Vec<TraversalNode>> {
let query = TraversalQuery {
direction,
max_depth,
..Default::default()
};
graph::dfs_traversal(&self.conn, start_id, query)
}
pub fn kg_shortest_path(
&self,
from_id: i64,
to_id: i64,
max_depth: u32,
) -> Result<Option<TraversalPath>> {
graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
}
pub fn kg_graph_stats(&self) -> Result<GraphStats> {
graph::compute_graph_stats(&self.conn)
}
pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
algorithms::pagerank(&self.conn, config.unwrap_or_default())
}
pub fn kg_louvain(&self) -> Result<CommunityResult> {
algorithms::louvain_communities(&self.conn)
}
pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
algorithms::connected_components(&self.conn)
}
pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
algorithms::analyze_graph(&self.conn)
}
pub fn export_json(&self) -> Result<D3ExportGraph> {
export::export_d3_json(&self.conn)
}
pub fn export_dot(&self, config: &DotConfig) -> Result<String> {
export::export_dot(&self.conn, config)
}
#[cfg(feature = "async")]
pub fn into_async(self) -> AsyncKnowledgeGraph {
AsyncKnowledgeGraph::from_sync(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_open_in_memory() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
assert!(schema_exists(kg.connection()).unwrap());
}
#[test]
fn test_crud_operations() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let mut entity = Entity::new("paper", "Test Paper");
entity.set_property("author", serde_json::json!("John Doe"));
let id = kg.insert_entity(&entity).unwrap();
let retrieved = kg.get_entity(id).unwrap();
assert_eq!(retrieved.name, "Test Paper");
let entities = kg.list_entities(Some("paper"), None).unwrap();
assert_eq!(entities.len(), 1);
let mut updated = retrieved.clone();
updated.set_property("year", serde_json::json!(2024));
kg.update_entity(&updated).unwrap();
kg.delete_entity(id).unwrap();
let entities = kg.list_entities(None, None).unwrap();
assert_eq!(entities.len(), 0);
}
#[test]
fn test_graph_traversal() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
.unwrap();
kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
.unwrap();
let neighbors = kg.get_neighbors(id1, 1).unwrap();
assert_eq!(neighbors.len(), 1);
let neighbors = kg.get_neighbors(id1, 2).unwrap();
assert_eq!(neighbors.len(), 2);
}
#[test]
fn test_vector_search() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity_id, id1);
}
#[test]
fn test_find_related_above_threshold() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
kg.insert_relation(&Relation::new(id1, id2, "related", 0.9).unwrap())
.unwrap();
kg.insert_relation(&Relation::new(id1, id3, "related", 0.3).unwrap())
.unwrap();
let results = kg.kg_find_related(id1, 0.5).unwrap();
assert_eq!(
results.len(),
1,
"only B (weight 0.9) should pass threshold 0.5"
);
assert_eq!(results[0].0.id, Some(id2));
}
#[test]
fn test_find_related_sorted_descending() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
kg.insert_relation(&Relation::new(id1, id2, "related", 0.4).unwrap())
.unwrap();
kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
.unwrap();
let results = kg.kg_find_related(id1, 0.0).unwrap();
assert_eq!(results.len(), 2);
assert!(
results[0].1 >= results[1].1,
"results should be sorted by weight desc"
);
assert_eq!(results[0].0.id, Some(id3)); }
#[test]
fn test_find_related_threshold_one() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
kg.insert_relation(&Relation::new(id1, id2, "related", 1.0).unwrap())
.unwrap();
kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
.unwrap();
let results = kg.kg_find_related(id1, 1.0).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0.id, Some(id2));
}
#[test]
fn test_find_related_no_neighbours() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&Entity::new("paper", "Isolated")).unwrap();
let results = kg.kg_find_related(id1, 0.0).unwrap();
assert!(results.is_empty(), "isolated entity should return empty");
}
#[test]
fn test_find_related_entity_not_found() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let result = kg.kg_find_related(9999, 0.5);
assert!(result.is_err(), "non-existent entity should return error");
}
#[test]
fn test_similar_entities() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&graph::Entity::new("paper", "A")).unwrap();
let id2 = kg.insert_entity(&graph::Entity::new("paper", "B")).unwrap();
let id3 = kg.insert_entity(&graph::Entity::new("paper", "C")).unwrap();
kg.insert_vector(id1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
kg.insert_vector(id2, vec![0.9, 0.1, 0.0, 0.0]).unwrap();
kg.insert_vector(id3, vec![0.0, 0.0, 1.0, 0.0]).unwrap();
let results = kg.kg_similar_entities(id1, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity.name, "B");
assert!(results[0].similarity > results[1].similarity);
}
#[test]
fn test_similar_entities_excludes_self() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg.insert_entity(&graph::Entity::new("paper", "X")).unwrap();
kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
let results = kg.kg_similar_entities(id1, 5).unwrap();
assert!(results.is_empty(), "self should not appear in results");
}
#[test]
fn test_similar_entities_no_vector() {
let kg = KnowledgeGraph::open_in_memory().unwrap();
let id1 = kg
.insert_entity(&graph::Entity::new("paper", "NoVec"))
.unwrap();
let result = kg.kg_similar_entities(id1, 5);
assert!(result.is_err(), "entity without vector should error");
}
}