Skip to main content

graphrag_core/entity/
llm_extractor.rs

1//! LLM-based entity and relationship extraction
2//!
3//! This module provides TRUE LLM-based extraction using Ollama or any other LLM service.
4//! Unlike pattern-based extraction, this uses actual language model inference to extract
5//! entities and relationships from text with deep semantic understanding.
6
7use crate::{
8    core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
9    entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
10    ollama::OllamaClient,
11    GraphRAGError, Result,
12};
13use serde_json;
14
15/// LLM-based entity extractor that uses actual language model calls
16pub struct LLMEntityExtractor {
17    ollama_client: OllamaClient,
18    prompt_builder: PromptBuilder,
19    temperature: f32,
20    max_tokens: usize,
21    /// keep_alive value forwarded to every Ollama request (e.g. "1h").
22    /// When set, Ollama keeps the model in VRAM between requests so the KV
23    /// cache for the static prompt prefix is preserved across all chunks.
24    keep_alive: Option<String>,
25}
26
27impl LLMEntityExtractor {
28    /// Create a new LLM-based entity extractor
29    ///
30    /// # Arguments
31    /// * `ollama_client` - Ollama client for LLM inference
32    /// * `entity_types` - List of entity types to extract (e.g., ["PERSON", "LOCATION", "ORGANIZATION"])
33    pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
34        Self {
35            ollama_client,
36            prompt_builder: PromptBuilder::new(entity_types),
37            temperature: 0.0, // Zero temperature for deterministic JSON extraction
38            max_tokens: 1500,
39            keep_alive: None,
40        }
41    }
42
43    /// Set temperature for LLM generation (default: 0.1)
44    pub fn with_temperature(mut self, temperature: f32) -> Self {
45        self.temperature = temperature;
46        self
47    }
48
49    /// Set maximum tokens for LLM generation (default: 1500)
50    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
51        self.max_tokens = max_tokens;
52        self
53    }
54
55    /// Set keep_alive for KV cache persistence (e.g. "1h", "30m").
56    pub fn with_keep_alive(mut self, keep_alive: Option<String>) -> Self {
57        self.keep_alive = keep_alive;
58        self
59    }
60
61    /// Estimate token count from text length (≈ 1 token per 4 chars).
62    pub fn estimate_tokens(text: &str) -> u32 {
63        (text.len() / 4) as u32
64    }
65
66    /// Calculate the required `num_ctx` for one entity-extraction LLM call.
67    ///
68    /// Formula (mirrors `ContextualEnricher::calculate_num_ctx`):
69    /// ```text
70    /// num_ctx = tokens(built_prompt)   // instructions + entity types + chunk text
71    ///         + max_output_tokens       // JSON with entities + relationships
72    ///         + 20% safety margin       // avoids silent truncation
73    /// ```
74    /// Result is rounded up to the nearest 1024 and clamped to [4096, 131072].
75    pub fn calculate_entity_num_ctx(built_prompt: &str, max_output_tokens: u32) -> u32 {
76        let prompt_tokens = Self::estimate_tokens(built_prompt);
77        let total = prompt_tokens + max_output_tokens;
78        let with_margin = (total as f32 * 1.20) as u32;
79        let rounded = ((with_margin + 1023) / 1024) * 1024;
80        rounded.max(4096).min(131_072)
81    }
82
83    /// Extract entities and relationships from a text chunk using LLM
84    ///
85    /// This is the REAL extraction that makes actual LLM API calls.
86    /// Expected time: 15-30 seconds per chunk depending on chunk size and model.
87    #[cfg(feature = "async")]
88    pub async fn extract_from_chunk(
89        &self,
90        chunk: &TextChunk,
91    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
92        tracing::debug!(
93            "LLM extraction for chunk: {} (size: {} chars)",
94            chunk.id,
95            chunk.content.len()
96        );
97
98        // Build extraction prompt
99        let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
100
101        // Call LLM for extraction (THIS IS THE REAL LLM CALL!)
102        let llm_response = self.call_llm_with_retry(&prompt).await?;
103
104        // Parse response into structured data
105        let extraction_output = self.parse_extraction_response(&llm_response)?;
106
107        // Convert to domain entities and relationships
108        let entities =
109            self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
110        let relationships =
111            self.convert_to_relationships(&extraction_output.relationships, &entities)?;
112
113        tracing::info!(
114            "LLM extracted {} entities and {} relationships from chunk {}",
115            entities.len(),
116            relationships.len(),
117            chunk.id
118        );
119
120        Ok((entities, relationships))
121    }
122
123    /// Extract additional entities in a gleaning round (continuation)
124    ///
125    /// This is used after the initial extraction to catch missed entities.
126    #[cfg(feature = "async")]
127    pub async fn extract_additional(
128        &self,
129        chunk: &TextChunk,
130        previous_entities: &[EntityData],
131        previous_relationships: &[RelationshipData],
132    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
133        tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
134
135        // Build continuation prompt with previous extraction
136        let prompt = self.prompt_builder.build_continuation_prompt(
137            &chunk.content,
138            previous_entities,
139            previous_relationships,
140        );
141
142        // Call LLM for additional extraction
143        let llm_response = self.call_llm_with_retry(&prompt).await?;
144
145        // Parse response
146        let extraction_output = self.parse_extraction_response(&llm_response)?;
147
148        // Convert to domain entities
149        let entities =
150            self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
151        let relationships =
152            self.convert_to_relationships(&extraction_output.relationships, &entities)?;
153
154        tracing::info!(
155            "LLM gleaning extracted {} additional entities and {} relationships",
156            entities.len(),
157            relationships.len()
158        );
159
160        Ok((entities, relationships))
161    }
162
163    /// Check if extraction is complete using LLM judgment
164    ///
165    /// Uses the LLM to determine if all significant entities have been extracted.
166    #[cfg(feature = "async")]
167    pub async fn check_completion(
168        &self,
169        chunk: &TextChunk,
170        entities: &[EntityData],
171        relationships: &[RelationshipData],
172    ) -> Result<bool> {
173        tracing::debug!("LLM completion check for chunk: {}", chunk.id);
174
175        // Build completion check prompt
176        let prompt =
177            self.prompt_builder
178                .build_completion_prompt(&chunk.content, entities, relationships);
179
180        // Call LLM with logit bias for YES/NO response
181        let llm_response = self.call_llm_completion_check(&prompt).await?;
182
183        // Parse YES/NO response
184        let response_trimmed = llm_response.trim().to_uppercase();
185        let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
186
187        tracing::debug!(
188            "LLM completion check result: {} (response: {})",
189            if is_complete {
190                "COMPLETE"
191            } else {
192                "INCOMPLETE"
193            },
194            llm_response.trim()
195        );
196
197        Ok(is_complete)
198    }
199
200    /// Call LLM with dynamic num_ctx and retry logic for extraction.
201    ///
202    /// `num_ctx` is calculated from the built prompt via
203    /// [`calculate_entity_num_ctx`] before this call, ensuring Ollama never
204    /// silently truncates the prompt. `keep_alive` keeps the model loaded in
205    /// VRAM between chunk requests so the KV cache is preserved.
206    #[cfg(feature = "async")]
207    async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
208        use crate::ollama::OllamaGenerationParams;
209        let num_ctx = Self::calculate_entity_num_ctx(prompt, self.max_tokens as u32);
210        tracing::debug!(
211            "Entity extraction: prompt_len={} num_ctx={} keep_alive={:?}",
212            prompt.len(),
213            num_ctx,
214            self.keep_alive,
215        );
216        let params = OllamaGenerationParams {
217            num_predict: Some(self.max_tokens as u32),
218            temperature: Some(self.temperature),
219            num_ctx: Some(num_ctx),
220            keep_alive: self.keep_alive.clone(),
221            ..Default::default()
222        };
223        match self.ollama_client.generate_with_params(prompt, params.clone()).await {
224            Ok(response) => Ok(response),
225            Err(e) => {
226                tracing::warn!("LLM call failed, retrying: {}", e);
227                tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
228                self.ollama_client.generate_with_params(prompt, params).await
229            },
230        }
231    }
232
233    /// Call LLM for completion check (expects a short YES/NO answer).
234    ///
235    /// Uses a much smaller `num_predict` (50 tokens) and a deterministic
236    /// temperature of 0.0 to get a reliable YES/NO response.
237    #[cfg(feature = "async")]
238    async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
239        use crate::ollama::OllamaGenerationParams;
240        let num_ctx = Self::calculate_entity_num_ctx(prompt, 50);
241        let params = OllamaGenerationParams {
242            num_predict: Some(50),
243            temperature: Some(0.0),
244            num_ctx: Some(num_ctx),
245            keep_alive: self.keep_alive.clone(),
246            ..Default::default()
247        };
248        self.ollama_client.generate_with_params(prompt, params).await
249    }
250
251    /// Parse LLM response into structured extraction output
252    ///
253    /// Handles multiple JSON formats and attempts repair if needed
254    fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
255        // Strategy 1: Try direct JSON parsing
256        if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
257            return Ok(output);
258        }
259
260        // Strategy 2: Try to extract JSON from markdown code blocks
261        if let Some(json_str) = Self::extract_json_from_markdown(response) {
262            if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
263                return Ok(output);
264            }
265        }
266
267        // Strategy 3: Try JSON repair using jsonfixer
268        match self.repair_and_parse_json(response) {
269            Ok(output) => return Ok(output),
270            Err(e) => {
271                tracing::warn!("JSON repair failed: {}", e);
272            },
273        }
274
275        // Strategy 4: Look for JSON anywhere in the response
276        if let Some(json_str) = Self::find_json_in_text(response) {
277            if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
278                return Ok(output);
279            }
280
281            // Try repairing the extracted JSON
282            if let Ok(output) = self.repair_and_parse_json(json_str) {
283                return Ok(output);
284            }
285        }
286
287        // If all strategies fail, return empty extraction
288        tracing::error!(
289            "Failed to parse LLM response as JSON. Response preview: {}",
290            &response.chars().take(200).collect::<String>()
291        );
292        Ok(ExtractionOutput {
293            entities: vec![],
294            relationships: vec![],
295        })
296    }
297
298    /// Extract JSON from markdown code blocks
299    fn extract_json_from_markdown(text: &str) -> Option<&str> {
300        // Look for ```json ... ``` or ``` ... ```
301        if let Some(start) = text.find("```json") {
302            let json_start = start + 7; // length of ```json
303            if let Some(end) = text[json_start..].find("```") {
304                return Some(text[json_start..json_start + end].trim());
305            }
306        }
307
308        if let Some(start) = text.find("```") {
309            let json_start = start + 3;
310            if let Some(end) = text[json_start..].find("```") {
311                let candidate = &text[json_start..json_start + end].trim();
312                // Check if it looks like JSON
313                if candidate.starts_with('{') || candidate.starts_with('[') {
314                    return Some(candidate);
315                }
316            }
317        }
318
319        None
320    }
321
322    /// Find JSON object or array anywhere in text
323    fn find_json_in_text(text: &str) -> Option<&str> {
324        // Find first { and last }
325        if let Some(start) = text.find('{') {
326            if let Some(end) = text.rfind('}') {
327                if end > start {
328                    return Some(&text[start..=end]);
329                }
330            }
331        }
332        None
333    }
334
335    /// Attempt to repair malformed JSON using jsonfixer
336    fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
337        // jsonfixer::repair_json returns Result<String, Error>
338        let options = jsonfixer::JsonRepairOptions::default();
339        let fixed_json =
340            jsonfixer::repair_json(json_str, options).map_err(|e| GraphRAGError::Generation {
341                message: format!("JSON repair failed: {:?}", e),
342            })?;
343
344        serde_json::from_str::<ExtractionOutput>(&fixed_json).map_err(|e| {
345            GraphRAGError::Generation {
346                message: format!("Failed to parse repaired JSON: {}", e),
347            }
348        })
349    }
350
351    /// Convert EntityData to domain Entity objects
352    fn convert_to_entities(
353        &self,
354        entity_data: &[EntityData],
355        chunk_id: &ChunkId,
356        chunk_text: &str,
357    ) -> Result<Vec<Entity>> {
358        let mut entities = Vec::new();
359
360        for entity_item in entity_data {
361            // Generate entity ID
362            let entity_id = EntityId::new(format!(
363                "{}_{}",
364                entity_item.entity_type,
365                self.normalize_name(&entity_item.name)
366            ));
367
368            // Find mentions in chunk
369            let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
370
371            // Create entity with mentions
372            // Note: Description is stored in the entity but not used in current Entity struct
373            // We store it in the entity name or as a separate field if needed
374            let entity = Entity::new(
375                entity_id,
376                entity_item.name.clone(),
377                entity_item.entity_type.clone(),
378                0.9, // High confidence since it's LLM-extracted
379            )
380            .with_mentions(mentions);
381
382            entities.push(entity);
383        }
384
385        Ok(entities)
386    }
387
388    /// Find all mentions of an entity name in the chunk text
389    fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
390        let mut mentions = Vec::new();
391        let mut start = 0;
392
393        while let Some(pos) = text[start..].find(name) {
394            let actual_pos = start + pos;
395            mentions.push(EntityMention {
396                chunk_id: chunk_id.clone(),
397                start_offset: actual_pos,
398                end_offset: actual_pos + name.len(),
399                confidence: 0.9,
400            });
401            start = actual_pos + name.len();
402        }
403
404        // If no exact matches, try case-insensitive
405        if mentions.is_empty() {
406            let name_lower = name.to_lowercase();
407            let text_lower = text.to_lowercase();
408            let mut start = 0;
409
410            while let Some(pos) = text_lower[start..].find(&name_lower) {
411                let actual_pos = start + pos;
412                mentions.push(EntityMention {
413                    chunk_id: chunk_id.clone(),
414                    start_offset: actual_pos,
415                    end_offset: actual_pos + name.len(),
416                    confidence: 0.85, // Slightly lower confidence for case-insensitive match
417                });
418                start = actual_pos + name.len();
419            }
420        }
421
422        mentions
423    }
424
425    /// Convert RelationshipData to domain Relationship objects
426    fn convert_to_relationships(
427        &self,
428        relationship_data: &[RelationshipData],
429        entities: &[Entity],
430    ) -> Result<Vec<Relationship>> {
431        let mut relationships = Vec::new();
432
433        // Build entity name to ID mapping
434        let mut name_to_entity: std::collections::HashMap<String, &Entity> =
435            std::collections::HashMap::new();
436        for entity in entities {
437            name_to_entity.insert(entity.name.to_lowercase(), entity);
438        }
439
440        for rel_item in relationship_data {
441            // Find source and target entities
442            let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
443            let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
444
445            if let (Some(source), Some(target)) = (source_entity, target_entity) {
446                let relationship = Relationship {
447                    source: source.id.clone(),
448                    target: target.id.clone(),
449                    relation_type: rel_item.description.clone(),
450                    confidence: rel_item.strength as f32,
451                    context: vec![], // No context chunks for this relationship
452                    embedding: None,
453                    temporal_type: None,
454                    temporal_range: None,
455                    causal_strength: None,
456                };
457
458                relationships.push(relationship);
459            } else {
460                tracing::warn!(
461                    "Skipping relationship: entity not found. Source: {}, Target: {}",
462                    rel_item.source,
463                    rel_item.target
464                );
465            }
466        }
467
468        Ok(relationships)
469    }
470
471    /// Normalize entity name for ID generation
472    fn normalize_name(&self, name: &str) -> String {
473        name.to_lowercase()
474            .chars()
475            .filter(|c| c.is_alphanumeric() || *c == '_')
476            .collect::<String>()
477            .replace(' ', "_")
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::{core::DocumentId, ollama::OllamaConfig};
485
486    fn create_test_chunk() -> TextChunk {
487        TextChunk::new(
488            ChunkId::new("chunk_001".to_string()),
489            DocumentId::new("doc_001".to_string()),
490            "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
491             Tom is best friends with Huckleberry Finn. They often go on adventures together."
492                .to_string(),
493            0,
494            150,
495        )
496    }
497
498    #[test]
499    fn test_extract_json_from_markdown() {
500        let markdown = r#"
501Here's the extraction:
502```json
503{
504  "entities": [],
505  "relationships": []
506}
507```
508"#;
509        let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
510        assert!(json.is_some());
511        assert!(json.unwrap().contains("entities"));
512    }
513
514    #[test]
515    fn test_find_json_in_text() {
516        let text = "Some text before { \"entities\": [] } some text after";
517        let json = LLMEntityExtractor::find_json_in_text(text);
518        assert!(json.is_some());
519        assert_eq!(json.unwrap(), "{ \"entities\": [] }");
520    }
521
522    #[test]
523    fn test_parse_valid_json() {
524        let ollama_config = OllamaConfig::default();
525        let ollama_client = OllamaClient::new(ollama_config);
526        let extractor = LLMEntityExtractor::new(
527            ollama_client,
528            vec!["PERSON".to_string(), "LOCATION".to_string()],
529        );
530
531        let response = r#"
532{
533  "entities": [
534    {
535      "name": "Tom Sawyer",
536      "type": "PERSON",
537      "description": "A young boy"
538    }
539  ],
540  "relationships": []
541}
542"#;
543
544        let result = extractor.parse_extraction_response(response);
545        assert!(result.is_ok());
546        let output = result.unwrap();
547        assert_eq!(output.entities.len(), 1);
548        assert_eq!(output.entities[0].name, "Tom Sawyer");
549    }
550
551    #[test]
552    fn test_convert_to_entities() {
553        let ollama_config = OllamaConfig::default();
554        let ollama_client = OllamaClient::new(ollama_config);
555        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
556
557        let chunk = create_test_chunk();
558        let entity_data = vec![EntityData {
559            name: "Tom Sawyer".to_string(),
560            entity_type: "PERSON".to_string(),
561            description: "A young boy".to_string(),
562        }];
563
564        let entities = extractor
565            .convert_to_entities(&entity_data, &chunk.id, &chunk.content)
566            .unwrap();
567
568        assert_eq!(entities.len(), 1);
569        assert_eq!(entities[0].name, "Tom Sawyer");
570        assert_eq!(entities[0].entity_type, "PERSON");
571        assert!(!entities[0].mentions.is_empty());
572    }
573
574    #[test]
575    fn test_find_mentions() {
576        let ollama_config = OllamaConfig::default();
577        let ollama_client = OllamaClient::new(ollama_config);
578        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
579
580        let chunk = create_test_chunk();
581        let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
582
583        assert!(!mentions.is_empty());
584        assert!(mentions.len() >= 2); // "Tom Sawyer" and "Tom is best friends"
585    }
586
587    #[test]
588    fn test_normalize_name() {
589        let ollama_config = OllamaConfig::default();
590        let ollama_client = OllamaClient::new(ollama_config);
591        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
592
593        assert_eq!(extractor.normalize_name("Tom Sawyer"), "tom_sawyer");
594        assert_eq!(extractor.normalize_name("New York City"), "new_york_city");
595        assert_eq!(extractor.normalize_name("Dr. Smith"), "dr_smith");
596    }
597}