Skip to main content

graphrag_core/entity/
gleaning_extractor.rs

1#![cfg_attr(not(feature = "async"), allow(unused_imports))]
2
3//! Gleaning-based entity extraction with TRUE LLM inference
4//!
5//! This module implements iterative gleaning refinement using actual LLM calls,
6//! not pattern matching. Based on Microsoft GraphRAG and LightRAG research.
7//!
8//! Expected performance: 15-30 seconds per chunk per round. For a 1000-page book
9//! with 4 gleaning rounds, expect 2-4 hours of processing time.
10
11use crate::{
12    core::{Entity, Relationship, Result, TextChunk},
13    entity::{
14        llm_extractor::LLMEntityExtractor,
15        prompts::{EntityData, RelationshipData},
16    },
17    ollama::OllamaClient,
18};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21
22/// Configuration for gleaning-based entity extraction
23#[derive(Debug, Clone)]
24pub struct GleaningConfig {
25    /// Maximum number of gleaning rounds (typically 3-4)
26    pub max_gleaning_rounds: usize,
27    /// Threshold for extraction completion (0.0-1.0)
28    pub completion_threshold: f64,
29    /// Minimum confidence for extracted entities (0.0-1.0)
30    pub entity_confidence_threshold: f64,
31    /// Whether to use LLM for completion checking (always true for real gleaning)
32    pub use_llm_completion_check: bool,
33    /// Entity types to extract
34    pub entity_types: Vec<String>,
35    /// LLM temperature for extraction (lower = more consistent)
36    pub temperature: f32,
37    /// Maximum tokens for LLM responses
38    pub max_tokens: usize,
39}
40
41impl Default for GleaningConfig {
42    fn default() -> Self {
43        Self {
44            max_gleaning_rounds: 4, // Microsoft GraphRAG uses 4 rounds
45            completion_threshold: 0.85,
46            entity_confidence_threshold: 0.7,
47            use_llm_completion_check: true, // Always use LLM for real gleaning
48            entity_types: vec![
49                "PERSON".to_string(),
50                "ORGANIZATION".to_string(),
51                "LOCATION".to_string(),
52                "EVENT".to_string(),
53                "CONCEPT".to_string(),
54            ],
55            temperature: 0.0, // Zero temperature for deterministic JSON extraction
56            max_tokens: 1500,
57        }
58    }
59}
60
61/// Status of entity extraction completion
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ExtractionCompletionStatus {
64    /// Whether extraction is considered complete
65    pub is_complete: bool,
66    /// Confidence score for completeness (0.0-1.0)
67    pub confidence: f64,
68    /// Aspects that may be missing from extraction
69    pub missing_aspects: Vec<String>,
70    /// Suggestions for improving extraction
71    pub suggestions: Vec<String>,
72}
73
74/// Entity extractor with iterative gleaning refinement using TRUE LLM calls
75///
76/// This is the REAL implementation that makes actual LLM API calls for every extraction.
77/// It replaces the fake pattern-based extraction with genuine language model inference.
78pub struct GleaningEntityExtractor {
79    #[cfg_attr(not(feature = "async"), allow(dead_code))]
80    llm_extractor: LLMEntityExtractor,
81    config: GleaningConfig,
82}
83
84impl GleaningEntityExtractor {
85    /// Create a new gleaning entity extractor with LLM client
86    ///
87    /// # Arguments
88    /// * `ollama_client` - Ollama client for LLM inference (REQUIRED)
89    /// * `config` - Gleaning configuration
90    pub fn new(ollama_client: OllamaClient, config: GleaningConfig) -> Self {
91        // Extract keep_alive before ollama_client is moved into the extractor
92        let keep_alive = ollama_client.config().keep_alive.clone();
93
94        // Create LLM extractor with configured entity types
95        let llm_extractor = LLMEntityExtractor::new(ollama_client, config.entity_types.clone())
96            .with_temperature(config.temperature)
97            .with_max_tokens(config.max_tokens)
98            .with_keep_alive(keep_alive);
99
100        Self {
101            llm_extractor,
102            config,
103        }
104    }
105
106    /// Extract entities with iterative refinement (gleaning) using TRUE LLM calls
107    ///
108    /// This is the REAL implementation that makes actual LLM API calls.
109    /// Expected time: 15-30 seconds per round = 60-120 seconds total for 4 rounds per chunk.
110    ///
111    /// # Performance
112    /// - 1 chunk, 4 rounds: ~2 minutes
113    /// - 100 chunks, 4 rounds: ~3-4 hours
114    /// - Tom Sawyer (1000 pages): ~2-4 hours
115    #[cfg(feature = "async")]
116    pub async fn extract_with_gleaning(
117        &self,
118        chunk: &TextChunk,
119    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
120        #[cfg(feature = "tracing")]
121        tracing::info!(
122            "🔍 Starting REAL LLM gleaning extraction for chunk: {} ({} chars)",
123            chunk.id,
124            chunk.content.len()
125        );
126
127        let start_time = std::time::Instant::now();
128
129        // Track all extracted entities and relationships across rounds
130        let mut all_entity_data: Vec<EntityData> = Vec::new();
131        let mut all_relationship_data: Vec<RelationshipData> = Vec::new();
132
133        // Round 1: Initial extraction (THIS IS A REAL LLM CALL!)
134        #[cfg(feature = "tracing")]
135        tracing::info!("📝 Round 1: Initial LLM extraction...");
136        let round_start = std::time::Instant::now();
137
138        let (initial_entities, initial_relationships) =
139            self.llm_extractor.extract_from_chunk(chunk).await?;
140
141        #[cfg(feature = "tracing")]
142        tracing::info!(
143            "✅ Round 1 complete: {} entities, {} relationships ({:.1}s)",
144            initial_entities.len(),
145            initial_relationships.len(),
146            round_start.elapsed().as_secs_f32()
147        );
148
149        // Convert to EntityData for tracking
150        let mut entity_data = self.convert_entities_to_data(&initial_entities);
151        let mut relationship_data = self.convert_relationships_to_data(&initial_relationships);
152
153        all_entity_data.append(&mut entity_data);
154        all_relationship_data.append(&mut relationship_data);
155
156        // Rounds 2-N: Gleaning continuation rounds
157        for round in 2..=self.config.max_gleaning_rounds {
158            #[cfg(feature = "tracing")]
159            tracing::info!("📝 Round {}: Gleaning continuation...", round);
160            let round_start = std::time::Instant::now();
161
162            // Check if extraction is complete using LLM (REAL LLM CALL!)
163            if self.config.use_llm_completion_check {
164                let is_complete = self
165                    .llm_extractor
166                    .check_completion(chunk, &all_entity_data, &all_relationship_data)
167                    .await?;
168
169                if is_complete {
170                    #[cfg(feature = "tracing")]
171                    tracing::info!(
172                        "✅ LLM determined extraction is COMPLETE after {} rounds ({:.1}s total)",
173                        round - 1,
174                        start_time.elapsed().as_secs_f32()
175                    );
176                    break;
177                }
178
179                #[cfg(feature = "tracing")]
180                tracing::debug!("⚠️  LLM determined extraction is INCOMPLETE, continuing...");
181            }
182
183            // Perform additional extraction round (REAL LLM CALL!)
184            let (additional_entities, additional_relationships) = self
185                .llm_extractor
186                .extract_additional(chunk, &all_entity_data, &all_relationship_data)
187                .await?;
188
189            #[cfg(feature = "tracing")]
190            tracing::info!(
191                "✅ Round {} complete: {} new entities, {} new relationships ({:.1}s)",
192                round,
193                additional_entities.len(),
194                additional_relationships.len(),
195                round_start.elapsed().as_secs_f32()
196            );
197
198            // If no new entities found, stop gleaning
199            if additional_entities.is_empty() && additional_relationships.is_empty() {
200                #[cfg(feature = "tracing")]
201                tracing::info!(
202                    "🛑 No additional entities found in round {}, stopping gleaning",
203                    round
204                );
205                break;
206            }
207
208            // Convert and merge new results
209            let new_entity_data = self.convert_entities_to_data(&additional_entities);
210            let mut new_relationship_data =
211                self.convert_relationships_to_data(&additional_relationships);
212
213            // Merge with length-based strategy (LightRAG approach)
214            all_entity_data = self.merge_entity_data(all_entity_data, new_entity_data);
215            all_relationship_data.append(&mut new_relationship_data);
216        }
217
218        // Convert back to domain entities and relationships
219        let final_entities =
220            self.convert_data_to_entities(&all_entity_data, &chunk.id, &chunk.content)?;
221        let final_relationships =
222            self.convert_data_to_relationships(&all_relationship_data, &final_entities)?;
223
224        // Deduplicate relationships
225        let deduplicated_relationships = self.deduplicate_relationships(final_relationships);
226
227        let total_time = start_time.elapsed().as_secs_f32();
228
229        #[cfg(feature = "tracing")]
230        tracing::info!(
231            "🎉 REAL LLM gleaning complete: {} entities, {} relationships ({:.1}s total)",
232            final_entities.len(),
233            deduplicated_relationships.len(),
234            total_time
235        );
236
237        Ok((final_entities, deduplicated_relationships))
238    }
239
240    /// Merge entity data using length-based strategy (LightRAG approach)
241    ///
242    /// When multiple rounds produce the same entity, keep the version with the longer description
243    /// as it likely contains more information.
244    #[cfg(feature = "async")]
245    fn merge_entity_data(
246        &self,
247        existing: Vec<EntityData>,
248        new: Vec<EntityData>,
249    ) -> Vec<EntityData> {
250        let mut merged: HashMap<String, EntityData> = HashMap::new();
251
252        // Add existing entities to map (normalized by lowercase name)
253        for entity in existing {
254            let key = entity.name.to_lowercase();
255            merged.insert(key, entity);
256        }
257
258        // Merge new entities - keep longer descriptions
259        for new_entity in new {
260            let key = new_entity.name.to_lowercase();
261
262            match merged.get(&key) {
263                Some(existing_entity) => {
264                    // Keep the entity with the longer description (more information)
265                    if new_entity.description.len() > existing_entity.description.len() {
266                        #[cfg(feature = "tracing")]
267                        tracing::debug!(
268                            "Merging entity '{}': keeping longer description ({} chars vs {} chars)",
269                            new_entity.name,
270                            new_entity.description.len(),
271                            existing_entity.description.len()
272                        );
273                        merged.insert(key, new_entity);
274                    } else {
275                        #[cfg(feature = "tracing")]
276                        tracing::debug!(
277                            "Entity '{}' already exists with longer description, keeping existing",
278                            new_entity.name
279                        );
280                    }
281                },
282                None => {
283                    // New entity, add it
284                    merged.insert(key, new_entity);
285                },
286            }
287        }
288
289        merged.into_values().collect()
290    }
291
292    /// Convert domain entities to EntityData
293    #[cfg(feature = "async")]
294    fn convert_entities_to_data(&self, entities: &[Entity]) -> Vec<EntityData> {
295        entities
296            .iter()
297            .map(|e| EntityData {
298                name: e.name.clone(),
299                entity_type: e.entity_type.clone(),
300                description: format!("{} (confidence: {:.2})", e.entity_type, e.confidence),
301            })
302            .collect()
303    }
304
305    /// Convert domain relationships to RelationshipData
306    #[cfg(feature = "async")]
307    fn convert_relationships_to_data(
308        &self,
309        relationships: &[Relationship],
310    ) -> Vec<RelationshipData> {
311        relationships
312            .iter()
313            .map(|r| RelationshipData {
314                source: r.source.0.clone(),
315                target: r.target.0.clone(),
316                description: r.relation_type.clone(),
317                strength: r.confidence as f64,
318            })
319            .collect()
320    }
321
322    /// Convert EntityData back to domain entities
323    #[cfg(feature = "async")]
324    fn convert_data_to_entities(
325        &self,
326        entity_data: &[EntityData],
327        chunk_id: &crate::core::ChunkId,
328        chunk_text: &str,
329    ) -> Result<Vec<Entity>> {
330        let mut entities = Vec::new();
331
332        for entity_item in entity_data {
333            // Generate entity ID
334            let entity_id = crate::core::EntityId::new(format!(
335                "{}_{}",
336                entity_item.entity_type,
337                self.normalize_name(&entity_item.name)
338            ));
339
340            // Find mentions in chunk
341            let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
342
343            // Create entity with mentions
344            let entity = Entity::new(
345                entity_id,
346                entity_item.name.clone(),
347                entity_item.entity_type.clone(),
348                0.9, // High confidence since LLM-extracted
349            )
350            .with_mentions(mentions);
351
352            entities.push(entity);
353        }
354
355        Ok(entities)
356    }
357
358    /// Find all mentions of an entity name in the chunk text
359    #[cfg(feature = "async")]
360    fn find_mentions(
361        &self,
362        name: &str,
363        chunk_id: &crate::core::ChunkId,
364        text: &str,
365    ) -> Vec<crate::core::EntityMention> {
366        let mut mentions = Vec::new();
367        let mut start = 0;
368
369        while let Some(pos) = text[start..].find(name) {
370            let actual_pos = start + pos;
371            mentions.push(crate::core::EntityMention {
372                chunk_id: chunk_id.clone(),
373                start_offset: actual_pos,
374                end_offset: actual_pos + name.len(),
375                confidence: 0.9,
376            });
377            start = actual_pos + name.len();
378        }
379
380        // If no exact matches, try case-insensitive
381        if mentions.is_empty() {
382            let name_lower = name.to_lowercase();
383            let text_lower = text.to_lowercase();
384            let mut start = 0;
385
386            while let Some(pos) = text_lower[start..].find(&name_lower) {
387                let actual_pos = start + pos;
388                mentions.push(crate::core::EntityMention {
389                    chunk_id: chunk_id.clone(),
390                    start_offset: actual_pos,
391                    end_offset: actual_pos + name.len(),
392                    confidence: 0.85,
393                });
394                start = actual_pos + name.len();
395            }
396        }
397
398        mentions
399    }
400
401    /// Convert RelationshipData to domain Relationships
402    #[cfg(feature = "async")]
403    fn convert_data_to_relationships(
404        &self,
405        relationship_data: &[RelationshipData],
406        entities: &[Entity],
407    ) -> Result<Vec<Relationship>> {
408        let mut relationships = Vec::new();
409
410        // Build entity name to entity mapping
411        let mut name_to_entity: HashMap<String, &Entity> = HashMap::new();
412        for entity in entities {
413            name_to_entity.insert(entity.name.to_lowercase(), entity);
414        }
415
416        for rel_item in relationship_data {
417            // Find source and target entities
418            let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
419            let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
420
421            if let (Some(source), Some(target)) = (source_entity, target_entity) {
422                let relationship = Relationship {
423                    source: source.id.clone(),
424                    target: target.id.clone(),
425                    relation_type: rel_item.description.clone(),
426                    confidence: rel_item.strength as f32,
427                    context: vec![],
428                    embedding: None,
429                    temporal_type: None,
430                    temporal_range: None,
431                    causal_strength: None,
432                };
433
434                relationships.push(relationship);
435            } else {
436                #[cfg(feature = "tracing")]
437                tracing::warn!(
438                    "Skipping relationship: entity not found. Source: {}, Target: {}",
439                    rel_item.source,
440                    rel_item.target
441                );
442            }
443        }
444
445        Ok(relationships)
446    }
447
448    /// Deduplicate relationships by source-target-type combination
449    #[cfg(feature = "async")]
450    fn deduplicate_relationships(&self, relationships: Vec<Relationship>) -> Vec<Relationship> {
451        let mut seen = std::collections::HashSet::new();
452        let mut deduplicated = Vec::new();
453
454        for relationship in relationships {
455            let key = format!(
456                "{}->{}:{}",
457                relationship.source, relationship.target, relationship.relation_type
458            );
459
460            if !seen.contains(&key) {
461                seen.insert(key);
462                deduplicated.push(relationship);
463            }
464        }
465
466        deduplicated
467    }
468
469    /// Normalize entity name for ID generation
470    #[cfg(feature = "async")]
471    fn normalize_name(&self, name: &str) -> String {
472        name.to_lowercase()
473            .chars()
474            .filter(|c| c.is_alphanumeric() || *c == '_')
475            .collect::<String>()
476            .replace(' ', "_")
477    }
478
479    /// Get extraction statistics
480    pub fn get_statistics(&self) -> GleaningStatistics {
481        GleaningStatistics {
482            config: self.config.clone(),
483            llm_available: true, // Always true for real gleaning
484        }
485    }
486}
487
488/// Statistics for gleaning extraction process
489#[derive(Debug, Clone)]
490pub struct GleaningStatistics {
491    /// Gleaning configuration used
492    pub config: GleaningConfig,
493    /// Whether LLM is available for completion checking
494    pub llm_available: bool,
495}
496
497impl GleaningStatistics {
498    /// Print statistics to stdout
499    #[allow(dead_code)]
500    pub fn print(&self) {
501        #[cfg(feature = "tracing")]
502        tracing::info!("REAL LLM Gleaning Extraction Statistics");
503        #[cfg(feature = "tracing")]
504        tracing::info!("  Max rounds: {}", self.config.max_gleaning_rounds);
505        #[cfg(feature = "tracing")]
506        tracing::info!(
507            "  Completion threshold: {:.2}",
508            self.config.completion_threshold
509        );
510        #[cfg(feature = "tracing")]
511        tracing::info!(
512            "  Entity confidence threshold: {:.2}",
513            self.config.entity_confidence_threshold
514        );
515        #[cfg(feature = "tracing")]
516        tracing::info!(
517            "  Uses LLM completion check: {}",
518            self.config.use_llm_completion_check
519        );
520        #[cfg(feature = "tracing")]
521        tracing::info!("  LLM available: {}", self.llm_available);
522        #[cfg(feature = "tracing")]
523        tracing::info!("  Entity types: {:?}", self.config.entity_types);
524        #[cfg(feature = "tracing")]
525        tracing::info!("  Temperature: {}", self.config.temperature);
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use crate::{
533        core::{ChunkId, DocumentId, TextChunk},
534        ollama::OllamaConfig,
535    };
536
537    fn create_test_chunk() -> TextChunk {
538        TextChunk::new(
539            ChunkId::new("test_chunk".to_string()),
540            DocumentId::new("test_doc".to_string()),
541            "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. Tom is best friends with Huckleberry Finn.".to_string(),
542            0,
543            120,
544        )
545    }
546
547    #[test]
548    fn test_gleaning_extractor_creation() {
549        let ollama_config = OllamaConfig::default();
550        let ollama_client = OllamaClient::new(ollama_config);
551        let config = GleaningConfig::default();
552
553        let extractor = GleaningEntityExtractor::new(ollama_client, config);
554
555        let stats = extractor.get_statistics();
556        assert_eq!(stats.config.max_gleaning_rounds, 4);
557        assert!(stats.llm_available);
558    }
559
560    #[test]
561    fn test_merge_entity_data() {
562        let ollama_config = OllamaConfig::default();
563        let ollama_client = OllamaClient::new(ollama_config);
564        let config = GleaningConfig::default();
565        let extractor = GleaningEntityExtractor::new(ollama_client, config);
566
567        let existing = vec![EntityData {
568            name: "Tom Sawyer".to_string(),
569            entity_type: "PERSON".to_string(),
570            description: "A boy".to_string(),
571        }];
572
573        let new = vec![
574            EntityData {
575                name: "Tom Sawyer".to_string(),
576                entity_type: "PERSON".to_string(),
577                description: "A young boy who lives in St. Petersburg".to_string(), // Longer description
578            },
579            EntityData {
580                name: "Huck Finn".to_string(),
581                entity_type: "PERSON".to_string(),
582                description: "Tom's friend".to_string(),
583            },
584        ];
585
586        let merged = extractor.merge_entity_data(existing, new);
587
588        assert_eq!(merged.len(), 2); // Tom (merged) and Huck
589        let tom = merged.iter().find(|e| e.name == "Tom Sawyer").unwrap();
590        assert!(tom.description.len() > 10); // Should have the longer description
591    }
592
593    #[test]
594    fn test_find_mentions() {
595        let ollama_config = OllamaConfig::default();
596        let ollama_client = OllamaClient::new(ollama_config);
597        let config = GleaningConfig::default();
598        let extractor = GleaningEntityExtractor::new(ollama_client, config);
599
600        let chunk = create_test_chunk();
601        let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
602
603        assert!(!mentions.is_empty());
604        assert!(mentions.len() >= 2); // "Tom Sawyer" and "Tom is best friends"
605    }
606
607    #[test]
608    fn test_deduplicate_relationships() {
609        let ollama_config = OllamaConfig::default();
610        let ollama_client = OllamaClient::new(ollama_config);
611        let config = GleaningConfig::default();
612        let extractor = GleaningEntityExtractor::new(ollama_client, config);
613
614        let relationships = vec![
615            Relationship::new(
616                crate::core::EntityId::new("person_tom".to_string()),
617                crate::core::EntityId::new("person_huck".to_string()),
618                "FRIENDS_WITH".to_string(),
619                0.9,
620            ),
621            Relationship::new(
622                crate::core::EntityId::new("person_tom".to_string()),
623                crate::core::EntityId::new("person_huck".to_string()),
624                "FRIENDS_WITH".to_string(), // Duplicate
625                0.85,
626            ),
627            Relationship::new(
628                crate::core::EntityId::new("person_tom".to_string()),
629                crate::core::EntityId::new("location_stpetersburg".to_string()),
630                "LIVES_IN".to_string(),
631                0.8,
632            ),
633        ];
634
635        let deduplicated = extractor.deduplicate_relationships(relationships);
636
637        assert_eq!(deduplicated.len(), 2); // Duplicate FRIENDS_WITH removed
638    }
639}