#![cfg_attr(not(feature = "async"), allow(unused_imports))]
use crate::{
core::{Entity, Relationship, Result, TextChunk},
entity::{
llm_extractor::LLMEntityExtractor,
prompts::{EntityData, RelationshipData},
},
ollama::OllamaClient,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GleaningConfig {
pub max_gleaning_rounds: usize,
pub completion_threshold: f64,
pub entity_confidence_threshold: f64,
pub use_llm_completion_check: bool,
pub entity_types: Vec<String>,
pub temperature: f32,
pub max_tokens: usize,
}
impl Default for GleaningConfig {
fn default() -> Self {
Self {
max_gleaning_rounds: 4, completion_threshold: 0.85,
entity_confidence_threshold: 0.7,
use_llm_completion_check: true, entity_types: vec![
"PERSON".to_string(),
"ORGANIZATION".to_string(),
"LOCATION".to_string(),
"EVENT".to_string(),
"CONCEPT".to_string(),
],
temperature: 0.0, max_tokens: 1500,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionCompletionStatus {
pub is_complete: bool,
pub confidence: f64,
pub missing_aspects: Vec<String>,
pub suggestions: Vec<String>,
}
pub struct GleaningEntityExtractor {
#[cfg_attr(not(feature = "async"), allow(dead_code))]
llm_extractor: LLMEntityExtractor,
config: GleaningConfig,
}
impl GleaningEntityExtractor {
pub fn new(ollama_client: OllamaClient, config: GleaningConfig) -> Self {
let keep_alive = ollama_client.config().keep_alive.clone();
let llm_extractor = LLMEntityExtractor::new(ollama_client, config.entity_types.clone())
.with_temperature(config.temperature)
.with_max_tokens(config.max_tokens)
.with_keep_alive(keep_alive);
Self {
llm_extractor,
config,
}
}
#[cfg(feature = "async")]
pub async fn extract_with_gleaning(
&self,
chunk: &TextChunk,
) -> Result<(Vec<Entity>, Vec<Relationship>)> {
#[cfg(feature = "tracing")]
tracing::info!(
"🔍 Starting REAL LLM gleaning extraction for chunk: {} ({} chars)",
chunk.id,
chunk.content.len()
);
let start_time = std::time::Instant::now();
let mut all_entity_data: Vec<EntityData> = Vec::new();
let mut all_relationship_data: Vec<RelationshipData> = Vec::new();
#[cfg(feature = "tracing")]
tracing::info!("📝 Round 1: Initial LLM extraction...");
let round_start = std::time::Instant::now();
let (initial_entities, initial_relationships) =
self.llm_extractor.extract_from_chunk(chunk).await?;
#[cfg(feature = "tracing")]
tracing::info!(
"✅ Round 1 complete: {} entities, {} relationships ({:.1}s)",
initial_entities.len(),
initial_relationships.len(),
round_start.elapsed().as_secs_f32()
);
let mut entity_data = self.convert_entities_to_data(&initial_entities);
let mut relationship_data = self.convert_relationships_to_data(&initial_relationships);
all_entity_data.append(&mut entity_data);
all_relationship_data.append(&mut relationship_data);
for round in 2..=self.config.max_gleaning_rounds {
#[cfg(feature = "tracing")]
tracing::info!("📝 Round {}: Gleaning continuation...", round);
let round_start = std::time::Instant::now();
if self.config.use_llm_completion_check {
let is_complete = self
.llm_extractor
.check_completion(chunk, &all_entity_data, &all_relationship_data)
.await?;
if is_complete {
#[cfg(feature = "tracing")]
tracing::info!(
"✅ LLM determined extraction is COMPLETE after {} rounds ({:.1}s total)",
round - 1,
start_time.elapsed().as_secs_f32()
);
break;
}
#[cfg(feature = "tracing")]
tracing::debug!("⚠️ LLM determined extraction is INCOMPLETE, continuing...");
}
let (additional_entities, additional_relationships) = self
.llm_extractor
.extract_additional(chunk, &all_entity_data, &all_relationship_data)
.await?;
#[cfg(feature = "tracing")]
tracing::info!(
"✅ Round {} complete: {} new entities, {} new relationships ({:.1}s)",
round,
additional_entities.len(),
additional_relationships.len(),
round_start.elapsed().as_secs_f32()
);
if additional_entities.is_empty() && additional_relationships.is_empty() {
#[cfg(feature = "tracing")]
tracing::info!(
"🛑 No additional entities found in round {}, stopping gleaning",
round
);
break;
}
let new_entity_data = self.convert_entities_to_data(&additional_entities);
let mut new_relationship_data =
self.convert_relationships_to_data(&additional_relationships);
all_entity_data = self.merge_entity_data(all_entity_data, new_entity_data);
all_relationship_data.append(&mut new_relationship_data);
}
let final_entities =
self.convert_data_to_entities(&all_entity_data, &chunk.id, &chunk.content)?;
let final_relationships =
self.convert_data_to_relationships(&all_relationship_data, &final_entities)?;
let deduplicated_relationships = self.deduplicate_relationships(final_relationships);
let total_time = start_time.elapsed().as_secs_f32();
#[cfg(feature = "tracing")]
tracing::info!(
"🎉 REAL LLM gleaning complete: {} entities, {} relationships ({:.1}s total)",
final_entities.len(),
deduplicated_relationships.len(),
total_time
);
Ok((final_entities, deduplicated_relationships))
}
#[cfg(feature = "async")]
fn merge_entity_data(
&self,
existing: Vec<EntityData>,
new: Vec<EntityData>,
) -> Vec<EntityData> {
let mut merged: HashMap<String, EntityData> = HashMap::new();
for entity in existing {
let key = entity.name.to_lowercase();
merged.insert(key, entity);
}
for new_entity in new {
let key = new_entity.name.to_lowercase();
match merged.get(&key) {
Some(existing_entity) => {
if new_entity.description.len() > existing_entity.description.len() {
#[cfg(feature = "tracing")]
tracing::debug!(
"Merging entity '{}': keeping longer description ({} chars vs {} chars)",
new_entity.name,
new_entity.description.len(),
existing_entity.description.len()
);
merged.insert(key, new_entity);
} else {
#[cfg(feature = "tracing")]
tracing::debug!(
"Entity '{}' already exists with longer description, keeping existing",
new_entity.name
);
}
},
None => {
merged.insert(key, new_entity);
},
}
}
merged.into_values().collect()
}
#[cfg(feature = "async")]
fn convert_entities_to_data(&self, entities: &[Entity]) -> Vec<EntityData> {
entities
.iter()
.map(|e| EntityData {
name: e.name.clone(),
entity_type: e.entity_type.clone(),
description: format!("{} (confidence: {:.2})", e.entity_type, e.confidence),
})
.collect()
}
#[cfg(feature = "async")]
fn convert_relationships_to_data(
&self,
relationships: &[Relationship],
) -> Vec<RelationshipData> {
relationships
.iter()
.map(|r| RelationshipData {
source: r.source.0.clone(),
target: r.target.0.clone(),
description: r.relation_type.clone(),
strength: r.confidence as f64,
})
.collect()
}
#[cfg(feature = "async")]
fn convert_data_to_entities(
&self,
entity_data: &[EntityData],
chunk_id: &crate::core::ChunkId,
chunk_text: &str,
) -> Result<Vec<Entity>> {
let mut entities = Vec::new();
for entity_item in entity_data {
let entity_id = crate::core::EntityId::new(format!(
"{}_{}",
entity_item.entity_type,
self.normalize_name(&entity_item.name)
));
let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
let entity = Entity::new(
entity_id,
entity_item.name.clone(),
entity_item.entity_type.clone(),
0.9, )
.with_mentions(mentions);
entities.push(entity);
}
Ok(entities)
}
#[cfg(feature = "async")]
fn find_mentions(
&self,
name: &str,
chunk_id: &crate::core::ChunkId,
text: &str,
) -> Vec<crate::core::EntityMention> {
let mut mentions = Vec::new();
let mut start = 0;
while let Some(pos) = text[start..].find(name) {
let actual_pos = start + pos;
mentions.push(crate::core::EntityMention {
chunk_id: chunk_id.clone(),
start_offset: actual_pos,
end_offset: actual_pos + name.len(),
confidence: 0.9,
});
start = actual_pos + name.len();
}
if mentions.is_empty() {
let name_lower = name.to_lowercase();
let text_lower = text.to_lowercase();
let mut start = 0;
while let Some(pos) = text_lower[start..].find(&name_lower) {
let actual_pos = start + pos;
mentions.push(crate::core::EntityMention {
chunk_id: chunk_id.clone(),
start_offset: actual_pos,
end_offset: actual_pos + name.len(),
confidence: 0.85,
});
start = actual_pos + name.len();
}
}
mentions
}
#[cfg(feature = "async")]
fn convert_data_to_relationships(
&self,
relationship_data: &[RelationshipData],
entities: &[Entity],
) -> Result<Vec<Relationship>> {
let mut relationships = Vec::new();
let mut name_to_entity: HashMap<String, &Entity> = HashMap::new();
for entity in entities {
name_to_entity.insert(entity.name.to_lowercase(), entity);
}
for rel_item in relationship_data {
let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
if let (Some(source), Some(target)) = (source_entity, target_entity) {
let relationship = Relationship {
source: source.id.clone(),
target: target.id.clone(),
relation_type: rel_item.description.clone(),
confidence: rel_item.strength as f32,
context: vec![],
embedding: None,
temporal_type: None,
temporal_range: None,
causal_strength: None,
};
relationships.push(relationship);
} else {
#[cfg(feature = "tracing")]
tracing::warn!(
"Skipping relationship: entity not found. Source: {}, Target: {}",
rel_item.source,
rel_item.target
);
}
}
Ok(relationships)
}
#[cfg(feature = "async")]
fn deduplicate_relationships(&self, relationships: Vec<Relationship>) -> Vec<Relationship> {
let mut seen = std::collections::HashSet::new();
let mut deduplicated = Vec::new();
for relationship in relationships {
let key = format!(
"{}->{}:{}",
relationship.source, relationship.target, relationship.relation_type
);
if !seen.contains(&key) {
seen.insert(key);
deduplicated.push(relationship);
}
}
deduplicated
}
#[cfg(feature = "async")]
fn normalize_name(&self, name: &str) -> String {
name.to_lowercase()
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect::<String>()
.replace(' ', "_")
}
pub fn get_statistics(&self) -> GleaningStatistics {
GleaningStatistics {
config: self.config.clone(),
llm_available: true, }
}
}
#[derive(Debug, Clone)]
pub struct GleaningStatistics {
pub config: GleaningConfig,
pub llm_available: bool,
}
impl GleaningStatistics {
#[allow(dead_code)]
pub fn print(&self) {
#[cfg(feature = "tracing")]
tracing::info!("REAL LLM Gleaning Extraction Statistics");
#[cfg(feature = "tracing")]
tracing::info!(" Max rounds: {}", self.config.max_gleaning_rounds);
#[cfg(feature = "tracing")]
tracing::info!(
" Completion threshold: {:.2}",
self.config.completion_threshold
);
#[cfg(feature = "tracing")]
tracing::info!(
" Entity confidence threshold: {:.2}",
self.config.entity_confidence_threshold
);
#[cfg(feature = "tracing")]
tracing::info!(
" Uses LLM completion check: {}",
self.config.use_llm_completion_check
);
#[cfg(feature = "tracing")]
tracing::info!(" LLM available: {}", self.llm_available);
#[cfg(feature = "tracing")]
tracing::info!(" Entity types: {:?}", self.config.entity_types);
#[cfg(feature = "tracing")]
tracing::info!(" Temperature: {}", self.config.temperature);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{ChunkId, DocumentId, TextChunk},
ollama::OllamaConfig,
};
fn create_test_chunk() -> TextChunk {
TextChunk::new(
ChunkId::new("test_chunk".to_string()),
DocumentId::new("test_doc".to_string()),
"Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. Tom is best friends with Huckleberry Finn.".to_string(),
0,
120,
)
}
#[test]
fn test_gleaning_extractor_creation() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let config = GleaningConfig::default();
let extractor = GleaningEntityExtractor::new(ollama_client, config);
let stats = extractor.get_statistics();
assert_eq!(stats.config.max_gleaning_rounds, 4);
assert!(stats.llm_available);
}
#[test]
fn test_merge_entity_data() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let config = GleaningConfig::default();
let extractor = GleaningEntityExtractor::new(ollama_client, config);
let existing = vec![EntityData {
name: "Tom Sawyer".to_string(),
entity_type: "PERSON".to_string(),
description: "A boy".to_string(),
}];
let new = vec![
EntityData {
name: "Tom Sawyer".to_string(),
entity_type: "PERSON".to_string(),
description: "A young boy who lives in St. Petersburg".to_string(), },
EntityData {
name: "Huck Finn".to_string(),
entity_type: "PERSON".to_string(),
description: "Tom's friend".to_string(),
},
];
let merged = extractor.merge_entity_data(existing, new);
assert_eq!(merged.len(), 2); let tom = merged.iter().find(|e| e.name == "Tom Sawyer").unwrap();
assert!(tom.description.len() > 10); }
#[test]
fn test_find_mentions() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let config = GleaningConfig::default();
let extractor = GleaningEntityExtractor::new(ollama_client, config);
let chunk = create_test_chunk();
let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
assert!(!mentions.is_empty());
assert!(mentions.len() >= 2); }
#[test]
fn test_deduplicate_relationships() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let config = GleaningConfig::default();
let extractor = GleaningEntityExtractor::new(ollama_client, config);
let relationships = vec![
Relationship::new(
crate::core::EntityId::new("person_tom".to_string()),
crate::core::EntityId::new("person_huck".to_string()),
"FRIENDS_WITH".to_string(),
0.9,
),
Relationship::new(
crate::core::EntityId::new("person_tom".to_string()),
crate::core::EntityId::new("person_huck".to_string()),
"FRIENDS_WITH".to_string(), 0.85,
),
Relationship::new(
crate::core::EntityId::new("person_tom".to_string()),
crate::core::EntityId::new("location_stpetersburg".to_string()),
"LIVES_IN".to_string(),
0.8,
),
];
let deduplicated = extractor.deduplicate_relationships(relationships);
assert_eq!(deduplicated.len(), 2); }
}