Skip to main content

sqlite_knowledge_graph/
lib.rs

1//! SQLite-based Knowledge Graph Library
2//!
3//! This library provides a knowledge graph implementation built on SQLite with support for:
4//! - Entities with typed properties
5//! - Relations between entities with weights
6//! - Vector embeddings for semantic search
7//! - Custom SQLite functions for direct SQL operations
8//! - RAG (Retrieval-Augmented Generation) query functions
9//! - Graph algorithms (PageRank, Louvain, Connected Components)
10//!
11//! ## SQLite Extension
12//!
13//! This crate can be compiled as a SQLite loadable extension:
14//! ```bash
15//! cargo build --release
16//! sqlite3 db.db ".load ./target/release/libsqlite_knowledge_graph.dylib"
17//! sqlite3 db.db "SELECT kg_version();"
18//! ```
19
20pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod extension;
24pub mod functions;
25pub mod graph;
26pub mod migrate;
27pub mod schema;
28pub mod vector;
29
30pub use algorithms::{
31    analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
32    PageRankConfig,
33};
34pub use embed::{
35    check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
36    EmbeddingStats,
37};
38pub use error::{Error, Result};
39pub use extension::sqlite3_sqlite_knowledge_graph_init;
40pub use functions::register_functions;
41pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
42pub use graph::{Entity, Neighbor, Relation};
43pub use migrate::{
44    build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
45};
46pub use schema::{create_schema, schema_exists};
47pub use vector::{cosine_similarity, SearchResult, VectorStore};
48pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
49
50use rusqlite::Connection;
51use serde::{Deserialize, Serialize};
52
53/// Semantic search result with entity information.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SearchResultWithEntity {
56    pub entity: Entity,
57    pub similarity: f32,
58}
59
60/// Graph context for an entity (root + neighbors).
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct GraphContext {
63    pub root_entity: Entity,
64    pub neighbors: Vec<Neighbor>,
65}
66
67/// Hybrid search result combining semantic similarity and graph context.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct HybridSearchResult {
70    pub entity: Entity,
71    pub similarity: f32,
72    pub context: Option<GraphContext>,
73}
74
75/// Knowledge Graph Manager - main entry point for the library.
76#[derive(Debug)]
77pub struct KnowledgeGraph {
78    conn: Connection,
79}
80
81impl KnowledgeGraph {
82    /// Open a new knowledge graph database connection.
83    pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
84        let conn = Connection::open(path)?;
85
86        // Enable foreign keys
87        conn.execute("PRAGMA foreign_keys = ON", [])?;
88
89        // Create schema if not exists
90        if !schema_exists(&conn)? {
91            create_schema(&conn)?;
92        }
93
94        // Register custom functions
95        register_functions(&conn)?;
96
97        Ok(Self { conn })
98    }
99
100    /// Open an in-memory knowledge graph (useful for testing).
101    pub fn open_in_memory() -> Result<Self> {
102        let conn = Connection::open_in_memory()?;
103
104        // Enable foreign keys
105        conn.execute("PRAGMA foreign_keys = ON", [])?;
106
107        // Create schema
108        create_schema(&conn)?;
109
110        // Register custom functions
111        register_functions(&conn)?;
112
113        Ok(Self { conn })
114    }
115
116    /// Get a reference to the underlying SQLite connection.
117    pub fn connection(&self) -> &Connection {
118        &self.conn
119    }
120
121    /// Begin a transaction for batch operations.
122    pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
123        Ok(self.conn.unchecked_transaction()?)
124    }
125
126    /// Insert an entity into the knowledge graph.
127    pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
128        graph::insert_entity(&self.conn, entity)
129    }
130
131    /// Get an entity by ID.
132    pub fn get_entity(&self, id: i64) -> Result<Entity> {
133        graph::get_entity(&self.conn, id)
134    }
135
136    /// List entities with optional filtering.
137    pub fn list_entities(
138        &self,
139        entity_type: Option<&str>,
140        limit: Option<i64>,
141    ) -> Result<Vec<Entity>> {
142        graph::list_entities(&self.conn, entity_type, limit)
143    }
144
145    /// Update an entity.
146    pub fn update_entity(&self, entity: &Entity) -> Result<()> {
147        graph::update_entity(&self.conn, entity)
148    }
149
150    /// Delete an entity.
151    pub fn delete_entity(&self, id: i64) -> Result<()> {
152        graph::delete_entity(&self.conn, id)
153    }
154
155    /// Insert a relation between entities.
156    pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
157        graph::insert_relation(&self.conn, relation)
158    }
159
160    /// Get neighbors of an entity using BFS traversal.
161    pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
162        graph::get_neighbors(&self.conn, entity_id, depth)
163    }
164
165    /// Insert a vector embedding for an entity.
166    pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
167        let store = VectorStore::new();
168        store.insert_vector(&self.conn, entity_id, vector)
169    }
170
171    /// Search for similar entities using vector embeddings.
172    pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
173        let store = VectorStore::new();
174        store.search_vectors(&self.conn, query, k)
175    }
176
177    // ========== TurboQuant Vector Index ==========
178
179    /// Create a TurboQuant index for fast approximate nearest neighbor search.
180    ///
181    /// TurboQuant provides:
182    /// - Instant indexing (no training required)
183    /// - 6x memory compression
184    /// - Near-zero accuracy loss
185    ///
186    /// # Arguments
187    /// * `config` - Optional configuration (uses defaults if None)
188    ///
189    /// # Example
190    /// ```ignore
191    /// let config = TurboQuantConfig {
192    ///     dimension: 384,
193    ///     bit_width: 3,
194    ///     seed: 42,
195    /// };
196    /// let mut index = kg.create_turboquant_index(Some(config))?;
197    ///
198    /// // Add vectors to index
199    /// for (entity_id, vector) in all_vectors {
200    ///     index.add_vector(entity_id, &vector)?;
201    /// }
202    ///
203    /// // Fast search
204    /// let results = index.search(&query_vector, 10)?;
205    /// ```
206    pub fn create_turboquant_index(
207        &self,
208        config: Option<TurboQuantConfig>,
209    ) -> Result<TurboQuantIndex> {
210        let config = config.unwrap_or_default();
211
212        TurboQuantIndex::new(config)
213    }
214
215    /// Build a TurboQuant index from all existing vectors in the database.
216    /// This is a convenience method that loads all vectors and indexes them.
217    pub fn build_turboquant_index(
218        &self,
219        config: Option<TurboQuantConfig>,
220    ) -> Result<TurboQuantIndex> {
221        // Get dimension from first vector
222        let dimension = self.get_vector_dimension()?.unwrap_or(384);
223
224        let config = config.unwrap_or(TurboQuantConfig {
225            dimension,
226            bit_width: 3,
227            seed: 42,
228        });
229
230        let mut index = TurboQuantIndex::new(config)?;
231
232        // Load all vectors
233        let vectors = self.load_all_vectors()?;
234
235        for (entity_id, vector) in vectors {
236            index.add_vector(entity_id, &vector)?;
237        }
238
239        Ok(index)
240    }
241
242    /// Get the dimension of stored vectors (if any exist).
243    fn get_vector_dimension(&self) -> Result<Option<usize>> {
244        let result = self
245            .conn
246            .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
247                row.get::<_, i64>(0)
248            });
249
250        match result {
251            Ok(dim) => Ok(Some(dim as usize)),
252            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
253            Err(e) => Err(e.into()),
254        }
255    }
256
257    /// Load all vectors from the database.
258    fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
259        let mut stmt = self
260            .conn
261            .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
262
263        let rows = stmt.query_map([], |row| {
264            let entity_id: i64 = row.get(0)?;
265            let vector_blob: Vec<u8> = row.get(1)?;
266            let dimension: i64 = row.get(2)?;
267
268            let mut vector = Vec::with_capacity(dimension as usize);
269            for chunk in vector_blob.chunks_exact(4) {
270                let bytes: [u8; 4] = chunk.try_into().unwrap();
271                vector.push(f32::from_le_bytes(bytes));
272            }
273
274            Ok((entity_id, vector))
275        })?;
276
277        let mut vectors = Vec::new();
278        for row in rows {
279            vectors.push(row?);
280        }
281
282        Ok(vectors)
283    }
284
285    // ========== RAG Query Functions ==========
286
287    /// Semantic search using vector embeddings.
288    /// Returns entities sorted by similarity score.
289    pub fn kg_semantic_search(
290        &self,
291        query_embedding: Vec<f32>,
292        k: usize,
293    ) -> Result<Vec<SearchResultWithEntity>> {
294        let results = self.search_vectors(query_embedding, k)?;
295
296        let mut entities_with_results = Vec::new();
297        for result in results {
298            let entity = self.get_entity(result.entity_id)?;
299            entities_with_results.push(SearchResultWithEntity {
300                entity,
301                similarity: result.similarity,
302            });
303        }
304
305        Ok(entities_with_results)
306    }
307
308    /// Get context around an entity using graph traversal.
309    /// Returns neighbors up to the specified depth.
310    pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
311        let root_entity = self.get_entity(entity_id)?;
312        let neighbors = self.get_neighbors(entity_id, depth)?;
313
314        Ok(GraphContext {
315            root_entity,
316            neighbors,
317        })
318    }
319
320    /// Hybrid search combining semantic search and graph context.
321    /// Performs semantic search first, then retrieves context for top-k results.
322    pub fn kg_hybrid_search(
323        &self,
324        _query_text: &str,
325        query_embedding: Vec<f32>,
326        k: usize,
327    ) -> Result<Vec<HybridSearchResult>> {
328        let semantic_results = self.kg_semantic_search(query_embedding, k)?;
329
330        let mut hybrid_results = Vec::new();
331        for result in semantic_results.iter() {
332            let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
333            let context = self.kg_get_context(entity_id, 1)?; // Depth 1 context
334
335            hybrid_results.push(HybridSearchResult {
336                entity: result.entity.clone(),
337                similarity: result.similarity,
338                context: Some(context),
339            });
340        }
341
342        Ok(hybrid_results)
343    }
344
345    // ========== Graph Traversal Functions ==========
346
347    /// BFS traversal from a starting entity.
348    /// Returns all reachable entities within max_depth with depth information.
349    pub fn kg_bfs_traversal(
350        &self,
351        start_id: i64,
352        direction: Direction,
353        max_depth: u32,
354    ) -> Result<Vec<TraversalNode>> {
355        let query = TraversalQuery {
356            direction,
357            max_depth,
358            ..Default::default()
359        };
360        graph::bfs_traversal(&self.conn, start_id, query)
361    }
362
363    /// DFS traversal from a starting entity.
364    /// Returns all reachable entities within max_depth.
365    pub fn kg_dfs_traversal(
366        &self,
367        start_id: i64,
368        direction: Direction,
369        max_depth: u32,
370    ) -> Result<Vec<TraversalNode>> {
371        let query = TraversalQuery {
372            direction,
373            max_depth,
374            ..Default::default()
375        };
376        graph::dfs_traversal(&self.conn, start_id, query)
377    }
378
379    /// Find shortest path between two entities using BFS.
380    /// Returns the path with all intermediate steps (if exists).
381    pub fn kg_shortest_path(
382        &self,
383        from_id: i64,
384        to_id: i64,
385        max_depth: u32,
386    ) -> Result<Option<TraversalPath>> {
387        graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
388    }
389
390    /// Compute graph statistics.
391    pub fn kg_graph_stats(&self) -> Result<GraphStats> {
392        graph::compute_graph_stats(&self.conn)
393    }
394
395    // ========== Graph Algorithms ==========
396
397    /// Compute PageRank scores for all entities.
398    /// Returns a vector of (entity_id, score) sorted by score descending.
399    pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
400        algorithms::pagerank(&self.conn, config.unwrap_or_default())
401    }
402
403    /// Detect communities using Louvain algorithm.
404    /// Returns community memberships and modularity score.
405    pub fn kg_louvain(&self) -> Result<CommunityResult> {
406        algorithms::louvain_communities(&self.conn)
407    }
408
409    /// Find connected components in the graph.
410    /// Returns a list of components, each being a list of entity IDs.
411    pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
412        algorithms::connected_components(&self.conn)
413    }
414
415    /// Run full graph analysis (PageRank + Louvain + Connected Components).
416    pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
417        algorithms::analyze_graph(&self.conn)
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_open_in_memory() {
427        let kg = KnowledgeGraph::open_in_memory().unwrap();
428        assert!(schema_exists(kg.connection()).unwrap());
429    }
430
431    #[test]
432    fn test_crud_operations() {
433        let kg = KnowledgeGraph::open_in_memory().unwrap();
434
435        // Create entity
436        let mut entity = Entity::new("paper", "Test Paper");
437        entity.set_property("author", serde_json::json!("John Doe"));
438        let id = kg.insert_entity(&entity).unwrap();
439
440        // Read entity
441        let retrieved = kg.get_entity(id).unwrap();
442        assert_eq!(retrieved.name, "Test Paper");
443
444        // List entities
445        let entities = kg.list_entities(Some("paper"), None).unwrap();
446        assert_eq!(entities.len(), 1);
447
448        // Update entity
449        let mut updated = retrieved.clone();
450        updated.set_property("year", serde_json::json!(2024));
451        kg.update_entity(&updated).unwrap();
452
453        // Delete entity
454        kg.delete_entity(id).unwrap();
455        let entities = kg.list_entities(None, None).unwrap();
456        assert_eq!(entities.len(), 0);
457    }
458
459    #[test]
460    fn test_graph_traversal() {
461        let kg = KnowledgeGraph::open_in_memory().unwrap();
462
463        // Create entities
464        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
465        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
466        let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
467
468        // Create relations
469        kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
470            .unwrap();
471        kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
472            .unwrap();
473
474        // Get neighbors depth 1
475        let neighbors = kg.get_neighbors(id1, 1).unwrap();
476        assert_eq!(neighbors.len(), 1);
477
478        // Get neighbors depth 2
479        let neighbors = kg.get_neighbors(id1, 2).unwrap();
480        assert_eq!(neighbors.len(), 2);
481    }
482
483    #[test]
484    fn test_vector_search() {
485        let kg = KnowledgeGraph::open_in_memory().unwrap();
486
487        // Create entities
488        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
489        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
490
491        // Insert vectors
492        kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
493        kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
494
495        // Search
496        let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
497        assert_eq!(results.len(), 2);
498        assert_eq!(results[0].entity_id, id1);
499    }
500}