1pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod extension;
24pub mod functions;
25pub mod graph;
26pub mod migrate;
27pub mod schema;
28pub mod vector;
29
30pub use algorithms::{
31 analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
32 PageRankConfig,
33};
34pub use embed::{
35 check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
36 EmbeddingStats,
37};
38pub use error::{Error, Result};
39pub use extension::sqlite3_sqlite_knowledge_graph_init;
40pub use functions::register_functions;
41pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
42pub use graph::{Entity, Neighbor, Relation};
43pub use graph::{HigherOrderNeighbor, HigherOrderPath, HigherOrderPathStep, Hyperedge};
44pub use migrate::{
45 build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
46};
47pub use schema::{create_schema, schema_exists};
48pub use vector::{cosine_similarity, SearchResult, VectorStore};
49pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
50
51use rusqlite::Connection;
52use serde::{Deserialize, Serialize};
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SearchResultWithEntity {
57 pub entity: Entity,
58 pub similarity: f32,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct GraphContext {
64 pub root_entity: Entity,
65 pub neighbors: Vec<Neighbor>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct HybridSearchResult {
71 pub entity: Entity,
72 pub similarity: f32,
73 pub context: Option<GraphContext>,
74}
75
76#[derive(Debug)]
78pub struct KnowledgeGraph {
79 conn: Connection,
80}
81
82impl KnowledgeGraph {
83 pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
85 let conn = Connection::open(path)?;
86
87 conn.execute("PRAGMA foreign_keys = ON", [])?;
89
90 if !schema_exists(&conn)? {
92 create_schema(&conn)?;
93 }
94
95 register_functions(&conn)?;
97
98 Ok(Self { conn })
99 }
100
101 pub fn open_in_memory() -> Result<Self> {
103 let conn = Connection::open_in_memory()?;
104
105 conn.execute("PRAGMA foreign_keys = ON", [])?;
107
108 create_schema(&conn)?;
110
111 register_functions(&conn)?;
113
114 Ok(Self { conn })
115 }
116
117 pub fn connection(&self) -> &Connection {
119 &self.conn
120 }
121
122 pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
124 Ok(self.conn.unchecked_transaction()?)
125 }
126
127 pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
129 graph::insert_entity(&self.conn, entity)
130 }
131
132 pub fn get_entity(&self, id: i64) -> Result<Entity> {
134 graph::get_entity(&self.conn, id)
135 }
136
137 pub fn list_entities(
139 &self,
140 entity_type: Option<&str>,
141 limit: Option<i64>,
142 ) -> Result<Vec<Entity>> {
143 graph::list_entities(&self.conn, entity_type, limit)
144 }
145
146 pub fn update_entity(&self, entity: &Entity) -> Result<()> {
148 graph::update_entity(&self.conn, entity)
149 }
150
151 pub fn delete_entity(&self, id: i64) -> Result<()> {
153 graph::delete_entity(&self.conn, id)
154 }
155
156 pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
158 graph::insert_relation(&self.conn, relation)
159 }
160
161 pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
163 graph::get_neighbors(&self.conn, entity_id, depth)
164 }
165
166 pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
168 let store = VectorStore::new();
169 store.insert_vector(&self.conn, entity_id, vector)
170 }
171
172 pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
174 let store = VectorStore::new();
175 store.search_vectors(&self.conn, query, k)
176 }
177
178 pub fn create_turboquant_index(
208 &self,
209 config: Option<TurboQuantConfig>,
210 ) -> Result<TurboQuantIndex> {
211 let config = config.unwrap_or_default();
212
213 TurboQuantIndex::new(config)
214 }
215
216 pub fn build_turboquant_index(
219 &self,
220 config: Option<TurboQuantConfig>,
221 ) -> Result<TurboQuantIndex> {
222 let dimension = self.get_vector_dimension()?.unwrap_or(384);
224
225 let config = config.unwrap_or(TurboQuantConfig {
226 dimension,
227 bit_width: 3,
228 seed: 42,
229 });
230
231 let mut index = TurboQuantIndex::new(config)?;
232
233 let vectors = self.load_all_vectors()?;
235
236 for (entity_id, vector) in vectors {
237 index.add_vector(entity_id, &vector)?;
238 }
239
240 Ok(index)
241 }
242
243 fn get_vector_dimension(&self) -> Result<Option<usize>> {
245 let result = self
246 .conn
247 .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
248 row.get::<_, i64>(0)
249 });
250
251 match result {
252 Ok(dim) => Ok(Some(dim as usize)),
253 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
254 Err(e) => Err(e.into()),
255 }
256 }
257
258 fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
260 let mut stmt = self
261 .conn
262 .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
263
264 let rows = stmt.query_map([], |row| {
265 let entity_id: i64 = row.get(0)?;
266 let vector_blob: Vec<u8> = row.get(1)?;
267 let dimension: i64 = row.get(2)?;
268
269 let mut vector = Vec::with_capacity(dimension as usize);
270 for chunk in vector_blob.chunks_exact(4) {
271 let bytes: [u8; 4] = chunk.try_into().unwrap();
272 vector.push(f32::from_le_bytes(bytes));
273 }
274
275 Ok((entity_id, vector))
276 })?;
277
278 let mut vectors = Vec::new();
279 for row in rows {
280 vectors.push(row?);
281 }
282
283 Ok(vectors)
284 }
285
286 pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<i64> {
290 graph::insert_hyperedge(&self.conn, hyperedge)
291 }
292
293 pub fn get_hyperedge(&self, id: i64) -> Result<Hyperedge> {
295 graph::get_hyperedge(&self.conn, id)
296 }
297
298 pub fn list_hyperedges(
300 &self,
301 hyperedge_type: Option<&str>,
302 min_arity: Option<usize>,
303 max_arity: Option<usize>,
304 limit: Option<i64>,
305 ) -> Result<Vec<Hyperedge>> {
306 graph::list_hyperedges(&self.conn, hyperedge_type, min_arity, max_arity, limit)
307 }
308
309 pub fn update_hyperedge(&self, hyperedge: &Hyperedge) -> Result<()> {
311 graph::update_hyperedge(&self.conn, hyperedge)
312 }
313
314 pub fn delete_hyperedge(&self, id: i64) -> Result<()> {
316 graph::delete_hyperedge(&self.conn, id)
317 }
318
319 pub fn get_higher_order_neighbors(
321 &self,
322 entity_id: i64,
323 min_arity: Option<usize>,
324 max_arity: Option<usize>,
325 ) -> Result<Vec<HigherOrderNeighbor>> {
326 graph::get_higher_order_neighbors(&self.conn, entity_id, min_arity, max_arity)
327 }
328
329 pub fn get_entity_hyperedges(&self, entity_id: i64) -> Result<Vec<Hyperedge>> {
331 graph::get_entity_hyperedges(&self.conn, entity_id)
332 }
333
334 pub fn kg_higher_order_bfs(
336 &self,
337 start_id: i64,
338 max_depth: u32,
339 min_arity: Option<usize>,
340 ) -> Result<Vec<TraversalNode>> {
341 graph::higher_order_bfs(&self.conn, start_id, max_depth, min_arity)
342 }
343
344 pub fn kg_higher_order_shortest_path(
346 &self,
347 from_id: i64,
348 to_id: i64,
349 max_depth: u32,
350 ) -> Result<Option<HigherOrderPath>> {
351 graph::higher_order_shortest_path(&self.conn, from_id, to_id, max_depth)
352 }
353
354 pub fn kg_hyperedge_degree(&self, entity_id: i64) -> Result<f64> {
356 graph::hyperedge_degree(&self.conn, entity_id)
357 }
358
359 pub fn kg_hypergraph_entity_pagerank(
361 &self,
362 damping: Option<f64>,
363 max_iter: Option<usize>,
364 tolerance: Option<f64>,
365 ) -> Result<std::collections::HashMap<i64, f64>> {
366 graph::hypergraph_entity_pagerank(
367 &self.conn,
368 damping.unwrap_or(0.85),
369 max_iter.unwrap_or(100),
370 tolerance.unwrap_or(1e-6),
371 )
372 }
373
374 pub fn kg_semantic_search(
379 &self,
380 query_embedding: Vec<f32>,
381 k: usize,
382 ) -> Result<Vec<SearchResultWithEntity>> {
383 let results = self.search_vectors(query_embedding, k)?;
384
385 let mut entities_with_results = Vec::new();
386 for result in results {
387 let entity = self.get_entity(result.entity_id)?;
388 entities_with_results.push(SearchResultWithEntity {
389 entity,
390 similarity: result.similarity,
391 });
392 }
393
394 Ok(entities_with_results)
395 }
396
397 pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
400 let root_entity = self.get_entity(entity_id)?;
401 let neighbors = self.get_neighbors(entity_id, depth)?;
402
403 Ok(GraphContext {
404 root_entity,
405 neighbors,
406 })
407 }
408
409 pub fn kg_hybrid_search(
412 &self,
413 _query_text: &str,
414 query_embedding: Vec<f32>,
415 k: usize,
416 ) -> Result<Vec<HybridSearchResult>> {
417 let semantic_results = self.kg_semantic_search(query_embedding, k)?;
418
419 let mut hybrid_results = Vec::new();
420 for result in semantic_results.iter() {
421 let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
422 let context = self.kg_get_context(entity_id, 1)?; hybrid_results.push(HybridSearchResult {
425 entity: result.entity.clone(),
426 similarity: result.similarity,
427 context: Some(context),
428 });
429 }
430
431 Ok(hybrid_results)
432 }
433
434 pub fn kg_bfs_traversal(
439 &self,
440 start_id: i64,
441 direction: Direction,
442 max_depth: u32,
443 ) -> Result<Vec<TraversalNode>> {
444 let query = TraversalQuery {
445 direction,
446 max_depth,
447 ..Default::default()
448 };
449 graph::bfs_traversal(&self.conn, start_id, query)
450 }
451
452 pub fn kg_dfs_traversal(
455 &self,
456 start_id: i64,
457 direction: Direction,
458 max_depth: u32,
459 ) -> Result<Vec<TraversalNode>> {
460 let query = TraversalQuery {
461 direction,
462 max_depth,
463 ..Default::default()
464 };
465 graph::dfs_traversal(&self.conn, start_id, query)
466 }
467
468 pub fn kg_shortest_path(
471 &self,
472 from_id: i64,
473 to_id: i64,
474 max_depth: u32,
475 ) -> Result<Option<TraversalPath>> {
476 graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
477 }
478
479 pub fn kg_graph_stats(&self) -> Result<GraphStats> {
481 graph::compute_graph_stats(&self.conn)
482 }
483
484 pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
489 algorithms::pagerank(&self.conn, config.unwrap_or_default())
490 }
491
492 pub fn kg_louvain(&self) -> Result<CommunityResult> {
495 algorithms::louvain_communities(&self.conn)
496 }
497
498 pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
501 algorithms::connected_components(&self.conn)
502 }
503
504 pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
506 algorithms::analyze_graph(&self.conn)
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_open_in_memory() {
516 let kg = KnowledgeGraph::open_in_memory().unwrap();
517 assert!(schema_exists(kg.connection()).unwrap());
518 }
519
520 #[test]
521 fn test_crud_operations() {
522 let kg = KnowledgeGraph::open_in_memory().unwrap();
523
524 let mut entity = Entity::new("paper", "Test Paper");
526 entity.set_property("author", serde_json::json!("John Doe"));
527 let id = kg.insert_entity(&entity).unwrap();
528
529 let retrieved = kg.get_entity(id).unwrap();
531 assert_eq!(retrieved.name, "Test Paper");
532
533 let entities = kg.list_entities(Some("paper"), None).unwrap();
535 assert_eq!(entities.len(), 1);
536
537 let mut updated = retrieved.clone();
539 updated.set_property("year", serde_json::json!(2024));
540 kg.update_entity(&updated).unwrap();
541
542 kg.delete_entity(id).unwrap();
544 let entities = kg.list_entities(None, None).unwrap();
545 assert_eq!(entities.len(), 0);
546 }
547
548 #[test]
549 fn test_graph_traversal() {
550 let kg = KnowledgeGraph::open_in_memory().unwrap();
551
552 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
554 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
555 let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
556
557 kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
559 .unwrap();
560 kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
561 .unwrap();
562
563 let neighbors = kg.get_neighbors(id1, 1).unwrap();
565 assert_eq!(neighbors.len(), 1);
566
567 let neighbors = kg.get_neighbors(id1, 2).unwrap();
569 assert_eq!(neighbors.len(), 2);
570 }
571
572 #[test]
573 fn test_vector_search() {
574 let kg = KnowledgeGraph::open_in_memory().unwrap();
575
576 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
578 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
579
580 kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
582 kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
583
584 let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
586 assert_eq!(results.len(), 2);
587 assert_eq!(results[0].entity_id, id1);
588 }
589}