Skip to main content

oxirs_embed/
entity_linking.rs

1//! Entity Linking and Relation Prediction for Knowledge Graphs
2//!
3//! This module provides advanced entity linking and relation prediction capabilities
4//! using learned embeddings and similarity metrics with full SciRS2 integration.
5
6use anyhow::{anyhow, Result};
7use rayon::prelude::*;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14/// Entity linker configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct EntityLinkerConfig {
17    /// Similarity threshold for entity matching
18    pub similarity_threshold: f32,
19    /// Maximum number of candidate entities to consider
20    pub max_candidates: usize,
21    /// Enable context-aware linking
22    pub use_context: bool,
23    /// Minimum confidence score for linking
24    pub min_confidence: f32,
25    /// Enable approximate nearest neighbor search
26    pub use_ann: bool,
27    /// Number of nearest neighbors to retrieve
28    pub k_neighbors: usize,
29}
30
31impl Default for EntityLinkerConfig {
32    fn default() -> Self {
33        Self {
34            similarity_threshold: 0.7,
35            max_candidates: 10,
36            use_context: true,
37            min_confidence: 0.5,
38            use_ann: true,
39            k_neighbors: 50,
40        }
41    }
42}
43
44/// Entity linking result with confidence scores
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LinkingResult {
47    /// Linked entity ID
48    pub entity_id: String,
49    /// Confidence score (0.0 to 1.0)
50    pub confidence: f32,
51    /// Similarity score
52    pub similarity: f32,
53    /// Context features used
54    pub context_features: Vec<String>,
55}
56
57/// Entity linker for knowledge graph entity resolution
58pub struct EntityLinker {
59    config: EntityLinkerConfig,
60    entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
61    entity_index: Vec<String>,
62    embedding_matrix: Array2<f32>,
63}
64
65impl EntityLinker {
66    /// Create new entity linker
67    pub fn new(
68        config: EntityLinkerConfig,
69        entity_embeddings: HashMap<String, Array1<f32>>,
70    ) -> Result<Self> {
71        let entity_count = entity_embeddings.len();
72        if entity_count == 0 {
73            return Err(anyhow!("Empty entity embedding set"));
74        }
75
76        // Build entity index for fast lookup
77        let mut entity_index = Vec::with_capacity(entity_count);
78        let embedding_dim = entity_embeddings
79            .values()
80            .next()
81            .expect("entity_embeddings should not be empty")
82            .len();
83        let mut embedding_matrix = Array2::zeros((entity_count, embedding_dim));
84
85        for (idx, (entity_id, embedding)) in entity_embeddings.iter().enumerate() {
86            entity_index.push(entity_id.clone());
87            embedding_matrix.row_mut(idx).assign(embedding);
88        }
89
90        info!(
91            "Initialized EntityLinker with {} entities, dim={}",
92            entity_count, embedding_dim
93        );
94
95        Ok(Self {
96            config,
97            entity_embeddings: Arc::new(entity_embeddings),
98            entity_index,
99            embedding_matrix,
100        })
101    }
102
103    /// Link a mention to knowledge graph entities
104    pub fn link_entity(
105        &self,
106        mention_embedding: &Array1<f32>,
107        context_embeddings: Option<&[Array1<f32>]>,
108    ) -> Result<Vec<LinkingResult>> {
109        // Compute similarities with all entities
110        let similarities = self.compute_similarities(mention_embedding)?;
111
112        // Get top-k candidates
113        let mut candidates: Vec<(usize, f32)> = similarities
114            .iter()
115            .enumerate()
116            .map(|(idx, &sim)| (idx, sim))
117            .collect();
118
119        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120        candidates.truncate(self.config.max_candidates);
121
122        // Apply context if available
123        let results = if let Some(ctx_emb) = context_embeddings.filter(|_| self.config.use_context)
124        {
125            self.rerank_with_context(&candidates, ctx_emb)?
126        } else {
127            candidates
128                .into_iter()
129                .filter(|(_, sim)| *sim >= self.config.similarity_threshold)
130                .map(|(idx, sim)| LinkingResult {
131                    entity_id: self.entity_index[idx].clone(),
132                    confidence: sim,
133                    similarity: sim,
134                    context_features: vec![],
135                })
136                .collect()
137        };
138
139        // Filter by minimum confidence
140        let filtered: Vec<_> = results
141            .into_iter()
142            .filter(|r| r.confidence >= self.config.min_confidence)
143            .collect();
144
145        debug!("Linked {} candidate entities", filtered.len());
146
147        Ok(filtered)
148    }
149
150    /// Batch entity linking for multiple mentions
151    pub fn link_entities_batch(
152        &self,
153        mention_embeddings: &[Array1<f32>],
154    ) -> Result<Vec<Vec<LinkingResult>>> {
155        // Parallel processing using rayon
156        let results: Vec<Vec<LinkingResult>> = mention_embeddings
157            .par_iter()
158            .map(|mention| self.link_entity(mention, None).unwrap_or_default())
159            .collect();
160
161        Ok(results)
162    }
163
164    /// Compute cosine similarities efficiently
165    fn compute_similarities(&self, query: &Array1<f32>) -> Result<Vec<f32>> {
166        // Normalize query
167        let query_norm = query.dot(query).sqrt();
168        if query_norm == 0.0 {
169            return Err(anyhow!("Zero-norm query vector"));
170        }
171
172        let normalized_query = query / query_norm;
173
174        // Compute similarities using matrix multiplication
175        let similarities: Vec<f32> = (0..self.embedding_matrix.nrows())
176            .into_par_iter()
177            .map(|i| {
178                let entity_emb = self.embedding_matrix.row(i);
179                let entity_norm = entity_emb.dot(&entity_emb).sqrt();
180
181                if entity_norm == 0.0 {
182                    0.0
183                } else {
184                    let normalized_entity = entity_emb.to_owned() / entity_norm;
185                    normalized_query.dot(&normalized_entity)
186                }
187            })
188            .collect();
189
190        Ok(similarities)
191    }
192
193    /// Rerank candidates using context information
194    fn rerank_with_context(
195        &self,
196        candidates: &[(usize, f32)],
197        context_embeddings: &[Array1<f32>],
198    ) -> Result<Vec<LinkingResult>> {
199        let results: Vec<LinkingResult> = candidates
200            .iter()
201            .map(|(idx, base_sim)| {
202                let entity_embedding = self.embedding_matrix.row(*idx);
203
204                // Compute context similarity
205                let context_sim = self
206                    .compute_context_similarity(&entity_embedding.to_owned(), context_embeddings);
207
208                // Combine base similarity and context similarity
209                let confidence = 0.7 * base_sim + 0.3 * context_sim;
210
211                LinkingResult {
212                    entity_id: self.entity_index[*idx].clone(),
213                    confidence,
214                    similarity: *base_sim,
215                    context_features: vec!["context_aware".to_string()],
216                }
217            })
218            .collect();
219
220        Ok(results)
221    }
222
223    /// Compute context similarity score
224    fn compute_context_similarity(
225        &self,
226        entity_embedding: &Array1<f32>,
227        context_embeddings: &[Array1<f32>],
228    ) -> f32 {
229        if context_embeddings.is_empty() {
230            return 0.0;
231        }
232
233        // Average similarity with context
234        let total_sim: f32 = context_embeddings
235            .iter()
236            .map(|ctx| {
237                let norm1 = entity_embedding.dot(entity_embedding).sqrt();
238                let norm2 = ctx.dot(ctx).sqrt();
239
240                if norm1 == 0.0 || norm2 == 0.0 {
241                    0.0
242                } else {
243                    entity_embedding.dot(ctx) / (norm1 * norm2)
244                }
245            })
246            .sum();
247
248        total_sim / context_embeddings.len() as f32
249    }
250
251    /// Get entity embedding by ID
252    pub fn get_embedding(&self, entity_id: &str) -> Option<&Array1<f32>> {
253        self.entity_embeddings.get(entity_id)
254    }
255}
256
257/// Relation prediction configuration
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct RelationPredictorConfig {
260    /// Score threshold for relation prediction
261    pub score_threshold: f32,
262    /// Maximum number of relations to predict
263    pub max_predictions: usize,
264    /// Enable type constraints
265    pub use_type_constraints: bool,
266    /// Enable path-based reasoning
267    pub use_path_reasoning: bool,
268}
269
270impl Default for RelationPredictorConfig {
271    fn default() -> Self {
272        Self {
273            score_threshold: 0.6,
274            max_predictions: 10,
275            use_type_constraints: true,
276            use_path_reasoning: false,
277        }
278    }
279}
280
281/// Relation prediction result
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct RelationPrediction {
284    /// Predicted relation type
285    pub relation: String,
286    /// Tail entity (if predicting tails)
287    pub tail_entity: Option<String>,
288    /// Prediction score
289    pub score: f32,
290    /// Confidence level
291    pub confidence: f32,
292}
293
294/// Relation predictor for knowledge graph completion
295pub struct RelationPredictor {
296    config: RelationPredictorConfig,
297    relation_embeddings: Arc<HashMap<String, Array1<f32>>>,
298    entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
299}
300
301impl RelationPredictor {
302    /// Create new relation predictor
303    pub fn new(
304        config: RelationPredictorConfig,
305        relation_embeddings: HashMap<String, Array1<f32>>,
306        entity_embeddings: HashMap<String, Array1<f32>>,
307    ) -> Self {
308        info!(
309            "Initialized RelationPredictor with {} relations, {} entities",
310            relation_embeddings.len(),
311            entity_embeddings.len()
312        );
313
314        Self {
315            config,
316            relation_embeddings: Arc::new(relation_embeddings),
317            entity_embeddings: Arc::new(entity_embeddings),
318        }
319    }
320
321    /// Predict relations between two entities
322    pub fn predict_relations(
323        &self,
324        head_entity: &str,
325        tail_entity: &str,
326    ) -> Result<Vec<RelationPrediction>> {
327        let head_emb = self
328            .entity_embeddings
329            .get(head_entity)
330            .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
331
332        let tail_emb = self
333            .entity_embeddings
334            .get(tail_entity)
335            .ok_or_else(|| anyhow!("Unknown tail entity: {}", tail_entity))?;
336
337        // Score all possible relations
338        let mut predictions: Vec<RelationPrediction> = self
339            .relation_embeddings
340            .par_iter()
341            .map(|(rel, rel_emb)| {
342                // TransE-style scoring: h + r ≈ t
343                let score = self.score_triple(head_emb, rel_emb, tail_emb);
344
345                RelationPrediction {
346                    relation: rel.clone(),
347                    tail_entity: Some(tail_entity.to_string()),
348                    score,
349                    confidence: score,
350                }
351            })
352            .filter(|pred| pred.score >= self.config.score_threshold)
353            .collect();
354
355        // Sort by score descending
356        predictions.sort_by(|a, b| {
357            b.score
358                .partial_cmp(&a.score)
359                .unwrap_or(std::cmp::Ordering::Equal)
360        });
361        predictions.truncate(self.config.max_predictions);
362
363        Ok(predictions)
364    }
365
366    /// Predict tail entities for a given head and relation
367    pub fn predict_tails(
368        &self,
369        head_entity: &str,
370        relation: &str,
371    ) -> Result<Vec<RelationPrediction>> {
372        let head_emb = self
373            .entity_embeddings
374            .get(head_entity)
375            .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
376
377        let rel_emb = self
378            .relation_embeddings
379            .get(relation)
380            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))?;
381
382        // Compute expected tail embedding: t = h + r
383        let expected_tail = head_emb + rel_emb;
384
385        // Find nearest entities to expected tail
386        let mut predictions: Vec<RelationPrediction> = self
387            .entity_embeddings
388            .par_iter()
389            .map(|(entity, entity_emb)| {
390                let distance = Self::euclidean_distance(&expected_tail, entity_emb);
391                let score = 1.0 / (1.0 + distance); // Convert distance to score
392
393                RelationPrediction {
394                    relation: relation.to_string(),
395                    tail_entity: Some(entity.clone()),
396                    score,
397                    confidence: score,
398                }
399            })
400            .filter(|pred| pred.score >= self.config.score_threshold)
401            .collect();
402
403        predictions.sort_by(|a, b| {
404            b.score
405                .partial_cmp(&a.score)
406                .unwrap_or(std::cmp::Ordering::Equal)
407        });
408        predictions.truncate(self.config.max_predictions);
409
410        Ok(predictions)
411    }
412
413    /// Score a triple using TransE-style scoring
414    fn score_triple(&self, head: &Array1<f32>, relation: &Array1<f32>, tail: &Array1<f32>) -> f32 {
415        // TransE: score = -||h + r - t||
416        let expected_tail = head + relation;
417        let distance = Self::euclidean_distance(&expected_tail, tail);
418
419        // Convert to similarity score (higher is better)
420        1.0 / (1.0 + distance)
421    }
422
423    /// Compute Euclidean distance
424    fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
425        let diff = a - b;
426        diff.dot(&diff).sqrt()
427    }
428
429    /// Batch prediction of tails
430    pub fn predict_tails_batch(
431        &self,
432        queries: &[(String, String)], // (head, relation) pairs
433    ) -> Result<Vec<Vec<RelationPrediction>>> {
434        let results: Vec<Vec<RelationPrediction>> = queries
435            .par_iter()
436            .map(|(head, rel)| self.predict_tails(head, rel).unwrap_or_default())
437            .collect();
438
439        Ok(results)
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use scirs2_core::ndarray_ext::array;
447
448    #[test]
449    fn test_entity_linker_creation() {
450        let mut embeddings = HashMap::new();
451        embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
452        embeddings.insert("entity2".to_string(), array![0.4, 0.5, 0.6]);
453
454        let config = EntityLinkerConfig::default();
455        let linker = EntityLinker::new(config, embeddings);
456        assert!(linker.is_ok());
457    }
458
459    #[test]
460    fn test_entity_linking() {
461        let mut embeddings = HashMap::new();
462        embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
463        embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
464        embeddings.insert("entity3".to_string(), array![0.7, 0.7, 0.0]);
465
466        let config = EntityLinkerConfig {
467            similarity_threshold: 0.5,
468            ..Default::default()
469        };
470
471        let linker = EntityLinker::new(config, embeddings).unwrap();
472
473        // Query similar to entity1
474        let query = array![0.9, 0.1, 0.0];
475        let results = linker.link_entity(&query, None).unwrap();
476
477        assert!(!results.is_empty());
478        assert_eq!(results[0].entity_id, "entity1");
479    }
480
481    #[test]
482    fn test_relation_predictor_creation() {
483        let mut entity_embeddings = HashMap::new();
484        entity_embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
485
486        let mut relation_embeddings = HashMap::new();
487        relation_embeddings.insert("rel1".to_string(), array![0.1, 0.1, 0.1]);
488
489        let config = RelationPredictorConfig::default();
490        let predictor = RelationPredictor::new(config, relation_embeddings, entity_embeddings);
491
492        // Just verify creation succeeds
493        assert_eq!(predictor.relation_embeddings.len(), 1);
494    }
495
496    #[test]
497    fn test_batch_entity_linking() {
498        let mut embeddings = HashMap::new();
499        embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
500        embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
501
502        let config = EntityLinkerConfig::default();
503        let linker = EntityLinker::new(config, embeddings).unwrap();
504
505        let queries = vec![array![0.9, 0.1, 0.0], array![0.1, 0.9, 0.0]];
506
507        let results = linker.link_entities_batch(&queries).unwrap();
508        assert_eq!(results.len(), 2);
509    }
510}