1use crate::{
8 core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
9 entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
10 ollama::OllamaClient,
11 GraphRAGError, Result,
12};
13use serde_json;
14
15pub struct LLMEntityExtractor {
17 ollama_client: OllamaClient,
18 prompt_builder: PromptBuilder,
19 temperature: f32,
20 max_tokens: usize,
21 keep_alive: Option<String>,
25}
26
27impl LLMEntityExtractor {
28 pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
34 Self {
35 ollama_client,
36 prompt_builder: PromptBuilder::new(entity_types),
37 temperature: 0.0, max_tokens: 1500,
39 keep_alive: None,
40 }
41 }
42
43 pub fn with_temperature(mut self, temperature: f32) -> Self {
45 self.temperature = temperature;
46 self
47 }
48
49 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
51 self.max_tokens = max_tokens;
52 self
53 }
54
55 pub fn with_keep_alive(mut self, keep_alive: Option<String>) -> Self {
57 self.keep_alive = keep_alive;
58 self
59 }
60
61 pub fn estimate_tokens(text: &str) -> u32 {
63 (text.len() / 4) as u32
64 }
65
66 pub fn calculate_entity_num_ctx(built_prompt: &str, max_output_tokens: u32) -> u32 {
76 let prompt_tokens = Self::estimate_tokens(built_prompt);
77 let total = prompt_tokens + max_output_tokens;
78 let with_margin = (total as f32 * 1.20) as u32;
79 let rounded = ((with_margin + 1023) / 1024) * 1024;
80 rounded.max(4096).min(131_072)
81 }
82
83 #[cfg(feature = "async")]
88 pub async fn extract_from_chunk(
89 &self,
90 chunk: &TextChunk,
91 ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
92 tracing::debug!(
93 "LLM extraction for chunk: {} (size: {} chars)",
94 chunk.id,
95 chunk.content.len()
96 );
97
98 let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
100
101 let llm_response = self.call_llm_with_retry(&prompt).await?;
103
104 let extraction_output = self.parse_extraction_response(&llm_response)?;
106
107 let entities =
109 self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
110 let relationships =
111 self.convert_to_relationships(&extraction_output.relationships, &entities)?;
112
113 tracing::info!(
114 "LLM extracted {} entities and {} relationships from chunk {}",
115 entities.len(),
116 relationships.len(),
117 chunk.id
118 );
119
120 Ok((entities, relationships))
121 }
122
123 #[cfg(feature = "async")]
127 pub async fn extract_additional(
128 &self,
129 chunk: &TextChunk,
130 previous_entities: &[EntityData],
131 previous_relationships: &[RelationshipData],
132 ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
133 tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
134
135 let prompt = self.prompt_builder.build_continuation_prompt(
137 &chunk.content,
138 previous_entities,
139 previous_relationships,
140 );
141
142 let llm_response = self.call_llm_with_retry(&prompt).await?;
144
145 let extraction_output = self.parse_extraction_response(&llm_response)?;
147
148 let entities =
150 self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
151 let relationships =
152 self.convert_to_relationships(&extraction_output.relationships, &entities)?;
153
154 tracing::info!(
155 "LLM gleaning extracted {} additional entities and {} relationships",
156 entities.len(),
157 relationships.len()
158 );
159
160 Ok((entities, relationships))
161 }
162
163 #[cfg(feature = "async")]
167 pub async fn check_completion(
168 &self,
169 chunk: &TextChunk,
170 entities: &[EntityData],
171 relationships: &[RelationshipData],
172 ) -> Result<bool> {
173 tracing::debug!("LLM completion check for chunk: {}", chunk.id);
174
175 let prompt =
177 self.prompt_builder
178 .build_completion_prompt(&chunk.content, entities, relationships);
179
180 let llm_response = self.call_llm_completion_check(&prompt).await?;
182
183 let response_trimmed = llm_response.trim().to_uppercase();
185 let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
186
187 tracing::debug!(
188 "LLM completion check result: {} (response: {})",
189 if is_complete {
190 "COMPLETE"
191 } else {
192 "INCOMPLETE"
193 },
194 llm_response.trim()
195 );
196
197 Ok(is_complete)
198 }
199
200 #[cfg(feature = "async")]
207 async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
208 use crate::ollama::OllamaGenerationParams;
209 let num_ctx = Self::calculate_entity_num_ctx(prompt, self.max_tokens as u32);
210 tracing::debug!(
211 "Entity extraction: prompt_len={} num_ctx={} keep_alive={:?}",
212 prompt.len(),
213 num_ctx,
214 self.keep_alive,
215 );
216 let params = OllamaGenerationParams {
217 num_predict: Some(self.max_tokens as u32),
218 temperature: Some(self.temperature),
219 num_ctx: Some(num_ctx),
220 keep_alive: self.keep_alive.clone(),
221 ..Default::default()
222 };
223 match self.ollama_client.generate_with_params(prompt, params.clone()).await {
224 Ok(response) => Ok(response),
225 Err(e) => {
226 tracing::warn!("LLM call failed, retrying: {}", e);
227 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
228 self.ollama_client.generate_with_params(prompt, params).await
229 },
230 }
231 }
232
233 #[cfg(feature = "async")]
238 async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
239 use crate::ollama::OllamaGenerationParams;
240 let num_ctx = Self::calculate_entity_num_ctx(prompt, 50);
241 let params = OllamaGenerationParams {
242 num_predict: Some(50),
243 temperature: Some(0.0),
244 num_ctx: Some(num_ctx),
245 keep_alive: self.keep_alive.clone(),
246 ..Default::default()
247 };
248 self.ollama_client.generate_with_params(prompt, params).await
249 }
250
251 fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
255 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
257 return Ok(output);
258 }
259
260 if let Some(json_str) = Self::extract_json_from_markdown(response) {
262 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
263 return Ok(output);
264 }
265 }
266
267 match self.repair_and_parse_json(response) {
269 Ok(output) => return Ok(output),
270 Err(e) => {
271 tracing::warn!("JSON repair failed: {}", e);
272 },
273 }
274
275 if let Some(json_str) = Self::find_json_in_text(response) {
277 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
278 return Ok(output);
279 }
280
281 if let Ok(output) = self.repair_and_parse_json(json_str) {
283 return Ok(output);
284 }
285 }
286
287 tracing::error!(
289 "Failed to parse LLM response as JSON. Response preview: {}",
290 &response.chars().take(200).collect::<String>()
291 );
292 Ok(ExtractionOutput {
293 entities: vec![],
294 relationships: vec![],
295 })
296 }
297
298 fn extract_json_from_markdown(text: &str) -> Option<&str> {
300 if let Some(start) = text.find("```json") {
302 let json_start = start + 7; if let Some(end) = text[json_start..].find("```") {
304 return Some(text[json_start..json_start + end].trim());
305 }
306 }
307
308 if let Some(start) = text.find("```") {
309 let json_start = start + 3;
310 if let Some(end) = text[json_start..].find("```") {
311 let candidate = &text[json_start..json_start + end].trim();
312 if candidate.starts_with('{') || candidate.starts_with('[') {
314 return Some(candidate);
315 }
316 }
317 }
318
319 None
320 }
321
322 fn find_json_in_text(text: &str) -> Option<&str> {
324 if let Some(start) = text.find('{') {
326 if let Some(end) = text.rfind('}') {
327 if end > start {
328 return Some(&text[start..=end]);
329 }
330 }
331 }
332 None
333 }
334
335 fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
337 let options = jsonfixer::JsonRepairOptions::default();
339 let fixed_json =
340 jsonfixer::repair_json(json_str, options).map_err(|e| GraphRAGError::Generation {
341 message: format!("JSON repair failed: {:?}", e),
342 })?;
343
344 serde_json::from_str::<ExtractionOutput>(&fixed_json).map_err(|e| {
345 GraphRAGError::Generation {
346 message: format!("Failed to parse repaired JSON: {}", e),
347 }
348 })
349 }
350
351 fn convert_to_entities(
353 &self,
354 entity_data: &[EntityData],
355 chunk_id: &ChunkId,
356 chunk_text: &str,
357 ) -> Result<Vec<Entity>> {
358 let mut entities = Vec::new();
359
360 for entity_item in entity_data {
361 let entity_id = EntityId::new(format!(
363 "{}_{}",
364 entity_item.entity_type,
365 self.normalize_name(&entity_item.name)
366 ));
367
368 let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
370
371 let entity = Entity::new(
375 entity_id,
376 entity_item.name.clone(),
377 entity_item.entity_type.clone(),
378 0.9, )
380 .with_mentions(mentions);
381
382 entities.push(entity);
383 }
384
385 Ok(entities)
386 }
387
388 fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
390 let mut mentions = Vec::new();
391 let mut start = 0;
392
393 while let Some(pos) = text[start..].find(name) {
394 let actual_pos = start + pos;
395 mentions.push(EntityMention {
396 chunk_id: chunk_id.clone(),
397 start_offset: actual_pos,
398 end_offset: actual_pos + name.len(),
399 confidence: 0.9,
400 });
401 start = actual_pos + name.len();
402 }
403
404 if mentions.is_empty() {
406 let name_lower = name.to_lowercase();
407 let text_lower = text.to_lowercase();
408 let mut start = 0;
409
410 while let Some(pos) = text_lower[start..].find(&name_lower) {
411 let actual_pos = start + pos;
412 mentions.push(EntityMention {
413 chunk_id: chunk_id.clone(),
414 start_offset: actual_pos,
415 end_offset: actual_pos + name.len(),
416 confidence: 0.85, });
418 start = actual_pos + name.len();
419 }
420 }
421
422 mentions
423 }
424
425 fn convert_to_relationships(
427 &self,
428 relationship_data: &[RelationshipData],
429 entities: &[Entity],
430 ) -> Result<Vec<Relationship>> {
431 let mut relationships = Vec::new();
432
433 let mut name_to_entity: std::collections::HashMap<String, &Entity> =
435 std::collections::HashMap::new();
436 for entity in entities {
437 name_to_entity.insert(entity.name.to_lowercase(), entity);
438 }
439
440 for rel_item in relationship_data {
441 let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
443 let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
444
445 if let (Some(source), Some(target)) = (source_entity, target_entity) {
446 let relationship = Relationship {
447 source: source.id.clone(),
448 target: target.id.clone(),
449 relation_type: rel_item.description.clone(),
450 confidence: rel_item.strength as f32,
451 context: vec![], embedding: None,
453 temporal_type: None,
454 temporal_range: None,
455 causal_strength: None,
456 };
457
458 relationships.push(relationship);
459 } else {
460 tracing::warn!(
461 "Skipping relationship: entity not found. Source: {}, Target: {}",
462 rel_item.source,
463 rel_item.target
464 );
465 }
466 }
467
468 Ok(relationships)
469 }
470
471 fn normalize_name(&self, name: &str) -> String {
473 name.to_lowercase()
474 .chars()
475 .filter(|c| c.is_alphanumeric() || *c == '_')
476 .collect::<String>()
477 .replace(' ', "_")
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::{core::DocumentId, ollama::OllamaConfig};
485
486 fn create_test_chunk() -> TextChunk {
487 TextChunk::new(
488 ChunkId::new("chunk_001".to_string()),
489 DocumentId::new("doc_001".to_string()),
490 "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
491 Tom is best friends with Huckleberry Finn. They often go on adventures together."
492 .to_string(),
493 0,
494 150,
495 )
496 }
497
498 #[test]
499 fn test_extract_json_from_markdown() {
500 let markdown = r#"
501Here's the extraction:
502```json
503{
504 "entities": [],
505 "relationships": []
506}
507```
508"#;
509 let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
510 assert!(json.is_some());
511 assert!(json.unwrap().contains("entities"));
512 }
513
514 #[test]
515 fn test_find_json_in_text() {
516 let text = "Some text before { \"entities\": [] } some text after";
517 let json = LLMEntityExtractor::find_json_in_text(text);
518 assert!(json.is_some());
519 assert_eq!(json.unwrap(), "{ \"entities\": [] }");
520 }
521
522 #[test]
523 fn test_parse_valid_json() {
524 let ollama_config = OllamaConfig::default();
525 let ollama_client = OllamaClient::new(ollama_config);
526 let extractor = LLMEntityExtractor::new(
527 ollama_client,
528 vec!["PERSON".to_string(), "LOCATION".to_string()],
529 );
530
531 let response = r#"
532{
533 "entities": [
534 {
535 "name": "Tom Sawyer",
536 "type": "PERSON",
537 "description": "A young boy"
538 }
539 ],
540 "relationships": []
541}
542"#;
543
544 let result = extractor.parse_extraction_response(response);
545 assert!(result.is_ok());
546 let output = result.unwrap();
547 assert_eq!(output.entities.len(), 1);
548 assert_eq!(output.entities[0].name, "Tom Sawyer");
549 }
550
551 #[test]
552 fn test_convert_to_entities() {
553 let ollama_config = OllamaConfig::default();
554 let ollama_client = OllamaClient::new(ollama_config);
555 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
556
557 let chunk = create_test_chunk();
558 let entity_data = vec![EntityData {
559 name: "Tom Sawyer".to_string(),
560 entity_type: "PERSON".to_string(),
561 description: "A young boy".to_string(),
562 }];
563
564 let entities = extractor
565 .convert_to_entities(&entity_data, &chunk.id, &chunk.content)
566 .unwrap();
567
568 assert_eq!(entities.len(), 1);
569 assert_eq!(entities[0].name, "Tom Sawyer");
570 assert_eq!(entities[0].entity_type, "PERSON");
571 assert!(!entities[0].mentions.is_empty());
572 }
573
574 #[test]
575 fn test_find_mentions() {
576 let ollama_config = OllamaConfig::default();
577 let ollama_client = OllamaClient::new(ollama_config);
578 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
579
580 let chunk = create_test_chunk();
581 let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
582
583 assert!(!mentions.is_empty());
584 assert!(mentions.len() >= 2); }
586
587 #[test]
588 fn test_normalize_name() {
589 let ollama_config = OllamaConfig::default();
590 let ollama_client = OllamaClient::new(ollama_config);
591 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
592
593 assert_eq!(extractor.normalize_name("Tom Sawyer"), "tom_sawyer");
594 assert_eq!(extractor.normalize_name("New York City"), "new_york_city");
595 assert_eq!(extractor.normalize_name("Dr. Smith"), "dr_smith");
596 }
597}