graphrag_core/core/
entity_adapters.rs1#[cfg(feature = "lightrag")]
7use crate::core::error::{GraphRAGError, Result};
8#[cfg(feature = "lightrag")]
9use crate::core::traits::AsyncEntityExtractor;
10#[cfg(feature = "lightrag")]
11use crate::core::Entity;
12#[cfg(feature = "lightrag")]
13use async_trait::async_trait;
14
15#[cfg(feature = "lightrag")]
16use crate::lightrag::graph_indexer::{ExtractedEntity, GraphIndexer};
17
18#[cfg(feature = "lightrag")]
20pub struct GraphIndexerAdapter {
21 indexer: GraphIndexer,
22 confidence_threshold: f32,
23}
24
25#[cfg(feature = "lightrag")]
26impl GraphIndexerAdapter {
27 pub fn new(entity_types: Vec<String>, max_depth: usize) -> Result<Self> {
29 Ok(Self {
30 indexer: GraphIndexer::new(entity_types, max_depth)?,
31 confidence_threshold: 0.5,
32 })
33 }
34
35 pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
37 self.confidence_threshold = threshold;
38 self
39 }
40
41 fn convert_entity(&self, extracted: &ExtractedEntity) -> Entity {
43 use crate::core::EntityId;
44 Entity {
45 id: EntityId::new(extracted.id.clone()),
46 name: extracted.name.clone(),
47 entity_type: extracted.entity_type.clone(),
48 confidence: extracted.confidence,
49 mentions: vec![], embedding: None, first_mentioned: None,
52 last_mentioned: None,
53 temporal_validity: None,
54 }
55 }
56}
57
58#[cfg(feature = "lightrag")]
59#[async_trait]
60impl AsyncEntityExtractor for GraphIndexerAdapter {
61 type Entity = Entity;
62 type Error = GraphRAGError;
63
64 async fn extract(&self, text: &str) -> Result<Vec<Self::Entity>> {
65 let result = self.indexer.extract_from_text(text)?;
66
67 Ok(result
69 .entities
70 .iter()
71 .filter(|e| e.confidence >= self.confidence_threshold)
72 .map(|e| self.convert_entity(e))
73 .collect())
74 }
75
76 async fn extract_with_confidence(&self, text: &str) -> Result<Vec<(Self::Entity, f32)>> {
77 let result = self.indexer.extract_from_text(text)?;
78
79 Ok(result
81 .entities
82 .iter()
83 .filter(|e| e.confidence >= self.confidence_threshold)
84 .map(|e| (self.convert_entity(e), e.confidence))
85 .collect())
86 }
87
88 async fn extract_batch(&self, texts: &[&str]) -> Result<Vec<Vec<Self::Entity>>> {
89 let mut results = Vec::with_capacity(texts.len());
90 for text in texts {
91 results.push(self.extract(text).await?);
92 }
93 Ok(results)
94 }
95
96 async fn set_confidence_threshold(&mut self, threshold: f32) {
97 self.confidence_threshold = threshold;
98 }
99
100 async fn get_confidence_threshold(&self) -> f32 {
101 self.confidence_threshold
102 }
103}
104
105#[cfg(all(test, feature = "lightrag"))]
106mod tests {
107 use super::*;
108
109 #[tokio::test]
110 async fn test_graph_indexer_adapter() {
111 let adapter =
112 GraphIndexerAdapter::new(vec!["person".to_string(), "organization".to_string()], 3)
113 .unwrap();
114
115 let text = "John Smith works at Microsoft Corporation.";
116 let entities = adapter.extract(text).await.unwrap();
117
118 assert!(!entities.is_empty());
119 for entity in &entities {
120 assert!(entity.confidence >= 0.5);
121 }
122 }
123
124 #[tokio::test]
125 async fn test_confidence_threshold_filtering() {
126 let adapter = GraphIndexerAdapter::new(vec!["person".to_string()], 3)
127 .unwrap()
128 .with_confidence_threshold(0.6);
129
130 let text = "John Smith works at Microsoft.";
131 let entities = adapter.extract(text).await.unwrap();
132
133 for entity in &entities {
135 assert!(entity.confidence >= 0.6);
136 }
137 }
138
139 #[tokio::test]
140 async fn test_batch_extraction() {
141 let adapter =
142 GraphIndexerAdapter::new(vec!["person".to_string(), "location".to_string()], 3)
143 .unwrap();
144
145 let texts = vec!["Alice lives in Paris.", "Bob works in London."];
146
147 let results = adapter.extract_batch(&texts).await.unwrap();
148 assert_eq!(results.len(), 2);
149 }
150}