Skip to main content

graphrag_core/core/
entity_adapters.rs

1//! Entity extraction adapters for core traits
2//!
3//! This module provides adapter implementations that bridge existing entity extractors
4//! with the core GraphRAG AsyncEntityExtractor trait.
5
6#[cfg(feature = "lightrag")]
7use crate::core::error::{GraphRAGError, Result};
8#[cfg(feature = "lightrag")]
9use crate::core::traits::AsyncEntityExtractor;
10#[cfg(feature = "lightrag")]
11use crate::core::Entity;
12#[cfg(feature = "lightrag")]
13use async_trait::async_trait;
14
15#[cfg(feature = "lightrag")]
16use crate::lightrag::graph_indexer::{ExtractedEntity, GraphIndexer};
17
18/// Adapter for GraphIndexer to implement AsyncEntityExtractor trait
19#[cfg(feature = "lightrag")]
20pub struct GraphIndexerAdapter {
21    indexer: GraphIndexer,
22    confidence_threshold: f32,
23}
24
25#[cfg(feature = "lightrag")]
26impl GraphIndexerAdapter {
27    /// Create a new GraphIndexer adapter
28    pub fn new(entity_types: Vec<String>, max_depth: usize) -> Result<Self> {
29        Ok(Self {
30            indexer: GraphIndexer::new(entity_types, max_depth)?,
31            confidence_threshold: 0.5,
32        })
33    }
34
35    /// Create with custom confidence threshold
36    pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
37        self.confidence_threshold = threshold;
38        self
39    }
40
41    /// Convert ExtractedEntity to core::Entity
42    fn convert_entity(&self, extracted: &ExtractedEntity) -> Entity {
43        use crate::core::EntityId;
44        Entity {
45            id: EntityId::new(extracted.id.clone()),
46            name: extracted.name.clone(),
47            entity_type: extracted.entity_type.clone(),
48            confidence: extracted.confidence,
49            mentions: vec![], // GraphIndexer doesn't track mentions
50            embedding: None,  // No embedding in GraphIndexer
51            first_mentioned: None,
52            last_mentioned: None,
53            temporal_validity: None,
54        }
55    }
56}
57
58#[cfg(feature = "lightrag")]
59#[async_trait]
60impl AsyncEntityExtractor for GraphIndexerAdapter {
61    type Entity = Entity;
62    type Error = GraphRAGError;
63
64    async fn extract(&self, text: &str) -> Result<Vec<Self::Entity>> {
65        let result = self.indexer.extract_from_text(text)?;
66
67        // Filter by confidence threshold and convert
68        Ok(result
69            .entities
70            .iter()
71            .filter(|e| e.confidence >= self.confidence_threshold)
72            .map(|e| self.convert_entity(e))
73            .collect())
74    }
75
76    async fn extract_with_confidence(&self, text: &str) -> Result<Vec<(Self::Entity, f32)>> {
77        let result = self.indexer.extract_from_text(text)?;
78
79        // Filter by confidence threshold and convert
80        Ok(result
81            .entities
82            .iter()
83            .filter(|e| e.confidence >= self.confidence_threshold)
84            .map(|e| (self.convert_entity(e), e.confidence))
85            .collect())
86    }
87
88    async fn extract_batch(&self, texts: &[&str]) -> Result<Vec<Vec<Self::Entity>>> {
89        let mut results = Vec::with_capacity(texts.len());
90        for text in texts {
91            results.push(self.extract(text).await?);
92        }
93        Ok(results)
94    }
95
96    async fn set_confidence_threshold(&mut self, threshold: f32) {
97        self.confidence_threshold = threshold;
98    }
99
100    async fn get_confidence_threshold(&self) -> f32 {
101        self.confidence_threshold
102    }
103}
104
105#[cfg(all(test, feature = "lightrag"))]
106mod tests {
107    use super::*;
108
109    #[tokio::test]
110    async fn test_graph_indexer_adapter() {
111        let adapter =
112            GraphIndexerAdapter::new(vec!["person".to_string(), "organization".to_string()], 3)
113                .unwrap();
114
115        let text = "John Smith works at Microsoft Corporation.";
116        let entities = adapter.extract(text).await.unwrap();
117
118        assert!(!entities.is_empty());
119        for entity in &entities {
120            assert!(entity.confidence >= 0.5);
121        }
122    }
123
124    #[tokio::test]
125    async fn test_confidence_threshold_filtering() {
126        let adapter = GraphIndexerAdapter::new(vec!["person".to_string()], 3)
127            .unwrap()
128            .with_confidence_threshold(0.6);
129
130        let text = "John Smith works at Microsoft.";
131        let entities = adapter.extract(text).await.unwrap();
132
133        // All entities should have confidence >= 0.6
134        for entity in &entities {
135            assert!(entity.confidence >= 0.6);
136        }
137    }
138
139    #[tokio::test]
140    async fn test_batch_extraction() {
141        let adapter =
142            GraphIndexerAdapter::new(vec!["person".to_string(), "location".to_string()], 3)
143                .unwrap();
144
145        let texts = vec!["Alice lives in Paris.", "Bob works in London."];
146
147        let results = adapter.extract_batch(&texts).await.unwrap();
148        assert_eq!(results.len(), 2);
149    }
150}