1use anyhow::{anyhow, Result};
7use rayon::prelude::*;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct EntityLinkerConfig {
17 pub similarity_threshold: f32,
19 pub max_candidates: usize,
21 pub use_context: bool,
23 pub min_confidence: f32,
25 pub use_ann: bool,
27 pub k_neighbors: usize,
29}
30
31impl Default for EntityLinkerConfig {
32 fn default() -> Self {
33 Self {
34 similarity_threshold: 0.7,
35 max_candidates: 10,
36 use_context: true,
37 min_confidence: 0.5,
38 use_ann: true,
39 k_neighbors: 50,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LinkingResult {
47 pub entity_id: String,
49 pub confidence: f32,
51 pub similarity: f32,
53 pub context_features: Vec<String>,
55}
56
57pub struct EntityLinker {
59 config: EntityLinkerConfig,
60 entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
61 entity_index: Vec<String>,
62 embedding_matrix: Array2<f32>,
63}
64
65impl EntityLinker {
66 pub fn new(
68 config: EntityLinkerConfig,
69 entity_embeddings: HashMap<String, Array1<f32>>,
70 ) -> Result<Self> {
71 let entity_count = entity_embeddings.len();
72 if entity_count == 0 {
73 return Err(anyhow!("Empty entity embedding set"));
74 }
75
76 let mut entity_index = Vec::with_capacity(entity_count);
78 let embedding_dim = entity_embeddings.values().next().unwrap().len();
79 let mut embedding_matrix = Array2::zeros((entity_count, embedding_dim));
80
81 for (idx, (entity_id, embedding)) in entity_embeddings.iter().enumerate() {
82 entity_index.push(entity_id.clone());
83 embedding_matrix.row_mut(idx).assign(embedding);
84 }
85
86 info!(
87 "Initialized EntityLinker with {} entities, dim={}",
88 entity_count, embedding_dim
89 );
90
91 Ok(Self {
92 config,
93 entity_embeddings: Arc::new(entity_embeddings),
94 entity_index,
95 embedding_matrix,
96 })
97 }
98
99 pub fn link_entity(
101 &self,
102 mention_embedding: &Array1<f32>,
103 context_embeddings: Option<&[Array1<f32>]>,
104 ) -> Result<Vec<LinkingResult>> {
105 let similarities = self.compute_similarities(mention_embedding)?;
107
108 let mut candidates: Vec<(usize, f32)> = similarities
110 .iter()
111 .enumerate()
112 .map(|(idx, &sim)| (idx, sim))
113 .collect();
114
115 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
116 candidates.truncate(self.config.max_candidates);
117
118 let results = if let Some(ctx_emb) = context_embeddings.filter(|_| self.config.use_context)
120 {
121 self.rerank_with_context(&candidates, ctx_emb)?
122 } else {
123 candidates
124 .into_iter()
125 .filter(|(_, sim)| *sim >= self.config.similarity_threshold)
126 .map(|(idx, sim)| LinkingResult {
127 entity_id: self.entity_index[idx].clone(),
128 confidence: sim,
129 similarity: sim,
130 context_features: vec![],
131 })
132 .collect()
133 };
134
135 let filtered: Vec<_> = results
137 .into_iter()
138 .filter(|r| r.confidence >= self.config.min_confidence)
139 .collect();
140
141 debug!("Linked {} candidate entities", filtered.len());
142
143 Ok(filtered)
144 }
145
146 pub fn link_entities_batch(
148 &self,
149 mention_embeddings: &[Array1<f32>],
150 ) -> Result<Vec<Vec<LinkingResult>>> {
151 let results: Vec<Vec<LinkingResult>> = mention_embeddings
153 .par_iter()
154 .map(|mention| self.link_entity(mention, None).unwrap_or_default())
155 .collect();
156
157 Ok(results)
158 }
159
160 fn compute_similarities(&self, query: &Array1<f32>) -> Result<Vec<f32>> {
162 let query_norm = query.dot(query).sqrt();
164 if query_norm == 0.0 {
165 return Err(anyhow!("Zero-norm query vector"));
166 }
167
168 let normalized_query = query / query_norm;
169
170 let similarities: Vec<f32> = (0..self.embedding_matrix.nrows())
172 .into_par_iter()
173 .map(|i| {
174 let entity_emb = self.embedding_matrix.row(i);
175 let entity_norm = entity_emb.dot(&entity_emb).sqrt();
176
177 if entity_norm == 0.0 {
178 0.0
179 } else {
180 let normalized_entity = entity_emb.to_owned() / entity_norm;
181 normalized_query.dot(&normalized_entity)
182 }
183 })
184 .collect();
185
186 Ok(similarities)
187 }
188
189 fn rerank_with_context(
191 &self,
192 candidates: &[(usize, f32)],
193 context_embeddings: &[Array1<f32>],
194 ) -> Result<Vec<LinkingResult>> {
195 let results: Vec<LinkingResult> = candidates
196 .iter()
197 .map(|(idx, base_sim)| {
198 let entity_embedding = self.embedding_matrix.row(*idx);
199
200 let context_sim = self
202 .compute_context_similarity(&entity_embedding.to_owned(), context_embeddings);
203
204 let confidence = 0.7 * base_sim + 0.3 * context_sim;
206
207 LinkingResult {
208 entity_id: self.entity_index[*idx].clone(),
209 confidence,
210 similarity: *base_sim,
211 context_features: vec!["context_aware".to_string()],
212 }
213 })
214 .collect();
215
216 Ok(results)
217 }
218
219 fn compute_context_similarity(
221 &self,
222 entity_embedding: &Array1<f32>,
223 context_embeddings: &[Array1<f32>],
224 ) -> f32 {
225 if context_embeddings.is_empty() {
226 return 0.0;
227 }
228
229 let total_sim: f32 = context_embeddings
231 .iter()
232 .map(|ctx| {
233 let norm1 = entity_embedding.dot(entity_embedding).sqrt();
234 let norm2 = ctx.dot(ctx).sqrt();
235
236 if norm1 == 0.0 || norm2 == 0.0 {
237 0.0
238 } else {
239 entity_embedding.dot(ctx) / (norm1 * norm2)
240 }
241 })
242 .sum();
243
244 total_sim / context_embeddings.len() as f32
245 }
246
247 pub fn get_embedding(&self, entity_id: &str) -> Option<&Array1<f32>> {
249 self.entity_embeddings.get(entity_id)
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct RelationPredictorConfig {
256 pub score_threshold: f32,
258 pub max_predictions: usize,
260 pub use_type_constraints: bool,
262 pub use_path_reasoning: bool,
264}
265
266impl Default for RelationPredictorConfig {
267 fn default() -> Self {
268 Self {
269 score_threshold: 0.6,
270 max_predictions: 10,
271 use_type_constraints: true,
272 use_path_reasoning: false,
273 }
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct RelationPrediction {
280 pub relation: String,
282 pub tail_entity: Option<String>,
284 pub score: f32,
286 pub confidence: f32,
288}
289
290pub struct RelationPredictor {
292 config: RelationPredictorConfig,
293 relation_embeddings: Arc<HashMap<String, Array1<f32>>>,
294 entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
295}
296
297impl RelationPredictor {
298 pub fn new(
300 config: RelationPredictorConfig,
301 relation_embeddings: HashMap<String, Array1<f32>>,
302 entity_embeddings: HashMap<String, Array1<f32>>,
303 ) -> Self {
304 info!(
305 "Initialized RelationPredictor with {} relations, {} entities",
306 relation_embeddings.len(),
307 entity_embeddings.len()
308 );
309
310 Self {
311 config,
312 relation_embeddings: Arc::new(relation_embeddings),
313 entity_embeddings: Arc::new(entity_embeddings),
314 }
315 }
316
317 pub fn predict_relations(
319 &self,
320 head_entity: &str,
321 tail_entity: &str,
322 ) -> Result<Vec<RelationPrediction>> {
323 let head_emb = self
324 .entity_embeddings
325 .get(head_entity)
326 .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
327
328 let tail_emb = self
329 .entity_embeddings
330 .get(tail_entity)
331 .ok_or_else(|| anyhow!("Unknown tail entity: {}", tail_entity))?;
332
333 let mut predictions: Vec<RelationPrediction> = self
335 .relation_embeddings
336 .par_iter()
337 .map(|(rel, rel_emb)| {
338 let score = self.score_triple(head_emb, rel_emb, tail_emb);
340
341 RelationPrediction {
342 relation: rel.clone(),
343 tail_entity: Some(tail_entity.to_string()),
344 score,
345 confidence: score,
346 }
347 })
348 .filter(|pred| pred.score >= self.config.score_threshold)
349 .collect();
350
351 predictions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
353 predictions.truncate(self.config.max_predictions);
354
355 Ok(predictions)
356 }
357
358 pub fn predict_tails(
360 &self,
361 head_entity: &str,
362 relation: &str,
363 ) -> Result<Vec<RelationPrediction>> {
364 let head_emb = self
365 .entity_embeddings
366 .get(head_entity)
367 .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
368
369 let rel_emb = self
370 .relation_embeddings
371 .get(relation)
372 .ok_or_else(|| anyhow!("Unknown relation: {}", relation))?;
373
374 let expected_tail = head_emb + rel_emb;
376
377 let mut predictions: Vec<RelationPrediction> = self
379 .entity_embeddings
380 .par_iter()
381 .map(|(entity, entity_emb)| {
382 let distance = Self::euclidean_distance(&expected_tail, entity_emb);
383 let score = 1.0 / (1.0 + distance); RelationPrediction {
386 relation: relation.to_string(),
387 tail_entity: Some(entity.clone()),
388 score,
389 confidence: score,
390 }
391 })
392 .filter(|pred| pred.score >= self.config.score_threshold)
393 .collect();
394
395 predictions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
396 predictions.truncate(self.config.max_predictions);
397
398 Ok(predictions)
399 }
400
401 fn score_triple(&self, head: &Array1<f32>, relation: &Array1<f32>, tail: &Array1<f32>) -> f32 {
403 let expected_tail = head + relation;
405 let distance = Self::euclidean_distance(&expected_tail, tail);
406
407 1.0 / (1.0 + distance)
409 }
410
411 fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
413 let diff = a - b;
414 diff.dot(&diff).sqrt()
415 }
416
417 pub fn predict_tails_batch(
419 &self,
420 queries: &[(String, String)], ) -> Result<Vec<Vec<RelationPrediction>>> {
422 let results: Vec<Vec<RelationPrediction>> = queries
423 .par_iter()
424 .map(|(head, rel)| self.predict_tails(head, rel).unwrap_or_default())
425 .collect();
426
427 Ok(results)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use scirs2_core::ndarray_ext::array;
435
436 #[test]
437 fn test_entity_linker_creation() {
438 let mut embeddings = HashMap::new();
439 embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
440 embeddings.insert("entity2".to_string(), array![0.4, 0.5, 0.6]);
441
442 let config = EntityLinkerConfig::default();
443 let linker = EntityLinker::new(config, embeddings);
444 assert!(linker.is_ok());
445 }
446
447 #[test]
448 fn test_entity_linking() {
449 let mut embeddings = HashMap::new();
450 embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
451 embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
452 embeddings.insert("entity3".to_string(), array![0.7, 0.7, 0.0]);
453
454 let config = EntityLinkerConfig {
455 similarity_threshold: 0.5,
456 ..Default::default()
457 };
458
459 let linker = EntityLinker::new(config, embeddings).unwrap();
460
461 let query = array![0.9, 0.1, 0.0];
463 let results = linker.link_entity(&query, None).unwrap();
464
465 assert!(!results.is_empty());
466 assert_eq!(results[0].entity_id, "entity1");
467 }
468
469 #[test]
470 fn test_relation_predictor_creation() {
471 let mut entity_embeddings = HashMap::new();
472 entity_embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
473
474 let mut relation_embeddings = HashMap::new();
475 relation_embeddings.insert("rel1".to_string(), array![0.1, 0.1, 0.1]);
476
477 let config = RelationPredictorConfig::default();
478 let predictor = RelationPredictor::new(config, relation_embeddings, entity_embeddings);
479
480 assert_eq!(predictor.relation_embeddings.len(), 1);
482 }
483
484 #[test]
485 fn test_batch_entity_linking() {
486 let mut embeddings = HashMap::new();
487 embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
488 embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
489
490 let config = EntityLinkerConfig::default();
491 let linker = EntityLinker::new(config, embeddings).unwrap();
492
493 let queries = vec![array![0.9, 0.1, 0.0], array![0.1, 0.9, 0.0]];
494
495 let results = linker.link_entities_batch(&queries).unwrap();
496 assert_eq!(results.len(), 2);
497 }
498}