oxirs_embed/
entity_linking.rs

1//! Entity Linking and Relation Prediction for Knowledge Graphs
2//!
3//! This module provides advanced entity linking and relation prediction capabilities
4//! using learned embeddings and similarity metrics with full SciRS2 integration.
5
6use anyhow::{anyhow, Result};
7use rayon::prelude::*;
8use scirs2_core::ndarray_ext::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14/// Entity linker configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct EntityLinkerConfig {
17    /// Similarity threshold for entity matching
18    pub similarity_threshold: f32,
19    /// Maximum number of candidate entities to consider
20    pub max_candidates: usize,
21    /// Enable context-aware linking
22    pub use_context: bool,
23    /// Minimum confidence score for linking
24    pub min_confidence: f32,
25    /// Enable approximate nearest neighbor search
26    pub use_ann: bool,
27    /// Number of nearest neighbors to retrieve
28    pub k_neighbors: usize,
29}
30
31impl Default for EntityLinkerConfig {
32    fn default() -> Self {
33        Self {
34            similarity_threshold: 0.7,
35            max_candidates: 10,
36            use_context: true,
37            min_confidence: 0.5,
38            use_ann: true,
39            k_neighbors: 50,
40        }
41    }
42}
43
44/// Entity linking result with confidence scores
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LinkingResult {
47    /// Linked entity ID
48    pub entity_id: String,
49    /// Confidence score (0.0 to 1.0)
50    pub confidence: f32,
51    /// Similarity score
52    pub similarity: f32,
53    /// Context features used
54    pub context_features: Vec<String>,
55}
56
57/// Entity linker for knowledge graph entity resolution
58pub struct EntityLinker {
59    config: EntityLinkerConfig,
60    entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
61    entity_index: Vec<String>,
62    embedding_matrix: Array2<f32>,
63}
64
65impl EntityLinker {
66    /// Create new entity linker
67    pub fn new(
68        config: EntityLinkerConfig,
69        entity_embeddings: HashMap<String, Array1<f32>>,
70    ) -> Result<Self> {
71        let entity_count = entity_embeddings.len();
72        if entity_count == 0 {
73            return Err(anyhow!("Empty entity embedding set"));
74        }
75
76        // Build entity index for fast lookup
77        let mut entity_index = Vec::with_capacity(entity_count);
78        let embedding_dim = entity_embeddings.values().next().unwrap().len();
79        let mut embedding_matrix = Array2::zeros((entity_count, embedding_dim));
80
81        for (idx, (entity_id, embedding)) in entity_embeddings.iter().enumerate() {
82            entity_index.push(entity_id.clone());
83            embedding_matrix.row_mut(idx).assign(embedding);
84        }
85
86        info!(
87            "Initialized EntityLinker with {} entities, dim={}",
88            entity_count, embedding_dim
89        );
90
91        Ok(Self {
92            config,
93            entity_embeddings: Arc::new(entity_embeddings),
94            entity_index,
95            embedding_matrix,
96        })
97    }
98
99    /// Link a mention to knowledge graph entities
100    pub fn link_entity(
101        &self,
102        mention_embedding: &Array1<f32>,
103        context_embeddings: Option<&[Array1<f32>]>,
104    ) -> Result<Vec<LinkingResult>> {
105        // Compute similarities with all entities
106        let similarities = self.compute_similarities(mention_embedding)?;
107
108        // Get top-k candidates
109        let mut candidates: Vec<(usize, f32)> = similarities
110            .iter()
111            .enumerate()
112            .map(|(idx, &sim)| (idx, sim))
113            .collect();
114
115        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
116        candidates.truncate(self.config.max_candidates);
117
118        // Apply context if available
119        let results = if let Some(ctx_emb) = context_embeddings.filter(|_| self.config.use_context)
120        {
121            self.rerank_with_context(&candidates, ctx_emb)?
122        } else {
123            candidates
124                .into_iter()
125                .filter(|(_, sim)| *sim >= self.config.similarity_threshold)
126                .map(|(idx, sim)| LinkingResult {
127                    entity_id: self.entity_index[idx].clone(),
128                    confidence: sim,
129                    similarity: sim,
130                    context_features: vec![],
131                })
132                .collect()
133        };
134
135        // Filter by minimum confidence
136        let filtered: Vec<_> = results
137            .into_iter()
138            .filter(|r| r.confidence >= self.config.min_confidence)
139            .collect();
140
141        debug!("Linked {} candidate entities", filtered.len());
142
143        Ok(filtered)
144    }
145
146    /// Batch entity linking for multiple mentions
147    pub fn link_entities_batch(
148        &self,
149        mention_embeddings: &[Array1<f32>],
150    ) -> Result<Vec<Vec<LinkingResult>>> {
151        // Parallel processing using rayon
152        let results: Vec<Vec<LinkingResult>> = mention_embeddings
153            .par_iter()
154            .map(|mention| self.link_entity(mention, None).unwrap_or_default())
155            .collect();
156
157        Ok(results)
158    }
159
160    /// Compute cosine similarities efficiently
161    fn compute_similarities(&self, query: &Array1<f32>) -> Result<Vec<f32>> {
162        // Normalize query
163        let query_norm = query.dot(query).sqrt();
164        if query_norm == 0.0 {
165            return Err(anyhow!("Zero-norm query vector"));
166        }
167
168        let normalized_query = query / query_norm;
169
170        // Compute similarities using matrix multiplication
171        let similarities: Vec<f32> = (0..self.embedding_matrix.nrows())
172            .into_par_iter()
173            .map(|i| {
174                let entity_emb = self.embedding_matrix.row(i);
175                let entity_norm = entity_emb.dot(&entity_emb).sqrt();
176
177                if entity_norm == 0.0 {
178                    0.0
179                } else {
180                    let normalized_entity = entity_emb.to_owned() / entity_norm;
181                    normalized_query.dot(&normalized_entity)
182                }
183            })
184            .collect();
185
186        Ok(similarities)
187    }
188
189    /// Rerank candidates using context information
190    fn rerank_with_context(
191        &self,
192        candidates: &[(usize, f32)],
193        context_embeddings: &[Array1<f32>],
194    ) -> Result<Vec<LinkingResult>> {
195        let results: Vec<LinkingResult> = candidates
196            .iter()
197            .map(|(idx, base_sim)| {
198                let entity_embedding = self.embedding_matrix.row(*idx);
199
200                // Compute context similarity
201                let context_sim = self
202                    .compute_context_similarity(&entity_embedding.to_owned(), context_embeddings);
203
204                // Combine base similarity and context similarity
205                let confidence = 0.7 * base_sim + 0.3 * context_sim;
206
207                LinkingResult {
208                    entity_id: self.entity_index[*idx].clone(),
209                    confidence,
210                    similarity: *base_sim,
211                    context_features: vec!["context_aware".to_string()],
212                }
213            })
214            .collect();
215
216        Ok(results)
217    }
218
219    /// Compute context similarity score
220    fn compute_context_similarity(
221        &self,
222        entity_embedding: &Array1<f32>,
223        context_embeddings: &[Array1<f32>],
224    ) -> f32 {
225        if context_embeddings.is_empty() {
226            return 0.0;
227        }
228
229        // Average similarity with context
230        let total_sim: f32 = context_embeddings
231            .iter()
232            .map(|ctx| {
233                let norm1 = entity_embedding.dot(entity_embedding).sqrt();
234                let norm2 = ctx.dot(ctx).sqrt();
235
236                if norm1 == 0.0 || norm2 == 0.0 {
237                    0.0
238                } else {
239                    entity_embedding.dot(ctx) / (norm1 * norm2)
240                }
241            })
242            .sum();
243
244        total_sim / context_embeddings.len() as f32
245    }
246
247    /// Get entity embedding by ID
248    pub fn get_embedding(&self, entity_id: &str) -> Option<&Array1<f32>> {
249        self.entity_embeddings.get(entity_id)
250    }
251}
252
253/// Relation prediction configuration
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct RelationPredictorConfig {
256    /// Score threshold for relation prediction
257    pub score_threshold: f32,
258    /// Maximum number of relations to predict
259    pub max_predictions: usize,
260    /// Enable type constraints
261    pub use_type_constraints: bool,
262    /// Enable path-based reasoning
263    pub use_path_reasoning: bool,
264}
265
266impl Default for RelationPredictorConfig {
267    fn default() -> Self {
268        Self {
269            score_threshold: 0.6,
270            max_predictions: 10,
271            use_type_constraints: true,
272            use_path_reasoning: false,
273        }
274    }
275}
276
277/// Relation prediction result
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct RelationPrediction {
280    /// Predicted relation type
281    pub relation: String,
282    /// Tail entity (if predicting tails)
283    pub tail_entity: Option<String>,
284    /// Prediction score
285    pub score: f32,
286    /// Confidence level
287    pub confidence: f32,
288}
289
290/// Relation predictor for knowledge graph completion
291pub struct RelationPredictor {
292    config: RelationPredictorConfig,
293    relation_embeddings: Arc<HashMap<String, Array1<f32>>>,
294    entity_embeddings: Arc<HashMap<String, Array1<f32>>>,
295}
296
297impl RelationPredictor {
298    /// Create new relation predictor
299    pub fn new(
300        config: RelationPredictorConfig,
301        relation_embeddings: HashMap<String, Array1<f32>>,
302        entity_embeddings: HashMap<String, Array1<f32>>,
303    ) -> Self {
304        info!(
305            "Initialized RelationPredictor with {} relations, {} entities",
306            relation_embeddings.len(),
307            entity_embeddings.len()
308        );
309
310        Self {
311            config,
312            relation_embeddings: Arc::new(relation_embeddings),
313            entity_embeddings: Arc::new(entity_embeddings),
314        }
315    }
316
317    /// Predict relations between two entities
318    pub fn predict_relations(
319        &self,
320        head_entity: &str,
321        tail_entity: &str,
322    ) -> Result<Vec<RelationPrediction>> {
323        let head_emb = self
324            .entity_embeddings
325            .get(head_entity)
326            .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
327
328        let tail_emb = self
329            .entity_embeddings
330            .get(tail_entity)
331            .ok_or_else(|| anyhow!("Unknown tail entity: {}", tail_entity))?;
332
333        // Score all possible relations
334        let mut predictions: Vec<RelationPrediction> = self
335            .relation_embeddings
336            .par_iter()
337            .map(|(rel, rel_emb)| {
338                // TransE-style scoring: h + r ≈ t
339                let score = self.score_triple(head_emb, rel_emb, tail_emb);
340
341                RelationPrediction {
342                    relation: rel.clone(),
343                    tail_entity: Some(tail_entity.to_string()),
344                    score,
345                    confidence: score,
346                }
347            })
348            .filter(|pred| pred.score >= self.config.score_threshold)
349            .collect();
350
351        // Sort by score descending
352        predictions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
353        predictions.truncate(self.config.max_predictions);
354
355        Ok(predictions)
356    }
357
358    /// Predict tail entities for a given head and relation
359    pub fn predict_tails(
360        &self,
361        head_entity: &str,
362        relation: &str,
363    ) -> Result<Vec<RelationPrediction>> {
364        let head_emb = self
365            .entity_embeddings
366            .get(head_entity)
367            .ok_or_else(|| anyhow!("Unknown head entity: {}", head_entity))?;
368
369        let rel_emb = self
370            .relation_embeddings
371            .get(relation)
372            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))?;
373
374        // Compute expected tail embedding: t = h + r
375        let expected_tail = head_emb + rel_emb;
376
377        // Find nearest entities to expected tail
378        let mut predictions: Vec<RelationPrediction> = self
379            .entity_embeddings
380            .par_iter()
381            .map(|(entity, entity_emb)| {
382                let distance = Self::euclidean_distance(&expected_tail, entity_emb);
383                let score = 1.0 / (1.0 + distance); // Convert distance to score
384
385                RelationPrediction {
386                    relation: relation.to_string(),
387                    tail_entity: Some(entity.clone()),
388                    score,
389                    confidence: score,
390                }
391            })
392            .filter(|pred| pred.score >= self.config.score_threshold)
393            .collect();
394
395        predictions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
396        predictions.truncate(self.config.max_predictions);
397
398        Ok(predictions)
399    }
400
401    /// Score a triple using TransE-style scoring
402    fn score_triple(&self, head: &Array1<f32>, relation: &Array1<f32>, tail: &Array1<f32>) -> f32 {
403        // TransE: score = -||h + r - t||
404        let expected_tail = head + relation;
405        let distance = Self::euclidean_distance(&expected_tail, tail);
406
407        // Convert to similarity score (higher is better)
408        1.0 / (1.0 + distance)
409    }
410
411    /// Compute Euclidean distance
412    fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
413        let diff = a - b;
414        diff.dot(&diff).sqrt()
415    }
416
417    /// Batch prediction of tails
418    pub fn predict_tails_batch(
419        &self,
420        queries: &[(String, String)], // (head, relation) pairs
421    ) -> Result<Vec<Vec<RelationPrediction>>> {
422        let results: Vec<Vec<RelationPrediction>> = queries
423            .par_iter()
424            .map(|(head, rel)| self.predict_tails(head, rel).unwrap_or_default())
425            .collect();
426
427        Ok(results)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use scirs2_core::ndarray_ext::array;
435
436    #[test]
437    fn test_entity_linker_creation() {
438        let mut embeddings = HashMap::new();
439        embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
440        embeddings.insert("entity2".to_string(), array![0.4, 0.5, 0.6]);
441
442        let config = EntityLinkerConfig::default();
443        let linker = EntityLinker::new(config, embeddings);
444        assert!(linker.is_ok());
445    }
446
447    #[test]
448    fn test_entity_linking() {
449        let mut embeddings = HashMap::new();
450        embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
451        embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
452        embeddings.insert("entity3".to_string(), array![0.7, 0.7, 0.0]);
453
454        let config = EntityLinkerConfig {
455            similarity_threshold: 0.5,
456            ..Default::default()
457        };
458
459        let linker = EntityLinker::new(config, embeddings).unwrap();
460
461        // Query similar to entity1
462        let query = array![0.9, 0.1, 0.0];
463        let results = linker.link_entity(&query, None).unwrap();
464
465        assert!(!results.is_empty());
466        assert_eq!(results[0].entity_id, "entity1");
467    }
468
469    #[test]
470    fn test_relation_predictor_creation() {
471        let mut entity_embeddings = HashMap::new();
472        entity_embeddings.insert("entity1".to_string(), array![0.1, 0.2, 0.3]);
473
474        let mut relation_embeddings = HashMap::new();
475        relation_embeddings.insert("rel1".to_string(), array![0.1, 0.1, 0.1]);
476
477        let config = RelationPredictorConfig::default();
478        let predictor = RelationPredictor::new(config, relation_embeddings, entity_embeddings);
479
480        // Just verify creation succeeds
481        assert_eq!(predictor.relation_embeddings.len(), 1);
482    }
483
484    #[test]
485    fn test_batch_entity_linking() {
486        let mut embeddings = HashMap::new();
487        embeddings.insert("entity1".to_string(), array![1.0, 0.0, 0.0]);
488        embeddings.insert("entity2".to_string(), array![0.0, 1.0, 0.0]);
489
490        let config = EntityLinkerConfig::default();
491        let linker = EntityLinker::new(config, embeddings).unwrap();
492
493        let queries = vec![array![0.9, 0.1, 0.0], array![0.1, 0.9, 0.0]];
494
495        let results = linker.link_entities_batch(&queries).unwrap();
496        assert_eq!(results.len(), 2);
497    }
498}