1#![cfg_attr(not(feature = "async"), allow(unused_imports))]
2
3use crate::{
10 core::{ChunkId, Entity, EntityId, EntityMention, Relationship, TextChunk},
11 entity::prompts::{EntityData, ExtractionOutput, PromptBuilder, RelationshipData},
12 ollama::OllamaClient,
13 GraphRAGError, Result,
14};
15use serde_json;
16
17pub struct LLMEntityExtractor {
19 #[cfg_attr(not(feature = "async"), allow(dead_code))]
20 ollama_client: OllamaClient,
21 #[cfg_attr(not(feature = "async"), allow(dead_code))]
22 prompt_builder: PromptBuilder,
23 temperature: f32,
24 max_tokens: usize,
25 keep_alive: Option<String>,
29}
30
31impl LLMEntityExtractor {
32 pub fn new(ollama_client: OllamaClient, entity_types: Vec<String>) -> Self {
38 Self {
39 ollama_client,
40 prompt_builder: PromptBuilder::new(entity_types),
41 temperature: 0.0, max_tokens: 1500,
43 keep_alive: None,
44 }
45 }
46
47 pub fn with_temperature(mut self, temperature: f32) -> Self {
49 self.temperature = temperature;
50 self
51 }
52
53 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
55 self.max_tokens = max_tokens;
56 self
57 }
58
59 pub fn with_keep_alive(mut self, keep_alive: Option<String>) -> Self {
61 self.keep_alive = keep_alive;
62 self
63 }
64
65 pub fn estimate_tokens(text: &str) -> u32 {
67 (text.len() / 4) as u32
68 }
69
70 pub fn calculate_entity_num_ctx(built_prompt: &str, max_output_tokens: u32) -> u32 {
80 let prompt_tokens = Self::estimate_tokens(built_prompt);
81 let total = prompt_tokens + max_output_tokens;
82 let with_margin = (total as f32 * 1.20) as u32;
83 let rounded = ((with_margin + 1023) / 1024) * 1024;
84 rounded.clamp(4096, 131_072)
85 }
86
87 #[cfg(feature = "async")]
92 pub async fn extract_from_chunk(
93 &self,
94 chunk: &TextChunk,
95 ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
96 #[cfg(feature = "tracing")]
97 tracing::debug!(
98 "LLM extraction for chunk: {} (size: {} chars)",
99 chunk.id,
100 chunk.content.len()
101 );
102
103 let prompt = self.prompt_builder.build_extraction_prompt(&chunk.content);
105
106 let llm_response = self.call_llm_with_retry(&prompt).await?;
108
109 let extraction_output = self.parse_extraction_response(&llm_response)?;
111
112 let entities =
114 self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
115 let relationships =
116 self.convert_to_relationships(&extraction_output.relationships, &entities)?;
117
118 #[cfg(feature = "tracing")]
119 tracing::info!(
120 "LLM extracted {} entities and {} relationships from chunk {}",
121 entities.len(),
122 relationships.len(),
123 chunk.id
124 );
125
126 Ok((entities, relationships))
127 }
128
129 #[cfg(feature = "async")]
133 pub async fn extract_additional(
134 &self,
135 chunk: &TextChunk,
136 previous_entities: &[EntityData],
137 previous_relationships: &[RelationshipData],
138 ) -> Result<(Vec<Entity>, Vec<Relationship>)> {
139 #[cfg(feature = "tracing")]
140 tracing::debug!("LLM gleaning round for chunk: {}", chunk.id);
141
142 let prompt = self.prompt_builder.build_continuation_prompt(
144 &chunk.content,
145 previous_entities,
146 previous_relationships,
147 );
148
149 let llm_response = self.call_llm_with_retry(&prompt).await?;
151
152 let extraction_output = self.parse_extraction_response(&llm_response)?;
154
155 let entities =
157 self.convert_to_entities(&extraction_output.entities, &chunk.id, &chunk.content)?;
158 let relationships =
159 self.convert_to_relationships(&extraction_output.relationships, &entities)?;
160
161 #[cfg(feature = "tracing")]
162 tracing::info!(
163 "LLM gleaning extracted {} additional entities and {} relationships",
164 entities.len(),
165 relationships.len()
166 );
167
168 Ok((entities, relationships))
169 }
170
171 #[cfg(feature = "async")]
175 pub async fn check_completion(
176 &self,
177 chunk: &TextChunk,
178 entities: &[EntityData],
179 relationships: &[RelationshipData],
180 ) -> Result<bool> {
181 #[cfg(feature = "tracing")]
182 tracing::debug!("LLM completion check for chunk: {}", chunk.id);
183
184 let prompt =
186 self.prompt_builder
187 .build_completion_prompt(&chunk.content, entities, relationships);
188
189 let llm_response = self.call_llm_completion_check(&prompt).await?;
191
192 let response_trimmed = llm_response.trim().to_uppercase();
194 let is_complete = response_trimmed.starts_with("YES") || response_trimmed.contains("YES");
195
196 #[cfg(feature = "tracing")]
197 tracing::debug!(
198 "LLM completion check result: {} (response: {})",
199 if is_complete {
200 "COMPLETE"
201 } else {
202 "INCOMPLETE"
203 },
204 llm_response.trim()
205 );
206
207 Ok(is_complete)
208 }
209
210 #[cfg(feature = "async")]
217 async fn call_llm_with_retry(&self, prompt: &str) -> Result<String> {
218 use crate::ollama::OllamaGenerationParams;
219 let num_ctx = Self::calculate_entity_num_ctx(prompt, self.max_tokens as u32);
220 #[cfg(feature = "tracing")]
221 tracing::debug!(
222 "Entity extraction: prompt_len={} num_ctx={} keep_alive={:?}",
223 prompt.len(),
224 num_ctx,
225 self.keep_alive,
226 );
227 let params = OllamaGenerationParams {
228 num_predict: Some(self.max_tokens as u32),
229 temperature: Some(self.temperature),
230 num_ctx: Some(num_ctx),
231 keep_alive: self.keep_alive.clone(),
232 ..Default::default()
233 };
234 match self
235 .ollama_client
236 .generate_with_params(prompt, params.clone())
237 .await
238 {
239 Ok(response) => Ok(response),
240 Err(e) => {
241 #[cfg(feature = "tracing")]
242 tracing::warn!("LLM call failed, retrying: {}", e);
243 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
244 self.ollama_client
245 .generate_with_params(prompt, params)
246 .await
247 },
248 }
249 }
250
251 #[cfg(feature = "async")]
256 async fn call_llm_completion_check(&self, prompt: &str) -> Result<String> {
257 use crate::ollama::OllamaGenerationParams;
258 let num_ctx = Self::calculate_entity_num_ctx(prompt, 50);
259 let params = OllamaGenerationParams {
260 num_predict: Some(50),
261 temperature: Some(0.0),
262 num_ctx: Some(num_ctx),
263 keep_alive: self.keep_alive.clone(),
264 ..Default::default()
265 };
266 self.ollama_client
267 .generate_with_params(prompt, params)
268 .await
269 }
270
271 #[cfg(feature = "async")]
275 fn parse_extraction_response(&self, response: &str) -> Result<ExtractionOutput> {
276 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(response) {
278 return Ok(output);
279 }
280
281 if let Some(json_str) = Self::extract_json_from_markdown(response) {
283 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
284 return Ok(output);
285 }
286 }
287
288 match self.repair_and_parse_json(response) {
290 Ok(output) => return Ok(output),
291 Err(_e) => {
292 #[cfg(feature = "tracing")]
293 tracing::warn!("JSON repair failed: {}", _e);
294 },
295 }
296
297 if let Some(json_str) = Self::find_json_in_text(response) {
299 if let Ok(output) = serde_json::from_str::<ExtractionOutput>(json_str) {
300 return Ok(output);
301 }
302
303 if let Ok(output) = self.repair_and_parse_json(json_str) {
305 return Ok(output);
306 }
307 }
308
309 #[cfg(feature = "tracing")]
311 tracing::error!(
312 "Failed to parse LLM response as JSON. Response preview: {}",
313 &response.chars().take(200).collect::<String>()
314 );
315 Ok(ExtractionOutput {
316 entities: vec![],
317 relationships: vec![],
318 })
319 }
320
321 #[cfg(feature = "async")]
323 fn extract_json_from_markdown(text: &str) -> Option<&str> {
324 if let Some(start) = text.find("```json") {
326 let json_start = start + 7; if let Some(end) = text[json_start..].find("```") {
328 return Some(text[json_start..json_start + end].trim());
329 }
330 }
331
332 if let Some(start) = text.find("```") {
333 let json_start = start + 3;
334 if let Some(end) = text[json_start..].find("```") {
335 let candidate = &text[json_start..json_start + end].trim();
336 if candidate.starts_with('{') || candidate.starts_with('[') {
338 return Some(candidate);
339 }
340 }
341 }
342
343 None
344 }
345
346 #[cfg(feature = "async")]
348 fn find_json_in_text(text: &str) -> Option<&str> {
349 if let Some(start) = text.find('{') {
351 if let Some(end) = text.rfind('}') {
352 if end > start {
353 return Some(&text[start..=end]);
354 }
355 }
356 }
357 None
358 }
359
360 #[cfg(feature = "async")]
362 fn repair_and_parse_json(&self, json_str: &str) -> Result<ExtractionOutput> {
363 let options = jsonfixer::JsonRepairOptions::default();
365 let fixed_json =
366 jsonfixer::repair_json(json_str, options).map_err(|e| GraphRAGError::Generation {
367 message: format!("JSON repair failed: {:?}", e),
368 })?;
369
370 serde_json::from_str::<ExtractionOutput>(&fixed_json).map_err(|e| {
371 GraphRAGError::Generation {
372 message: format!("Failed to parse repaired JSON: {}", e),
373 }
374 })
375 }
376
377 #[cfg(feature = "async")]
379 fn convert_to_entities(
380 &self,
381 entity_data: &[EntityData],
382 chunk_id: &ChunkId,
383 chunk_text: &str,
384 ) -> Result<Vec<Entity>> {
385 let mut entities = Vec::new();
386
387 for entity_item in entity_data {
388 let entity_id = EntityId::new(format!(
390 "{}_{}",
391 entity_item.entity_type,
392 self.normalize_name(&entity_item.name)
393 ));
394
395 let mentions = self.find_mentions(&entity_item.name, chunk_id, chunk_text);
397
398 let entity = Entity::new(
402 entity_id,
403 entity_item.name.clone(),
404 entity_item.entity_type.clone(),
405 0.9, )
407 .with_mentions(mentions);
408
409 entities.push(entity);
410 }
411
412 Ok(entities)
413 }
414
415 #[cfg(feature = "async")]
417 fn find_mentions(&self, name: &str, chunk_id: &ChunkId, text: &str) -> Vec<EntityMention> {
418 let mut mentions = Vec::new();
419 let mut start = 0;
420
421 while let Some(pos) = text[start..].find(name) {
422 let actual_pos = start + pos;
423 mentions.push(EntityMention {
424 chunk_id: chunk_id.clone(),
425 start_offset: actual_pos,
426 end_offset: actual_pos + name.len(),
427 confidence: 0.9,
428 });
429 start = actual_pos + name.len();
430 }
431
432 if mentions.is_empty() {
434 let name_lower = name.to_lowercase();
435 let text_lower = text.to_lowercase();
436 let mut start = 0;
437
438 while let Some(pos) = text_lower[start..].find(&name_lower) {
439 let actual_pos = start + pos;
440 mentions.push(EntityMention {
441 chunk_id: chunk_id.clone(),
442 start_offset: actual_pos,
443 end_offset: actual_pos + name.len(),
444 confidence: 0.85, });
446 start = actual_pos + name.len();
447 }
448 }
449
450 mentions
451 }
452
453 #[cfg(feature = "async")]
455 fn convert_to_relationships(
456 &self,
457 relationship_data: &[RelationshipData],
458 entities: &[Entity],
459 ) -> Result<Vec<Relationship>> {
460 let mut relationships = Vec::new();
461
462 let mut name_to_entity: std::collections::HashMap<String, &Entity> =
464 std::collections::HashMap::new();
465 for entity in entities {
466 name_to_entity.insert(entity.name.to_lowercase(), entity);
467 }
468
469 for rel_item in relationship_data {
470 let source_entity = name_to_entity.get(&rel_item.source.to_lowercase());
472 let target_entity = name_to_entity.get(&rel_item.target.to_lowercase());
473
474 if let (Some(source), Some(target)) = (source_entity, target_entity) {
475 let relationship = Relationship {
476 source: source.id.clone(),
477 target: target.id.clone(),
478 relation_type: rel_item.description.clone(),
479 confidence: rel_item.strength as f32,
480 context: vec![], embedding: None,
482 temporal_type: None,
483 temporal_range: None,
484 causal_strength: None,
485 };
486
487 relationships.push(relationship);
488 } else {
489 #[cfg(feature = "tracing")]
490 tracing::warn!(
491 "Skipping relationship: entity not found. Source: {}, Target: {}",
492 rel_item.source,
493 rel_item.target
494 );
495 }
496 }
497
498 Ok(relationships)
499 }
500
501 #[cfg(feature = "async")]
503 fn normalize_name(&self, name: &str) -> String {
504 name.to_lowercase()
505 .chars()
506 .filter(|c| c.is_alphanumeric() || *c == '_')
507 .collect::<String>()
508 .replace(' ', "_")
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::{core::DocumentId, ollama::OllamaConfig};
516
517 fn create_test_chunk() -> TextChunk {
518 TextChunk::new(
519 ChunkId::new("chunk_001".to_string()),
520 DocumentId::new("doc_001".to_string()),
521 "Tom Sawyer is a young boy who lives in St. Petersburg with his Aunt Polly. \
522 Tom is best friends with Huckleberry Finn. They often go on adventures together."
523 .to_string(),
524 0,
525 150,
526 )
527 }
528
529 #[test]
530 fn test_extract_json_from_markdown() {
531 let markdown = r#"
532Here's the extraction:
533```json
534{
535 "entities": [],
536 "relationships": []
537}
538```
539"#;
540 let json = LLMEntityExtractor::extract_json_from_markdown(markdown);
541 assert!(json.is_some());
542 assert!(json.unwrap().contains("entities"));
543 }
544
545 #[test]
546 fn test_find_json_in_text() {
547 let text = "Some text before { \"entities\": [] } some text after";
548 let json = LLMEntityExtractor::find_json_in_text(text);
549 assert!(json.is_some());
550 assert_eq!(json.unwrap(), "{ \"entities\": [] }");
551 }
552
553 #[test]
554 fn test_parse_valid_json() {
555 let ollama_config = OllamaConfig::default();
556 let ollama_client = OllamaClient::new(ollama_config);
557 let extractor = LLMEntityExtractor::new(
558 ollama_client,
559 vec!["PERSON".to_string(), "LOCATION".to_string()],
560 );
561
562 let response = r#"
563{
564 "entities": [
565 {
566 "name": "Tom Sawyer",
567 "type": "PERSON",
568 "description": "A young boy"
569 }
570 ],
571 "relationships": []
572}
573"#;
574
575 let result = extractor.parse_extraction_response(response);
576 assert!(result.is_ok());
577 let output = result.unwrap();
578 assert_eq!(output.entities.len(), 1);
579 assert_eq!(output.entities[0].name, "Tom Sawyer");
580 }
581
582 #[test]
583 fn test_convert_to_entities() {
584 let ollama_config = OllamaConfig::default();
585 let ollama_client = OllamaClient::new(ollama_config);
586 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
587
588 let chunk = create_test_chunk();
589 let entity_data = vec![EntityData {
590 name: "Tom Sawyer".to_string(),
591 entity_type: "PERSON".to_string(),
592 description: "A young boy".to_string(),
593 }];
594
595 let entities = extractor
596 .convert_to_entities(&entity_data, &chunk.id, &chunk.content)
597 .unwrap();
598
599 assert_eq!(entities.len(), 1);
600 assert_eq!(entities[0].name, "Tom Sawyer");
601 assert_eq!(entities[0].entity_type, "PERSON");
602 assert!(!entities[0].mentions.is_empty());
603 }
604
605 #[test]
606 fn test_find_mentions() {
607 let ollama_config = OllamaConfig::default();
608 let ollama_client = OllamaClient::new(ollama_config);
609 let extractor = LLMEntityExtractor::new(ollama_client, vec!["PERSON".to_string()]);
610
611 let chunk = create_test_chunk();
612 let mentions = extractor.find_mentions("Tom", &chunk.id, &chunk.content);
613
614 assert!(!mentions.is_empty());
615 assert!(mentions.len() >= 2); }
617}