1use crate::{
8 core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
9 entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
10 ollama::OllamaClient,
11 Result, GraphRAGError,
12};
13use serde_json;
14
15pub struct LLMEntityExtractor {
17 ollama_client: OllamaClient,
18 prompt_builder: PromptBuilder,
19 temperature: f32,
20 max_tokens: usize,
21}
22
23impl LLMEntityExtractor {
24 pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
30 Self {
31 ollama_client,
32 prompt_builder: PromptBuilder::new(entity_types),
33 temperature: 0.1, max_tokens: 1500,
35 }
36 }
37
38 pub fn with_temperature(mut self, temperature: f32) -> Self {
40 self.temperature = temperature;
41 self
42 }
43
44 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
46 self.max_tokens = max_tokens;
47 self
48 }
49
50 #[cfg(feature = "async")]
55 pub async fn extract_from_chunk(
56 &self,
57 chunk: &TextChunk,
58 ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
59 tracing::debug!("LLM extraction for chunk: {} (size: {} chars)", chunk.id, chunk.content.len());
60
61 let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
63
64 let llm_response = self.call_llm_with_retry(&prompt).await?;
66
67 let extraction_output = self.parse_extraction_response(&llm_response)?;
69
70 let entities = self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
72 let relationships = self.convert_to_relationships(&extraction_output.relationships, &entities)?;
73
74 tracing::info!(
75 "LLM extracted {} entities and {} relationships from chunk {}",
76 entities.len(),
77 relationships.len(),
78 chunk.id
79 );
80
81 Ok((entities, relationships))
82 }
83
84 #[cfg(feature = "async")]
88 pub async fn extract_additional(
89 &self,
90 chunk: &TextChunk,
91 previous_entities: &[EntityData],
92 previous_relationships: &[RelationshipData],
93 ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
94 tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
95
96 let prompt = self.prompt_builder.build_continuation_prompt(
98 &chunk.content,
99 previous_entities,
100 previous_relationships,
101 );
102
103 let llm_response = self.call_llm_with_retry(&prompt).await?;
105
106 let extraction_output = self.parse_extraction_response(&llm_response)?;
108
109 let entities = self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
111 let relationships = self.convert_to_relationships(&extraction_output.relationships, &entities)?;
112
113 tracing::info!(
114 "LLM gleaning extracted {} additional entities and {} relationships",
115 entities.len(),
116 relationships.len()
117 );
118
119 Ok((entities, relationships))
120 }
121
122 #[cfg(feature = "async")]
126 pub async fn check_completion(
127 &self,
128 chunk: &TextChunk,
129 entities: &[EntityData],
130 relationships: &[RelationshipData],
131 ) -> Result<bool> {
132 tracing::debug!("LLM completion check for chunk: {}", chunk.id);
133
134 let prompt = self.prompt_builder.build_completion_prompt(
136 &chunk.content,
137 entities,
138 relationships,
139 );
140
141 let llm_response = self.call_llm_completion_check(&prompt).await?;
143
144 let response_trimmed = llm_response.trim().to_uppercase();
146 let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
147
148 tracing::debug!(
149 "LLM completion check result: {} (response: {})",
150 if is_complete { "COMPLETE" } else { "INCOMPLETE" },
151 llm_response.trim()
152 );
153
154 Ok(is_complete)
155 }
156
157 #[cfg(feature = "async")]
159 async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
160 match self.ollama_client.generate(prompt).await {
163 Ok(response) => Ok(response),
164 Err(e) => {
165 tracing::warn!("LLM call failed, retrying: {}", e);
166 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
168 self.ollama_client.generate(prompt).await
169 }
170 }
171 }
172
173 #[cfg(feature = "async")]
175 async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
176 self.ollama_client.generate(prompt).await
179 }
180
181 fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
185 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
187 return Ok(output);
188 }
189
190 if let Some(json_str) = Self::extract_json_from_markdown(response) {
192 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
193 return Ok(output);
194 }
195 }
196
197 match self.repair_and_parse_json(response) {
199 Ok(output) => return Ok(output),
200 Err(e) => {
201 tracing::warn!("JSON repair failed: {}", e);
202 }
203 }
204
205 if let Some(json_str) = Self::find_json_in_text(response) {
207 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
208 return Ok(output);
209 }
210
211 if let Ok(output) = self.repair_and_parse_json(json_str) {
213 return Ok(output);
214 }
215 }
216
217 tracing::error!("Failed to parse LLM response as JSON. Response preview: {}", &response.chars().take(200).collect::<String>());
219 Ok(ExtractionOutput {
220 entities: vec![],
221 relationships: vec![],
222 })
223 }
224
225 fn extract_json_from_markdown(text: &str) -> Option<&str> {
227 if let Some(start) = text.find("```json") {
229 let json_start = start + 7; if let Some(end) = text[json_start..].find("```") {
231 return Some(&text[json_start..json_start + end].trim());
232 }
233 }
234
235 if let Some(start) = text.find("```") {
236 let json_start = start + 3;
237 if let Some(end) = text[json_start..].find("```") {
238 let candidate = &text[json_start..json_start + end].trim();
239 if candidate.starts_with('{') || candidate.starts_with('[') {
241 return Some(candidate);
242 }
243 }
244 }
245
246 None
247 }
248
249 fn find_json_in_text(text: &str) -> Option<&str> {
251 if let Some(start) = text.find('{') {
253 if let Some(end) = text.rfind('}') {
254 if end > start {
255 return Some(&text[start..=end]);
256 }
257 }
258 }
259 None
260 }
261
262 fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
264 let options = jsonfixer::JsonRepairOptions::default();
266 let fixed_json = jsonfixer::repair_json(json_str, options)
267 .map_err(|e| GraphRAGError::Generation {
268 message: format!("JSON repair failed: {:?}", e),
269 })?;
270
271 serde_json::from_str::<ExtractionOutput>(&fixed_json)
272 .map_err(|e| GraphRAGError::Generation {
273 message: format!("Failed to parse repaired JSON: {}", e),
274 })
275 }
276
277 fn convert_to_entities(
279 &self,
280 entity_data: &[EntityData],
281 chunk_id: &ChunkId,
282 chunk_text: &str,
283 ) -> Result<Vec<Entity>> {
284 let mut entities = Vec::new();
285
286 for data in entity_data {
287 let entity_id = EntityId::new(format!(
289 "{}_{}",
290 data.entity_type,
291 self.normalize_name(&data.name)
292 ));
293
294 let mentions = self.find_mentions(&data.name, chunk_id, chunk_text);
296
297 let entity = Entity::new(
301 entity_id,
302 data.name.clone(),
303 data.entity_type.clone(),
304 0.9, )
306 .with_mentions(mentions);
307
308 entities.push(entity);
309 }
310
311 Ok(entities)
312 }
313
314 fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
316 let mut mentions = Vec::new();
317 let mut start = 0;
318
319 while let Some(pos) = text[start..].find(name) {
320 let actual_pos = start + pos;
321 mentions.push(EntityMention {
322 chunk_id: chunk_id.clone(),
323 start_offset: actual_pos,
324 end_offset: actual_pos + name.len(),
325 confidence: 0.9,
326 });
327 start = actual_pos + name.len();
328 }
329
330 if mentions.is_empty() {
332 let name_lower = name.to_lowercase();
333 let text_lower = text.to_lowercase();
334 let mut start = 0;
335
336 while let Some(pos) = text_lower[start..].find(&name_lower) {
337 let actual_pos = start + pos;
338 mentions.push(EntityMention {
339 chunk_id: chunk_id.clone(),
340 start_offset: actual_pos,
341 end_offset: actual_pos + name.len(),
342 confidence: 0.85, });
344 start = actual_pos + name.len();
345 }
346 }
347
348 mentions
349 }
350
351 fn convert_to_relationships(
353 &self,
354 relationship_data: &[RelationshipData],
355 entities: &[Entity],
356 ) -> Result<Vec<Relationship>> {
357 let mut relationships = Vec::new();
358
359 let mut name_to_entity: std::collections::HashMap<String, &Entity> = std::collections::HashMap::new();
361 for entity in entities {
362 name_to_entity.insert(entity.name.to_lowercase(), entity);
363 }
364
365 for data in relationship_data {
366 let source_entity = name_to_entity.get(&data.source.to_lowercase());
368 let target_entity = name_to_entity.get(&data.target.to_lowercase());
369
370 if let (Some(source), Some(target)) = (source_entity, target_entity) {
371 let relationship = Relationship {
372 source: source.id.clone(),
373 target: target.id.clone(),
374 relation_type: data.description.clone(),
375 confidence: data.strength as f32,
376 context: vec![], };
378
379 relationships.push(relationship);
380 } else {
381 tracing::warn!(
382 "Skipping relationship: entity not found. Source: {}, Target: {}",
383 data.source,
384 data.target
385 );
386 }
387 }
388
389 Ok(relationships)
390 }
391
392 fn normalize_name(&self, name: &str) -> String {
394 name.to_lowercase()
395 .chars()
396 .filter(|c| c.is_alphanumeric() || *c == '_')
397 .collect::<String>()
398 .replace(' ', "_")
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::{core::DocumentId, ollama::OllamaConfig};
406
407 fn create_test_chunk() -> TextChunk {
408 TextChunk::new(
409 ChunkId::new("chunk_001".to_string()),
410 DocumentId::new("doc_001".to_string()),
411 "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
412 Tom is best friends with Huckleberry Finn. They often go on adventures together."
413 .to_string(),
414 0,
415 150,
416 )
417 }
418
419 #[test]
420 fn test_extract_json_from_markdown() {
421 let markdown = r#"
422Here's the extraction:
423```json
424{
425 "entities": [],
426 "relationships": []
427}
428```
429"#;
430 let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
431 assert!(json.is_some());
432 assert!(json.unwrap().contains("entities"));
433 }
434
435 #[test]
436 fn test_find_json_in_text() {
437 let text = "Some text before { \"entities\": [] } some text after";
438 let json = LLMEntityExtractor::find_json_in_text(text);
439 assert!(json.is_some());
440 assert_eq!(json.unwrap(), "{ \"entities\": [] }");
441 }
442
443 #[test]
444 fn test_parse_valid_json() {
445 let ollama_config = OllamaConfig::default();
446 let ollama_client = OllamaClient::new(ollama_config);
447 let extractor = LLMEntityExtractor::new(
448 ollama_client,
449 vec!["PERSON".to_string(), "LOCATION".to_string()],
450 );
451
452 let response = r#"
453{
454 "entities": [
455 {
456 "name": "Tom Sawyer",
457 "type": "PERSON",
458 "description": "A young boy"
459 }
460 ],
461 "relationships": []
462}
463"#;
464
465 let result = extractor.parse_extraction_response(response);
466 assert!(result.is_ok());
467 let output = result.unwrap();
468 assert_eq!(output.entities.len(), 1);
469 assert_eq!(output.entities[0].name, "Tom Sawyer");
470 }
471
472 #[test]
473 fn test_convert_to_entities() {
474 let ollama_config = OllamaConfig::default();
475 let ollama_client = OllamaClient::new(ollama_config);
476 let extractor = LLMEntityExtractor::new(
477 ollama_client,
478 vec!["PERSON".to_string()],
479 );
480
481 let chunk = create_test_chunk();
482 let entity_data = vec![EntityData {
483 name: "Tom Sawyer".to_string(),
484 entity_type: "PERSON".to_string(),
485 description: "A young boy".to_string(),
486 }];
487
488 let entities = extractor
489 .convert_to_entities(&entity_data, &chunk.id, &chunk.content)
490 .unwrap();
491
492 assert_eq!(entities.len(), 1);
493 assert_eq!(entities[0].name, "Tom Sawyer");
494 assert_eq!(entities[0].entity_type, "PERSON");
495 assert!(!entities[0].mentions.is_empty());
496 }
497
498 #[test]
499 fn test_find_mentions() {
500 let ollama_config = OllamaConfig::default();
501 let ollama_client = OllamaClient::new(ollama_config);
502 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
503
504 let chunk = create_test_chunk();
505 let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
506
507 assert!(!mentions.is_empty());
508 assert!(mentions.len() >= 2); }
510
511 #[test]
512 fn test_normalize_name() {
513 let ollama_config = OllamaConfig::default();
514 let ollama_client = OllamaClient::new(ollama_config);
515 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
516
517 assert_eq!(extractor.normalize_name("Tom Sawyer"), "tom_sawyer");
518 assert_eq!(extractor.normalize_name("New York City"), "new_york_city");
519 assert_eq!(extractor.normalize_name("Dr. Smith"), "dr_smith");
520 }
521}