1use crate::ai::AiConfig;
7use crate::model::{Literal, NamedNode, Triple};
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub struct RelationExtractor {
14 config: ExtractionConfig,
16
17 ner_model: Box<dyn NamedEntityRecognizer>,
19
20 relation_model: Box<dyn RelationClassifier>,
22
23 entity_linker: Box<dyn EntityLinker>,
25
26 confidence_threshold: f32,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ExtractionConfig {
33 pub enable_ner: bool,
35
36 pub enable_relation_classification: bool,
38
39 pub enable_entity_linking: bool,
41
42 pub confidence_threshold: f32,
44
45 pub max_sentence_length: usize,
47
48 pub language_model: String,
50
51 pub enable_coreference: bool,
53
54 pub supported_languages: Vec<String>,
56}
57
58impl Default for ExtractionConfig {
59 fn default() -> Self {
60 Self {
61 enable_ner: true,
62 enable_relation_classification: true,
63 enable_entity_linking: true,
64 confidence_threshold: 0.7,
65 max_sentence_length: 512,
66 language_model: "bert-base-uncased".to_string(),
67 enable_coreference: true,
68 supported_languages: vec!["en".to_string()],
69 }
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ExtractedRelation {
76 pub subject: ExtractedEntity,
78
79 pub predicate: String,
81
82 pub object: ExtractedEntity,
84
85 pub confidence: f32,
87
88 pub source_span: TextSpan,
90
91 pub context: String,
93
94 pub metadata: HashMap<String, String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ExtractedEntity {
101 pub text: String,
103
104 pub entity_type: EntityType,
106
107 pub kb_id: Option<String>,
109
110 pub confidence: f32,
112
113 pub span: TextSpan,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub enum EntityType {
120 Person,
121 Organization,
122 Location,
123 Date,
124 Time,
125 Money,
126 Percent,
127 Product,
128 Event,
129 Concept,
130 Other(String),
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct TextSpan {
136 pub start: usize,
138
139 pub end: usize,
141
142 pub text: String,
144}
145
146pub trait NamedEntityRecognizer: Send + Sync {
148 fn extract_entities(&self, text: &str) -> Result<Vec<ExtractedEntity>>;
150
151 fn supported_types(&self) -> Vec<EntityType>;
153}
154
155pub trait RelationClassifier: Send + Sync {
157 fn classify_relation(
159 &self,
160 text: &str,
161 subject: &ExtractedEntity,
162 object: &ExtractedEntity,
163 ) -> Result<Option<(String, f32)>>;
164
165 fn supported_relations(&self) -> Vec<String>;
167}
168
169pub trait EntityLinker: Send + Sync {
171 fn link_entity(&self, entity: &ExtractedEntity, context: &str) -> Result<Option<String>>;
173
174 fn kb_info(&self) -> KnowledgeBaseInfo;
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct KnowledgeBaseInfo {
181 pub name: String,
183
184 pub base_uri: String,
186
187 pub version: String,
189
190 pub entity_count: usize,
192}
193
194impl RelationExtractor {
195 pub fn new(_config: &AiConfig) -> Result<Self> {
197 let extraction_config = ExtractionConfig::default();
198
199 let ner_model = Box::new(DummyNER::new());
201
202 let relation_model = Box::new(DummyRelationClassifier::new());
204
205 let entity_linker = Box::new(DummyEntityLinker::new());
207
208 Ok(Self {
209 config: extraction_config,
210 ner_model,
211 relation_model,
212 entity_linker,
213 confidence_threshold: 0.7,
214 })
215 }
216
217 pub async fn extract_relations(&self, text: &str) -> Result<Vec<ExtractedRelation>> {
219 let sentences = self.segment_sentences(text);
221
222 let mut all_relations = Vec::new();
223
224 for sentence in sentences {
225 let entities = if self.config.enable_ner {
227 self.ner_model.extract_entities(&sentence)?
228 } else {
229 Vec::new()
230 };
231
232 let linked_entities = if self.config.enable_entity_linking {
234 self.link_entities(&entities, &sentence).await?
235 } else {
236 entities
237 };
238
239 if self.config.enable_relation_classification {
241 let relations =
242 self.extract_relations_from_entities(&sentence, &linked_entities)?;
243 all_relations.extend(relations);
244 }
245 }
246
247 let filtered_relations = all_relations
249 .into_iter()
250 .filter(|r| r.confidence >= self.confidence_threshold)
251 .collect();
252
253 Ok(filtered_relations)
254 }
255
256 pub fn to_triples(&self, relations: &[ExtractedRelation]) -> Result<Vec<Triple>> {
258 let mut triples = Vec::new();
259
260 for relation in relations {
261 let subject = if let Some(kb_id) = &relation.subject.kb_id {
263 NamedNode::new(kb_id)?
264 } else {
265 NamedNode::new(format!(
267 "http://example.org/entity/{}",
268 relation.subject.text.replace(' ', "_")
269 ))?
270 };
271
272 let predicate = NamedNode::new(format!(
274 "http://example.org/relation/{}",
275 relation.predicate.replace(' ', "_")
276 ))?;
277
278 let object = if let Some(kb_id) = &relation.object.kb_id {
280 crate::model::Object::NamedNode(NamedNode::new(kb_id)?)
281 } else {
282 match relation.object.entity_type {
284 EntityType::Date
285 | EntityType::Time
286 | EntityType::Money
287 | EntityType::Percent => {
288 crate::model::Object::Literal(Literal::new(&relation.object.text))
289 }
290 _ => crate::model::Object::NamedNode(NamedNode::new(format!(
291 "http://example.org/entity/{}",
292 relation.object.text.replace(' ', "_")
293 ))?),
294 }
295 };
296
297 let triple = Triple::new(subject, predicate, object);
298 triples.push(triple);
299 }
300
301 Ok(triples)
302 }
303
304 fn segment_sentences(&self, text: &str) -> Vec<String> {
306 text.split(". ")
308 .map(|s| s.trim().to_string())
309 .filter(|s| !s.is_empty())
310 .collect()
311 }
312
313 async fn link_entities(
315 &self,
316 entities: &[ExtractedEntity],
317 context: &str,
318 ) -> Result<Vec<ExtractedEntity>> {
319 let mut linked_entities = Vec::new();
320
321 for entity in entities {
322 let mut linked_entity = entity.clone();
323
324 if let Ok(Some(kb_id)) = self.entity_linker.link_entity(entity, context) {
325 linked_entity.kb_id = Some(kb_id);
326 }
327
328 linked_entities.push(linked_entity);
329 }
330
331 Ok(linked_entities)
332 }
333
334 fn extract_relations_from_entities(
336 &self,
337 sentence: &str,
338 entities: &[ExtractedEntity],
339 ) -> Result<Vec<ExtractedRelation>> {
340 let mut relations = Vec::new();
341
342 for (i, subject) in entities.iter().enumerate() {
344 for (j, object) in entities.iter().enumerate() {
345 if i != j {
346 if let Ok(Some((relation_type, confidence))) = self
347 .relation_model
348 .classify_relation(sentence, subject, object)
349 {
350 let relation = ExtractedRelation {
351 subject: subject.clone(),
352 predicate: relation_type,
353 object: object.clone(),
354 confidence,
355 source_span: TextSpan {
356 start: 0,
357 end: sentence.len(),
358 text: sentence.to_string(),
359 },
360 context: sentence.to_string(),
361 metadata: HashMap::new(),
362 };
363
364 relations.push(relation);
365 }
366 }
367 }
368 }
369
370 Ok(relations)
371 }
372}
373
374struct DummyNER;
376
377impl DummyNER {
378 fn new() -> Self {
379 Self
380 }
381}
382
383impl NamedEntityRecognizer for DummyNER {
384 fn extract_entities(&self, text: &str) -> Result<Vec<ExtractedEntity>> {
385 let words: Vec<&str> = text.split_whitespace().collect();
389 let mut entities = Vec::new();
390
391 for (i, word) in words.iter().enumerate() {
392 if word.chars().next().unwrap_or(' ').is_uppercase() {
394 let entity = ExtractedEntity {
395 text: word.to_string(),
396 entity_type: EntityType::Person, kb_id: None,
398 confidence: 0.8,
399 span: TextSpan {
400 start: i * 5, end: (i + 1) * 5,
402 text: word.to_string(),
403 },
404 };
405 entities.push(entity);
406 }
407 }
408
409 Ok(entities)
410 }
411
412 fn supported_types(&self) -> Vec<EntityType> {
413 vec![
414 EntityType::Person,
415 EntityType::Organization,
416 EntityType::Location,
417 ]
418 }
419}
420
421struct DummyRelationClassifier;
423
424impl DummyRelationClassifier {
425 fn new() -> Self {
426 Self
427 }
428}
429
430impl RelationClassifier for DummyRelationClassifier {
431 fn classify_relation(
432 &self,
433 text: &str,
434 _subject: &ExtractedEntity,
435 _object: &ExtractedEntity,
436 ) -> Result<Option<(String, f32)>> {
437 if text.contains("work") || text.contains("employ") {
441 Ok(Some(("worksFor".to_string(), 0.85)))
442 } else if text.contains("live") || text.contains("reside") {
443 Ok(Some(("livesIn".to_string(), 0.80)))
444 } else if text.contains("born") || text.contains("birth") {
445 Ok(Some(("bornIn".to_string(), 0.90)))
446 } else {
447 Ok(None)
448 }
449 }
450
451 fn supported_relations(&self) -> Vec<String> {
452 vec![
453 "worksFor".to_string(),
454 "livesIn".to_string(),
455 "bornIn".to_string(),
456 "marriedTo".to_string(),
457 "locatedIn".to_string(),
458 ]
459 }
460}
461
462struct DummyEntityLinker;
464
465impl DummyEntityLinker {
466 fn new() -> Self {
467 Self
468 }
469}
470
471impl EntityLinker for DummyEntityLinker {
472 fn link_entity(&self, entity: &ExtractedEntity, _context: &str) -> Result<Option<String>> {
473 match entity.entity_type {
477 EntityType::Person => Ok(Some(format!(
478 "http://dbpedia.org/resource/{}",
479 entity.text.replace(' ', "_")
480 ))),
481 EntityType::Location => Ok(Some(format!(
482 "http://dbpedia.org/resource/{}",
483 entity.text.replace(' ', "_")
484 ))),
485 _ => Ok(None),
486 }
487 }
488
489 fn kb_info(&self) -> KnowledgeBaseInfo {
490 KnowledgeBaseInfo {
491 name: "DBpedia".to_string(),
492 base_uri: "http://dbpedia.org/resource/".to_string(),
493 version: "2023-09".to_string(),
494 entity_count: 6_000_000,
495 }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use crate::ai::AiConfig;
503
504 #[tokio::test]
505 async fn test_relation_extractor_creation() {
506 let config = AiConfig::default();
507 let extractor = RelationExtractor::new(&config);
508 assert!(extractor.is_ok());
509 }
510
511 #[tokio::test]
512 async fn test_relation_extraction() {
513 let config = AiConfig::default();
514 let extractor = RelationExtractor::new(&config).expect("construction should succeed");
515
516 let text = "John works for Microsoft. He lives in Seattle.";
517 let relations = extractor
518 .extract_relations(text)
519 .await
520 .expect("async operation should succeed");
521
522 assert!(!relations.is_empty());
524 }
525
526 #[test]
527 fn test_sentence_segmentation() {
528 let config = AiConfig::default();
529 let extractor = RelationExtractor::new(&config).expect("construction should succeed");
530
531 let text = "First sentence. Second sentence. Third sentence.";
532 let sentences = extractor.segment_sentences(text);
533
534 assert_eq!(sentences.len(), 3);
535 assert_eq!(sentences[0], "First sentence");
536 }
537
538 #[test]
539 fn test_to_triples() {
540 let config = AiConfig::default();
541 let extractor = RelationExtractor::new(&config).expect("construction should succeed");
542
543 let relation = ExtractedRelation {
544 subject: ExtractedEntity {
545 text: "John".to_string(),
546 entity_type: EntityType::Person,
547 kb_id: None,
548 confidence: 0.9,
549 span: TextSpan {
550 start: 0,
551 end: 4,
552 text: "John".to_string(),
553 },
554 },
555 predicate: "worksFor".to_string(),
556 object: ExtractedEntity {
557 text: "Microsoft".to_string(),
558 entity_type: EntityType::Organization,
559 kb_id: None,
560 confidence: 0.85,
561 span: TextSpan {
562 start: 15,
563 end: 24,
564 text: "Microsoft".to_string(),
565 },
566 },
567 confidence: 0.8,
568 source_span: TextSpan {
569 start: 0,
570 end: 25,
571 text: "John works for Microsoft.".to_string(),
572 },
573 context: "John works for Microsoft.".to_string(),
574 metadata: HashMap::new(),
575 };
576
577 let triples = extractor
578 .to_triples(&[relation])
579 .expect("operation should succeed");
580 assert_eq!(triples.len(), 1);
581 }
582}