#![cfg_attr(not(feature = "async"), allow(unused_imports))]
use crate::{
core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
ollama::OllamaClient,
GraphRAGError, Result,
};
use serde_json;
pub struct LLMEntityExtractor {
#[cfg_attr(not(feature = "async"), allow(dead_code))]
ollama_client: OllamaClient,
#[cfg_attr(not(feature = "async"), allow(dead_code))]
prompt_builder: PromptBuilder,
temperature: f32,
max_tokens: usize,
keep_alive: Option<String>,
}
impl LLMEntityExtractor {
pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
Self {
ollama_client,
prompt_builder: PromptBuilder::new(entity_types),
temperature: 0.0, max_tokens: 1500,
keep_alive: None,
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_keep_alive(mut self, keep_alive: Option<String>) -> Self {
self.keep_alive = keep_alive;
self
}
pub fn estimate_tokens(text: &str) -> u32 {
(text.len() / 4) as u32
}
pub fn calculate_entity_num_ctx(built_prompt: &str, max_output_tokens: u32) -> u32 {
let prompt_tokens = Self::estimate_tokens(built_prompt);
let total = prompt_tokens + max_output_tokens;
let with_margin = (total as f32 * 1.20) as u32;
let rounded = ((with_margin + 1023) / 1024) * 1024;
rounded.clamp(4096, 131_072)
}
#[cfg(feature = "async")]
pub async fn extract_from_chunk(
&self,
chunk: &TextChunk,
) -> Result<(Vec<Entity>, Vec<Relationship>)> {
#[cfg(feature = "tracing")]
tracing::debug!(
"LLM extraction for chunk: {} (size: {} chars)",
chunk.id,
chunk.content.len()
);
let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
let llm_response = self.call_llm_with_retry(&prompt).await?;
let extraction_output = self.parse_extraction_response(&llm_response)?;
let entities =
self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
let relationships =
self.convert_to_relationships(&extraction_output.relationships, &entities)?;
#[cfg(feature = "tracing")]
tracing::info!(
"LLM extracted {} entities and {} relationships from chunk {}",
entities.len(),
relationships.len(),
chunk.id
);
Ok((entities, relationships))
}
#[cfg(feature = "async")]
pub async fn extract_additional(
&self,
chunk: &TextChunk,
previous_entities: &[EntityData],
previous_relationships: &[RelationshipData],
) -> Result<(Vec<Entity>, Vec<Relationship>)> {
#[cfg(feature = "tracing")]
tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
let prompt = self.prompt_builder.build_continuation_prompt(
&chunk.content,
previous_entities,
previous_relationships,
);
let llm_response = self.call_llm_with_retry(&prompt).await?;
let extraction_output = self.parse_extraction_response(&llm_response)?;
let entities =
self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
let relationships =
self.convert_to_relationships(&extraction_output.relationships, &entities)?;
#[cfg(feature = "tracing")]
tracing::info!(
"LLM gleaning extracted {} additional entities and {} relationships",
entities.len(),
relationships.len()
);
Ok((entities, relationships))
}
#[cfg(feature = "async")]
pub async fn check_completion(
&self,
chunk: &TextChunk,
entities: &[EntityData],
relationships: &[RelationshipData],
) -> Result<bool> {
#[cfg(feature = "tracing")]
tracing::debug!("LLM completion check for chunk: {}", chunk.id);
let prompt =
self.prompt_builder
.build_completion_prompt(&chunk.content, entities, relationships);
let llm_response = self.call_llm_completion_check(&prompt).await?;
let response_trimmed = llm_response.trim().to_uppercase();
let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
#[cfg(feature = "tracing")]
tracing::debug!(
"LLM completion check result: {} (response: {})",
if is_complete {
"COMPLETE"
} else {
"INCOMPLETE"
},
llm_response.trim()
);
Ok(is_complete)
}
#[cfg(feature = "async")]
async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
use crate::ollama::OllamaGenerationParams;
let num_ctx = Self::calculate_entity_num_ctx(prompt, self.max_tokens as u32);
#[cfg(feature = "tracing")]
tracing::debug!(
"Entity extraction: prompt_len={} num_ctx={} keep_alive={:?}",
prompt.len(),
num_ctx,
self.keep_alive,
);
let params = OllamaGenerationParams {
num_predict: Some(self.max_tokens as u32),
temperature: Some(self.temperature),
num_ctx: Some(num_ctx),
keep_alive: self.keep_alive.clone(),
..Default::default()
};
match self
.ollama_client
.generate_with_params(prompt, params.clone())
.await
{
Ok(response) => Ok(response),
Err(e) => {
#[cfg(feature = "tracing")]
tracing::warn!("LLM call failed, retrying: {}", e);
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
self.ollama_client
.generate_with_params(prompt, params)
.await
},
}
}
#[cfg(feature = "async")]
async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
use crate::ollama::OllamaGenerationParams;
let num_ctx = Self::calculate_entity_num_ctx(prompt, 50);
let params = OllamaGenerationParams {
num_predict: Some(50),
temperature: Some(0.0),
num_ctx: Some(num_ctx),
keep_alive: self.keep_alive.clone(),
..Default::default()
};
self.ollama_client
.generate_with_params(prompt, params)
.await
}
#[cfg(feature = "async")]
fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
return Ok(output);
}
if let Some(json_str) = Self::extract_json_from_markdown(response) {
if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
return Ok(output);
}
}
match self.repair_and_parse_json(response) {
Ok(output) => return Ok(output),
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::warn!("JSON repair failed: {}", _e);
},
}
if let Some(json_str) = Self::find_json_in_text(response) {
if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
return Ok(output);
}
if let Ok(output) = self.repair_and_parse_json(json_str) {
return Ok(output);
}
}
#[cfg(feature = "tracing")]
tracing::error!(
"Failed to parse LLM response as JSON. Response preview: {}",
&response.chars().take(200).collect::<String>()
);
Ok(ExtractionOutput {
entities: vec![],
relationships: vec![],
})
}
#[cfg(feature = "async")]
fn extract_json_from_markdown(text: &str) -> Option<&str> {
if let Some(start) = text.find("```json") {
let json_start = start + 7; if let Some(end) = text[json_start..].find("```") {
return Some(text[json_start..json_start + end].trim());
}
}
if let Some(start) = text.find("```") {
let json_start = start + 3;
if let Some(end) = text[json_start..].find("```") {
let candidate = &text[json_start..json_start + end].trim();
if candidate.starts_with('{') || candidate.starts_with('[') {
return Some(candidate);
}
}
}
None
}
#[cfg(feature = "async")]
fn find_json_in_text(text: &str) -> Option<&str> {
if let Some(start) = text.find('{') {
if let Some(end) = text.rfind('}') {
if end > start {
return Some(&text[start..=end]);
}
}
}
None
}
#[cfg(feature = "async")]
fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
let options = jsonfixer::JsonRepairOptions::default();
let fixed_json =
jsonfixer::repair_json(json_str, options).map_err(|e| GraphRAGError::Generation {
message: format!("JSON repair failed: {:?}", e),
})?;
serde_json::from_str::<ExtractionOutput>(&fixed_json).map_err(|e| {
GraphRAGError::Generation {
message: format!("Failed to parse repaired JSON: {}", e),
}
})
}
#[cfg(feature = "async")]
fn convert_to_entities(
&self,
entity_data: &[EntityData],
chunk_id: &ChunkId,
chunk_text: &str,
) -> Result<Vec<Entity>> {
let mut entities = Vec::new();
for entity_item in entity_data {
let entity_id = EntityId::new(format!(
"{}_{}",
entity_item.entity_type,
self.normalize_name(&entity_item.name)
));
let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
let entity = Entity::new(
entity_id,
entity_item.name.clone(),
entity_item.entity_type.clone(),
0.9, )
.with_mentions(mentions);
entities.push(entity);
}
Ok(entities)
}
#[cfg(feature = "async")]
fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
let mut mentions = Vec::new();
let mut start = 0;
while let Some(pos) = text[start..].find(name) {
let actual_pos = start + pos;
mentions.push(EntityMention {
chunk_id: chunk_id.clone(),
start_offset: actual_pos,
end_offset: actual_pos + name.len(),
confidence: 0.9,
});
start = actual_pos + name.len();
}
if mentions.is_empty() {
let name_lower = name.to_lowercase();
let text_lower = text.to_lowercase();
let mut start = 0;
while let Some(pos) = text_lower[start..].find(&name_lower) {
let actual_pos = start + pos;
mentions.push(EntityMention {
chunk_id: chunk_id.clone(),
start_offset: actual_pos,
end_offset: actual_pos + name.len(),
confidence: 0.85, });
start = actual_pos + name.len();
}
}
mentions
}
#[cfg(feature = "async")]
fn convert_to_relationships(
&self,
relationship_data: &[RelationshipData],
entities: &[Entity],
) -> Result<Vec<Relationship>> {
let mut relationships = Vec::new();
let mut name_to_entity: std::collections::HashMap<String, &Entity> =
std::collections::HashMap::new();
for entity in entities {
name_to_entity.insert(entity.name.to_lowercase(), entity);
}
for rel_item in relationship_data {
let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
if let (Some(source), Some(target)) = (source_entity, target_entity) {
let relationship = Relationship {
source: source.id.clone(),
target: target.id.clone(),
relation_type: rel_item.description.clone(),
confidence: rel_item.strength as f32,
context: vec![], embedding: None,
temporal_type: None,
temporal_range: None,
causal_strength: None,
};
relationships.push(relationship);
} else {
#[cfg(feature = "tracing")]
tracing::warn!(
"Skipping relationship: entity not found. Source: {}, Target: {}",
rel_item.source,
rel_item.target
);
}
}
Ok(relationships)
}
#[cfg(feature = "async")]
fn normalize_name(&self, name: &str) -> String {
name.to_lowercase()
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect::<String>()
.replace(' ', "_")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{core::DocumentId, ollama::OllamaConfig};
fn create_test_chunk() -> TextChunk {
TextChunk::new(
ChunkId::new("chunk_001".to_string()),
DocumentId::new("doc_001".to_string()),
"Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
Tom is best friends with Huckleberry Finn. They often go on adventures together."
.to_string(),
0,
150,
)
}
#[test]
fn test_extract_json_from_markdown() {
let markdown = r#"
Here's the extraction:
```json
{
"entities": [],
"relationships": []
}
```
"#;
let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
assert!(json.is_some());
assert!(json.unwrap().contains("entities"));
}
#[test]
fn test_find_json_in_text() {
let text = "Some text before { \"entities\": [] } some text after";
let json = LLMEntityExtractor::find_json_in_text(text);
assert!(json.is_some());
assert_eq!(json.unwrap(), "{ \"entities\": [] }");
}
#[test]
fn test_parse_valid_json() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let extractor = LLMEntityExtractor::new(
ollama_client,
vec!["PERSON".to_string(), "LOCATION".to_string()],
);
let response = r#"
{
"entities": [
{
"name": "Tom Sawyer",
"type": "PERSON",
"description": "A young boy"
}
],
"relationships": []
}
"#;
let result = extractor.parse_extraction_response(response);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.entities.len(), 1);
assert_eq!(output.entities[0].name, "Tom Sawyer");
}
#[test]
fn test_convert_to_entities() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
let chunk = create_test_chunk();
let entity_data = vec![EntityData {
name: "Tom Sawyer".to_string(),
entity_type: "PERSON".to_string(),
description: "A young boy".to_string(),
}];
let entities = extractor
.convert_to_entities(&entity_data, &chunk.id, &chunk.content)
.unwrap();
assert_eq!(entities.len(), 1);
assert_eq!(entities[0].name, "Tom Sawyer");
assert_eq!(entities[0].entity_type, "PERSON");
assert!(!entities[0].mentions.is_empty());
}
#[test]
fn test_find_mentions() {
let ollama_config = OllamaConfig::default();
let ollama_client = OllamaClient::new(ollama_config);
let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
let chunk = create_test_chunk();
let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
assert!(!mentions.is_empty());
assert!(mentions.len() >= 2); }
}