1use crate::core::{Entity, EntityId, GraphRAGError, Result, TextChunk};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ExtractedRelationship {
14 pub source: String,
16 pub target: String,
18 pub relation_type: String,
20 pub description: String,
22 pub strength: f32,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TripleValidation {
32 pub is_valid: bool,
34 pub confidence: f32,
36 pub reason: String,
38 pub suggested_fix: Option<String>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ExtractionResult {
45 pub entities: Vec<ExtractedEntity>,
47 pub relationships: Vec<ExtractedRelationship>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ExtractedEntity {
54 pub name: String,
56 #[serde(rename = "type")]
58 pub entity_type: String,
59 pub description: Option<String>,
61}
62
63pub struct LLMRelationshipExtractor {
68 pub ollama_client: Option<crate::ollama::OllamaClient>,
70}
71
72impl LLMRelationshipExtractor {
73 pub fn new(ollama_config: Option<&crate::ollama::OllamaConfig>) -> Result<Self> {
85 let ollama_client = if let Some(config) = ollama_config {
86 if config.enabled {
87 let local_config = crate::ollama::OllamaConfig {
88 enabled: config.enabled,
89 host: config.host.clone(),
90 port: config.port,
91 chat_model: config.chat_model.clone(),
92 embedding_model: config.embedding_model.clone(),
93 timeout_seconds: config.timeout_seconds,
94 max_retries: config.max_retries,
95 fallback_to_hash: config.fallback_to_hash,
96 max_tokens: None,
97 temperature: None,
98 enable_caching: true,
99 keep_alive: config.keep_alive.clone(),
100 num_ctx: config.num_ctx,
101 };
102
103 Some(crate::ollama::OllamaClient::new(local_config))
104 } else {
105 None
106 }
107 } else {
108 None
109 };
110
111 Ok(Self { ollama_client })
112 }
113
114 fn build_extraction_prompt(&self, chunk_content: &str) -> String {
127 format!(
128 r#"You are an expert at extracting entities and relationships from text.
129Extract all meaningful entities and relationships from the provided text.
130
131**ENTITIES**: Extract people, concepts, locations, events, organizations, and other significant entities.
132For each entity provide:
133- name: the entity name
134- type: entity type (PERSON, CONCEPT, LOCATION, EVENT, ORGANIZATION, OBJECT, etc.)
135- description: brief description of the entity (optional)
136
137**RELATIONSHIPS**: For entities that interact or are related, extract their relationships.
138For each relationship provide:
139- source: source entity name (must match an entity name)
140- target: target entity name (must match an entity name)
141- type: relationship type (DISCUSSES, QUESTIONS, RESPONDS_TO, TEACHES, LOVES, ADMIRES, ARGUES_WITH, MENTIONS, WORKS_FOR, LOCATED_IN, etc.)
142- description: brief explanation of why they are related
143- strength: confidence score between 0.0 and 1.0
144
145**IMPORTANT GUIDELINES**:
1461. Extract relationships for entities that have meaningful connections
1472. Choose descriptive relationship types that capture the nature of the connection
1483. For philosophical/dialogue texts, use types like DISCUSSES, QUESTIONS, RESPONDS_TO
1494. For narrative texts, use types like MEETS, HELPS, OPPOSES, TRAVELS_WITH
1505. For technical texts, use types like IMPLEMENTS, DEPENDS_ON, EXTENDS
1516. Provide higher strength values (0.8-1.0) for explicit relationships
1527. Provide lower strength values (0.5-0.7) for implicit or inferred relationships
153
154**TEXT TO ANALYZE**:
155{chunk_content}
156
157**OUTPUT FORMAT** (JSON only, no other text):
158{{
159 "entities": [
160 {{"name": "Entity Name", "type": "PERSON", "description": "Brief description"}},
161 ...
162 ],
163 "relationships": [
164 {{"source": "Entity1", "target": "Entity2", "type": "DISCUSSES", "description": "Why they are related", "strength": 0.85}},
165 ...
166 ]
167}}
168
169Return ONLY valid JSON, nothing else."#,
170 chunk_content = chunk_content
171 )
172 }
173
174 pub async fn extract_with_llm(&self, chunk: &TextChunk) -> Result<ExtractionResult> {
194 if let Some(client) = &self.ollama_client {
195 let prompt = self.build_extraction_prompt(&chunk.content);
196
197 #[cfg(feature = "tracing")]
198 tracing::debug!(
199 chunk_id = %chunk.id,
200 "Extracting entities and relationships with LLM"
201 );
202
203 match client.generate(&prompt).await {
204 Ok(response) => {
205 let json_str = response.trim();
207
208 let json_str = if let Some(start) = json_str.find('{') {
210 if let Some(end) = json_str.rfind('}') {
211 &json_str[start..=end]
212 } else {
213 json_str
214 }
215 } else {
216 json_str
217 };
218
219 match serde_json::from_str::<ExtractionResult>(json_str) {
220 Ok(result) => {
221 #[cfg(feature = "tracing")]
222 tracing::info!(
223 chunk_id = %chunk.id,
224 entity_count = result.entities.len(),
225 relationship_count = result.relationships.len(),
226 "Successfully extracted entities and relationships"
227 );
228 Ok(result)
229 },
230 Err(_e) => {
231 #[cfg(feature = "tracing")]
232 tracing::warn!(
233 chunk_id = %chunk.id,
234 error = %_e,
235 response = %json_str,
236 "Failed to parse LLM response as JSON, falling back to entity-only extraction"
237 );
238 Ok(ExtractionResult {
240 entities: Vec::new(),
241 relationships: Vec::new(),
242 })
243 },
244 }
245 },
246 Err(e) => {
247 #[cfg(feature = "tracing")]
248 tracing::error!(
249 chunk_id = %chunk.id,
250 error = %e,
251 "LLM extraction failed"
252 );
253 Err(GraphRAGError::EntityExtraction {
254 message: format!("LLM extraction failed: {}", e),
255 })
256 },
257 }
258 } else {
259 Err(GraphRAGError::Config {
260 message: "Ollama client not configured".to_string(),
261 })
262 }
263 }
264
265 #[cfg(feature = "async")]
281 pub async fn validate_triple(
282 &self,
283 source: &str,
284 relation_type: &str,
285 target: &str,
286 source_text: &str,
287 ) -> Result<TripleValidation> {
288 if let Some(client) = &self.ollama_client {
289 let prompt = format!(
290 r#"You are validating a relationship extracted from text.
291
292Text: "{}"
293
294Extracted Relationship:
295- Source: {}
296- Relationship: {}
297- Target: {}
298
299Does this text EXPLICITLY support this relationship?
300Consider:
3011. Are both entities mentioned in the text?
3022. Is the relationship type accurate?
3033. Is there direct evidence for this connection?
304
305Respond ONLY with valid JSON in this exact format:
306{{
307 "valid": true/false,
308 "confidence": 0.0-1.0,
309 "reason": "brief explanation",
310 "suggested_fix": "optional fix if invalid"
311}}
312
313JSON:"#,
314 source_text, source, relation_type, target
315 );
316
317 #[cfg(feature = "tracing")]
318 tracing::debug!(
319 source = %source,
320 relation = %relation_type,
321 target = %target,
322 "Validating relationship triple"
323 );
324
325 match client.generate(&prompt).await {
326 Ok(response) => {
327 let json_str = response.trim();
329 let json_str = if let Some(start) = json_str.find('{') {
330 if let Some(end) = json_str.rfind('}') {
331 &json_str[start..=end]
332 } else {
333 json_str
334 }
335 } else {
336 json_str
337 };
338
339 #[derive(Deserialize)]
341 struct ValidationJson {
342 valid: bool,
343 confidence: f32,
344 reason: String,
345 suggested_fix: Option<String>,
346 }
347
348 match serde_json::from_str::<ValidationJson>(json_str) {
349 Ok(val) => {
350 #[cfg(feature = "tracing")]
351 tracing::debug!(
352 source = %source,
353 target = %target,
354 valid = val.valid,
355 confidence = val.confidence,
356 "Triple validation complete"
357 );
358
359 Ok(TripleValidation {
360 is_valid: val.valid,
361 confidence: val.confidence.clamp(0.0, 1.0),
362 reason: val.reason,
363 suggested_fix: val.suggested_fix,
364 })
365 },
366 Err(_e) => {
367 #[cfg(feature = "tracing")]
368 tracing::warn!(
369 error = %_e,
370 response = %json_str,
371 "Failed to parse validation response, assuming valid"
372 );
373
374 Ok(TripleValidation {
376 is_valid: true,
377 confidence: 0.5,
378 reason: "Failed to parse validation response".to_string(),
379 suggested_fix: None,
380 })
381 },
382 }
383 },
384 Err(e) => {
385 #[cfg(feature = "tracing")]
386 tracing::error!(
387 error = %e,
388 "Triple validation failed"
389 );
390
391 Ok(TripleValidation {
393 is_valid: true,
394 confidence: 0.5,
395 reason: format!("Validation LLM call failed: {}", e),
396 suggested_fix: None,
397 })
398 },
399 }
400 } else {
401 Ok(TripleValidation {
403 is_valid: true,
404 confidence: 1.0,
405 reason: "Ollama client not configured, skipping validation".to_string(),
406 suggested_fix: None,
407 })
408 }
409 }
410
411 pub fn extract_relationships_fallback(
429 &self,
430 entities: &[Entity],
431 chunk: &TextChunk,
432 ) -> Vec<(EntityId, EntityId, String, f32)> {
433 let mut relationships = Vec::new();
434
435 let chunk_entities: Vec<&Entity> = entities
437 .iter()
438 .filter(|e| e.mentions.iter().any(|m| m.chunk_id == chunk.id))
439 .collect();
440
441 for i in 0..chunk_entities.len() {
443 for j in (i + 1)..chunk_entities.len() {
444 let entity1 = chunk_entities[i];
445 let entity2 = chunk_entities[j];
446
447 if let Some((rel_type, confidence)) =
449 self.infer_relationship_with_context(entity1, entity2, &chunk.content)
450 {
451 relationships.push((
452 entity1.id.clone(),
453 entity2.id.clone(),
454 rel_type,
455 confidence,
456 ));
457 }
458 }
459 }
460
461 relationships
462 }
463
464 fn infer_relationship_with_context(
481 &self,
482 entity1: &Entity,
483 entity2: &Entity,
484 context: &str,
485 ) -> Option<(String, f32)> {
486 let context_lower = context.to_lowercase();
487 let e1_name_lower = entity1.name.to_lowercase();
488 let e2_name_lower = entity2.name.to_lowercase();
489
490 let e1_pos = context_lower.find(&e1_name_lower)?;
492 let e2_pos = context_lower.find(&e2_name_lower)?;
493
494 let start = e1_pos.min(e2_pos);
496 let end = (e1_pos.max(e2_pos) + 50).min(context.len());
497 let window = &context_lower[start..end];
498
499 match (&entity1.entity_type[..], &entity2.entity_type[..]) {
501 ("PERSON", "PERSON") | ("CHARACTER", "CHARACTER") | ("SPEAKER", "SPEAKER") => {
503 if window.contains("said")
504 || window.contains("replied")
505 || window.contains("responded")
506 {
507 Some(("RESPONDS_TO".to_string(), 0.85))
508 } else if window.contains("asked") || window.contains("questioned") {
509 Some(("QUESTIONS".to_string(), 0.85))
510 } else if window.contains("taught") || window.contains("explained") {
511 Some(("TEACHES".to_string(), 0.80))
512 } else if window.contains("discussed") || window.contains("spoke about") {
513 Some(("DISCUSSES".to_string(), 0.80))
514 } else if window.contains("loved") || window.contains("admired") {
515 Some(("ADMIRES".to_string(), 0.85))
516 } else if window.contains("argued") || window.contains("disagreed") {
517 Some(("ARGUES_WITH".to_string(), 0.85))
518 } else if window.contains("met") || window.contains("encountered") {
519 Some(("MEETS".to_string(), 0.75))
520 } else {
521 Some(("INTERACTS_WITH".to_string(), 0.60))
523 }
524 },
525
526 ("PERSON", "CONCEPT") | ("CHARACTER", "CONCEPT") | ("SPEAKER", "CONCEPT") => {
528 if window.contains("discussed") || window.contains("spoke of") {
529 Some(("DISCUSSES".to_string(), 0.80))
530 } else if window.contains("defined") || window.contains("described") {
531 Some(("DEFINES".to_string(), 0.85))
532 } else if window.contains("questioned") || window.contains("wondered about") {
533 Some(("QUESTIONS".to_string(), 0.80))
534 } else {
535 Some(("MENTIONS".to_string(), 0.70))
536 }
537 },
538
539 ("CONCEPT", "PERSON") | ("CONCEPT", "CHARACTER") | ("CONCEPT", "SPEAKER") => {
541 Some(("DISCUSSED_BY".to_string(), 0.70))
542 },
543
544 ("PERSON", "ORGANIZATION") | ("ORGANIZATION", "PERSON") => {
546 if window.contains("works for") || window.contains("employed by") {
547 Some(("WORKS_FOR".to_string(), 0.90))
548 } else if window.contains("founded")
549 || window.contains("CEO")
550 || window.contains("leads")
551 {
552 Some(("LEADS".to_string(), 0.90))
553 } else {
554 Some(("ASSOCIATED_WITH".to_string(), 0.65))
555 }
556 },
557
558 ("PERSON", "LOCATION") | ("CHARACTER", "LOCATION") => {
560 if window.contains("born in") || window.contains("from") {
561 Some(("BORN_IN".to_string(), 0.90))
562 } else if window.contains("lives in") || window.contains("resides in") {
563 Some(("LIVES_IN".to_string(), 0.85))
564 } else if window.contains("traveled to") || window.contains("visited") {
565 Some(("VISITED".to_string(), 0.80))
566 } else {
567 Some(("LOCATED_IN".to_string(), 0.70))
568 }
569 },
570
571 ("ORGANIZATION", "LOCATION") | ("LOCATION", "ORGANIZATION") => {
573 if window.contains("headquartered") || window.contains("based in") {
574 Some(("HEADQUARTERED_IN".to_string(), 0.90))
575 } else {
576 Some(("LOCATED_IN".to_string(), 0.75))
577 }
578 },
579
580 ("CONCEPT", "CONCEPT") => {
582 if window.contains("similar to") || window.contains("related to") {
583 Some(("RELATED_TO".to_string(), 0.75))
584 } else if window.contains("opposite") || window.contains("contrasts with") {
585 Some(("CONTRASTS_WITH".to_string(), 0.80))
586 } else {
587 Some(("ASSOCIATED_WITH".to_string(), 0.60))
588 }
589 },
590
591 ("PERSON", "EVENT") | ("CHARACTER", "EVENT") => {
593 Some(("PARTICIPATES_IN".to_string(), 0.75))
594 },
595 ("EVENT", "LOCATION") => Some(("OCCURS_IN".to_string(), 0.80)),
596
597 _ => {
599 if (e1_pos as i32 - e2_pos as i32).abs() < 100 {
601 Some(("CO_OCCURS".to_string(), 0.50))
602 } else {
603 None
604 }
605 },
606 }
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use crate::core::{ChunkId, DocumentId};
614
615 #[test]
616 fn test_prompt_generation() {
617 let extractor = LLMRelationshipExtractor::new(None).unwrap();
618 let prompt = extractor.build_extraction_prompt("Socrates discusses love with Phaedrus.");
619
620 assert!(prompt.contains("entities"));
621 assert!(prompt.contains("relationships"));
622 assert!(prompt.contains("Socrates discusses love with Phaedrus"));
623 }
624
625 #[test]
626 fn test_fallback_extraction() {
627 let extractor = LLMRelationshipExtractor::new(None).unwrap();
628
629 let chunk = TextChunk::new(
630 ChunkId::new("test".to_string()),
631 DocumentId::new("doc".to_string()),
632 "Socrates discussed love with Phaedrus in Athens.".to_string(),
633 0,
634 50,
635 );
636
637 let entities = vec![
638 Entity::new(
639 EntityId::new("person_socrates".to_string()),
640 "Socrates".to_string(),
641 "PERSON".to_string(),
642 0.9,
643 ),
644 Entity::new(
645 EntityId::new("person_phaedrus".to_string()),
646 "Phaedrus".to_string(),
647 "PERSON".to_string(),
648 0.9,
649 ),
650 ];
651
652 let relationships = extractor.extract_relationships_fallback(&entities, &chunk);
653
654 assert!(!relationships.is_empty());
656 }
657
658 #[test]
659 fn test_triple_validation_struct() {
660 let validation = TripleValidation {
662 is_valid: true,
663 confidence: 0.85,
664 reason: "The text explicitly states this relationship.".to_string(),
665 suggested_fix: None,
666 };
667
668 assert!(validation.is_valid);
669 assert_eq!(validation.confidence, 0.85);
670 assert!(!validation.reason.is_empty());
671
672 let json = serde_json::to_string(&validation).unwrap();
674 assert!(json.contains("is_valid"));
675 assert!(json.contains("confidence"));
676 assert!(json.contains("reason"));
677 }
678
679 #[test]
680 fn test_triple_validation_deserialization() {
681 let json = r#"{
683 "is_valid": true,
684 "confidence": 0.9,
685 "reason": "Explicitly supported",
686 "suggested_fix": null
687 }"#;
688
689 let validation: TripleValidation = serde_json::from_str(json).unwrap();
690 assert!(validation.is_valid);
691 assert_eq!(validation.confidence, 0.9);
692 assert_eq!(validation.reason, "Explicitly supported");
693 assert!(validation.suggested_fix.is_none());
694 }
695
696 #[test]
697 fn test_triple_validation_with_suggested_fix() {
698 let validation = TripleValidation {
699 is_valid: false,
700 confidence: 0.3,
701 reason: "The relationship is implied but not explicit.".to_string(),
702 suggested_fix: Some("Change TAUGHT to INFLUENCED".to_string()),
703 };
704
705 assert!(!validation.is_valid);
706 assert!(validation.confidence < 0.5);
707 assert!(validation.suggested_fix.is_some());
708
709 let fix = validation.suggested_fix.unwrap();
710 assert!(fix.contains("INFLUENCED"));
711 }
712
713 #[test]
714 fn test_validation_confidence_thresholds() {
715 let high_confidence = TripleValidation {
717 is_valid: true,
718 confidence: 0.95,
719 reason: "Strong evidence".to_string(),
720 suggested_fix: None,
721 };
722
723 let medium_confidence = TripleValidation {
724 is_valid: true,
725 confidence: 0.7,
726 reason: "Moderate evidence".to_string(),
727 suggested_fix: None,
728 };
729
730 let low_confidence = TripleValidation {
731 is_valid: false,
732 confidence: 0.3,
733 reason: "Weak evidence".to_string(),
734 suggested_fix: Some("Revise".to_string()),
735 };
736
737 let threshold = 0.7;
739 assert!(high_confidence.confidence >= threshold);
740 assert!(medium_confidence.confidence >= threshold);
741 assert!(low_confidence.confidence < threshold);
742 }
743
744 #[cfg(feature = "async")]
745 #[tokio::test]
746 async fn test_validate_triple_without_ollama() {
747 let extractor = LLMRelationshipExtractor::new(None).unwrap();
749
750 let result = extractor
751 .validate_triple("Socrates", "TAUGHT", "Plato", "Socrates taught Plato.")
752 .await;
753
754 assert!(
756 result.is_ok(),
757 "Should gracefully handle missing Ollama client"
758 );
759
760 let validation = result.unwrap();
761 assert!(validation.is_valid, "Fallback should assume valid");
762 assert_eq!(
763 validation.confidence, 1.0,
764 "Fallback should have high confidence"
765 );
766 assert!(
767 validation.reason.contains("not configured"),
768 "Reason should explain Ollama is not configured"
769 );
770 }
771}