1pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod extension;
24pub mod functions;
25pub mod graph;
26pub mod migrate;
27pub mod schema;
28pub mod vector;
29
30pub use algorithms::{
31 analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
32 PageRankConfig,
33};
34pub use embed::{
35 check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
36 EmbeddingStats,
37};
38pub use error::{Error, Result};
39pub use extension::sqlite3_sqlite_knowledge_graph_init;
40pub use functions::register_functions;
41pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
42pub use graph::{Entity, Neighbor, Relation};
43pub use migrate::{
44 build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
45};
46pub use schema::{create_schema, schema_exists};
47pub use vector::{cosine_similarity, SearchResult, VectorStore};
48pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
49
50use rusqlite::Connection;
51use serde::{Deserialize, Serialize};
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SearchResultWithEntity {
56 pub entity: Entity,
57 pub similarity: f32,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct GraphContext {
63 pub root_entity: Entity,
64 pub neighbors: Vec<Neighbor>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct HybridSearchResult {
70 pub entity: Entity,
71 pub similarity: f32,
72 pub context: Option<GraphContext>,
73}
74
75#[derive(Debug)]
77pub struct KnowledgeGraph {
78 conn: Connection,
79}
80
81impl KnowledgeGraph {
82 pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
84 let conn = Connection::open(path)?;
85
86 conn.execute("PRAGMA foreign_keys = ON", [])?;
88
89 if !schema_exists(&conn)? {
91 create_schema(&conn)?;
92 }
93
94 register_functions(&conn)?;
96
97 Ok(Self { conn })
98 }
99
100 pub fn open_in_memory() -> Result<Self> {
102 let conn = Connection::open_in_memory()?;
103
104 conn.execute("PRAGMA foreign_keys = ON", [])?;
106
107 create_schema(&conn)?;
109
110 register_functions(&conn)?;
112
113 Ok(Self { conn })
114 }
115
116 pub fn connection(&self) -> &Connection {
118 &self.conn
119 }
120
121 pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
123 Ok(self.conn.unchecked_transaction()?)
124 }
125
126 pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
128 graph::insert_entity(&self.conn, entity)
129 }
130
131 pub fn get_entity(&self, id: i64) -> Result<Entity> {
133 graph::get_entity(&self.conn, id)
134 }
135
136 pub fn list_entities(
138 &self,
139 entity_type: Option<&str>,
140 limit: Option<i64>,
141 ) -> Result<Vec<Entity>> {
142 graph::list_entities(&self.conn, entity_type, limit)
143 }
144
145 pub fn update_entity(&self, entity: &Entity) -> Result<()> {
147 graph::update_entity(&self.conn, entity)
148 }
149
150 pub fn delete_entity(&self, id: i64) -> Result<()> {
152 graph::delete_entity(&self.conn, id)
153 }
154
155 pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
157 graph::insert_relation(&self.conn, relation)
158 }
159
160 pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
162 graph::get_neighbors(&self.conn, entity_id, depth)
163 }
164
165 pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
167 let store = VectorStore::new();
168 store.insert_vector(&self.conn, entity_id, vector)
169 }
170
171 pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
173 let store = VectorStore::new();
174 store.search_vectors(&self.conn, query, k)
175 }
176
177 pub fn create_turboquant_index(
207 &self,
208 config: Option<TurboQuantConfig>,
209 ) -> Result<TurboQuantIndex> {
210 let config = config.unwrap_or_default();
211
212 TurboQuantIndex::new(config)
213 }
214
215 pub fn build_turboquant_index(
218 &self,
219 config: Option<TurboQuantConfig>,
220 ) -> Result<TurboQuantIndex> {
221 let dimension = self.get_vector_dimension()?.unwrap_or(384);
223
224 let config = config.unwrap_or(TurboQuantConfig {
225 dimension,
226 bit_width: 3,
227 seed: 42,
228 });
229
230 let mut index = TurboQuantIndex::new(config)?;
231
232 let vectors = self.load_all_vectors()?;
234
235 for (entity_id, vector) in vectors {
236 index.add_vector(entity_id, &vector)?;
237 }
238
239 Ok(index)
240 }
241
242 fn get_vector_dimension(&self) -> Result<Option<usize>> {
244 let result = self
245 .conn
246 .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
247 row.get::<_, i64>(0)
248 });
249
250 match result {
251 Ok(dim) => Ok(Some(dim as usize)),
252 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
253 Err(e) => Err(e.into()),
254 }
255 }
256
257 fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
259 let mut stmt = self
260 .conn
261 .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
262
263 let rows = stmt.query_map([], |row| {
264 let entity_id: i64 = row.get(0)?;
265 let vector_blob: Vec<u8> = row.get(1)?;
266 let dimension: i64 = row.get(2)?;
267
268 let mut vector = Vec::with_capacity(dimension as usize);
269 for chunk in vector_blob.chunks_exact(4) {
270 let bytes: [u8; 4] = chunk.try_into().unwrap();
271 vector.push(f32::from_le_bytes(bytes));
272 }
273
274 Ok((entity_id, vector))
275 })?;
276
277 let mut vectors = Vec::new();
278 for row in rows {
279 vectors.push(row?);
280 }
281
282 Ok(vectors)
283 }
284
285 pub fn kg_semantic_search(
290 &self,
291 query_embedding: Vec<f32>,
292 k: usize,
293 ) -> Result<Vec<SearchResultWithEntity>> {
294 let results = self.search_vectors(query_embedding, k)?;
295
296 let mut entities_with_results = Vec::new();
297 for result in results {
298 let entity = self.get_entity(result.entity_id)?;
299 entities_with_results.push(SearchResultWithEntity {
300 entity,
301 similarity: result.similarity,
302 });
303 }
304
305 Ok(entities_with_results)
306 }
307
308 pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
311 let root_entity = self.get_entity(entity_id)?;
312 let neighbors = self.get_neighbors(entity_id, depth)?;
313
314 Ok(GraphContext {
315 root_entity,
316 neighbors,
317 })
318 }
319
320 pub fn kg_hybrid_search(
323 &self,
324 _query_text: &str,
325 query_embedding: Vec<f32>,
326 k: usize,
327 ) -> Result<Vec<HybridSearchResult>> {
328 let semantic_results = self.kg_semantic_search(query_embedding, k)?;
329
330 let mut hybrid_results = Vec::new();
331 for result in semantic_results.iter() {
332 let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
333 let context = self.kg_get_context(entity_id, 1)?; hybrid_results.push(HybridSearchResult {
336 entity: result.entity.clone(),
337 similarity: result.similarity,
338 context: Some(context),
339 });
340 }
341
342 Ok(hybrid_results)
343 }
344
345 pub fn kg_bfs_traversal(
350 &self,
351 start_id: i64,
352 direction: Direction,
353 max_depth: u32,
354 ) -> Result<Vec<TraversalNode>> {
355 let query = TraversalQuery {
356 direction,
357 max_depth,
358 ..Default::default()
359 };
360 graph::bfs_traversal(&self.conn, start_id, query)
361 }
362
363 pub fn kg_dfs_traversal(
366 &self,
367 start_id: i64,
368 direction: Direction,
369 max_depth: u32,
370 ) -> Result<Vec<TraversalNode>> {
371 let query = TraversalQuery {
372 direction,
373 max_depth,
374 ..Default::default()
375 };
376 graph::dfs_traversal(&self.conn, start_id, query)
377 }
378
379 pub fn kg_shortest_path(
382 &self,
383 from_id: i64,
384 to_id: i64,
385 max_depth: u32,
386 ) -> Result<Option<TraversalPath>> {
387 graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
388 }
389
390 pub fn kg_graph_stats(&self) -> Result<GraphStats> {
392 graph::compute_graph_stats(&self.conn)
393 }
394
395 pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
400 algorithms::pagerank(&self.conn, config.unwrap_or_default())
401 }
402
403 pub fn kg_louvain(&self) -> Result<CommunityResult> {
406 algorithms::louvain_communities(&self.conn)
407 }
408
409 pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
412 algorithms::connected_components(&self.conn)
413 }
414
415 pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
417 algorithms::analyze_graph(&self.conn)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_open_in_memory() {
427 let kg = KnowledgeGraph::open_in_memory().unwrap();
428 assert!(schema_exists(kg.connection()).unwrap());
429 }
430
431 #[test]
432 fn test_crud_operations() {
433 let kg = KnowledgeGraph::open_in_memory().unwrap();
434
435 let mut entity = Entity::new("paper", "Test Paper");
437 entity.set_property("author", serde_json::json!("John Doe"));
438 let id = kg.insert_entity(&entity).unwrap();
439
440 let retrieved = kg.get_entity(id).unwrap();
442 assert_eq!(retrieved.name, "Test Paper");
443
444 let entities = kg.list_entities(Some("paper"), None).unwrap();
446 assert_eq!(entities.len(), 1);
447
448 let mut updated = retrieved.clone();
450 updated.set_property("year", serde_json::json!(2024));
451 kg.update_entity(&updated).unwrap();
452
453 kg.delete_entity(id).unwrap();
455 let entities = kg.list_entities(None, None).unwrap();
456 assert_eq!(entities.len(), 0);
457 }
458
459 #[test]
460 fn test_graph_traversal() {
461 let kg = KnowledgeGraph::open_in_memory().unwrap();
462
463 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
465 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
466 let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
467
468 kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
470 .unwrap();
471 kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
472 .unwrap();
473
474 let neighbors = kg.get_neighbors(id1, 1).unwrap();
476 assert_eq!(neighbors.len(), 1);
477
478 let neighbors = kg.get_neighbors(id1, 2).unwrap();
480 assert_eq!(neighbors.len(), 2);
481 }
482
483 #[test]
484 fn test_vector_search() {
485 let kg = KnowledgeGraph::open_in_memory().unwrap();
486
487 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
489 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
490
491 kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
493 kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
494
495 let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
497 assert_eq!(results.len(), 2);
498 assert_eq!(results[0].entity_id, id1);
499 }
500}