use crate::ai::AiConfig;
use crate::model::{Literal, NamedNode, Triple};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct RelationExtractor {
config: ExtractionConfig,
ner_model: Box<dyn NamedEntityRecognizer>,
relation_model: Box<dyn RelationClassifier>,
entity_linker: Box<dyn EntityLinker>,
confidence_threshold: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractionConfig {
pub enable_ner: bool,
pub enable_relation_classification: bool,
pub enable_entity_linking: bool,
pub confidence_threshold: f32,
pub max_sentence_length: usize,
pub language_model: String,
pub enable_coreference: bool,
pub supported_languages: Vec<String>,
}
impl Default for ExtractionConfig {
fn default() -> Self {
Self {
enable_ner: true,
enable_relation_classification: true,
enable_entity_linking: true,
confidence_threshold: 0.7,
max_sentence_length: 512,
language_model: "bert-base-uncased".to_string(),
enable_coreference: true,
supported_languages: vec!["en".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedRelation {
pub subject: ExtractedEntity,
pub predicate: String,
pub object: ExtractedEntity,
pub confidence: f32,
pub source_span: TextSpan,
pub context: String,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity {
pub text: String,
pub entity_type: EntityType,
pub kb_id: Option<String>,
pub confidence: f32,
pub span: TextSpan,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EntityType {
Person,
Organization,
Location,
Date,
Time,
Money,
Percent,
Product,
Event,
Concept,
Other(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextSpan {
pub start: usize,
pub end: usize,
pub text: String,
}
pub trait NamedEntityRecognizer: Send + Sync {
fn extract_entities(&self, text: &str) -> Result<Vec<ExtractedEntity>>;
fn supported_types(&self) -> Vec<EntityType>;
}
pub trait RelationClassifier: Send + Sync {
fn classify_relation(
&self,
text: &str,
subject: &ExtractedEntity,
object: &ExtractedEntity,
) -> Result<Option<(String, f32)>>;
fn supported_relations(&self) -> Vec<String>;
}
pub trait EntityLinker: Send + Sync {
fn link_entity(&self, entity: &ExtractedEntity, context: &str) -> Result<Option<String>>;
fn kb_info(&self) -> KnowledgeBaseInfo;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeBaseInfo {
pub name: String,
pub base_uri: String,
pub version: String,
pub entity_count: usize,
}
impl RelationExtractor {
pub fn new(_config: &AiConfig) -> Result<Self> {
let extraction_config = ExtractionConfig::default();
let ner_model = Box::new(DummyNER::new());
let relation_model = Box::new(DummyRelationClassifier::new());
let entity_linker = Box::new(DummyEntityLinker::new());
Ok(Self {
config: extraction_config,
ner_model,
relation_model,
entity_linker,
confidence_threshold: 0.7,
})
}
pub async fn extract_relations(&self, text: &str) -> Result<Vec<ExtractedRelation>> {
let sentences = self.segment_sentences(text);
let mut all_relations = Vec::new();
for sentence in sentences {
let entities = if self.config.enable_ner {
self.ner_model.extract_entities(&sentence)?
} else {
Vec::new()
};
let linked_entities = if self.config.enable_entity_linking {
self.link_entities(&entities, &sentence).await?
} else {
entities
};
if self.config.enable_relation_classification {
let relations =
self.extract_relations_from_entities(&sentence, &linked_entities)?;
all_relations.extend(relations);
}
}
let filtered_relations = all_relations
.into_iter()
.filter(|r| r.confidence >= self.confidence_threshold)
.collect();
Ok(filtered_relations)
}
pub fn to_triples(&self, relations: &[ExtractedRelation]) -> Result<Vec<Triple>> {
let mut triples = Vec::new();
for relation in relations {
let subject = if let Some(kb_id) = &relation.subject.kb_id {
NamedNode::new(kb_id)?
} else {
NamedNode::new(format!(
"http://example.org/entity/{}",
relation.subject.text.replace(' ', "_")
))?
};
let predicate = NamedNode::new(format!(
"http://example.org/relation/{}",
relation.predicate.replace(' ', "_")
))?;
let object = if let Some(kb_id) = &relation.object.kb_id {
crate::model::Object::NamedNode(NamedNode::new(kb_id)?)
} else {
match relation.object.entity_type {
EntityType::Date
| EntityType::Time
| EntityType::Money
| EntityType::Percent => {
crate::model::Object::Literal(Literal::new(&relation.object.text))
}
_ => crate::model::Object::NamedNode(NamedNode::new(format!(
"http://example.org/entity/{}",
relation.object.text.replace(' ', "_")
))?),
}
};
let triple = Triple::new(subject, predicate, object);
triples.push(triple);
}
Ok(triples)
}
fn segment_sentences(&self, text: &str) -> Vec<String> {
text.split(". ")
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}
async fn link_entities(
&self,
entities: &[ExtractedEntity],
context: &str,
) -> Result<Vec<ExtractedEntity>> {
let mut linked_entities = Vec::new();
for entity in entities {
let mut linked_entity = entity.clone();
if let Ok(Some(kb_id)) = self.entity_linker.link_entity(entity, context) {
linked_entity.kb_id = Some(kb_id);
}
linked_entities.push(linked_entity);
}
Ok(linked_entities)
}
fn extract_relations_from_entities(
&self,
sentence: &str,
entities: &[ExtractedEntity],
) -> Result<Vec<ExtractedRelation>> {
let mut relations = Vec::new();
for (i, subject) in entities.iter().enumerate() {
for (j, object) in entities.iter().enumerate() {
if i != j {
if let Ok(Some((relation_type, confidence))) = self
.relation_model
.classify_relation(sentence, subject, object)
{
let relation = ExtractedRelation {
subject: subject.clone(),
predicate: relation_type,
object: object.clone(),
confidence,
source_span: TextSpan {
start: 0,
end: sentence.len(),
text: sentence.to_string(),
},
context: sentence.to_string(),
metadata: HashMap::new(),
};
relations.push(relation);
}
}
}
}
Ok(relations)
}
}
struct DummyNER;
impl DummyNER {
fn new() -> Self {
Self
}
}
impl NamedEntityRecognizer for DummyNER {
fn extract_entities(&self, text: &str) -> Result<Vec<ExtractedEntity>> {
let words: Vec<&str> = text.split_whitespace().collect();
let mut entities = Vec::new();
for (i, word) in words.iter().enumerate() {
if word.chars().next().unwrap_or(' ').is_uppercase() {
let entity = ExtractedEntity {
text: word.to_string(),
entity_type: EntityType::Person, kb_id: None,
confidence: 0.8,
span: TextSpan {
start: i * 5, end: (i + 1) * 5,
text: word.to_string(),
},
};
entities.push(entity);
}
}
Ok(entities)
}
fn supported_types(&self) -> Vec<EntityType> {
vec![
EntityType::Person,
EntityType::Organization,
EntityType::Location,
]
}
}
struct DummyRelationClassifier;
impl DummyRelationClassifier {
fn new() -> Self {
Self
}
}
impl RelationClassifier for DummyRelationClassifier {
fn classify_relation(
&self,
text: &str,
_subject: &ExtractedEntity,
_object: &ExtractedEntity,
) -> Result<Option<(String, f32)>> {
if text.contains("work") || text.contains("employ") {
Ok(Some(("worksFor".to_string(), 0.85)))
} else if text.contains("live") || text.contains("reside") {
Ok(Some(("livesIn".to_string(), 0.80)))
} else if text.contains("born") || text.contains("birth") {
Ok(Some(("bornIn".to_string(), 0.90)))
} else {
Ok(None)
}
}
fn supported_relations(&self) -> Vec<String> {
vec![
"worksFor".to_string(),
"livesIn".to_string(),
"bornIn".to_string(),
"marriedTo".to_string(),
"locatedIn".to_string(),
]
}
}
struct DummyEntityLinker;
impl DummyEntityLinker {
fn new() -> Self {
Self
}
}
impl EntityLinker for DummyEntityLinker {
fn link_entity(&self, entity: &ExtractedEntity, _context: &str) -> Result<Option<String>> {
match entity.entity_type {
EntityType::Person => Ok(Some(format!(
"http://dbpedia.org/resource/{}",
entity.text.replace(' ', "_")
))),
EntityType::Location => Ok(Some(format!(
"http://dbpedia.org/resource/{}",
entity.text.replace(' ', "_")
))),
_ => Ok(None),
}
}
fn kb_info(&self) -> KnowledgeBaseInfo {
KnowledgeBaseInfo {
name: "DBpedia".to_string(),
base_uri: "http://dbpedia.org/resource/".to_string(),
version: "2023-09".to_string(),
entity_count: 6_000_000,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ai::AiConfig;
#[tokio::test]
async fn test_relation_extractor_creation() {
let config = AiConfig::default();
let extractor = RelationExtractor::new(&config);
assert!(extractor.is_ok());
}
#[tokio::test]
async fn test_relation_extraction() {
let config = AiConfig::default();
let extractor = RelationExtractor::new(&config).expect("construction should succeed");
let text = "John works for Microsoft. He lives in Seattle.";
let relations = extractor
.extract_relations(text)
.await
.expect("async operation should succeed");
assert!(!relations.is_empty());
}
#[test]
fn test_sentence_segmentation() {
let config = AiConfig::default();
let extractor = RelationExtractor::new(&config).expect("construction should succeed");
let text = "First sentence. Second sentence. Third sentence.";
let sentences = extractor.segment_sentences(text);
assert_eq!(sentences.len(), 3);
assert_eq!(sentences[0], "First sentence");
}
#[test]
fn test_to_triples() {
let config = AiConfig::default();
let extractor = RelationExtractor::new(&config).expect("construction should succeed");
let relation = ExtractedRelation {
subject: ExtractedEntity {
text: "John".to_string(),
entity_type: EntityType::Person,
kb_id: None,
confidence: 0.9,
span: TextSpan {
start: 0,
end: 4,
text: "John".to_string(),
},
},
predicate: "worksFor".to_string(),
object: ExtractedEntity {
text: "Microsoft".to_string(),
entity_type: EntityType::Organization,
kb_id: None,
confidence: 0.85,
span: TextSpan {
start: 15,
end: 24,
text: "Microsoft".to_string(),
},
},
confidence: 0.8,
source_span: TextSpan {
start: 0,
end: 25,
text: "John works for Microsoft.".to_string(),
},
context: "John works for Microsoft.".to_string(),
metadata: HashMap::new(),
};
let triples = extractor
.to_triples(&[relation])
.expect("operation should succeed");
assert_eq!(triples.len(), 1);
}
}