Skip to main content

graphrag_core/entity/
llm_extractor.rs

1#![cfg_attr(not(feature = "async"), allow(unused_imports))]
2
3//! LLM-based entity and relationship extraction
4//!
5//! This module provides TRUE LLM-based extraction using Ollama or any other LLM service.
6//! Unlike pattern-based extraction, this uses actual language model inference to extract
7//! entities and relationships from text with deep semantic understanding.
8
9use crate::{
10    core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
11    entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
12    ollama::OllamaClient,
13    GraphRAGError, Result,
14};
15use serde_json;
16
17/// LLM-based entity extractor that uses actual language model calls
18pub struct LLMEntityExtractor {
19    #[cfg_attr(not(feature = "async"), allow(dead_code))]
20    ollama_client: OllamaClient,
21    #[cfg_attr(not(feature = "async"), allow(dead_code))]
22    prompt_builder: PromptBuilder,
23    temperature: f32,
24    max_tokens: usize,
25    /// keep_alive value forwarded to every Ollama request (e.g. "1h").
26    /// When set, Ollama keeps the model in VRAM between requests so the KV
27    /// cache for the static prompt prefix is preserved across all chunks.
28    keep_alive: Option<String>,
29}
30
31impl LLMEntityExtractor {
32    /// Create a new LLM-based entity extractor
33    ///
34    /// # Arguments
35    /// * `ollama_client` - Ollama client for LLM inference
36    /// * `entity_types` - List of entity types to extract (e.g., ["PERSON", "LOCATION", "ORGANIZATION"])
37    pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
38        Self {
39            ollama_client,
40            prompt_builder: PromptBuilder::new(entity_types),
41            temperature: 0.0, // Zero temperature for deterministic JSON extraction
42            max_tokens: 1500,
43            keep_alive: None,
44        }
45    }
46
47    /// Set temperature for LLM generation (default: 0.1)
48    pub fn with_temperature(mut self, temperature: f32) -> Self {
49        self.temperature = temperature;
50        self
51    }
52
53    /// Set maximum tokens for LLM generation (default: 1500)
54    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
55        self.max_tokens = max_tokens;
56        self
57    }
58
59    /// Set keep_alive for KV cache persistence (e.g. "1h", "30m").
60    pub fn with_keep_alive(mut self, keep_alive: Option<String>) -> Self {
61        self.keep_alive = keep_alive;
62        self
63    }
64
65    /// Estimate token count from text length (≈ 1 token per 4 chars).
66    pub fn estimate_tokens(text: &str) -> u32 {
67        (text.len() / 4) as u32
68    }
69
70    /// Calculate the required `num_ctx` for one entity-extraction LLM call.
71    ///
72    /// Formula (mirrors `ContextualEnricher::calculate_num_ctx`):
73    /// ```text
74    /// num_ctx = tokens(built_prompt)   // instructions + entity types + chunk text
75    ///         + max_output_tokens       // JSON with entities + relationships
76    ///         + 20% safety margin       // avoids silent truncation
77    /// ```
78    /// Result is rounded up to the nearest 1024 and clamped to [4096, 131072].
79    pub fn calculate_entity_num_ctx(built_prompt: &str, max_output_tokens: u32) -> u32 {
80        let prompt_tokens = Self::estimate_tokens(built_prompt);
81        let total = prompt_tokens + max_output_tokens;
82        let with_margin = (total as f32 * 1.20) as u32;
83        let rounded = ((with_margin + 1023) / 1024) * 1024;
84        rounded.clamp(4096, 131_072)
85    }
86
87    /// Extract entities and relationships from a text chunk using LLM
88    ///
89    /// This is the REAL extraction that makes actual LLM API calls.
90    /// Expected time: 15-30 seconds per chunk depending on chunk size and model.
91    #[cfg(feature = "async")]
92    pub async fn extract_from_chunk(
93        &self,
94        chunk: &TextChunk,
95    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
96        #[cfg(feature = "tracing")]
97        tracing::debug!(
98            "LLM extraction for chunk: {} (size: {} chars)",
99            chunk.id,
100            chunk.content.len()
101        );
102
103        // Build extraction prompt
104        let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
105
106        // Call LLM for extraction (THIS IS THE REAL LLM CALL!)
107        let llm_response = self.call_llm_with_retry(&prompt).await?;
108
109        // Parse response into structured data
110        let extraction_output = self.parse_extraction_response(&llm_response)?;
111
112        // Convert to domain entities and relationships
113        let entities =
114            self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
115        let relationships =
116            self.convert_to_relationships(&extraction_output.relationships, &entities)?;
117
118        #[cfg(feature = "tracing")]
119        tracing::info!(
120            "LLM extracted {} entities and {} relationships from chunk {}",
121            entities.len(),
122            relationships.len(),
123            chunk.id
124        );
125
126        Ok((entities, relationships))
127    }
128
129    /// Extract additional entities in a gleaning round (continuation)
130    ///
131    /// This is used after the initial extraction to catch missed entities.
132    #[cfg(feature = "async")]
133    pub async fn extract_additional(
134        &self,
135        chunk: &TextChunk,
136        previous_entities: &[EntityData],
137        previous_relationships: &[RelationshipData],
138    ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
139        #[cfg(feature = "tracing")]
140        tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
141
142        // Build continuation prompt with previous extraction
143        let prompt = self.prompt_builder.build_continuation_prompt(
144            &chunk.content,
145            previous_entities,
146            previous_relationships,
147        );
148
149        // Call LLM for additional extraction
150        let llm_response = self.call_llm_with_retry(&prompt).await?;
151
152        // Parse response
153        let extraction_output = self.parse_extraction_response(&llm_response)?;
154
155        // Convert to domain entities
156        let entities =
157            self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
158        let relationships =
159            self.convert_to_relationships(&extraction_output.relationships, &entities)?;
160
161        #[cfg(feature = "tracing")]
162        tracing::info!(
163            "LLM gleaning extracted {} additional entities and {} relationships",
164            entities.len(),
165            relationships.len()
166        );
167
168        Ok((entities, relationships))
169    }
170
171    /// Check if extraction is complete using LLM judgment
172    ///
173    /// Uses the LLM to determine if all significant entities have been extracted.
174    #[cfg(feature = "async")]
175    pub async fn check_completion(
176        &self,
177        chunk: &TextChunk,
178        entities: &[EntityData],
179        relationships: &[RelationshipData],
180    ) -> Result<bool> {
181        #[cfg(feature = "tracing")]
182        tracing::debug!("LLM completion check for chunk: {}", chunk.id);
183
184        // Build completion check prompt
185        let prompt =
186            self.prompt_builder
187                .build_completion_prompt(&chunk.content, entities, relationships);
188
189        // Call LLM with logit bias for YES/NO response
190        let llm_response = self.call_llm_completion_check(&prompt).await?;
191
192        // Parse YES/NO response
193        let response_trimmed = llm_response.trim().to_uppercase();
194        let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
195
196        #[cfg(feature = "tracing")]
197        tracing::debug!(
198            "LLM completion check result: {} (response: {})",
199            if is_complete {
200                "COMPLETE"
201            } else {
202                "INCOMPLETE"
203            },
204            llm_response.trim()
205        );
206
207        Ok(is_complete)
208    }
209
210    /// Call LLM with dynamic num_ctx and retry logic for extraction.
211    ///
212    /// `num_ctx` is calculated from the built prompt via
213    /// [`calculate_entity_num_ctx`] before this call, ensuring Ollama never
214    /// silently truncates the prompt. `keep_alive` keeps the model loaded in
215    /// VRAM between chunk requests so the KV cache is preserved.
216    #[cfg(feature = "async")]
217    async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
218        use crate::ollama::OllamaGenerationParams;
219        let num_ctx = Self::calculate_entity_num_ctx(prompt, self.max_tokens as u32);
220        #[cfg(feature = "tracing")]
221        tracing::debug!(
222            "Entity extraction: prompt_len={} num_ctx={} keep_alive={:?}",
223            prompt.len(),
224            num_ctx,
225            self.keep_alive,
226        );
227        let params = OllamaGenerationParams {
228            num_predict: Some(self.max_tokens as u32),
229            temperature: Some(self.temperature),
230            num_ctx: Some(num_ctx),
231            keep_alive: self.keep_alive.clone(),
232            ..Default::default()
233        };
234        match self
235            .ollama_client
236            .generate_with_params(prompt, params.clone())
237            .await
238        {
239            Ok(response) => Ok(response),
240            Err(e) => {
241                #[cfg(feature = "tracing")]
242                tracing::warn!("LLM call failed, retrying: {}", e);
243                tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
244                self.ollama_client
245                    .generate_with_params(prompt, params)
246                    .await
247            },
248        }
249    }
250
251    /// Call LLM for completion check (expects a short YES/NO answer).
252    ///
253    /// Uses a much smaller `num_predict` (50 tokens) and a deterministic
254    /// temperature of 0.0 to get a reliable YES/NO response.
255    #[cfg(feature = "async")]
256    async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
257        use crate::ollama::OllamaGenerationParams;
258        let num_ctx = Self::calculate_entity_num_ctx(prompt, 50);
259        let params = OllamaGenerationParams {
260            num_predict: Some(50),
261            temperature: Some(0.0),
262            num_ctx: Some(num_ctx),
263            keep_alive: self.keep_alive.clone(),
264            ..Default::default()
265        };
266        self.ollama_client
267            .generate_with_params(prompt, params)
268            .await
269    }
270
271    /// Parse LLM response into structured extraction output
272    ///
273    /// Handles multiple JSON formats and attempts repair if needed
274    #[cfg(feature = "async")]
275    fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
276        // Strategy 1: Try direct JSON parsing
277        if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
278            return Ok(output);
279        }
280
281        // Strategy 2: Try to extract JSON from markdown code blocks
282        if let Some(json_str) = Self::extract_json_from_markdown(response) {
283            if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
284                return Ok(output);
285            }
286        }
287
288        // Strategy 3: Try JSON repair using jsonfixer
289        match self.repair_and_parse_json(response) {
290            Ok(output) => return Ok(output),
291            Err(_e) => {
292                #[cfg(feature = "tracing")]
293                tracing::warn!("JSON repair failed: {}", _e);
294            },
295        }
296
297        // Strategy 4: Look for JSON anywhere in the response
298        if let Some(json_str) = Self::find_json_in_text(response) {
299            if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
300                return Ok(output);
301            }
302
303            // Try repairing the extracted JSON
304            if let Ok(output) = self.repair_and_parse_json(json_str) {
305                return Ok(output);
306            }
307        }
308
309        // If all strategies fail, return empty extraction
310        #[cfg(feature = "tracing")]
311        tracing::error!(
312            "Failed to parse LLM response as JSON. Response preview: {}",
313            &response.chars().take(200).collect::<String>()
314        );
315        Ok(ExtractionOutput {
316            entities: vec![],
317            relationships: vec![],
318        })
319    }
320
321    /// Extract JSON from markdown code blocks
322    #[cfg(feature = "async")]
323    fn extract_json_from_markdown(text: &str) -> Option<&str> {
324        // Look for ```json ... ``` or ``` ... ```
325        if let Some(start) = text.find("```json") {
326            let json_start = start + 7; // length of ```json
327            if let Some(end) = text[json_start..].find("```") {
328                return Some(text[json_start..json_start + end].trim());
329            }
330        }
331
332        if let Some(start) = text.find("```") {
333            let json_start = start + 3;
334            if let Some(end) = text[json_start..].find("```") {
335                let candidate = &text[json_start..json_start + end].trim();
336                // Check if it looks like JSON
337                if candidate.starts_with('{') || candidate.starts_with('[') {
338                    return Some(candidate);
339                }
340            }
341        }
342
343        None
344    }
345
346    /// Find JSON object or array anywhere in text
347    #[cfg(feature = "async")]
348    fn find_json_in_text(text: &str) -> Option<&str> {
349        // Find first { and last }
350        if let Some(start) = text.find('{') {
351            if let Some(end) = text.rfind('}') {
352                if end > start {
353                    return Some(&text[start..=end]);
354                }
355            }
356        }
357        None
358    }
359
360    /// Attempt to repair malformed JSON using jsonfixer
361    #[cfg(feature = "async")]
362    fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
363        // jsonfixer::repair_json returns Result<String, Error>
364        let options = jsonfixer::JsonRepairOptions::default();
365        let fixed_json =
366            jsonfixer::repair_json(json_str, options).map_err(|e| GraphRAGError::Generation {
367                message: format!("JSON repair failed: {:?}", e),
368            })?;
369
370        serde_json::from_str::<ExtractionOutput>(&fixed_json).map_err(|e| {
371            GraphRAGError::Generation {
372                message: format!("Failed to parse repaired JSON: {}", e),
373            }
374        })
375    }
376
377    /// Convert EntityData to domain Entity objects
378    #[cfg(feature = "async")]
379    fn convert_to_entities(
380        &self,
381        entity_data: &[EntityData],
382        chunk_id: &ChunkId,
383        chunk_text: &str,
384    ) -> Result<Vec<Entity>> {
385        let mut entities = Vec::new();
386
387        for entity_item in entity_data {
388            // Generate entity ID
389            let entity_id = EntityId::new(format!(
390                "{}_{}",
391                entity_item.entity_type,
392                self.normalize_name(&entity_item.name)
393            ));
394
395            // Find mentions in chunk
396            let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
397
398            // Create entity with mentions
399            // Note: Description is stored in the entity but not used in current Entity struct
400            // We store it in the entity name or as a separate field if needed
401            let entity = Entity::new(
402                entity_id,
403                entity_item.name.clone(),
404                entity_item.entity_type.clone(),
405                0.9, // High confidence since it's LLM-extracted
406            )
407            .with_mentions(mentions);
408
409            entities.push(entity);
410        }
411
412        Ok(entities)
413    }
414
415    /// Find all mentions of an entity name in the chunk text
416    #[cfg(feature = "async")]
417    fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
418        let mut mentions = Vec::new();
419        let mut start = 0;
420
421        while let Some(pos) = text[start..].find(name) {
422            let actual_pos = start + pos;
423            mentions.push(EntityMention {
424                chunk_id: chunk_id.clone(),
425                start_offset: actual_pos,
426                end_offset: actual_pos + name.len(),
427                confidence: 0.9,
428            });
429            start = actual_pos + name.len();
430        }
431
432        // If no exact matches, try case-insensitive
433        if mentions.is_empty() {
434            let name_lower = name.to_lowercase();
435            let text_lower = text.to_lowercase();
436            let mut start = 0;
437
438            while let Some(pos) = text_lower[start..].find(&name_lower) {
439                let actual_pos = start + pos;
440                mentions.push(EntityMention {
441                    chunk_id: chunk_id.clone(),
442                    start_offset: actual_pos,
443                    end_offset: actual_pos + name.len(),
444                    confidence: 0.85, // Slightly lower confidence for case-insensitive match
445                });
446                start = actual_pos + name.len();
447            }
448        }
449
450        mentions
451    }
452
453    /// Convert RelationshipData to domain Relationship objects
454    #[cfg(feature = "async")]
455    fn convert_to_relationships(
456        &self,
457        relationship_data: &[RelationshipData],
458        entities: &[Entity],
459    ) -> Result<Vec<Relationship>> {
460        let mut relationships = Vec::new();
461
462        // Build entity name to ID mapping
463        let mut name_to_entity: std::collections::HashMap<String, &Entity> =
464            std::collections::HashMap::new();
465        for entity in entities {
466            name_to_entity.insert(entity.name.to_lowercase(), entity);
467        }
468
469        for rel_item in relationship_data {
470            // Find source and target entities
471            let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
472            let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
473
474            if let (Some(source), Some(target)) = (source_entity, target_entity) {
475                let relationship = Relationship {
476                    source: source.id.clone(),
477                    target: target.id.clone(),
478                    relation_type: rel_item.description.clone(),
479                    confidence: rel_item.strength as f32,
480                    context: vec![], // No context chunks for this relationship
481                    embedding: None,
482                    temporal_type: None,
483                    temporal_range: None,
484                    causal_strength: None,
485                };
486
487                relationships.push(relationship);
488            } else {
489                #[cfg(feature = "tracing")]
490                tracing::warn!(
491                    "Skipping relationship: entity not found. Source: {}, Target: {}",
492                    rel_item.source,
493                    rel_item.target
494                );
495            }
496        }
497
498        Ok(relationships)
499    }
500
501    /// Normalize entity name for ID generation
502    #[cfg(feature = "async")]
503    fn normalize_name(&self, name: &str) -> String {
504        name.to_lowercase()
505            .chars()
506            .filter(|c| c.is_alphanumeric() || *c == '_')
507            .collect::<String>()
508            .replace(' ', "_")
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::{core::DocumentId, ollama::OllamaConfig};
516
517    fn create_test_chunk() -> TextChunk {
518        TextChunk::new(
519            ChunkId::new("chunk_001".to_string()),
520            DocumentId::new("doc_001".to_string()),
521            "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
522             Tom is best friends with Huckleberry Finn. They often go on adventures together."
523                .to_string(),
524            0,
525            150,
526        )
527    }
528
529    #[test]
530    fn test_extract_json_from_markdown() {
531        let markdown = r#"
532Here's the extraction:
533```json
534{
535  "entities": [],
536  "relationships": []
537}
538```
539"#;
540        let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
541        assert!(json.is_some());
542        assert!(json.unwrap().contains("entities"));
543    }
544
545    #[test]
546    fn test_find_json_in_text() {
547        let text = "Some text before { \"entities\": [] } some text after";
548        let json = LLMEntityExtractor::find_json_in_text(text);
549        assert!(json.is_some());
550        assert_eq!(json.unwrap(), "{ \"entities\": [] }");
551    }
552
553    #[test]
554    fn test_parse_valid_json() {
555        let ollama_config = OllamaConfig::default();
556        let ollama_client = OllamaClient::new(ollama_config);
557        let extractor = LLMEntityExtractor::new(
558            ollama_client,
559            vec!["PERSON".to_string(), "LOCATION".to_string()],
560        );
561
562        let response = r#"
563{
564  "entities": [
565    {
566      "name": "Tom Sawyer",
567      "type": "PERSON",
568      "description": "A young boy"
569    }
570  ],
571  "relationships": []
572}
573"#;
574
575        let result = extractor.parse_extraction_response(response);
576        assert!(result.is_ok());
577        let output = result.unwrap();
578        assert_eq!(output.entities.len(), 1);
579        assert_eq!(output.entities[0].name, "Tom Sawyer");
580    }
581
582    #[test]
583    fn test_convert_to_entities() {
584        let ollama_config = OllamaConfig::default();
585        let ollama_client = OllamaClient::new(ollama_config);
586        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
587
588        let chunk = create_test_chunk();
589        let entity_data = vec![EntityData {
590            name: "Tom Sawyer".to_string(),
591            entity_type: "PERSON".to_string(),
592            description: "A young boy".to_string(),
593        }];
594
595        let entities = extractor
596            .convert_to_entities(&entity_data, &chunk.id, &chunk.content)
597            .unwrap();
598
599        assert_eq!(entities.len(), 1);
600        assert_eq!(entities[0].name, "Tom Sawyer");
601        assert_eq!(entities[0].entity_type, "PERSON");
602        assert!(!entities[0].mentions.is_empty());
603    }
604
605    #[test]
606    fn test_find_mentions() {
607        let ollama_config = OllamaConfig::default();
608        let ollama_client = OllamaClient::new(ollama_config);
609        let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
610
611        let chunk = create_test_chunk();
612        let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
613
614        assert!(!mentions.is_empty());
615        assert!(mentions.len() >= 2); // "Tom Sawyer" and "Tom is best friends"
616    }
617}