1use anyhow::{anyhow, Result};
7use rayon::prelude::*;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct EntityLinkerConfig {
17 pub similarity_threshold: f32,
19 pub max_candidates: usize,
21 pub use_context: bool,
23 pub min_confidence: f32,
25 pub use_ann: bool,
27 pub k_neighbors: usize,
29}
30
31impl Default for EntityLinkerConfig {
32 fn default() -> Self {
33 Self {
34 similarity_threshold: 0.7,
35 max_candidates: 10,
36 use_context: true,
37 min_confidence: 0.5,
38 use_ann: true,
39 k_neighbors: 50,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LinkingResult {
47 pub entity_id: String,
49 pub confidence: f32,
51 pub similarity: f32,
53 pub context_features: Vec<String>,
55}
56
57pub struct EntityLinker {
59 config: EntityLinkerConfig,
60 entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
61 entity_index: Vec<String>,
62 embedding_matrix: Array2<f32>,
63}
64
65impl EntityLinker {
66 pub fn new(
68 config: EntityLinkerConfig,
69 entity_embeddings: HashMap<String, Array1<f32>>,
70 ) -> Result<Self> {
71 let entity_count = entity_embeddings.len();
72 if entity_count == 0 {
73 return Err(anyhow!("Empty entity embedding set"));
74 }
75
76 let mut entity_index = Vec::with_capacity(entity_count);
78 let embedding_dim = entity_embeddings
79 .values()
80 .next()
81 .expect("entity_embeddings should not be empty")
82 .len();
83 let mut embedding_matrix = Array2::zeros((entity_count, embedding_dim));
84
85 for (idx, (entity_id, embedding)) in entity_embeddings.iter().enumerate() {
86 entity_index.push(entity_id.clone());
87 embedding_matrix.row_mut(idx).assign(embedding);
88 }
89
90 info!(
91 "Initialized EntityLinker with {} entities, dim={}",
92 entity_count, embedding_dim
93 );
94
95 Ok(Self {
96 config,
97 entity_embeddings: Arc::new(entity_embeddings),
98 entity_index,
99 embedding_matrix,
100 })
101 }
102
103 pub fn link_entity(
105 &self,
106 mention_embedding: &Array1<f32>,
107 context_embeddings: Option<&[Array1<f32>]>,
108 ) -> Result<Vec<LinkingResult>> {
109 let similarities = self.compute_similarities(mention_embedding)?;
111
112 let mut candidates: Vec<(usize, f32)> = similarities
114 .iter()
115 .enumerate()
116 .map(|(idx, &sim)| (idx, sim))
117 .collect();
118
119 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120 candidates.truncate(self.config.max_candidates);
121
122 let results = if let Some(ctx_emb) = context_embeddings.filter(|_| self.config.use_context)
124 {
125 self.rerank_with_context(&candidates, ctx_emb)?
126 } else {
127 candidates
128 .into_iter()
129 .filter(|(_, sim)| *sim >= self.config.similarity_threshold)
130 .map(|(idx, sim)| LinkingResult {
131 entity_id: self.entity_index[idx].clone(),
132 confidence: sim,
133 similarity: sim,
134 context_features: vec![],
135 })
136 .collect()
137 };
138
139 let filtered: Vec<_> = results
141 .into_iter()
142 .filter(|r| r.confidence >= self.config.min_confidence)
143 .collect();
144
145 debug!("Linked {} candidate entities", filtered.len());
146
147 Ok(filtered)
148 }
149
150 pub fn link_entities_batch(
152 &self,
153 mention_embeddings: &[Array1<f32>],
154 ) -> Result<Vec<Vec<LinkingResult>>> {
155 let results: Vec<Vec<LinkingResult>> = mention_embeddings
157 .par_iter()
158 .map(|mention| self.link_entity(mention, None).unwrap_or_default())
159 .collect();
160
161 Ok(results)
162 }
163
164 fn compute_similarities(&self, query: &Array1<f32>) -> Result<Vec<f32>> {
166 let query_norm = query.dot(query).sqrt();
168 if query_norm == 0.0 {
169 return Err(anyhow!("Zero-norm query vector"));
170 }
171
172 let normalized_query = query / query_norm;
173
174 let similarities: Vec<f32> = (0..self.embedding_matrix.nrows())
176 .into_par_iter()
177 .map(|i| {
178 let entity_emb = self.embedding_matrix.row(i);
179 let entity_norm = entity_emb.dot(&entity_emb).sqrt();
180
181 if entity_norm == 0.0 {
182 0.0
183 } else {
184 let normalized_entity = entity_emb.to_owned() / entity_norm;
185 normalized_query.dot(&normalized_entity)
186 }
187 })
188 .collect();
189
190 Ok(similarities)
191 }
192
193 fn rerank_with_context(
195 &self,
196 candidates: &[(usize, f32)],
197 context_embeddings: &[Array1<f32>],
198 ) -> Result<Vec<LinkingResult>> {
199 let results: Vec<LinkingResult> = candidates
200 .iter()
201 .map(|(idx, base_sim)| {
202 let entity_embedding = self.embedding_matrix.row(*idx);
203
204 let context_sim = self
206 .compute_context_similarity(&entity_embedding.to_owned(), context_embeddings);
207
208 let confidence = 0.7 * base_sim + 0.3 * context_sim;
210
211 LinkingResult {
212 entity_id: self.entity_index[*idx].clone(),
213 confidence,
214 similarity: *base_sim,
215 context_features: vec!["context_aware".to_string()],
216 }
217 })
218 .collect();
219
220 Ok(results)
221 }
222
223 fn compute_context_similarity(
225 &self,
226 entity_embedding: &Array1<f32>,
227 context_embeddings: &[Array1<f32>],
228 ) -> f32 {
229 if context_embeddings.is_empty() {
230 return 0.0;
231 }
232
233 let total_sim: f32 = context_embeddings
235 .iter()
236 .map(|ctx| {
237 let norm1 = entity_embedding.dot(entity_embedding).sqrt();
238 let norm2 = ctx.dot(ctx).sqrt();
239
240 if norm1 == 0.0 || norm2 == 0.0 {
241 0.0
242 } else {
243 entity_embedding.dot(ctx) / (norm1 * norm2)
244 }
245 })
246 .sum();
247
248 total_sim / context_embeddings.len() as f32
249 }
250
251 pub fn get_embedding(&self, entity_id: &str) -> Option<&Array1<f32>> {
253 self.entity_embeddings.get(entity_id)
254 }
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct RelationPredictorConfig {
260 pub score_threshold: f32,
262 pub max_predictions: usize,
264 pub use_type_constraints: bool,
266 pub use_path_reasoning: bool,
268}
269
270impl Default for RelationPredictorConfig {
271 fn default() -> Self {
272 Self {
273 score_threshold: 0.6,
274 max_predictions: 10,
275 use_type_constraints: true,
276 use_path_reasoning: false,
277 }
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct RelationPrediction {
284 pub relation: String,
286 pub tail_entity: Option<String>,
288 pub score: f32,
290 pub confidence: f32,
292}
293
294pub struct RelationPredictor {
296 config: RelationPredictorConfig,
297 relation_embeddings: Arc<HashMap<String, Array1<f32>>>,
298 entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
299}
300
301impl RelationPredictor {
302 pub fn new(
304 config: RelationPredictorConfig,
305 relation_embeddings: HashMap<String, Array1<f32>>,
306 entity_embeddings: HashMap<String, Array1<f32>>,
307 ) -> Self {
308 info!(
309 "Initialized RelationPredictor with {} relations, {} entities",
310 relation_embeddings.len(),
311 entity_embeddings.len()
312 );
313
314 Self {
315 config,
316 relation_embeddings: Arc::new(relation_embeddings),
317 entity_embeddings: Arc::new(entity_embeddings),
318 }
319 }
320
321 pub fn predict_relations(
323 &self,
324 head_entity: &str,
325 tail_entity: &str,
326 ) -> Result<Vec<RelationPrediction>> {
327 let head_emb = self
328 .entity_embeddings
329 .get(head_entity)
330 .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
331
332 let tail_emb = self
333 .entity_embeddings
334 .get(tail_entity)
335 .ok_or_else(|| anyhow!("Unknown tail entity: {}", tail_entity))?;
336
337 let mut predictions: Vec<RelationPrediction> = self
339 .relation_embeddings
340 .par_iter()
341 .map(|(rel, rel_emb)| {
342 let score = self.score_triple(head_emb, rel_emb, tail_emb);
344
345 RelationPrediction {
346 relation: rel.clone(),
347 tail_entity: Some(tail_entity.to_string()),
348 score,
349 confidence: score,
350 }
351 })
352 .filter(|pred| pred.score >= self.config.score_threshold)
353 .collect();
354
355 predictions.sort_by(|a, b| {
357 b.score
358 .partial_cmp(&a.score)
359 .unwrap_or(std::cmp::Ordering::Equal)
360 });
361 predictions.truncate(self.config.max_predictions);
362
363 Ok(predictions)
364 }
365
366 pub fn predict_tails(
368 &self,
369 head_entity: &str,
370 relation: &str,
371 ) -> Result<Vec<RelationPrediction>> {
372 let head_emb = self
373 .entity_embeddings
374 .get(head_entity)
375 .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
376
377 let rel_emb = self
378 .relation_embeddings
379 .get(relation)
380 .ok_or_else(|| anyhow!("Unknown relation: {}", relation))?;
381
382 let expected_tail = head_emb + rel_emb;
384
385 let mut predictions: Vec<RelationPrediction> = self
387 .entity_embeddings
388 .par_iter()
389 .map(|(entity, entity_emb)| {
390 let distance = Self::euclidean_distance(&expected_tail, entity_emb);
391 let score = 1.0 / (1.0 + distance); RelationPrediction {
394 relation: relation.to_string(),
395 tail_entity: Some(entity.clone()),
396 score,
397 confidence: score,
398 }
399 })
400 .filter(|pred| pred.score >= self.config.score_threshold)
401 .collect();
402
403 predictions.sort_by(|a, b| {
404 b.score
405 .partial_cmp(&a.score)
406 .unwrap_or(std::cmp::Ordering::Equal)
407 });
408 predictions.truncate(self.config.max_predictions);
409
410 Ok(predictions)
411 }
412
413 fn score_triple(&self, head: &Array1<f32>, relation: &Array1<f32>, tail: &Array1<f32>) -> f32 {
415 let expected_tail = head + relation;
417 let distance = Self::euclidean_distance(&expected_tail, tail);
418
419 1.0 / (1.0 + distance)
421 }
422
423 fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
425 let diff = a - b;
426 diff.dot(&diff).sqrt()
427 }
428
429 pub fn predict_tails_batch(
431 &self,
432 queries: &[(String, String)], ) -> Result<Vec<Vec<RelationPrediction>>> {
434 let results: Vec<Vec<RelationPrediction>> = queries
435 .par_iter()
436 .map(|(head, rel)| self.predict_tails(head, rel).unwrap_or_default())
437 .collect();
438
439 Ok(results)
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use scirs2_core::ndarray_ext::array;
447
448 #[test]
449 fn test_entity_linker_creation() {
450 let mut embeddings = HashMap::new();
451 embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
452 embeddings.insert("entity2".to_string(), array![0.4, 0.5, 0.6]);
453
454 let config = EntityLinkerConfig::default();
455 let linker = EntityLinker::new(config, embeddings);
456 assert!(linker.is_ok());
457 }
458
459 #[test]
460 fn test_entity_linking() {
461 let mut embeddings = HashMap::new();
462 embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
463 embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
464 embeddings.insert("entity3".to_string(), array![0.7, 0.7, 0.0]);
465
466 let config = EntityLinkerConfig {
467 similarity_threshold: 0.5,
468 ..Default::default()
469 };
470
471 let linker = EntityLinker::new(config, embeddings).unwrap();
472
473 let query = array![0.9, 0.1, 0.0];
475 let results = linker.link_entity(&query, None).unwrap();
476
477 assert!(!results.is_empty());
478 assert_eq!(results[0].entity_id, "entity1");
479 }
480
481 #[test]
482 fn test_relation_predictor_creation() {
483 let mut entity_embeddings = HashMap::new();
484 entity_embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
485
486 let mut relation_embeddings = HashMap::new();
487 relation_embeddings.insert("rel1".to_string(), array![0.1, 0.1, 0.1]);
488
489 let config = RelationPredictorConfig::default();
490 let predictor = RelationPredictor::new(config, relation_embeddings, entity_embeddings);
491
492 assert_eq!(predictor.relation_embeddings.len(), 1);
494 }
495
496 #[test]
497 fn test_batch_entity_linking() {
498 let mut embeddings = HashMap::new();
499 embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
500 embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
501
502 let config = EntityLinkerConfig::default();
503 let linker = EntityLinker::new(config, embeddings).unwrap();
504
505 let queries = vec![array![0.9, 0.1, 0.0], array![0.1, 0.9, 0.0]];
506
507 let results = linker.link_entities_batch(&queries).unwrap();
508 assert_eq!(results.len(), 2);
509 }
510}