Skip to main content

graphrag_core/entity/
llm_extractor.rs

1//! LLM-based entity and relationship extraction
2//!
3//! This module provides TRUE LLM-based extraction using Ollama or any other LLM service.
4//! Unlike pattern-based extraction, this uses actual language model inference to extract
5//! entities and relationships from text with deep semantic understanding.
6
7use crate::{
8    core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
9    entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
10    ollama::OllamaClient,
11    Result, GraphRAGError,
12};
13use serde_json;
14
15/// LLM-based entity extractor that uses actual language model calls
16pub struct LLMEntityExtractor {
17    ollama_client: OllamaClient,
18    prompt_builder: PromptBuilder,
19    temperature: f32,
20    max_tokens: usize,
21}
22
23impl LLMEntityExtractor {
24    /// Create a new LLM-based entity extractor
25    ///
26    /// # Arguments
27    /// * `ollama_client` - Ollama client for LLM inference
28    /// * `entity_types` - List of entity types to extract (e.g., ["PERSON", "LOCATION", "ORGANIZATION"])
29    pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
30        Self {
31            ollama_client,
32            prompt_builder: PromptBuilder::new(entity_types),
33            temperature: 0.1, // Low temperature for consistent extraction
34            max_tokens: 1500,
35        }
36    }
37
38    /// Set temperature for LLM generation (default: 0.1)
39    pub fn with_temperature(mut self, temperature: f32) -> Self {
40        self.temperature = temperature;
41        self
42    }
43
44    /// Set maximum tokens for LLM generation (default: 1500)
45    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
46        self.max_tokens = max_tokens;
47        self
48    }
49
50    /// Extract entities and relationships from a text chunk using LLM
51    ///
52    /// This is the REAL extraction that makes actual LLM API calls.
53    /// Expected time: 15-30 seconds per chunk depending on chunk size and model.
54    #[cfg(feature = "async")]
55    pub async fn extract_from_chunk(
56        &self,
57        chunk: &TextChunk,
58    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
59        tracing::debug!("LLM extraction for chunk: {} (size: {} chars)", chunk.id, chunk.content.len());
60
61        // Build extraction prompt
62        let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
63
64        // Call LLM for extraction (THIS IS THE REAL LLM CALL!)
65        let llm_response = self.call_llm_with_retry(&prompt).await?;
66
67        // Parse response into structured data
68        let extraction_output = self.parse_extraction_response(&llm_response)?;
69
70        // Convert to domain entities and relationships
71        let entities = self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
72        let relationships = self.convert_to_relationships(&extraction_output.relationships, &entities)?;
73
74        tracing::info!(
75            "LLM extracted {} entities and {} relationships from chunk {}",
76            entities.len(),
77            relationships.len(),
78            chunk.id
79        );
80
81        Ok((entities, relationships))
82    }
83
84    /// Extract additional entities in a gleaning round (continuation)
85    ///
86    /// This is used after the initial extraction to catch missed entities.
87    #[cfg(feature = "async")]
88    pub async fn extract_additional(
89        &self,
90        chunk: &TextChunk,
91        previous_entities: &[EntityData],
92        previous_relationships: &[RelationshipData],
93    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
94        tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
95
96        // Build continuation prompt with previous extraction
97        let prompt = self.prompt_builder.build_continuation_prompt(
98            &chunk.content,
99            previous_entities,
100            previous_relationships,
101        );
102
103        // Call LLM for additional extraction
104        let llm_response = self.call_llm_with_retry(&prompt).await?;
105
106        // Parse response
107        let extraction_output = self.parse_extraction_response(&llm_response)?;
108
109        // Convert to domain entities
110        let entities = self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
111        let relationships = self.convert_to_relationships(&extraction_output.relationships, &entities)?;
112
113        tracing::info!(
114            "LLM gleaning extracted {} additional entities and {} relationships",
115            entities.len(),
116            relationships.len()
117        );
118
119        Ok((entities, relationships))
120    }
121
122    /// Check if extraction is complete using LLM judgment
123    ///
124    /// Uses the LLM to determine if all significant entities have been extracted.
125    #[cfg(feature = "async")]
126    pub async fn check_completion(
127        &self,
128        chunk: &TextChunk,
129        entities: &[EntityData],
130        relationships: &[RelationshipData],
131    ) -> Result<bool> {
132        tracing::debug!("LLM completion check for chunk: {}", chunk.id);
133
134        // Build completion check prompt
135        let prompt = self.prompt_builder.build_completion_prompt(
136            &chunk.content,
137            entities,
138            relationships,
139        );
140
141        // Call LLM with logit bias for YES/NO response
142        let llm_response = self.call_llm_completion_check(&prompt).await?;
143
144        // Parse YES/NO response
145        let response_trimmed = llm_response.trim().to_uppercase();
146        let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
147
148        tracing::debug!(
149            "LLM completion check result: {} (response: {})",
150            if is_complete { "COMPLETE" } else { "INCOMPLETE" },
151            llm_response.trim()
152        );
153
154        Ok(is_complete)
155    }
156
157    /// Call LLM with retry logic for extraction
158    #[cfg(feature = "async")]
159    async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
160        // Try to get structured JSON output if supported
161        // Otherwise fall back to regular generation
162        match self.ollama_client.generate(prompt).await {
163            Ok(response) => Ok(response),
164            Err(e) => {
165                tracing::warn!("LLM call failed, retrying: {}", e);
166                // Retry once
167                tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
168                self.ollama_client.generate(prompt).await
169            }
170        }
171    }
172
173    /// Call LLM for completion check with short response
174    #[cfg(feature = "async")]
175    async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
176        // For completion check, we want a short YES/NO answer
177        // In future, we can use logit bias to force YES/NO tokens
178        self.ollama_client.generate(prompt).await
179    }
180
181    /// Parse LLM response into structured extraction output
182    ///
183    /// Handles multiple JSON formats and attempts repair if needed
184    fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
185        // Strategy 1: Try direct JSON parsing
186        if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
187            return Ok(output);
188        }
189
190        // Strategy 2: Try to extract JSON from markdown code blocks
191        if let Some(json_str) = Self::extract_json_from_markdown(response) {
192            if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
193                return Ok(output);
194            }
195        }
196
197        // Strategy 3: Try JSON repair using jsonfixer
198        match self.repair_and_parse_json(response) {
199            Ok(output) => return Ok(output),
200            Err(e) => {
201                tracing::warn!("JSON repair failed: {}", e);
202            }
203        }
204
205        // Strategy 4: Look for JSON anywhere in the response
206        if let Some(json_str) = Self::find_json_in_text(response) {
207            if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
208                return Ok(output);
209            }
210
211            // Try repairing the extracted JSON
212            if let Ok(output) = self.repair_and_parse_json(json_str) {
213                return Ok(output);
214            }
215        }
216
217        // If all strategies fail, return empty extraction
218        tracing::error!("Failed to parse LLM response as JSON. Response preview: {}", &response.chars().take(200).collect::<String>());
219        Ok(ExtractionOutput {
220            entities: vec![],
221            relationships: vec![],
222        })
223    }
224
225    /// Extract JSON from markdown code blocks
226    fn extract_json_from_markdown(text: &str) -> Option<&str> {
227        // Look for ```json ... ``` or ``` ... ```
228        if let Some(start) = text.find("```json") {
229            let json_start = start + 7; // length of ```json
230            if let Some(end) = text[json_start..].find("```") {
231                return Some(&text[json_start..json_start + end].trim());
232            }
233        }
234
235        if let Some(start) = text.find("```") {
236            let json_start = start + 3;
237            if let Some(end) = text[json_start..].find("```") {
238                let candidate = &text[json_start..json_start + end].trim();
239                // Check if it looks like JSON
240                if candidate.starts_with('{') || candidate.starts_with('[') {
241                    return Some(candidate);
242                }
243            }
244        }
245
246        None
247    }
248
249    /// Find JSON object or array anywhere in text
250    fn find_json_in_text(text: &str) -> Option<&str> {
251        // Find first { and last }
252        if let Some(start) = text.find('{') {
253            if let Some(end) = text.rfind('}') {
254                if end > start {
255                    return Some(&text[start..=end]);
256                }
257            }
258        }
259        None
260    }
261
262    /// Attempt to repair malformed JSON using jsonfixer
263    fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
264        // jsonfixer::repair_json returns Result<String, Error>
265        let options = jsonfixer::JsonRepairOptions::default();
266        let fixed_json = jsonfixer::repair_json(json_str, options)
267            .map_err(|e| GraphRAGError::Generation {
268                message: format!("JSON repair failed: {:?}", e),
269            })?;
270
271        serde_json::from_str::<ExtractionOutput>(&fixed_json)
272            .map_err(|e| GraphRAGError::Generation {
273                message: format!("Failed to parse repaired JSON: {}", e),
274            })
275    }
276
277    /// Convert EntityData to domain Entity objects
278    fn convert_to_entities(
279        &self,
280        entity_data: &[EntityData],
281        chunk_id: &ChunkId,
282        chunk_text: &str,
283    ) -> Result<Vec<Entity>> {
284        let mut entities = Vec::new();
285
286        for data in entity_data {
287            // Generate entity ID
288            let entity_id = EntityId::new(format!(
289                "{}_{}",
290                data.entity_type,
291                self.normalize_name(&data.name)
292            ));
293
294            // Find mentions in chunk
295            let mentions = self.find_mentions(&data.name, chunk_id, chunk_text);
296
297            // Create entity with mentions
298            // Note: Description is stored in the entity but not used in current Entity struct
299            // We store it in the entity name or as a separate field if needed
300            let entity = Entity::new(
301                entity_id,
302                data.name.clone(),
303                data.entity_type.clone(),
304                0.9, // High confidence since it's LLM-extracted
305            )
306            .with_mentions(mentions);
307
308            entities.push(entity);
309        }
310
311        Ok(entities)
312    }
313
314    /// Find all mentions of an entity name in the chunk text
315    fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
316        let mut mentions = Vec::new();
317        let mut start = 0;
318
319        while let Some(pos) = text[start..].find(name) {
320            let actual_pos = start + pos;
321            mentions.push(EntityMention {
322                chunk_id: chunk_id.clone(),
323                start_offset: actual_pos,
324                end_offset: actual_pos + name.len(),
325                confidence: 0.9,
326            });
327            start = actual_pos + name.len();
328        }
329
330        // If no exact matches, try case-insensitive
331        if mentions.is_empty() {
332            let name_lower = name.to_lowercase();
333            let text_lower = text.to_lowercase();
334            let mut start = 0;
335
336            while let Some(pos) = text_lower[start..].find(&name_lower) {
337                let actual_pos = start + pos;
338                mentions.push(EntityMention {
339                    chunk_id: chunk_id.clone(),
340                    start_offset: actual_pos,
341                    end_offset: actual_pos + name.len(),
342                    confidence: 0.85, // Slightly lower confidence for case-insensitive match
343                });
344                start = actual_pos + name.len();
345            }
346        }
347
348        mentions
349    }
350
351    /// Convert RelationshipData to domain Relationship objects
352    fn convert_to_relationships(
353        &self,
354        relationship_data: &[RelationshipData],
355        entities: &[Entity],
356    ) -> Result<Vec<Relationship>> {
357        let mut relationships = Vec::new();
358
359        // Build entity name to ID mapping
360        let mut name_to_entity: std::collections::HashMap<String, &Entity> = std::collections::HashMap::new();
361        for entity in entities {
362            name_to_entity.insert(entity.name.to_lowercase(), entity);
363        }
364
365        for data in relationship_data {
366            // Find source and target entities
367            let source_entity = name_to_entity.get(&data.source.to_lowercase());
368            let target_entity = name_to_entity.get(&data.target.to_lowercase());
369
370            if let (Some(source), Some(target)) = (source_entity, target_entity) {
371                let relationship = Relationship {
372                    source: source.id.clone(),
373                    target: target.id.clone(),
374                    relation_type: data.description.clone(),
375                    confidence: data.strength as f32,
376                    context: vec![], // No context chunks for this relationship
377                };
378
379                relationships.push(relationship);
380            } else {
381                tracing::warn!(
382                    "Skipping relationship: entity not found. Source: {}, Target: {}",
383                    data.source,
384                    data.target
385                );
386            }
387        }
388
389        Ok(relationships)
390    }
391
392    /// Normalize entity name for ID generation
393    fn normalize_name(&self, name: &str) -> String {
394        name.to_lowercase()
395            .chars()
396            .filter(|c| c.is_alphanumeric() || *c == '_')
397            .collect::<String>()
398            .replace(' ', "_")
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::{core::DocumentId, ollama::OllamaConfig};
406
407    fn create_test_chunk() -> TextChunk {
408        TextChunk::new(
409            ChunkId::new("chunk_001".to_string()),
410            DocumentId::new("doc_001".to_string()),
411            "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
412             Tom is best friends with Huckleberry Finn. They often go on adventures together."
413                .to_string(),
414            0,
415            150,
416        )
417    }
418
419    #[test]
420    fn test_extract_json_from_markdown() {
421        let markdown = r#"
422Here's the extraction:
423```json
424{
425  "entities": [],
426  "relationships": []
427}
428```
429"#;
430        let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
431        assert!(json.is_some());
432        assert!(json.unwrap().contains("entities"));
433    }
434
435    #[test]
436    fn test_find_json_in_text() {
437        let text = "Some text before { \"entities\": [] } some text after";
438        let json = LLMEntityExtractor::find_json_in_text(text);
439        assert!(json.is_some());
440        assert_eq!(json.unwrap(), "{ \"entities\": [] }");
441    }
442
443    #[test]
444    fn test_parse_valid_json() {
445        let ollama_config = OllamaConfig::default();
446        let ollama_client = OllamaClient::new(ollama_config);
447        let extractor = LLMEntityExtractor::new(
448            ollama_client,
449            vec!["PERSON".to_string(), "LOCATION".to_string()],
450        );
451
452        let response = r#"
453{
454  "entities": [
455    {
456      "name": "Tom Sawyer",
457      "type": "PERSON",
458      "description": "A young boy"
459    }
460  ],
461  "relationships": []
462}
463"#;
464
465        let result = extractor.parse_extraction_response(response);
466        assert!(result.is_ok());
467        let output = result.unwrap();
468        assert_eq!(output.entities.len(), 1);
469        assert_eq!(output.entities[0].name, "Tom Sawyer");
470    }
471
472    #[test]
473    fn test_convert_to_entities() {
474        let ollama_config = OllamaConfig::default();
475        let ollama_client = OllamaClient::new(ollama_config);
476        let extractor = LLMEntityExtractor::new(
477            ollama_client,
478            vec!["PERSON".to_string()],
479        );
480
481        let chunk = create_test_chunk();
482        let entity_data = vec![EntityData {
483            name: "Tom Sawyer".to_string(),
484            entity_type: "PERSON".to_string(),
485            description: "A young boy".to_string(),
486        }];
487
488        let entities = extractor
489            .convert_to_entities(&entity_data, &chunk.id, &chunk.content)
490            .unwrap();
491
492        assert_eq!(entities.len(), 1);
493        assert_eq!(entities[0].name, "Tom Sawyer");
494        assert_eq!(entities[0].entity_type, "PERSON");
495        assert!(!entities[0].mentions.is_empty());
496    }
497
498    #[test]
499    fn test_find_mentions() {
500        let ollama_config = OllamaConfig::default();
501        let ollama_client = OllamaClient::new(ollama_config);
502        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
503
504        let chunk = create_test_chunk();
505        let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
506
507        assert!(!mentions.is_empty());
508        assert!(mentions.len() >= 2); // "Tom Sawyer" and "Tom is best friends"
509    }
510
511    #[test]
512    fn test_normalize_name() {
513        let ollama_config = OllamaConfig::default();
514        let ollama_client = OllamaClient::new(ollama_config);
515        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
516
517        assert_eq!(extractor.normalize_name("Tom Sawyer"), "tom_sawyer");
518        assert_eq!(extractor.normalize_name("New York City"), "new_york_city");
519        assert_eq!(extractor.normalize_name("Dr. Smith"), "dr_smith");
520    }
521}