use anyhow::{anyhow, Result};
use rayon::prelude::*;
use scirs2_core::ndarray_ext::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntityLinkerConfig {
pub similarity_threshold: f32,
pub max_candidates: usize,
pub use_context: bool,
pub min_confidence: f32,
pub use_ann: bool,
pub k_neighbors: usize,
}
impl Default for EntityLinkerConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.7,
max_candidates: 10,
use_context: true,
min_confidence: 0.5,
use_ann: true,
k_neighbors: 50,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinkingResult {
pub entity_id: String,
pub confidence: f32,
pub similarity: f32,
pub context_features: Vec<String>,
}
pub struct EntityLinker {
config: EntityLinkerConfig,
entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
entity_index: Vec<String>,
embedding_matrix: Array2<f32>,
}
impl EntityLinker {
pub fn new(
config: EntityLinkerConfig,
entity_embeddings: HashMap<String, Array1<f32>>,
) -> Result<Self> {
let entity_count = entity_embeddings.len();
if entity_count == 0 {
return Err(anyhow!("Empty entity embedding set"));
}
let mut entity_index = Vec::with_capacity(entity_count);
let embedding_dim = entity_embeddings
.values()
.next()
.expect("entity_embeddings should not be empty")
.len();
let mut embedding_matrix = Array2::zeros((entity_count, embedding_dim));
for (idx, (entity_id, embedding)) in entity_embeddings.iter().enumerate() {
entity_index.push(entity_id.clone());
embedding_matrix.row_mut(idx).assign(embedding);
}
info!(
"Initialized EntityLinker with {} entities, dim={}",
entity_count, embedding_dim
);
Ok(Self {
config,
entity_embeddings: Arc::new(entity_embeddings),
entity_index,
embedding_matrix,
})
}
pub fn link_entity(
&self,
mention_embedding: &Array1<f32>,
context_embeddings: Option<&[Array1<f32>]>,
) -> Result<Vec<LinkingResult>> {
let similarities = self.compute_similarities(mention_embedding)?;
let mut candidates: Vec<(usize, f32)> = similarities
.iter()
.enumerate()
.map(|(idx, &sim)| (idx, sim))
.collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(self.config.max_candidates);
let results = if let Some(ctx_emb) = context_embeddings.filter(|_| self.config.use_context)
{
self.rerank_with_context(&candidates, ctx_emb)?
} else {
candidates
.into_iter()
.filter(|(_, sim)| *sim >= self.config.similarity_threshold)
.map(|(idx, sim)| LinkingResult {
entity_id: self.entity_index[idx].clone(),
confidence: sim,
similarity: sim,
context_features: vec![],
})
.collect()
};
let filtered: Vec<_> = results
.into_iter()
.filter(|r| r.confidence >= self.config.min_confidence)
.collect();
debug!("Linked {} candidate entities", filtered.len());
Ok(filtered)
}
pub fn link_entities_batch(
&self,
mention_embeddings: &[Array1<f32>],
) -> Result<Vec<Vec<LinkingResult>>> {
let results: Vec<Vec<LinkingResult>> = mention_embeddings
.par_iter()
.map(|mention| self.link_entity(mention, None).unwrap_or_default())
.collect();
Ok(results)
}
fn compute_similarities(&self, query: &Array1<f32>) -> Result<Vec<f32>> {
let query_norm = query.dot(query).sqrt();
if query_norm == 0.0 {
return Err(anyhow!("Zero-norm query vector"));
}
let normalized_query = query / query_norm;
let similarities: Vec<f32> = (0..self.embedding_matrix.nrows())
.into_par_iter()
.map(|i| {
let entity_emb = self.embedding_matrix.row(i);
let entity_norm = entity_emb.dot(&entity_emb).sqrt();
if entity_norm == 0.0 {
0.0
} else {
let normalized_entity = entity_emb.to_owned() / entity_norm;
normalized_query.dot(&normalized_entity)
}
})
.collect();
Ok(similarities)
}
fn rerank_with_context(
&self,
candidates: &[(usize, f32)],
context_embeddings: &[Array1<f32>],
) -> Result<Vec<LinkingResult>> {
let results: Vec<LinkingResult> = candidates
.iter()
.map(|(idx, base_sim)| {
let entity_embedding = self.embedding_matrix.row(*idx);
let context_sim = self
.compute_context_similarity(&entity_embedding.to_owned(), context_embeddings);
let confidence = 0.7 * base_sim + 0.3 * context_sim;
LinkingResult {
entity_id: self.entity_index[*idx].clone(),
confidence,
similarity: *base_sim,
context_features: vec!["context_aware".to_string()],
}
})
.collect();
Ok(results)
}
fn compute_context_similarity(
&self,
entity_embedding: &Array1<f32>,
context_embeddings: &[Array1<f32>],
) -> f32 {
if context_embeddings.is_empty() {
return 0.0;
}
let total_sim: f32 = context_embeddings
.iter()
.map(|ctx| {
let norm1 = entity_embedding.dot(entity_embedding).sqrt();
let norm2 = ctx.dot(ctx).sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
0.0
} else {
entity_embedding.dot(ctx) / (norm1 * norm2)
}
})
.sum();
total_sim / context_embeddings.len() as f32
}
pub fn get_embedding(&self, entity_id: &str) -> Option<&Array1<f32>> {
self.entity_embeddings.get(entity_id)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelationPredictorConfig {
pub score_threshold: f32,
pub max_predictions: usize,
pub use_type_constraints: bool,
pub use_path_reasoning: bool,
}
impl Default for RelationPredictorConfig {
fn default() -> Self {
Self {
score_threshold: 0.6,
max_predictions: 10,
use_type_constraints: true,
use_path_reasoning: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelationPrediction {
pub relation: String,
pub tail_entity: Option<String>,
pub score: f32,
pub confidence: f32,
}
pub struct RelationPredictor {
config: RelationPredictorConfig,
relation_embeddings: Arc<HashMap<String, Array1<f32>>>,
entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
}
impl RelationPredictor {
pub fn new(
config: RelationPredictorConfig,
relation_embeddings: HashMap<String, Array1<f32>>,
entity_embeddings: HashMap<String, Array1<f32>>,
) -> Self {
info!(
"Initialized RelationPredictor with {} relations, {} entities",
relation_embeddings.len(),
entity_embeddings.len()
);
Self {
config,
relation_embeddings: Arc::new(relation_embeddings),
entity_embeddings: Arc::new(entity_embeddings),
}
}
pub fn predict_relations(
&self,
head_entity: &str,
tail_entity: &str,
) -> Result<Vec<RelationPrediction>> {
let head_emb = self
.entity_embeddings
.get(head_entity)
.ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
let tail_emb = self
.entity_embeddings
.get(tail_entity)
.ok_or_else(|| anyhow!("Unknown tail entity: {}", tail_entity))?;
let mut predictions: Vec<RelationPrediction> = self
.relation_embeddings
.par_iter()
.map(|(rel, rel_emb)| {
let score = self.score_triple(head_emb, rel_emb, tail_emb);
RelationPrediction {
relation: rel.clone(),
tail_entity: Some(tail_entity.to_string()),
score,
confidence: score,
}
})
.filter(|pred| pred.score >= self.config.score_threshold)
.collect();
predictions.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
predictions.truncate(self.config.max_predictions);
Ok(predictions)
}
pub fn predict_tails(
&self,
head_entity: &str,
relation: &str,
) -> Result<Vec<RelationPrediction>> {
let head_emb = self
.entity_embeddings
.get(head_entity)
.ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
let rel_emb = self
.relation_embeddings
.get(relation)
.ok_or_else(|| anyhow!("Unknown relation: {}", relation))?;
let expected_tail = head_emb + rel_emb;
let mut predictions: Vec<RelationPrediction> = self
.entity_embeddings
.par_iter()
.map(|(entity, entity_emb)| {
let distance = Self::euclidean_distance(&expected_tail, entity_emb);
let score = 1.0 / (1.0 + distance);
RelationPrediction {
relation: relation.to_string(),
tail_entity: Some(entity.clone()),
score,
confidence: score,
}
})
.filter(|pred| pred.score >= self.config.score_threshold)
.collect();
predictions.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
predictions.truncate(self.config.max_predictions);
Ok(predictions)
}
fn score_triple(&self, head: &Array1<f32>, relation: &Array1<f32>, tail: &Array1<f32>) -> f32 {
let expected_tail = head + relation;
let distance = Self::euclidean_distance(&expected_tail, tail);
1.0 / (1.0 + distance)
}
fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let diff = a - b;
diff.dot(&diff).sqrt()
}
pub fn predict_tails_batch(
&self,
queries: &[(String, String)], ) -> Result<Vec<Vec<RelationPrediction>>> {
let results: Vec<Vec<RelationPrediction>> = queries
.par_iter()
.map(|(head, rel)| self.predict_tails(head, rel).unwrap_or_default())
.collect();
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray_ext::array;
#[test]
fn test_entity_linker_creation() {
let mut embeddings = HashMap::new();
embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
embeddings.insert("entity2".to_string(), array![0.4, 0.5, 0.6]);
let config = EntityLinkerConfig::default();
let linker = EntityLinker::new(config, embeddings);
assert!(linker.is_ok());
}
#[test]
fn test_entity_linking() {
let mut embeddings = HashMap::new();
embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
embeddings.insert("entity3".to_string(), array![0.7, 0.7, 0.0]);
let config = EntityLinkerConfig {
similarity_threshold: 0.5,
..Default::default()
};
let linker = EntityLinker::new(config, embeddings).expect("should succeed");
let query = array![0.9, 0.1, 0.0];
let results = linker.link_entity(&query, None).expect("should succeed");
assert!(!results.is_empty());
assert_eq!(results[0].entity_id, "entity1");
}
#[test]
fn test_relation_predictor_creation() {
let mut entity_embeddings = HashMap::new();
entity_embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
let mut relation_embeddings = HashMap::new();
relation_embeddings.insert("rel1".to_string(), array![0.1, 0.1, 0.1]);
let config = RelationPredictorConfig::default();
let predictor = RelationPredictor::new(config, relation_embeddings, entity_embeddings);
assert_eq!(predictor.relation_embeddings.len(), 1);
}
#[test]
fn test_batch_entity_linking() {
let mut embeddings = HashMap::new();
embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
let config = EntityLinkerConfig::default();
let linker = EntityLinker::new(config, embeddings).expect("should succeed");
let queries = vec![array![0.9, 0.1, 0.0], array![0.1, 0.9, 0.0]];
let results = linker
.link_entities_batch(&queries)
.expect("should succeed");
assert_eq!(results.len(), 2);
}
}