oxirs_graphrag/lib.rs
1//! # OxiRS GraphRAG
2//!
3//! **GraphRAG** (Graph Retrieval-Augmented Generation) is a production-ready
4//! Rust library that combines **knowledge-graph topology traversal** with
5//! **vector similarity search** to deliver context-rich answers for LLM
6//! pipelines — without any network dependencies at query time.
7//!
8//! It is the JVM-free, pure-Rust counterpart of Microsoft's GraphRAG and
9//! LangChain's knowledge-graph QA stack, integrated directly with the OxiRS
10//! semantic-web engine.
11//!
12//! ## Data-flow overview
13//!
14//! ```text
15//! Natural-Language Query
16//! │
17//! ▼
18//! ┌───────────────────┐
19//! │ Query Embedding │ (oxirs-embed / Node2Vec / TransE)
20//! └────────┬──────────┘
21//! │
22//! ┌──────┴──────┐
23//! │ │
24//! ▼ ▼
25//! Vector Keyword
26//! KNN BM25
27//! Search Search
28//! │ │
29//! └──────┬──────┘
30//! │
31//! ▼
32//! ┌───────────────┐
33//! │ RRF Fusion │ Reciprocal Rank Fusion → Seed Entities
34//! └───────┬───────┘
35//! │
36//! ▼
37//! ┌────────────────────────┐
38//! │ SPARQL N-hop Expansion│ Graph traversal (up to 500 triples)
39//! └────────────┬───────────┘
40//! │
41//! ▼
42//! ┌────────────────────────┐
43//! │ Community Detection │ Louvain / Leiden clustering
44//! └────────────┬───────────┘
45//! │
46//! ▼
47//! ┌────────────────────────┐
48//! │ Context Building │ Subgraph → natural-language context
49//! └────────────┬───────────┘
50//! │
51//! ▼
52//! ┌────────────────────────┐
53//! │ LLM Generation │ Answer + citations
54//! └────────────────────────┘
55//! ```
56//!
57//! ## Key modules
58//!
59//! | Module | Purpose |
60//! |--------|---------|
61//! [`triple_extractor`] | Rule-based NLP → RDF triple extraction |
62//! [`community_detector`] | Greedy label-propagation community detection |
63//! [`path_finder`] | BFS / DFS shortest-path retrieval in KGs |
64//! [`graph_embedder`] | Node2Vec-style random-walk structural embeddings |
65//! [`summarizer`] | Cluster-based subgraph summarization for LLM context |
66//! [`path_ranker`] | Predicate-weighted path ranking |
67//! [`context_builder`] | N-hop subgraph extraction and truncation |
68//! [`knowledge_fusion`] | Multi-source KG fusion with provenance |
69//! [`graph_summarization`] | PageRank-style community summary generation |
70//! [`entity_linking`] | Entity linking and disambiguation |
71//! [`explainability`] | Attention weights, path explanation, provenance |
72//! [`feedback`] | Session-scoped user-feedback weight adaptation |
73//! [`graph`] | Core community detection and graph traversal primitives |
74//! [`retrieval`] | Hybrid vector + keyword retrieval with RRF fusion |
75//! [`generation`] | Prompt templates and LLM context building |
76//! [`temporal`] | Temporal knowledge graph retrieval |
77//!
78//! ## Quickstart — standalone pipeline (no network, no LLM)
79//!
80//! The example below runs an end-to-end mini-pipeline entirely in memory on a
81//! synthetic 8-node knowledge graph: extract triples from text, detect
82//! communities, find paths, and summarize the result.
83//!
84//! ```rust
85//! use oxirs_graphrag::triple_extractor::{ExtractionConfig, TripleExtractor};
86//! use oxirs_graphrag::community_detector::{CommunityGraph, CommunityDetector};
87//! use oxirs_graphrag::path_finder::{KnowledgeEdge, PathFinder, PathFinderConfig};
88//! use oxirs_graphrag::summarizer::{KgEdge, KgNode, KgSubgraph, SubgraphSummarizer};
89//!
90//! // ── Step 1: Extract triples from natural language ─────────────────────────
91//! let corpus = [
92//! "Alice is a data scientist.",
93//! "Bob works at ACME.",
94//! "Carol is a software engineer.",
95//! "Dave is part of the AI team.",
96//! "ACME has a research division.",
97//! ];
98//! let extractor = TripleExtractor::with_defaults(ExtractionConfig::default());
99//! let all_triples: Vec<_> = corpus
100//! .iter()
101//! .flat_map(|sentence| extractor.extract(sentence))
102//! .collect();
103//! assert!(!all_triples.is_empty(), "at least one triple extracted");
104//!
105//! // ── Step 2: Build community graph and detect clusters ─────────────────────
106//! let mut cg = CommunityGraph::new();
107//! // 8 synthetic nodes
108//! for (id, label) in [
109//! (1u64, "Alice"), (2, "Bob"), (3, "Carol"), (4, "Dave"),
110//! (5, "ACME"), (6, "AI-Team"), (7, "Research"), (8, "Berlin"),
111//! ] {
112//! cg.add_node(id, label);
113//! }
114//! for (a, b) in [(1,5),(2,5),(3,6),(4,6),(5,7),(6,7),(7,8),(1,2)] {
115//! cg.add_edge(a, b, 1.0);
116//! }
117//! let detector = CommunityDetector::new(2, 50);
118//! let detection = detector.detect(&mut cg);
119//! assert!(!detection.communities.is_empty(), "at least one community");
120//!
121//! // ── Step 3: Graph path retrieval ──────────────────────────────────────────
122//! let edges = vec![
123//! KnowledgeEdge::new("Alice", "works_at", "ACME"),
124//! KnowledgeEdge::new("ACME", "located_in", "Berlin"),
125//! KnowledgeEdge::new("Bob", "knows", "Alice"),
126//! KnowledgeEdge::new("Alice", "member_of", "AI-Team"),
127//! KnowledgeEdge::new("AI-Team", "part_of", "ACME"),
128//! KnowledgeEdge::new("Carol", "works_at", "ACME"),
129//! KnowledgeEdge::new("Dave", "leads", "AI-Team"),
130//! KnowledgeEdge::new("Research", "division_of", "ACME"),
131//! ];
132//! let finder = PathFinder::new(edges, PathFinderConfig::default());
133//! let paths = finder.bfs_paths("Bob", "Berlin", 4);
134//! assert!(!paths.is_empty(), "path Bob→Berlin found");
135//!
136//! // ── Step 4: Summarize subgraph for LLM context ────────────────────────────
137//! let mut subgraph = KgSubgraph::new();
138//! for (id, label, ty) in [
139//! ("alice", "Alice", "Person"),
140//! ("bob", "Bob", "Person"),
141//! ("carol", "Carol", "Person"),
142//! ("acme", "ACME", "Organization"),
143//! ("berlin", "Berlin", "Place"),
144//! ("ai_team", "AI-Team", "Team"),
145//! ("research", "Research", "Department"),
146//! ("dave", "Dave", "Person"),
147//! ] {
148//! subgraph.add_node(KgNode::simple(id, label, ty));
149//! }
150//! subgraph.add_edge(KgEdge::unweighted("alice", "acme", "works_at"));
151//! subgraph.add_edge(KgEdge::unweighted("acme", "berlin","located_in"));
152//!
153//! let summarizer = SubgraphSummarizer::new();
154//! let clusters = summarizer.summarize(&subgraph, 10);
155//! assert!(!clusters.is_empty(), "at least one cluster");
156//! let text_summary = summarizer.generate_text_summary(&clusters);
157//! assert!(!text_summary.is_empty(), "non-empty summary text");
158//! ```
159//!
160//! ## Full engine usage (async, requires trait impls)
161//!
162//! For production usage with a real vector index, embedding model, SPARQL engine,
163//! and LLM client:
164//!
165//! ```rust,ignore
166//! use oxirs_graphrag::{GraphRAGEngine, GraphRAGConfig};
167//! use std::sync::Arc;
168//!
169//! let config = GraphRAGConfig {
170//! top_k: 20,
171//! expansion_hops: 2,
172//! enable_communities: true,
173//! ..Default::default()
174//! };
175//!
176//! // Provide your own implementations of VectorIndexTrait, EmbeddingModelTrait,
177//! // SparqlEngineTrait, and LlmClientTrait:
178//! let engine = GraphRAGEngine::new(
179//! Arc::new(my_vec_index),
180//! Arc::new(my_embedder),
181//! Arc::new(my_sparql),
182//! Arc::new(my_llm),
183//! config,
184//! );
185//!
186//! let result = engine.query("What safety issues affect battery cells?").await?;
187//! println!("Answer: {}", result.answer);
188//! println!("Confidence: {:.2}", result.confidence);
189//! ```
190//!
191//! See [`docs/tutorial.md`](https://github.com/cool-japan/oxirs/blob/master/ai/oxirs-graphrag/docs/tutorial.md)
192//! for a step-by-step walkthrough.
193
194pub mod cache;
195pub mod config;
196pub mod distributed;
197// v1.1.0: Graph summarization for RAG
198pub mod embeddings;
199pub mod federation;
200pub mod fusion;
201pub mod generation;
202pub mod graph;
203pub mod graph_summarization;
204pub mod query;
205pub mod reasoning;
206pub mod retrieval;
207pub mod sparql;
208pub mod streaming;
209pub mod temporal;
210
211// v1.1.0 TransE knowledge graph embedding model
212pub mod transe_model;
213
214// v1.1.0: Entity linking and disambiguation for knowledge graphs
215pub mod entity_linking;
216
217// v1.1.0 round 5: Community detection (Louvain-inspired greedy label propagation)
218pub mod community_detector;
219
220// v1.1.0 round 6: Knowledge graph path ranking (DFS + Dijkstra + scoring)
221pub mod path_ranker;
222
223// v1.1.0 round 7: String-to-RDF entity linking (mention detection + candidate ranking)
224pub mod entity_linker;
225
226// v1.1.0 round 11: Node2Vec-inspired graph embedding and structural node representations
227pub mod graph_embedder;
228
229// v1.1.0 round 12: Graph partitioning using greedy / label-propagation / bisection methods
230pub mod graph_partitioner;
231
232// v1.1.0 round 13: Rule-based knowledge triple extraction from natural language text
233pub mod triple_extractor;
234
235// v1.1.0 round 11: Multi-source knowledge fusion with provenance tracking
236pub mod knowledge_fusion;
237
238// v1.1.0 round 12: Context building for graph-based RAG (N-hop, ranking, truncation, formatting)
239pub mod context_builder;
240
241// v1.1.0 round 13: Graph path finding for RAG (BFS/DFS, shortest path, predicate filtering, scoring)
242pub mod path_finder;
243
244// v1.1.0 round 14: KG subgraph summarization via cluster-based abstraction
245pub mod summarizer;
246
247// v1.1.0 round 15: Entity type classification for knowledge graph nodes
248pub mod entity_classifier;
249
250// v1.1.0 round 16: Explainability — attention weights, path explanation, provenance
251pub mod explainability;
252
253// v1.1.0 round 17: Interactive refinement with user feedback
254pub mod feedback;
255
256// v0.4.0: Re-export new GraphSummarizer + GraphSummary types
257pub use summarizer::{GraphSummarizer, GraphSummary};
258// v0.4.0: Re-export new TripleRelevanceFeedback + Relevance types
259pub use feedback::{Relevance, TripleId, TripleRelevanceFeedback};
260
261// v0.3.0 / block-5: GNN encoder — phase a: GraphSAGE over the knowledge graph
262pub mod gnn_encoder;
263
264// v0.3.1: GNN encoder new components
265pub use gnn_encoder::{
266 AdjacencyGraph, EdgeList, GnnEncoder, GnnEncoderConfig, ScaledDotProductAttention,
267};
268
269// v0.3.0 / block-6: Hybrid GNN+LLM — phase b/c: LLM head with frozen GNN soft-prompt
270pub mod hybrid;
271
272// v0.3.0 / block-8: Hybrid GNN+LLM phase d — GGUF model loader + LoRA adapter
273#[cfg(feature = "gguf-loader")]
274pub mod model_loader;
275
276// v0.3.0 / block-7: Neuro-symbolic fusion — PINN-driven physics-informed entity scoring
277pub mod neuro_symbolic;
278
279use std::collections::HashMap;
280use std::sync::atomic::{AtomicU64, Ordering};
281use std::sync::Arc;
282use std::time::{Duration, SystemTime};
283
284use async_trait::async_trait;
285use chrono::{DateTime, Utc};
286use serde::{Deserialize, Serialize};
287use thiserror::Error;
288use tokio::sync::RwLock;
289
290// Re-exports
291pub use cache::query_cache::{CacheEntry, CacheStats, QueryCache, QueryCacheConfig};
292pub use config::{CacheConfiguration, GraphRAGConfig};
293pub use embeddings::node2vec::{
294 Node2VecConfig, Node2VecEmbedder, Node2VecEmbeddings, Node2VecWalkConfig,
295};
296pub use graph::community::{CommunityAlgorithm, CommunityConfig, CommunityDetector};
297pub use graph::embeddings::{CommunityAwareEmbeddings, CommunityStructure, EmbeddingConfig};
298pub use graph::traversal::GraphTraversal;
299pub use hybrid::lora::{LoraAdapter, LoraTrainer};
300pub use query::planner::QueryPlanner;
301pub use retrieval::fusion::FusionStrategy;
302
303// Feature-gated re-exports for GGUF model loader.
304#[cfg(feature = "gguf-loader")]
305pub use model_loader::{
306 GgufMetadata, GgufModelArch, GgufParseError, GgufParser, GgufTensorInfo, GgufValue,
307 ModelHandle, ModelInfo, ModelRegistry, RegistryError,
308};
309
310/// GraphRAG error types
311#[derive(Error, Debug)]
312pub enum GraphRAGError {
313 #[error("Vector search failed: {0}")]
314 VectorSearchError(String),
315
316 #[error("Graph traversal failed: {0}")]
317 GraphTraversalError(String),
318
319 #[error("Community detection failed: {0}")]
320 CommunityDetectionError(String),
321
322 #[error("LLM generation failed: {0}")]
323 GenerationError(String),
324
325 #[error("Embedding failed: {0}")]
326 EmbeddingError(String),
327
328 #[error("SPARQL query failed: {0}")]
329 SparqlError(String),
330
331 #[error("Configuration error: {0}")]
332 ConfigError(String),
333
334 #[error("Internal error: {0}")]
335 InternalError(String),
336}
337
338pub type GraphRAGResult<T> = Result<T, GraphRAGError>;
339
340/// Triple representation for RDF data
341#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
342pub struct Triple {
343 pub subject: String,
344 pub predicate: String,
345 pub object: String,
346}
347
348impl Triple {
349 pub fn new(
350 subject: impl Into<String>,
351 predicate: impl Into<String>,
352 object: impl Into<String>,
353 ) -> Self {
354 Self {
355 subject: subject.into(),
356 predicate: predicate.into(),
357 object: object.into(),
358 }
359 }
360}
361
362/// Entity with relevance score
363#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct ScoredEntity {
365 /// Entity URI
366 pub uri: String,
367 /// Relevance score (0.0 - 1.0)
368 pub score: f64,
369 /// Source of the score (vector, keyword, or fused)
370 pub source: ScoreSource,
371 /// Additional metadata
372 pub metadata: HashMap<String, String>,
373}
374
375/// Source of entity score
376#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
377pub enum ScoreSource {
378 /// Score from vector similarity search
379 Vector,
380 /// Score from keyword/BM25 search
381 Keyword,
382 /// Fused score from multiple sources
383 Fused,
384 /// Score from graph traversal (path-based)
385 Graph,
386}
387
388/// Community summary for hierarchical retrieval
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct CommunitySummary {
391 /// Community identifier
392 pub id: String,
393 /// Human-readable summary of the community
394 pub summary: String,
395 /// Member entities in this community
396 pub entities: Vec<String>,
397 /// Representative triples from this community
398 pub representative_triples: Vec<Triple>,
399 /// Community level in hierarchy (0 = leaf, higher = more abstract)
400 pub level: u32,
401 /// Modularity score
402 pub modularity: f64,
403}
404
405/// Query provenance for attribution
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct QueryProvenance {
408 /// Query timestamp
409 pub timestamp: DateTime<Utc>,
410 /// Original query text
411 pub original_query: String,
412 /// Expanded query (if any)
413 pub expanded_query: Option<String>,
414 /// Seed entities used
415 pub seed_entities: Vec<String>,
416 /// Triples contributing to the answer
417 pub source_triples: Vec<Triple>,
418 /// Community summaries used (if hierarchical)
419 pub community_sources: Vec<String>,
420 /// Processing time in milliseconds
421 pub processing_time_ms: u64,
422}
423
424/// GraphRAG query result
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct GraphRAGResult2 {
427 /// Natural language answer
428 pub answer: String,
429 /// Source subgraph (RDF triples)
430 pub subgraph: Vec<Triple>,
431 /// Seed entities with scores
432 pub seeds: Vec<ScoredEntity>,
433 /// Community summaries (if enabled)
434 pub communities: Vec<CommunitySummary>,
435 /// Provenance information
436 pub provenance: QueryProvenance,
437 /// Confidence score (0.0 - 1.0)
438 pub confidence: f64,
439}
440
441/// Trait for vector index operations
442#[async_trait]
443pub trait VectorIndexTrait: Send + Sync {
444 /// Search for k nearest neighbors
445 async fn search_knn(
446 &self,
447 query_vector: &[f32],
448 k: usize,
449 ) -> GraphRAGResult<Vec<(String, f32)>>;
450
451 /// Search with similarity threshold
452 async fn search_threshold(
453 &self,
454 query_vector: &[f32],
455 threshold: f32,
456 ) -> GraphRAGResult<Vec<(String, f32)>>;
457}
458
459/// Trait for embedding model operations
460#[async_trait]
461pub trait EmbeddingModelTrait: Send + Sync {
462 /// Embed text into vector
463 async fn embed(&self, text: &str) -> GraphRAGResult<Vec<f32>>;
464
465 /// Embed multiple texts in batch
466 async fn embed_batch(&self, texts: &[&str]) -> GraphRAGResult<Vec<Vec<f32>>>;
467}
468
469/// Trait for SPARQL engine operations
470#[async_trait]
471pub trait SparqlEngineTrait: Send + Sync {
472 /// Execute SELECT query
473 async fn select(&self, query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>>;
474
475 /// Execute ASK query
476 async fn ask(&self, query: &str) -> GraphRAGResult<bool>;
477
478 /// Execute CONSTRUCT query
479 async fn construct(&self, query: &str) -> GraphRAGResult<Vec<Triple>>;
480}
481
482/// Trait for LLM client operations
483#[async_trait]
484pub trait LlmClientTrait: Send + Sync {
485 /// Generate response from context and query
486 async fn generate(&self, context: &str, query: &str) -> GraphRAGResult<String>;
487
488 /// Generate with streaming response
489 async fn generate_stream(
490 &self,
491 context: &str,
492 query: &str,
493 callback: Box<dyn Fn(&str) + Send + Sync>,
494 ) -> GraphRAGResult<String>;
495}
496
497/// Cached result with metadata
498#[derive(Debug, Clone)]
499struct CachedResult {
500 result: GraphRAGResult2,
501 timestamp: SystemTime,
502 ttl: Duration,
503}
504
505impl CachedResult {
506 /// Check if the cached result is still fresh
507 fn is_fresh(&self) -> bool {
508 self.timestamp
509 .elapsed()
510 .map(|elapsed| elapsed < self.ttl)
511 .unwrap_or(false)
512 }
513}
514
515/// Cache configuration
516#[derive(Debug, Clone)]
517pub struct CacheConfig {
518 /// Base TTL in seconds (default: 3600 = 1 hour)
519 pub base_ttl_seconds: u64,
520 /// Minimum TTL in seconds (default: 300 = 5 minutes)
521 pub min_ttl_seconds: u64,
522 /// Maximum TTL in seconds (default: 86400 = 24 hours)
523 pub max_ttl_seconds: u64,
524 /// Enable adaptive TTL based on update frequency
525 pub adaptive: bool,
526}
527
528impl Default for CacheConfig {
529 fn default() -> Self {
530 Self {
531 base_ttl_seconds: 3600,
532 min_ttl_seconds: 300,
533 max_ttl_seconds: 86400,
534 adaptive: true,
535 }
536 }
537}
538
539/// Main GraphRAG engine
540pub struct GraphRAGEngine<V, E, S, L>
541where
542 V: VectorIndexTrait,
543 E: EmbeddingModelTrait,
544 S: SparqlEngineTrait,
545 L: LlmClientTrait,
546{
547 /// Vector index for similarity search
548 vec_index: Arc<V>,
549 /// Embedding model for query vectorization
550 embedding_model: Arc<E>,
551 /// SPARQL engine for graph traversal
552 sparql_engine: Arc<S>,
553 /// LLM client for answer generation
554 llm_client: Arc<L>,
555 /// Configuration
556 config: GraphRAGConfig,
557 /// Query result cache with adaptive TTL
558 cache: Arc<RwLock<lru::LruCache<String, CachedResult>>>,
559 /// Cache configuration
560 cache_config: CacheConfig,
561 /// Graph update counter for adaptive TTL
562 graph_update_count: Arc<AtomicU64>,
563 /// Community detector (lazy initialized)
564 community_detector: Option<Arc<CommunityDetector>>,
565}
566
567impl<V, E, S, L> GraphRAGEngine<V, E, S, L>
568where
569 V: VectorIndexTrait,
570 E: EmbeddingModelTrait,
571 S: SparqlEngineTrait,
572 L: LlmClientTrait,
573{
574 /// Create a new GraphRAG engine
575 pub fn new(
576 vec_index: Arc<V>,
577 embedding_model: Arc<E>,
578 sparql_engine: Arc<S>,
579 llm_client: Arc<L>,
580 config: GraphRAGConfig,
581 ) -> Self {
582 let cache_config = CacheConfig {
583 base_ttl_seconds: config.cache_config.base_ttl_seconds,
584 min_ttl_seconds: config.cache_config.min_ttl_seconds,
585 max_ttl_seconds: config.cache_config.max_ttl_seconds,
586 adaptive: config.cache_config.adaptive,
587 };
588
589 Self::with_cache_config(
590 vec_index,
591 embedding_model,
592 sparql_engine,
593 llm_client,
594 config,
595 cache_config,
596 )
597 }
598
599 /// Create a new GraphRAG engine with custom cache configuration
600 pub fn with_cache_config(
601 vec_index: Arc<V>,
602 embedding_model: Arc<E>,
603 sparql_engine: Arc<S>,
604 llm_client: Arc<L>,
605 config: GraphRAGConfig,
606 cache_config: CacheConfig,
607 ) -> Self {
608 const DEFAULT_CACHE_SIZE: std::num::NonZeroUsize = match std::num::NonZeroUsize::new(1000) {
609 Some(size) => size,
610 None => panic!("constant is non-zero"),
611 };
612
613 let cache_size = config
614 .cache_size
615 .and_then(std::num::NonZeroUsize::new)
616 .unwrap_or(DEFAULT_CACHE_SIZE);
617
618 Self {
619 vec_index,
620 embedding_model,
621 sparql_engine,
622 llm_client,
623 config,
624 cache: Arc::new(RwLock::new(lru::LruCache::new(cache_size))),
625 cache_config,
626 graph_update_count: Arc::new(AtomicU64::new(0)),
627 community_detector: None,
628 }
629 }
630
631 /// Calculate adaptive TTL based on graph update frequency
632 fn calculate_ttl(&self) -> Duration {
633 if !self.cache_config.adaptive {
634 return Duration::from_secs(self.cache_config.base_ttl_seconds);
635 }
636
637 let updates_per_hour = self.graph_update_count.load(Ordering::Relaxed) as f64;
638
639 // More updates = shorter TTL
640 let ttl_secs = if updates_per_hour > 100.0 {
641 self.cache_config.min_ttl_seconds // High update rate: 5 min TTL
642 } else if updates_per_hour > 10.0 {
643 self.cache_config.base_ttl_seconds / 2 // Medium: 30 min TTL
644 } else {
645 self.cache_config.max_ttl_seconds // Low update rate: 24 hour TTL
646 };
647
648 Duration::from_secs(ttl_secs)
649 }
650
651 /// Record graph update for adaptive TTL calculation
652 pub fn record_graph_update(&self) {
653 self.graph_update_count.fetch_add(1, Ordering::Relaxed);
654 }
655
656 /// Get current cache hit rate for monitoring
657 pub async fn get_cache_stats(&self) -> (usize, usize) {
658 let cache = self.cache.read().await;
659 (cache.len(), cache.cap().get())
660 }
661
662 /// Execute a GraphRAG query
663 pub async fn query(&self, query: &str) -> GraphRAGResult<GraphRAGResult2> {
664 let start_time = std::time::Instant::now();
665
666 // Check cache with freshness validation
667 {
668 let cache = self.cache.read().await;
669 if let Some(cached) = cache.peek(&query.to_string()) {
670 if cached.is_fresh() {
671 return Ok(cached.result.clone());
672 }
673 }
674 }
675
676 // 1. Embed query
677 let query_vec = self.embedding_model.embed(query).await?;
678
679 // 2. Vector retrieval (Top-K)
680 let vector_results = self
681 .vec_index
682 .search_knn(&query_vec, self.config.top_k)
683 .await?;
684
685 // 3. Keyword retrieval (BM25) - simplified for now
686 let keyword_results = self.keyword_search(query).await?;
687
688 // 4. Fusion (RRF)
689 let seeds = self.fuse_results(&vector_results, &keyword_results)?;
690
691 // 5. Graph expansion (SPARQL)
692 let subgraph = self.expand_graph(&seeds).await?;
693
694 // 6. Community detection (optional)
695 let communities = if self.config.enable_communities {
696 self.detect_communities(&subgraph)?
697 } else {
698 vec![]
699 };
700
701 // 7. Build context
702 let context = self.build_context(&subgraph, &communities, query)?;
703
704 // 8. Generate answer
705 let answer = self.llm_client.generate(&context, query).await?;
706
707 // Calculate confidence based on seed scores and graph coverage
708 let confidence = self.calculate_confidence(&seeds, &subgraph);
709
710 let result = GraphRAGResult2 {
711 answer,
712 subgraph: subgraph.clone(),
713 seeds: seeds.clone(),
714 communities,
715 provenance: QueryProvenance {
716 timestamp: Utc::now(),
717 original_query: query.to_string(),
718 expanded_query: None,
719 seed_entities: seeds.iter().map(|s| s.uri.clone()).collect(),
720 source_triples: subgraph,
721 community_sources: vec![],
722 processing_time_ms: start_time.elapsed().as_millis() as u64,
723 },
724 confidence,
725 };
726
727 // Update cache with adaptive TTL
728 let ttl = self.calculate_ttl();
729 let cached = CachedResult {
730 result: result.clone(),
731 timestamp: SystemTime::now(),
732 ttl,
733 };
734 self.cache.write().await.put(query.to_string(), cached);
735
736 Ok(result)
737 }
738
739 /// Keyword search using BM25 (simplified)
740 async fn keyword_search(&self, query: &str) -> GraphRAGResult<Vec<(String, f32)>> {
741 // Build SPARQL query with text matching
742 let terms: Vec<&str> = query.split_whitespace().collect();
743 if terms.is_empty() {
744 return Ok(vec![]);
745 }
746
747 // Create SPARQL FILTER with regex for each term
748 let filters: Vec<String> = terms
749 .iter()
750 .map(|term| format!("REGEX(STR(?label), \"{}\", \"i\")", term))
751 .collect();
752
753 let sparql = format!(
754 r#"
755 SELECT DISTINCT ?entity (COUNT(*) AS ?score) WHERE {{
756 ?entity rdfs:label|schema:name|dc:title ?label .
757 FILTER({})
758 }}
759 GROUP BY ?entity
760 ORDER BY DESC(?score)
761 LIMIT {}
762 "#,
763 filters.join(" || "),
764 self.config.top_k
765 );
766
767 let results = self.sparql_engine.select(&sparql).await?;
768
769 Ok(results
770 .into_iter()
771 .filter_map(|row| {
772 let entity = row.get("entity")?.clone();
773 let score = row.get("score")?.parse::<f32>().ok()?;
774 Some((entity, score))
775 })
776 .collect())
777 }
778
779 /// Fuse vector and keyword results using Reciprocal Rank Fusion
780 fn fuse_results(
781 &self,
782 vector_results: &[(String, f32)],
783 keyword_results: &[(String, f32)],
784 ) -> GraphRAGResult<Vec<ScoredEntity>> {
785 let k = 60.0; // RRF constant
786
787 let mut scores: HashMap<String, (f64, ScoreSource)> = HashMap::new();
788
789 // Add vector scores
790 for (rank, (uri, score)) in vector_results.iter().enumerate() {
791 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
792 scores.insert(
793 uri.clone(),
794 (
795 rrf_score * self.config.vector_weight as f64,
796 ScoreSource::Vector,
797 ),
798 );
799 }
800
801 // Add keyword scores
802 for (rank, (uri, _score)) in keyword_results.iter().enumerate() {
803 let rrf_score = 1.0 / (k + rank as f64 + 1.0);
804 let keyword_contribution = rrf_score * self.config.keyword_weight as f64;
805
806 match scores.get(uri).cloned() {
807 Some((existing_score, _)) => {
808 let new_score = existing_score + keyword_contribution;
809 scores.insert(uri.clone(), (new_score, ScoreSource::Fused));
810 }
811 None => {
812 scores.insert(uri.clone(), (keyword_contribution, ScoreSource::Keyword));
813 }
814 }
815 }
816
817 // Sort by score and take top results
818 let mut entities: Vec<ScoredEntity> = scores
819 .into_iter()
820 .map(|(uri, (score, source))| ScoredEntity {
821 uri,
822 score,
823 source,
824 metadata: HashMap::new(),
825 })
826 .collect();
827
828 entities.sort_by(|a, b| {
829 b.score
830 .partial_cmp(&a.score)
831 .unwrap_or(std::cmp::Ordering::Equal)
832 });
833 entities.truncate(self.config.max_seeds);
834
835 Ok(entities)
836 }
837
838 /// Expand graph from seed entities using SPARQL
839 async fn expand_graph(&self, seeds: &[ScoredEntity]) -> GraphRAGResult<Vec<Triple>> {
840 if seeds.is_empty() {
841 return Ok(vec![]);
842 }
843
844 let seed_uris: Vec<String> = seeds.iter().map(|s| format!("<{}>", s.uri)).collect();
845 let values = seed_uris.join(" ");
846
847 // N-hop neighbor expansion
848 let hops = self.config.expansion_hops;
849 let path_pattern = if hops == 1 {
850 "?seed ?p ?neighbor".to_string()
851 } else {
852 format!("?seed (:|!:){{1,{}}} ?neighbor", hops)
853 };
854
855 let sparql = format!(
856 r#"
857 CONSTRUCT {{
858 ?seed ?p ?o .
859 ?s ?p2 ?seed .
860 ?neighbor ?p3 ?o2 .
861 }}
862 WHERE {{
863 VALUES ?seed {{ {} }}
864 {{
865 ?seed ?p ?o .
866 }} UNION {{
867 ?s ?p2 ?seed .
868 }} UNION {{
869 {}
870 ?neighbor ?p3 ?o2 .
871 }}
872 }}
873 LIMIT {}
874 "#,
875 values, path_pattern, self.config.max_subgraph_size
876 );
877
878 self.sparql_engine.construct(&sparql).await
879 }
880
881 /// Detect communities in the subgraph using Louvain algorithm
882 fn detect_communities(&self, subgraph: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
883 use petgraph::graph::UnGraph;
884
885 if subgraph.is_empty() {
886 return Ok(vec![]);
887 }
888
889 // Build undirected graph
890 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
891 let mut node_indices: HashMap<String, petgraph::graph::NodeIndex> = HashMap::new();
892
893 for triple in subgraph {
894 let subj_idx = *node_indices
895 .entry(triple.subject.clone())
896 .or_insert_with(|| graph.add_node(triple.subject.clone()));
897 let obj_idx = *node_indices
898 .entry(triple.object.clone())
899 .or_insert_with(|| graph.add_node(triple.object.clone()));
900
901 if subj_idx != obj_idx {
902 graph.add_edge(subj_idx, obj_idx, ());
903 }
904 }
905
906 // Simple community detection based on connected components
907 // (Full Louvain implementation would be more complex)
908 let components = petgraph::algo::kosaraju_scc(&graph);
909
910 let communities: Vec<CommunitySummary> = components
911 .into_iter()
912 .enumerate()
913 .filter(|(_, component)| component.len() >= 2)
914 .map(|(idx, component)| {
915 let entities: Vec<String> = component
916 .iter()
917 .filter_map(|&node_idx| graph.node_weight(node_idx).cloned())
918 .collect();
919
920 let representative_triples: Vec<Triple> = subgraph
921 .iter()
922 .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
923 .take(5)
924 .cloned()
925 .collect();
926
927 CommunitySummary {
928 id: format!("community_{}", idx),
929 summary: format!("Community with {} entities", entities.len()),
930 entities,
931 representative_triples,
932 level: 0,
933 modularity: 0.0,
934 }
935 })
936 .collect();
937
938 Ok(communities)
939 }
940
941 /// Build context string for LLM from subgraph and communities
942 fn build_context(
943 &self,
944 subgraph: &[Triple],
945 communities: &[CommunitySummary],
946 _query: &str,
947 ) -> GraphRAGResult<String> {
948 let mut context = String::new();
949
950 // Add community summaries if available
951 if !communities.is_empty() {
952 context.push_str("## Community Context\n\n");
953 for community in communities {
954 context.push_str(&format!("### {}\n", community.id));
955 context.push_str(&format!("{}\n", community.summary));
956 context.push_str(&format!("Entities: {}\n\n", community.entities.join(", ")));
957 }
958 }
959
960 // Add relevant triples
961 context.push_str("## Knowledge Graph Facts\n\n");
962 for triple in subgraph.iter().take(self.config.max_context_triples) {
963 context.push_str(&format!(
964 "- {} → {} → {}\n",
965 triple.subject, triple.predicate, triple.object
966 ));
967 }
968
969 Ok(context)
970 }
971
972 /// Calculate confidence score based on retrieval quality
973 fn calculate_confidence(&self, seeds: &[ScoredEntity], subgraph: &[Triple]) -> f64 {
974 if seeds.is_empty() {
975 return 0.0;
976 }
977
978 // Average seed score
979 let avg_seed_score: f64 = seeds.iter().map(|s| s.score).sum::<f64>() / seeds.len() as f64;
980
981 // Graph coverage (how many seeds appear in subgraph)
982 let seed_uris: std::collections::HashSet<_> = seeds.iter().map(|s| &s.uri).collect();
983 let covered: usize = subgraph
984 .iter()
985 .filter(|t| seed_uris.contains(&t.subject) || seed_uris.contains(&t.object))
986 .count();
987 let coverage = if subgraph.is_empty() {
988 0.0
989 } else {
990 (covered as f64 / subgraph.len() as f64).min(1.0)
991 };
992
993 // Combined confidence
994 (avg_seed_score * 0.6 + coverage * 0.4).min(1.0)
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use super::*;
1001
1002 #[test]
1003 fn test_triple_creation() {
1004 let triple = Triple::new(
1005 "http://example.org/s",
1006 "http://example.org/p",
1007 "http://example.org/o",
1008 );
1009 assert_eq!(triple.subject, "http://example.org/s");
1010 assert_eq!(triple.predicate, "http://example.org/p");
1011 assert_eq!(triple.object, "http://example.org/o");
1012 }
1013
1014 #[test]
1015 fn test_scored_entity() {
1016 let entity = ScoredEntity {
1017 uri: "http://example.org/entity".to_string(),
1018 score: 0.85,
1019 source: ScoreSource::Fused,
1020 metadata: HashMap::new(),
1021 };
1022 assert_eq!(entity.score, 0.85);
1023 assert_eq!(entity.source, ScoreSource::Fused);
1024 }
1025}