Skip to main content

sqlite_knowledge_graph/
lib.rs

1//! SQLite-based Knowledge Graph Library
2//!
3//! This library provides a knowledge graph implementation built on SQLite with support for:
4//! - Entities with typed properties
5//! - Relations between entities with weights
6//! - Vector embeddings for semantic search
7//! - Custom SQLite functions for direct SQL operations
8//! - RAG (Retrieval-Augmented Generation) query functions
9//! - Graph algorithms (PageRank, Louvain, Connected Components)
10//!
11//! ## SQLite Extension
12//!
13//! This crate can be compiled as a SQLite loadable extension:
14//! ```bash
15//! cargo build --release
16//! sqlite3 db.db ".load ./target/release/libsqlite_knowledge_graph.dylib"
17//! sqlite3 db.db "SELECT kg_version();"
18//! ```
19
20pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod extension;
24pub mod functions;
25pub mod graph;
26pub mod migrate;
27pub mod schema;
28pub mod vector;
29
30pub use algorithms::{
31    analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
32    PageRankConfig,
33};
34pub use embed::{
35    check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
36    EmbeddingStats,
37};
38pub use error::{Error, Result};
39pub use extension::sqlite3_sqlite_knowledge_graph_init;
40pub use functions::register_functions;
41pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
42pub use graph::{Entity, Neighbor, Relation};
43pub use graph::{HigherOrderNeighbor, HigherOrderPath, HigherOrderPathStep, Hyperedge};
44pub use migrate::{
45    build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
46};
47pub use schema::{create_schema, schema_exists};
48pub use vector::{cosine_similarity, SearchResult, VectorStore};
49pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
50
51use rusqlite::Connection;
52use serde::{Deserialize, Serialize};
53
54/// Semantic search result with entity information.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SearchResultWithEntity {
57    pub entity: Entity,
58    pub similarity: f32,
59}
60
61/// Graph context for an entity (root + neighbors).
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct GraphContext {
64    pub root_entity: Entity,
65    pub neighbors: Vec<Neighbor>,
66}
67
68/// Hybrid search result combining semantic similarity and graph context.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct HybridSearchResult {
71    pub entity: Entity,
72    pub similarity: f32,
73    pub context: Option<GraphContext>,
74}
75
76/// Knowledge Graph Manager - main entry point for the library.
77#[derive(Debug)]
78pub struct KnowledgeGraph {
79    conn: Connection,
80}
81
82impl KnowledgeGraph {
83    /// Open a new knowledge graph database connection.
84    pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
85        let conn = Connection::open(path)?;
86
87        // Enable foreign keys
88        conn.execute("PRAGMA foreign_keys = ON", [])?;
89
90        // Create schema if not exists
91        if !schema_exists(&conn)? {
92            create_schema(&conn)?;
93        }
94
95        // Register custom functions
96        register_functions(&conn)?;
97
98        Ok(Self { conn })
99    }
100
101    /// Open an in-memory knowledge graph (useful for testing).
102    pub fn open_in_memory() -> Result<Self> {
103        let conn = Connection::open_in_memory()?;
104
105        // Enable foreign keys
106        conn.execute("PRAGMA foreign_keys = ON", [])?;
107
108        // Create schema
109        create_schema(&conn)?;
110
111        // Register custom functions
112        register_functions(&conn)?;
113
114        Ok(Self { conn })
115    }
116
117    /// Get a reference to the underlying SQLite connection.
118    pub fn connection(&self) -> &Connection {
119        &self.conn
120    }
121
122    /// Begin a transaction for batch operations.
123    pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
124        Ok(self.conn.unchecked_transaction()?)
125    }
126
127    /// Insert an entity into the knowledge graph.
128    pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
129        graph::insert_entity(&self.conn, entity)
130    }
131
132    /// Get an entity by ID.
133    pub fn get_entity(&self, id: i64) -> Result<Entity> {
134        graph::get_entity(&self.conn, id)
135    }
136
137    /// List entities with optional filtering.
138    pub fn list_entities(
139        &self,
140        entity_type: Option<&str>,
141        limit: Option<i64>,
142    ) -> Result<Vec<Entity>> {
143        graph::list_entities(&self.conn, entity_type, limit)
144    }
145
146    /// Update an entity.
147    pub fn update_entity(&self, entity: &Entity) -> Result<()> {
148        graph::update_entity(&self.conn, entity)
149    }
150
151    /// Delete an entity.
152    pub fn delete_entity(&self, id: i64) -> Result<()> {
153        graph::delete_entity(&self.conn, id)
154    }
155
156    /// Insert a relation between entities.
157    pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
158        graph::insert_relation(&self.conn, relation)
159    }
160
161    /// Get neighbors of an entity using BFS traversal.
162    pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
163        graph::get_neighbors(&self.conn, entity_id, depth)
164    }
165
166    /// Insert a vector embedding for an entity.
167    pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
168        let store = VectorStore::new();
169        store.insert_vector(&self.conn, entity_id, vector)
170    }
171
172    /// Search for similar entities using vector embeddings.
173    pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
174        let store = VectorStore::new();
175        store.search_vectors(&self.conn, query, k)
176    }
177
178    // ========== TurboQuant Vector Index ==========
179
180    /// Create a TurboQuant index for fast approximate nearest neighbor search.
181    ///
182    /// TurboQuant provides:
183    /// - Instant indexing (no training required)
184    /// - 6x memory compression
185    /// - Near-zero accuracy loss
186    ///
187    /// # Arguments
188    /// * `config` - Optional configuration (uses defaults if None)
189    ///
190    /// # Example
191    /// ```ignore
192    /// let config = TurboQuantConfig {
193    ///     dimension: 384,
194    ///     bit_width: 3,
195    ///     seed: 42,
196    /// };
197    /// let mut index = kg.create_turboquant_index(Some(config))?;
198    ///
199    /// // Add vectors to index
200    /// for (entity_id, vector) in all_vectors {
201    ///     index.add_vector(entity_id, &vector)?;
202    /// }
203    ///
204    /// // Fast search
205    /// let results = index.search(&query_vector, 10)?;
206    /// ```
207    pub fn create_turboquant_index(
208        &self,
209        config: Option<TurboQuantConfig>,
210    ) -> Result<TurboQuantIndex> {
211        let config = config.unwrap_or_default();
212
213        TurboQuantIndex::new(config)
214    }
215
216    /// Build a TurboQuant index from all existing vectors in the database.
217    /// This is a convenience method that loads all vectors and indexes them.
218    pub fn build_turboquant_index(
219        &self,
220        config: Option<TurboQuantConfig>,
221    ) -> Result<TurboQuantIndex> {
222        // Get dimension from first vector
223        let dimension = self.get_vector_dimension()?.unwrap_or(384);
224
225        let config = config.unwrap_or(TurboQuantConfig {
226            dimension,
227            bit_width: 3,
228            seed: 42,
229        });
230
231        let mut index = TurboQuantIndex::new(config)?;
232
233        // Load all vectors
234        let vectors = self.load_all_vectors()?;
235
236        for (entity_id, vector) in vectors {
237            index.add_vector(entity_id, &vector)?;
238        }
239
240        Ok(index)
241    }
242
243    /// Get the dimension of stored vectors (if any exist).
244    fn get_vector_dimension(&self) -> Result<Option<usize>> {
245        let result = self
246            .conn
247            .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
248                row.get::<_, i64>(0)
249            });
250
251        match result {
252            Ok(dim) => Ok(Some(dim as usize)),
253            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
254            Err(e) => Err(e.into()),
255        }
256    }
257
258    /// Load all vectors from the database.
259    fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
260        let mut stmt = self
261            .conn
262            .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
263
264        let rows = stmt.query_map([], |row| {
265            let entity_id: i64 = row.get(0)?;
266            let vector_blob: Vec<u8> = row.get(1)?;
267            let dimension: i64 = row.get(2)?;
268
269            let mut vector = Vec::with_capacity(dimension as usize);
270            for chunk in vector_blob.chunks_exact(4) {
271                let bytes: [u8; 4] = chunk.try_into().unwrap();
272                vector.push(f32::from_le_bytes(bytes));
273            }
274
275            Ok((entity_id, vector))
276        })?;
277
278        let mut vectors = Vec::new();
279        for row in rows {
280            vectors.push(row?);
281        }
282
283        Ok(vectors)
284    }
285
286    // ========== Higher-Order Relations (Hyperedges) ==========
287
288    /// Insert a hyperedge (higher-order relation) into the knowledge graph.
289    pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<i64> {
290        graph::insert_hyperedge(&self.conn, hyperedge)
291    }
292
293    /// Get a hyperedge by ID.
294    pub fn get_hyperedge(&self, id: i64) -> Result<Hyperedge> {
295        graph::get_hyperedge(&self.conn, id)
296    }
297
298    /// List hyperedges with optional filtering.
299    pub fn list_hyperedges(
300        &self,
301        hyperedge_type: Option<&str>,
302        min_arity: Option<usize>,
303        max_arity: Option<usize>,
304        limit: Option<i64>,
305    ) -> Result<Vec<Hyperedge>> {
306        graph::list_hyperedges(&self.conn, hyperedge_type, min_arity, max_arity, limit)
307    }
308
309    /// Update a hyperedge.
310    pub fn update_hyperedge(&self, hyperedge: &Hyperedge) -> Result<()> {
311        graph::update_hyperedge(&self.conn, hyperedge)
312    }
313
314    /// Delete a hyperedge by ID.
315    pub fn delete_hyperedge(&self, id: i64) -> Result<()> {
316        graph::delete_hyperedge(&self.conn, id)
317    }
318
319    /// Get higher-order neighbors of an entity (connected through hyperedges).
320    pub fn get_higher_order_neighbors(
321        &self,
322        entity_id: i64,
323        min_arity: Option<usize>,
324        max_arity: Option<usize>,
325    ) -> Result<Vec<HigherOrderNeighbor>> {
326        graph::get_higher_order_neighbors(&self.conn, entity_id, min_arity, max_arity)
327    }
328
329    /// Get all hyperedges that an entity participates in.
330    pub fn get_entity_hyperedges(&self, entity_id: i64) -> Result<Vec<Hyperedge>> {
331        graph::get_entity_hyperedges(&self.conn, entity_id)
332    }
333
334    /// Higher-order BFS traversal through hyperedges.
335    pub fn kg_higher_order_bfs(
336        &self,
337        start_id: i64,
338        max_depth: u32,
339        min_arity: Option<usize>,
340    ) -> Result<Vec<TraversalNode>> {
341        graph::higher_order_bfs(&self.conn, start_id, max_depth, min_arity)
342    }
343
344    /// Find shortest path between two entities through hyperedges.
345    pub fn kg_higher_order_shortest_path(
346        &self,
347        from_id: i64,
348        to_id: i64,
349        max_depth: u32,
350    ) -> Result<Option<HigherOrderPath>> {
351        graph::higher_order_shortest_path(&self.conn, from_id, to_id, max_depth)
352    }
353
354    /// Compute hyperedge degree centrality for an entity.
355    pub fn kg_hyperedge_degree(&self, entity_id: i64) -> Result<f64> {
356        graph::hyperedge_degree(&self.conn, entity_id)
357    }
358
359    /// Compute entity-level hypergraph PageRank using Zhou formula.
360    pub fn kg_hypergraph_entity_pagerank(
361        &self,
362        damping: Option<f64>,
363        max_iter: Option<usize>,
364        tolerance: Option<f64>,
365    ) -> Result<std::collections::HashMap<i64, f64>> {
366        graph::hypergraph_entity_pagerank(
367            &self.conn,
368            damping.unwrap_or(0.85),
369            max_iter.unwrap_or(100),
370            tolerance.unwrap_or(1e-6),
371        )
372    }
373
374    // ========== RAG Query Functions ==========
375
376    /// Semantic search using vector embeddings.
377    /// Returns entities sorted by similarity score.
378    pub fn kg_semantic_search(
379        &self,
380        query_embedding: Vec<f32>,
381        k: usize,
382    ) -> Result<Vec<SearchResultWithEntity>> {
383        let results = self.search_vectors(query_embedding, k)?;
384
385        let mut entities_with_results = Vec::new();
386        for result in results {
387            let entity = self.get_entity(result.entity_id)?;
388            entities_with_results.push(SearchResultWithEntity {
389                entity,
390                similarity: result.similarity,
391            });
392        }
393
394        Ok(entities_with_results)
395    }
396
397    /// Get context around an entity using graph traversal.
398    /// Returns neighbors up to the specified depth.
399    pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
400        let root_entity = self.get_entity(entity_id)?;
401        let neighbors = self.get_neighbors(entity_id, depth)?;
402
403        Ok(GraphContext {
404            root_entity,
405            neighbors,
406        })
407    }
408
409    /// Hybrid search combining semantic search and graph context.
410    /// Performs semantic search first, then retrieves context for top-k results.
411    pub fn kg_hybrid_search(
412        &self,
413        _query_text: &str,
414        query_embedding: Vec<f32>,
415        k: usize,
416    ) -> Result<Vec<HybridSearchResult>> {
417        let semantic_results = self.kg_semantic_search(query_embedding, k)?;
418
419        let mut hybrid_results = Vec::new();
420        for result in semantic_results.iter() {
421            let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
422            let context = self.kg_get_context(entity_id, 1)?; // Depth 1 context
423
424            hybrid_results.push(HybridSearchResult {
425                entity: result.entity.clone(),
426                similarity: result.similarity,
427                context: Some(context),
428            });
429        }
430
431        Ok(hybrid_results)
432    }
433
434    // ========== Graph Traversal Functions ==========
435
436    /// BFS traversal from a starting entity.
437    /// Returns all reachable entities within max_depth with depth information.
438    pub fn kg_bfs_traversal(
439        &self,
440        start_id: i64,
441        direction: Direction,
442        max_depth: u32,
443    ) -> Result<Vec<TraversalNode>> {
444        let query = TraversalQuery {
445            direction,
446            max_depth,
447            ..Default::default()
448        };
449        graph::bfs_traversal(&self.conn, start_id, query)
450    }
451
452    /// DFS traversal from a starting entity.
453    /// Returns all reachable entities within max_depth.
454    pub fn kg_dfs_traversal(
455        &self,
456        start_id: i64,
457        direction: Direction,
458        max_depth: u32,
459    ) -> Result<Vec<TraversalNode>> {
460        let query = TraversalQuery {
461            direction,
462            max_depth,
463            ..Default::default()
464        };
465        graph::dfs_traversal(&self.conn, start_id, query)
466    }
467
468    /// Find shortest path between two entities using BFS.
469    /// Returns the path with all intermediate steps (if exists).
470    pub fn kg_shortest_path(
471        &self,
472        from_id: i64,
473        to_id: i64,
474        max_depth: u32,
475    ) -> Result<Option<TraversalPath>> {
476        graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
477    }
478
479    /// Compute graph statistics.
480    pub fn kg_graph_stats(&self) -> Result<GraphStats> {
481        graph::compute_graph_stats(&self.conn)
482    }
483
484    // ========== Graph Algorithms ==========
485
486    /// Compute PageRank scores for all entities.
487    /// Returns a vector of (entity_id, score) sorted by score descending.
488    pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
489        algorithms::pagerank(&self.conn, config.unwrap_or_default())
490    }
491
492    /// Detect communities using Louvain algorithm.
493    /// Returns community memberships and modularity score.
494    pub fn kg_louvain(&self) -> Result<CommunityResult> {
495        algorithms::louvain_communities(&self.conn)
496    }
497
498    /// Find connected components in the graph.
499    /// Returns a list of components, each being a list of entity IDs.
500    pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
501        algorithms::connected_components(&self.conn)
502    }
503
504    /// Run full graph analysis (PageRank + Louvain + Connected Components).
505    pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
506        algorithms::analyze_graph(&self.conn)
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_open_in_memory() {
516        let kg = KnowledgeGraph::open_in_memory().unwrap();
517        assert!(schema_exists(kg.connection()).unwrap());
518    }
519
520    #[test]
521    fn test_crud_operations() {
522        let kg = KnowledgeGraph::open_in_memory().unwrap();
523
524        // Create entity
525        let mut entity = Entity::new("paper", "Test Paper");
526        entity.set_property("author", serde_json::json!("John Doe"));
527        let id = kg.insert_entity(&entity).unwrap();
528
529        // Read entity
530        let retrieved = kg.get_entity(id).unwrap();
531        assert_eq!(retrieved.name, "Test Paper");
532
533        // List entities
534        let entities = kg.list_entities(Some("paper"), None).unwrap();
535        assert_eq!(entities.len(), 1);
536
537        // Update entity
538        let mut updated = retrieved.clone();
539        updated.set_property("year", serde_json::json!(2024));
540        kg.update_entity(&updated).unwrap();
541
542        // Delete entity
543        kg.delete_entity(id).unwrap();
544        let entities = kg.list_entities(None, None).unwrap();
545        assert_eq!(entities.len(), 0);
546    }
547
548    #[test]
549    fn test_graph_traversal() {
550        let kg = KnowledgeGraph::open_in_memory().unwrap();
551
552        // Create entities
553        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
554        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
555        let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
556
557        // Create relations
558        kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
559            .unwrap();
560        kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
561            .unwrap();
562
563        // Get neighbors depth 1
564        let neighbors = kg.get_neighbors(id1, 1).unwrap();
565        assert_eq!(neighbors.len(), 1);
566
567        // Get neighbors depth 2
568        let neighbors = kg.get_neighbors(id1, 2).unwrap();
569        assert_eq!(neighbors.len(), 2);
570    }
571
572    #[test]
573    fn test_vector_search() {
574        let kg = KnowledgeGraph::open_in_memory().unwrap();
575
576        // Create entities
577        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
578        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
579
580        // Insert vectors
581        kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
582        kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
583
584        // Search
585        let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
586        assert_eq!(results.len(), 2);
587        assert_eq!(results[0].entity_id, id1);
588    }
589}