Skip to main content

oxirs_embed/models/
kg_embeddings.rs

1//! Knowledge Graph Embedding algorithms for link prediction.
2//!
3//! These algorithms learn dense vector representations of entities and
4//! relations in a knowledge graph to predict missing triples. All math
5//! is implemented with plain `Vec<f64>` to keep compilation simple and
6//! avoid external numerical-library dependencies in this file.
7//!
8//! Implemented algorithms:
9//! * **TransE**   – translating embeddings; score = −‖h + r − t‖₂
10//! * **DistMult** – bilinear diagonal model; score = Σ(hᵢ · rᵢ · tᵢ)
11//! * **RotatE**   – complex-space rotation; score = −‖h ∘ r − t‖₂
12
13use std::collections::HashMap;
14use std::fmt;
15
16// ---------------------------------------------------------------------------
17// Error type
18// ---------------------------------------------------------------------------
19
20/// Errors produced by KG embedding operations.
21#[derive(Debug)]
22pub enum KgError {
23    /// The model has not been trained yet.
24    NotTrained,
25    /// An entity ID is out of range.
26    UnknownEntity(EntityId),
27    /// A relation ID is out of range.
28    UnknownRelation(RelationId),
29    /// The embedding dimension is zero or otherwise invalid.
30    InvalidDimension,
31    /// No triples were provided for training.
32    NoTrainingData,
33    /// A numerical issue occurred (NaN / Inf).
34    NumericalError(String),
35    /// top-k is zero.
36    InvalidTopK,
37}
38
39impl fmt::Display for KgError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            KgError::NotTrained => write!(f, "model has not been trained"),
43            KgError::UnknownEntity(id) => write!(f, "unknown entity id {id}"),
44            KgError::UnknownRelation(id) => write!(f, "unknown relation id {id}"),
45            KgError::InvalidDimension => write!(f, "embedding dimension must be > 0"),
46            KgError::NoTrainingData => write!(f, "no training triples provided"),
47            KgError::NumericalError(msg) => write!(f, "numerical error: {msg}"),
48            KgError::InvalidTopK => write!(f, "top_k must be > 0"),
49        }
50    }
51}
52
53impl std::error::Error for KgError {}
54
55/// Convenience alias.
56pub type KgResult<T> = Result<T, KgError>;
57
58// ---------------------------------------------------------------------------
59// Core types
60// ---------------------------------------------------------------------------
61
62/// Index of an entity in the embedding table.
63pub type EntityId = usize;
64/// Index of a relation in the embedding table.
65pub type RelationId = usize;
66
67/// A single (head, relation, tail) triple.
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct KgTriple {
70    pub head: EntityId,
71    pub relation: RelationId,
72    pub tail: EntityId,
73}
74
75impl KgTriple {
76    /// Construct a new triple.
77    pub fn new(head: EntityId, relation: RelationId, tail: EntityId) -> Self {
78        Self {
79            head,
80            relation,
81            tail,
82        }
83    }
84}
85
86/// Hyper-parameters shared by all KG embedding trainers.
87#[derive(Debug, Clone)]
88pub struct KgEmbeddingConfig {
89    /// Dimensionality of entity and relation vectors (e.g. 50, 100, 200).
90    pub embedding_dim: usize,
91    /// SGD learning rate.
92    pub learning_rate: f64,
93    /// Number of training epochs.
94    pub num_epochs: usize,
95    /// Mini-batch size.
96    pub batch_size: usize,
97    /// Number of negative samples generated per positive triple.
98    pub neg_samples: usize,
99    /// Margin γ used in max-margin / hinge loss (primarily TransE).
100    pub margin: f64,
101    /// L2 regularisation coefficient.
102    pub regularization: f64,
103    /// Fixed seed for reproducibility (simple LCG).
104    pub seed: u64,
105}
106
107impl Default for KgEmbeddingConfig {
108    fn default() -> Self {
109        Self {
110            embedding_dim: 50,
111            learning_rate: 0.01,
112            num_epochs: 100,
113            batch_size: 32,
114            neg_samples: 1,
115            margin: 1.0,
116            regularization: 1e-4,
117            seed: 42,
118        }
119    }
120}
121
122/// Trained embedding tables together with string→id look-ups.
123#[derive(Debug, Clone)]
124pub struct KgEmbeddings {
125    /// Entity embedding matrix: `entity_embeddings[entity_id]` is a `dim`-vector.
126    pub entity_embeddings: Vec<Vec<f64>>,
127    /// Relation embedding matrix: `relation_embeddings[relation_id]` is a `dim`-vector.
128    pub relation_embeddings: Vec<Vec<f64>>,
129    /// Map from entity string to numeric id.
130    pub entity_to_id: HashMap<String, EntityId>,
131    /// Map from relation string to numeric id.
132    pub relation_to_id: HashMap<String, RelationId>,
133}
134
135/// Loss and convergence information collected during training.
136#[derive(Debug, Clone)]
137pub struct TrainingHistory {
138    /// Mean loss recorded at the end of each epoch.
139    pub losses: Vec<f64>,
140    /// Loss at the final epoch.
141    pub final_loss: f64,
142    /// Total number of epochs that were actually run.
143    pub epochs_trained: usize,
144}
145
146// ---------------------------------------------------------------------------
147// Trait
148// ---------------------------------------------------------------------------
149
150/// Shared interface for all KG embedding models.
151pub trait KgModel {
152    /// Score a (head, relation, tail) triple. Higher ⟹ more plausible.
153    fn score(&self, triple: &KgTriple) -> KgResult<f64>;
154
155    /// Rank all entities as possible tails; returns top-`k` (entity, score) pairs.
156    fn predict_tail(
157        &self,
158        head: EntityId,
159        relation: RelationId,
160        top_k: usize,
161    ) -> KgResult<Vec<(EntityId, f64)>>;
162
163    /// Rank all entities as possible heads; returns top-`k` (entity, score) pairs.
164    fn predict_head(
165        &self,
166        relation: RelationId,
167        tail: EntityId,
168        top_k: usize,
169    ) -> KgResult<Vec<(EntityId, f64)>>;
170}
171
172// ---------------------------------------------------------------------------
173// Minimal deterministic pseudo-random number generator (LCG)
174// ---------------------------------------------------------------------------
175
176/// A tiny, dependency-free LCG used for reproducible weight initialisation
177/// and negative-sample corruption. Not suitable for cryptographic use.
178#[derive(Debug, Clone)]
179struct Lcg {
180    state: u64,
181}
182
183impl Lcg {
184    fn new(seed: u64) -> Self {
185        Self {
186            state: seed.wrapping_add(1),
187        }
188    }
189
190    /// Next value in [0, 1).
191    fn next_f64(&mut self) -> f64 {
192        self.state = self
193            .state
194            .wrapping_mul(6_364_136_223_846_793_005)
195            .wrapping_add(1_442_695_040_888_963_407);
196        (self.state >> 11) as f64 / (1u64 << 53) as f64
197    }
198
199    /// Uniform integer in `[0, n)`.
200    fn next_usize(&mut self, n: usize) -> usize {
201        (self.next_f64() * n as f64) as usize % n
202    }
203}
204
205// ---------------------------------------------------------------------------
206// Helper math on plain Vec<f64>
207// ---------------------------------------------------------------------------
208
209fn l2_norm(v: &[f64]) -> f64 {
210    v.iter().map(|x| x * x).sum::<f64>().sqrt()
211}
212
213fn l2_dist(a: &[f64], b: &[f64]) -> f64 {
214    a.iter()
215        .zip(b.iter())
216        .map(|(x, y)| (x - y).powi(2))
217        .sum::<f64>()
218        .sqrt()
219}
220
221fn dot(a: &[f64], b: &[f64]) -> f64 {
222    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
223}
224
225fn clamp_vec(v: &mut [f64], lo: f64, hi: f64) {
226    for x in v.iter_mut() {
227        *x = x.clamp(lo, hi);
228    }
229}
230
231fn normalize_vec(v: &mut [f64]) {
232    let norm = l2_norm(v);
233    if norm > 1e-12 {
234        for x in v.iter_mut() {
235            *x /= norm;
236        }
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Negative sampling
242// ---------------------------------------------------------------------------
243
244/// Generate a corrupted triple by replacing head OR tail randomly.
245fn corrupt_triple(
246    triple: &KgTriple,
247    num_entities: usize,
248    positive_set: &std::collections::HashSet<(usize, usize, usize)>,
249    rng: &mut Lcg,
250) -> KgTriple {
251    // Try up to 20 times to find a genuinely negative triple.
252    for _ in 0..20 {
253        let corrupt_head = rng.next_usize(2) == 0;
254        let candidate = if corrupt_head {
255            let new_head = rng.next_usize(num_entities);
256            KgTriple::new(new_head, triple.relation, triple.tail)
257        } else {
258            let new_tail = rng.next_usize(num_entities);
259            KgTriple::new(triple.head, triple.relation, new_tail)
260        };
261        if !positive_set.contains(&(candidate.head, candidate.relation, candidate.tail)) {
262            return candidate;
263        }
264    }
265    // Fallback: return candidate even if it overlaps (rare with large entity sets).
266    let new_tail = (triple.tail + 1) % num_entities;
267    KgTriple::new(triple.head, triple.relation, new_tail)
268}
269
270// ---------------------------------------------------------------------------
271// TransE
272// ---------------------------------------------------------------------------
273
274/// **TransE** – translating embeddings.
275///
276/// The scoring function is `−‖h + r − t‖₂`.  
277/// Training uses max-margin (hinge) loss with SGD and entity-norm projection.
278#[derive(Debug, Clone)]
279pub struct TransE {
280    pub config: KgEmbeddingConfig,
281    pub embeddings: Option<KgEmbeddings>,
282    num_entities: usize,
283    num_relations: usize,
284}
285
286impl TransE {
287    /// Create a new, untrained TransE model.
288    pub fn new(config: KgEmbeddingConfig) -> Self {
289        Self {
290            config,
291            embeddings: None,
292            num_entities: 0,
293            num_relations: 0,
294        }
295    }
296
297    /// Train on the provided triples.
298    ///
299    /// `num_entities` and `num_relations` define the size of the embedding
300    /// tables; IDs in `triples` must be in `[0, num_entities)` /
301    /// `[0, num_relations)`.
302    pub fn train(
303        &mut self,
304        triples: &[KgTriple],
305        num_entities: usize,
306        num_relations: usize,
307    ) -> KgResult<TrainingHistory> {
308        if triples.is_empty() {
309            return Err(KgError::NoTrainingData);
310        }
311        if self.config.embedding_dim == 0 {
312            return Err(KgError::InvalidDimension);
313        }
314        self.num_entities = num_entities;
315        self.num_relations = num_relations;
316
317        let dim = self.config.embedding_dim;
318        let mut rng = Lcg::new(self.config.seed);
319
320        // Initialise entity embeddings uniformly in [-6/√dim, 6/√dim].
321        let bound = 6.0 / (dim as f64).sqrt();
322        let mut ent_emb: Vec<Vec<f64>> = (0..num_entities)
323            .map(|_| {
324                let mut v: Vec<f64> = (0..dim)
325                    .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
326                    .collect();
327                normalize_vec(&mut v);
328                v
329            })
330            .collect();
331
332        // Initialise relation embeddings uniformly in [-6/√dim, 6/√dim].
333        let mut rel_emb: Vec<Vec<f64>> = (0..num_relations)
334            .map(|_| {
335                (0..dim)
336                    .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
337                    .collect()
338            })
339            .collect();
340
341        // Build positive-triple look-up for negative sampling.
342        let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
343            .iter()
344            .map(|t| (t.head, t.relation, t.tail))
345            .collect();
346
347        let lr = self.config.learning_rate;
348        let margin = self.config.margin;
349        let reg = self.config.regularization;
350        let mut losses = Vec::with_capacity(self.config.num_epochs);
351
352        for _epoch in 0..self.config.num_epochs {
353            let mut epoch_loss = 0.0_f64;
354            let mut count = 0usize;
355
356            for pos in triples {
357                for _ in 0..self.config.neg_samples {
358                    let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
359
360                    let h_pos = &ent_emb[pos.head];
361                    let r = &rel_emb[pos.relation];
362                    let t_pos = &ent_emb[pos.tail];
363                    let h_neg = &ent_emb[neg.head];
364                    let t_neg = &ent_emb[neg.tail];
365
366                    // Compute h+r−t for positive and negative.
367                    let pos_diff: Vec<f64> = (0..dim).map(|i| h_pos[i] + r[i] - t_pos[i]).collect();
368                    let neg_diff: Vec<f64> = (0..dim).map(|i| h_neg[i] + r[i] - t_neg[i]).collect();
369
370                    let d_pos = l2_norm(&pos_diff);
371                    let d_neg = l2_norm(&neg_diff);
372
373                    let loss = (margin + d_pos - d_neg).max(0.0);
374                    epoch_loss += loss;
375                    count += 1;
376
377                    if loss > 0.0 {
378                        // Gradient of L2 norm: ∂‖v‖/∂vᵢ = vᵢ/‖v‖.
379                        let grad_pos: Vec<f64> = if d_pos > 1e-12 {
380                            pos_diff.iter().map(|x| x / d_pos).collect()
381                        } else {
382                            vec![0.0; dim]
383                        };
384                        let grad_neg: Vec<f64> = if d_neg > 1e-12 {
385                            neg_diff.iter().map(|x| x / d_neg).collect()
386                        } else {
387                            vec![0.0; dim]
388                        };
389
390                        // Update positive triple components.
391                        for i in 0..dim {
392                            let g = grad_pos[i];
393                            ent_emb[pos.head][i] -= lr * (g + reg * ent_emb[pos.head][i]);
394                            rel_emb[pos.relation][i] -= lr * (g + reg * rel_emb[pos.relation][i]);
395                            ent_emb[pos.tail][i] += lr * (g - reg * ent_emb[pos.tail][i]);
396                        }
397
398                        // Update negative triple components.
399                        for i in 0..dim {
400                            let g = grad_neg[i];
401                            ent_emb[neg.head][i] += lr * (g + reg * ent_emb[neg.head][i]);
402                            ent_emb[neg.tail][i] -= lr * (g - reg * ent_emb[neg.tail][i]);
403                        }
404                    }
405
406                    // Project entity embeddings back to unit sphere.
407                    normalize_vec(&mut ent_emb[pos.head]);
408                    normalize_vec(&mut ent_emb[pos.tail]);
409                    normalize_vec(&mut ent_emb[neg.head]);
410                    normalize_vec(&mut ent_emb[neg.tail]);
411                }
412            }
413
414            let mean_loss = if count > 0 {
415                epoch_loss / count as f64
416            } else {
417                0.0
418            };
419            losses.push(mean_loss);
420        }
421
422        let final_loss = losses.last().copied().unwrap_or(0.0);
423        let epochs_trained = losses.len();
424
425        self.embeddings = Some(KgEmbeddings {
426            entity_embeddings: ent_emb,
427            relation_embeddings: rel_emb,
428            entity_to_id: HashMap::new(),
429            relation_to_id: HashMap::new(),
430        });
431
432        Ok(TrainingHistory {
433            losses,
434            final_loss,
435            epochs_trained,
436        })
437    }
438
439    /// Score a triple: −‖h + r − t‖₂  (higher = more plausible).
440    pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
441        let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
442        let h = emb
443            .entity_embeddings
444            .get(triple.head)
445            .ok_or(KgError::UnknownEntity(triple.head))?;
446        let r = emb
447            .relation_embeddings
448            .get(triple.relation)
449            .ok_or(KgError::UnknownRelation(triple.relation))?;
450        let t = emb
451            .entity_embeddings
452            .get(triple.tail)
453            .ok_or(KgError::UnknownEntity(triple.tail))?;
454
455        Ok(-Self::score_fn(h, r, t))
456    }
457
458    /// Rank all entities as candidate tails; return top-`k`.
459    pub fn predict_tail(
460        &self,
461        head: EntityId,
462        relation: RelationId,
463        top_k: usize,
464    ) -> KgResult<Vec<(EntityId, f64)>> {
465        if top_k == 0 {
466            return Err(KgError::InvalidTopK);
467        }
468        let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
469        let h = emb
470            .entity_embeddings
471            .get(head)
472            .ok_or(KgError::UnknownEntity(head))?;
473        let r = emb
474            .relation_embeddings
475            .get(relation)
476            .ok_or(KgError::UnknownRelation(relation))?;
477
478        let mut scored: Vec<(EntityId, f64)> = emb
479            .entity_embeddings
480            .iter()
481            .enumerate()
482            .map(|(id, t)| (id, -Self::score_fn(h, r, t)))
483            .collect();
484
485        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
486        scored.truncate(top_k);
487        Ok(scored)
488    }
489
490    /// Rank all entities as candidate heads; return top-`k`.
491    pub fn predict_head(
492        &self,
493        relation: RelationId,
494        tail: EntityId,
495        top_k: usize,
496    ) -> KgResult<Vec<(EntityId, f64)>> {
497        if top_k == 0 {
498            return Err(KgError::InvalidTopK);
499        }
500        let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
501        let r = emb
502            .relation_embeddings
503            .get(relation)
504            .ok_or(KgError::UnknownRelation(relation))?;
505        let t = emb
506            .entity_embeddings
507            .get(tail)
508            .ok_or(KgError::UnknownEntity(tail))?;
509
510        let mut scored: Vec<(EntityId, f64)> = emb
511            .entity_embeddings
512            .iter()
513            .enumerate()
514            .map(|(id, h)| (id, -Self::score_fn(h, r, t)))
515            .collect();
516
517        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
518        scored.truncate(top_k);
519        Ok(scored)
520    }
521
522    /// Project all entity embeddings onto the unit sphere.
523    pub fn normalize_entities(&mut self) {
524        if let Some(ref mut emb) = self.embeddings {
525            for v in emb.entity_embeddings.iter_mut() {
526                normalize_vec(v);
527            }
528        }
529    }
530
531    /// TransE distance: ‖h + r − t‖₂.
532    fn score_fn(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
533        let diff: Vec<f64> = (0..h.len()).map(|i| h[i] + r[i] - t[i]).collect();
534        l2_norm(&diff)
535    }
536}
537
538impl KgModel for TransE {
539    fn score(&self, triple: &KgTriple) -> KgResult<f64> {
540        self.score(triple)
541    }
542
543    fn predict_tail(
544        &self,
545        head: EntityId,
546        relation: RelationId,
547        top_k: usize,
548    ) -> KgResult<Vec<(EntityId, f64)>> {
549        self.predict_tail(head, relation, top_k)
550    }
551
552    fn predict_head(
553        &self,
554        relation: RelationId,
555        tail: EntityId,
556        top_k: usize,
557    ) -> KgResult<Vec<(EntityId, f64)>> {
558        self.predict_head(relation, tail, top_k)
559    }
560}
561
562// ---------------------------------------------------------------------------
563// DistMult
564// ---------------------------------------------------------------------------
565
566/// **DistMult** – bilinear diagonal scoring.
567///
568/// Scoring: `Σ(hᵢ · rᵢ · tᵢ)`.  
569/// Trained with softplus (logistic) loss and SGD.
570#[derive(Debug, Clone)]
571pub struct DistMult {
572    pub config: KgEmbeddingConfig,
573    pub embeddings: Option<KgEmbeddings>,
574    num_entities: usize,
575    num_relations: usize,
576}
577
578impl DistMult {
579    /// Create a new, untrained DistMult model.
580    pub fn new(config: KgEmbeddingConfig) -> Self {
581        Self {
582            config,
583            embeddings: None,
584            num_entities: 0,
585            num_relations: 0,
586        }
587    }
588
589    /// Train using negative-sampling and softplus loss.
590    pub fn train(
591        &mut self,
592        triples: &[KgTriple],
593        num_entities: usize,
594        num_relations: usize,
595    ) -> KgResult<TrainingHistory> {
596        if triples.is_empty() {
597            return Err(KgError::NoTrainingData);
598        }
599        if self.config.embedding_dim == 0 {
600            return Err(KgError::InvalidDimension);
601        }
602        self.num_entities = num_entities;
603        self.num_relations = num_relations;
604
605        let dim = self.config.embedding_dim;
606        let mut rng = Lcg::new(self.config.seed);
607        let bound = 1.0 / (dim as f64).sqrt();
608
609        let mut ent_emb: Vec<Vec<f64>> = (0..num_entities)
610            .map(|_| {
611                (0..dim)
612                    .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
613                    .collect()
614            })
615            .collect();
616        let mut rel_emb: Vec<Vec<f64>> = (0..num_relations)
617            .map(|_| {
618                (0..dim)
619                    .map(|_| (rng.next_f64() * 2.0 - 1.0) * bound)
620                    .collect()
621            })
622            .collect();
623
624        let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
625            .iter()
626            .map(|t| (t.head, t.relation, t.tail))
627            .collect();
628
629        let lr = self.config.learning_rate;
630        let reg = self.config.regularization;
631        let mut losses = Vec::with_capacity(self.config.num_epochs);
632
633        for _epoch in 0..self.config.num_epochs {
634            let mut epoch_loss = 0.0_f64;
635            let mut count = 0usize;
636
637            for pos in triples {
638                // Positive sample loss: −log σ(score_pos)
639                {
640                    let s = Self::score_fn(
641                        &ent_emb[pos.head],
642                        &rel_emb[pos.relation],
643                        &ent_emb[pos.tail],
644                    );
645                    let sig = sigmoid(s);
646                    let loss = -sig.ln().max(-100.0);
647                    epoch_loss += loss;
648                    count += 1;
649
650                    // Gradient: -(1 - σ(s)) · ∂s/∂params
651                    let g = -(1.0 - sig);
652                    for i in 0..dim {
653                        let h_i = ent_emb[pos.head][i];
654                        let r_i = rel_emb[pos.relation][i];
655                        let t_i = ent_emb[pos.tail][i];
656                        ent_emb[pos.head][i] -= lr * (g * r_i * t_i + reg * h_i);
657                        rel_emb[pos.relation][i] -= lr * (g * h_i * t_i + reg * r_i);
658                        ent_emb[pos.tail][i] -= lr * (g * h_i * r_i + reg * t_i);
659                    }
660                    clamp_vec(&mut ent_emb[pos.head], -10.0, 10.0);
661                    clamp_vec(&mut rel_emb[pos.relation], -10.0, 10.0);
662                    clamp_vec(&mut ent_emb[pos.tail], -10.0, 10.0);
663                }
664
665                for _ in 0..self.config.neg_samples {
666                    let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
667                    let s = Self::score_fn(
668                        &ent_emb[neg.head],
669                        &rel_emb[neg.relation],
670                        &ent_emb[neg.tail],
671                    );
672                    let sig = sigmoid(-s);
673                    let loss = -sig.ln().max(-100.0);
674                    epoch_loss += loss;
675                    count += 1;
676
677                    let g = 1.0 - sig; // ∂loss/∂s = σ(s) - 0
678                    for i in 0..dim {
679                        let h_i = ent_emb[neg.head][i];
680                        let r_i = rel_emb[neg.relation][i];
681                        let t_i = ent_emb[neg.tail][i];
682                        ent_emb[neg.head][i] -= lr * (g * r_i * t_i + reg * h_i);
683                        rel_emb[neg.relation][i] -= lr * (g * h_i * t_i + reg * r_i);
684                        ent_emb[neg.tail][i] -= lr * (g * h_i * r_i + reg * t_i);
685                    }
686                    clamp_vec(&mut ent_emb[neg.head], -10.0, 10.0);
687                    clamp_vec(&mut ent_emb[neg.tail], -10.0, 10.0);
688                }
689            }
690
691            let mean_loss = if count > 0 {
692                epoch_loss / count as f64
693            } else {
694                0.0
695            };
696            losses.push(mean_loss);
697        }
698
699        let final_loss = losses.last().copied().unwrap_or(0.0);
700        let epochs_trained = losses.len();
701
702        self.embeddings = Some(KgEmbeddings {
703            entity_embeddings: ent_emb,
704            relation_embeddings: rel_emb,
705            entity_to_id: HashMap::new(),
706            relation_to_id: HashMap::new(),
707        });
708
709        Ok(TrainingHistory {
710            losses,
711            final_loss,
712            epochs_trained,
713        })
714    }
715
716    /// Score a triple: Σ(hᵢ · rᵢ · tᵢ).
717    pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
718        let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
719        let h = emb
720            .entity_embeddings
721            .get(triple.head)
722            .ok_or(KgError::UnknownEntity(triple.head))?;
723        let r = emb
724            .relation_embeddings
725            .get(triple.relation)
726            .ok_or(KgError::UnknownRelation(triple.relation))?;
727        let t = emb
728            .entity_embeddings
729            .get(triple.tail)
730            .ok_or(KgError::UnknownEntity(triple.tail))?;
731        Ok(Self::score_fn(h, r, t))
732    }
733
734    /// Rank all entities as candidate tails; return top-`k`.
735    pub fn predict_tail(
736        &self,
737        head: EntityId,
738        relation: RelationId,
739        top_k: usize,
740    ) -> KgResult<Vec<(EntityId, f64)>> {
741        if top_k == 0 {
742            return Err(KgError::InvalidTopK);
743        }
744        let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
745        let h = emb
746            .entity_embeddings
747            .get(head)
748            .ok_or(KgError::UnknownEntity(head))?;
749        let r = emb
750            .relation_embeddings
751            .get(relation)
752            .ok_or(KgError::UnknownRelation(relation))?;
753
754        let mut scored: Vec<(EntityId, f64)> = emb
755            .entity_embeddings
756            .iter()
757            .enumerate()
758            .map(|(id, t)| (id, Self::score_fn(h, r, t)))
759            .collect();
760        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
761        scored.truncate(top_k);
762        Ok(scored)
763    }
764
765    /// Rank all entities as candidate heads; return top-`k`.
766    pub fn predict_head(
767        &self,
768        relation: RelationId,
769        tail: EntityId,
770        top_k: usize,
771    ) -> KgResult<Vec<(EntityId, f64)>> {
772        if top_k == 0 {
773            return Err(KgError::InvalidTopK);
774        }
775        let emb = self.embeddings.as_ref().ok_or(KgError::NotTrained)?;
776        let r = emb
777            .relation_embeddings
778            .get(relation)
779            .ok_or(KgError::UnknownRelation(relation))?;
780        let t = emb
781            .entity_embeddings
782            .get(tail)
783            .ok_or(KgError::UnknownEntity(tail))?;
784
785        let mut scored: Vec<(EntityId, f64)> = emb
786            .entity_embeddings
787            .iter()
788            .enumerate()
789            .map(|(id, h)| (id, Self::score_fn(h, r, t)))
790            .collect();
791        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
792        scored.truncate(top_k);
793        Ok(scored)
794    }
795
796    /// DistMult scoring: Σ(hᵢ · rᵢ · tᵢ).
797    fn score_fn(h: &[f64], r: &[f64], t: &[f64]) -> f64 {
798        h.iter()
799            .zip(r.iter())
800            .zip(t.iter())
801            .map(|((hi, ri), ti)| hi * ri * ti)
802            .sum()
803    }
804}
805
806impl KgModel for DistMult {
807    fn score(&self, triple: &KgTriple) -> KgResult<f64> {
808        self.score(triple)
809    }
810
811    fn predict_tail(
812        &self,
813        head: EntityId,
814        relation: RelationId,
815        top_k: usize,
816    ) -> KgResult<Vec<(EntityId, f64)>> {
817        self.predict_tail(head, relation, top_k)
818    }
819
820    fn predict_head(
821        &self,
822        relation: RelationId,
823        tail: EntityId,
824        top_k: usize,
825    ) -> KgResult<Vec<(EntityId, f64)>> {
826        self.predict_head(relation, tail, top_k)
827    }
828}
829
830// ---------------------------------------------------------------------------
831// RotatE
832// ---------------------------------------------------------------------------
833
834/// **RotatE** – knowledge-graph embedding by relational rotation in ℂ.
835///
836/// Each entity is embedded in ℂ^(d/2); each relation is a phase vector
837/// θ ∈ ℝ^(d/2).  Scoring: `−‖h ∘ r − t‖` where `∘` is element-wise
838/// complex multiplication with r as a unit-modulus complex number
839/// (e^{iθ}).
840#[derive(Debug, Clone)]
841pub struct RotatE {
842    pub config: KgEmbeddingConfig,
843    /// Real parts of entity embeddings: `real[entity_id][k]`.
844    pub entity_re: Option<Vec<Vec<f64>>>,
845    /// Imaginary parts of entity embeddings: `imag[entity_id][k]`.
846    pub entity_im: Option<Vec<Vec<f64>>>,
847    /// Relation phase vectors: `phases[relation_id][k]`.
848    pub relation_phases: Option<Vec<Vec<f64>>>,
849    num_entities: usize,
850    num_relations: usize,
851}
852
853impl RotatE {
854    /// Create a new, untrained RotatE model.
855    pub fn new(config: KgEmbeddingConfig) -> Self {
856        Self {
857            config,
858            entity_re: None,
859            entity_im: None,
860            relation_phases: None,
861            num_entities: 0,
862            num_relations: 0,
863        }
864    }
865
866    /// Train using negative-sampling and max-margin loss.
867    pub fn train(
868        &mut self,
869        triples: &[KgTriple],
870        num_entities: usize,
871        num_relations: usize,
872    ) -> KgResult<TrainingHistory> {
873        if triples.is_empty() {
874            return Err(KgError::NoTrainingData);
875        }
876        if self.config.embedding_dim == 0 {
877            return Err(KgError::InvalidDimension);
878        }
879
880        self.num_entities = num_entities;
881        self.num_relations = num_relations;
882
883        // Half-dimension: d/2 complex components per embedding.
884        let half_dim = (self.config.embedding_dim + 1) / 2;
885        let mut rng = Lcg::new(self.config.seed);
886        let pi = std::f64::consts::PI;
887
888        // Entity embeddings: unit-modulus complex (cos θ, sin θ).
889        let mut ent_re: Vec<Vec<f64>> = (0..num_entities)
890            .map(|_| (0..half_dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect())
891            .collect();
892        let mut ent_im: Vec<Vec<f64>> = (0..num_entities)
893            .map(|_| (0..half_dim).map(|_| rng.next_f64() * 2.0 - 1.0).collect())
894            .collect();
895
896        // Normalise to unit modulus initially.
897        for i in 0..num_entities {
898            for k in 0..half_dim {
899                let norm = (ent_re[i][k].powi(2) + ent_im[i][k].powi(2))
900                    .sqrt()
901                    .max(1e-12);
902                ent_re[i][k] /= norm;
903                ent_im[i][k] /= norm;
904            }
905        }
906
907        // Relation phases in (−π, π).
908        let mut rel_phases: Vec<Vec<f64>> = (0..num_relations)
909            .map(|_| {
910                (0..half_dim)
911                    .map(|_| (rng.next_f64() * 2.0 - 1.0) * pi)
912                    .collect()
913            })
914            .collect();
915
916        let positive_set: std::collections::HashSet<(usize, usize, usize)> = triples
917            .iter()
918            .map(|t| (t.head, t.relation, t.tail))
919            .collect();
920
921        let lr = self.config.learning_rate;
922        let margin = self.config.margin;
923        let reg = self.config.regularization;
924        let mut losses = Vec::with_capacity(self.config.num_epochs);
925
926        for _epoch in 0..self.config.num_epochs {
927            let mut epoch_loss = 0.0_f64;
928            let mut count = 0usize;
929
930            for pos in triples {
931                for _ in 0..self.config.neg_samples {
932                    let neg = corrupt_triple(pos, num_entities, &positive_set, &mut rng);
933
934                    let d_pos = Self::dist_fn(
935                        &ent_re[pos.head],
936                        &ent_im[pos.head],
937                        &rel_phases[pos.relation],
938                        &ent_re[pos.tail],
939                        &ent_im[pos.tail],
940                    );
941                    let d_neg = Self::dist_fn(
942                        &ent_re[neg.head],
943                        &ent_im[neg.head],
944                        &rel_phases[neg.relation],
945                        &ent_re[neg.tail],
946                        &ent_im[neg.tail],
947                    );
948
949                    let loss = (margin + d_pos - d_neg).max(0.0);
950                    epoch_loss += loss;
951                    count += 1;
952
953                    if loss > 0.0 && d_pos > 1e-12 {
954                        // Gradient of ‖h∘r − t‖ w.r.t. phases and entity components.
955                        let r_re: Vec<f64> = rel_phases[pos.relation]
956                            .iter()
957                            .map(|&ph| ph.cos())
958                            .collect();
959                        let r_im: Vec<f64> = rel_phases[pos.relation]
960                            .iter()
961                            .map(|&ph| ph.sin())
962                            .collect();
963
964                        for k in 0..half_dim {
965                            let (res_re, res_im) = Self::complex_multiply(
966                                ent_re[pos.head][k],
967                                ent_im[pos.head][k],
968                                r_re[k],
969                                r_im[k],
970                            );
971                            let err_re = res_re - ent_re[pos.tail][k];
972                            let err_im = res_im - ent_im[pos.tail][k];
973
974                            // Gradients (positive sample – push apart).
975                            let g_scale = 1.0 / d_pos;
976
977                            // d(dist)/d(h_re_k)
978                            let d_h_re = g_scale * (err_re * r_re[k] + err_im * r_im[k]);
979                            // d(dist)/d(h_im_k)
980                            let d_h_im = g_scale * (err_im * r_re[k] - err_re * r_im[k]);
981                            // d(dist)/d(phase_k) = g * (-h_re*sin + h_im*cos)*err_re + (h_re*cos + h_im*(-sin))*err_im
982                            let d_ph = g_scale
983                                * ((-ent_re[pos.head][k] * r_im[k]
984                                    + ent_im[pos.head][k] * r_re[k])
985                                    * err_re
986                                    + (-ent_re[pos.head][k] * r_re[k]
987                                        - ent_im[pos.head][k] * r_im[k])
988                                        * err_im);
989                            // d(dist)/d(t_re_k)
990                            let d_t_re = g_scale * (-err_re);
991                            // d(dist)/d(t_im_k)
992                            let d_t_im = g_scale * (-err_im);
993
994                            ent_re[pos.head][k] -= lr * (d_h_re + reg * ent_re[pos.head][k]);
995                            ent_im[pos.head][k] -= lr * (d_h_im + reg * ent_im[pos.head][k]);
996                            rel_phases[pos.relation][k] -=
997                                lr * (d_ph + reg * rel_phases[pos.relation][k]);
998                            ent_re[pos.tail][k] -= lr * (d_t_re + reg * ent_re[pos.tail][k]);
999                            ent_im[pos.tail][k] -= lr * (d_t_im + reg * ent_im[pos.tail][k]);
1000                        }
1001
1002                        // Keep phases bounded in (−2π, 2π).
1003                        for ph in rel_phases[pos.relation].iter_mut() {
1004                            *ph = ph.clamp(-2.0 * pi, 2.0 * pi);
1005                        }
1006                    }
1007                }
1008            }
1009
1010            let mean_loss = if count > 0 {
1011                epoch_loss / count as f64
1012            } else {
1013                0.0
1014            };
1015            losses.push(mean_loss);
1016        }
1017
1018        let final_loss = losses.last().copied().unwrap_or(0.0);
1019        let epochs_trained = losses.len();
1020
1021        self.entity_re = Some(ent_re);
1022        self.entity_im = Some(ent_im);
1023        self.relation_phases = Some(rel_phases);
1024
1025        Ok(TrainingHistory {
1026            losses,
1027            final_loss,
1028            epochs_trained,
1029        })
1030    }
1031
1032    /// Score a triple: `−‖h ∘ e^{iθ} − t‖`.
1033    pub fn score(&self, triple: &KgTriple) -> KgResult<f64> {
1034        let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
1035        let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
1036        let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
1037
1038        let h_re = ent_re
1039            .get(triple.head)
1040            .ok_or(KgError::UnknownEntity(triple.head))?;
1041        let h_im = ent_im
1042            .get(triple.head)
1043            .ok_or(KgError::UnknownEntity(triple.head))?;
1044        let ph = phases
1045            .get(triple.relation)
1046            .ok_or(KgError::UnknownRelation(triple.relation))?;
1047        let t_re = ent_re
1048            .get(triple.tail)
1049            .ok_or(KgError::UnknownEntity(triple.tail))?;
1050        let t_im = ent_im
1051            .get(triple.tail)
1052            .ok_or(KgError::UnknownEntity(triple.tail))?;
1053
1054        Ok(-Self::dist_fn(h_re, h_im, ph, t_re, t_im))
1055    }
1056
1057    /// Rank all entities as candidate tails; return top-`k`.
1058    pub fn predict_tail(
1059        &self,
1060        head: EntityId,
1061        relation: RelationId,
1062        top_k: usize,
1063    ) -> KgResult<Vec<(EntityId, f64)>> {
1064        if top_k == 0 {
1065            return Err(KgError::InvalidTopK);
1066        }
1067        let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
1068        let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
1069        let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
1070
1071        let h_re = ent_re.get(head).ok_or(KgError::UnknownEntity(head))?;
1072        let h_im = ent_im.get(head).ok_or(KgError::UnknownEntity(head))?;
1073        let ph = phases
1074            .get(relation)
1075            .ok_or(KgError::UnknownRelation(relation))?;
1076
1077        let num = ent_re.len();
1078        let mut scored: Vec<(EntityId, f64)> = (0..num)
1079            .map(|id| (id, -Self::dist_fn(h_re, h_im, ph, &ent_re[id], &ent_im[id])))
1080            .collect();
1081        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1082        scored.truncate(top_k);
1083        Ok(scored)
1084    }
1085
1086    /// Rank all entities as candidate heads; return top-`k`.
1087    pub fn predict_head(
1088        &self,
1089        relation: RelationId,
1090        tail: EntityId,
1091        top_k: usize,
1092    ) -> KgResult<Vec<(EntityId, f64)>> {
1093        if top_k == 0 {
1094            return Err(KgError::InvalidTopK);
1095        }
1096        let ent_re = self.entity_re.as_ref().ok_or(KgError::NotTrained)?;
1097        let ent_im = self.entity_im.as_ref().ok_or(KgError::NotTrained)?;
1098        let phases = self.relation_phases.as_ref().ok_or(KgError::NotTrained)?;
1099
1100        let ph = phases
1101            .get(relation)
1102            .ok_or(KgError::UnknownRelation(relation))?;
1103        let t_re = ent_re.get(tail).ok_or(KgError::UnknownEntity(tail))?;
1104        let t_im = ent_im.get(tail).ok_or(KgError::UnknownEntity(tail))?;
1105
1106        let num = ent_re.len();
1107        // For head prediction we find h such that h ∘ r ≈ t,
1108        // i.e. h ≈ t ∘ r̄ (conjugate rotation).
1109        // Score is still computed with the standard formula.
1110        let mut scored: Vec<(EntityId, f64)> = (0..num)
1111            .map(|id| (id, -Self::dist_fn(&ent_re[id], &ent_im[id], ph, t_re, t_im)))
1112            .collect();
1113        scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1114        scored.truncate(top_k);
1115        Ok(scored)
1116    }
1117
1118    /// Distance: ‖h ∘ e^{iθ} − t‖.
1119    fn dist_fn(h_re: &[f64], h_im: &[f64], phases: &[f64], t_re: &[f64], t_im: &[f64]) -> f64 {
1120        let sum_sq: f64 = phases
1121            .iter()
1122            .enumerate()
1123            .map(|(k, &ph)| {
1124                let (res_re, res_im) = Self::complex_multiply(h_re[k], h_im[k], ph.cos(), ph.sin());
1125                (res_re - t_re[k]).powi(2) + (res_im - t_im[k]).powi(2)
1126            })
1127            .sum();
1128        sum_sq.sqrt()
1129    }
1130
1131    /// Complex multiplication: (a + ib)(c + id) = (ac − bd) + i(ad + bc).
1132    pub fn complex_multiply(a_re: f64, a_im: f64, b_re: f64, b_im: f64) -> (f64, f64) {
1133        (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re)
1134    }
1135}
1136
1137impl KgModel for RotatE {
1138    fn score(&self, triple: &KgTriple) -> KgResult<f64> {
1139        self.score(triple)
1140    }
1141
1142    fn predict_tail(
1143        &self,
1144        head: EntityId,
1145        relation: RelationId,
1146        top_k: usize,
1147    ) -> KgResult<Vec<(EntityId, f64)>> {
1148        self.predict_tail(head, relation, top_k)
1149    }
1150
1151    fn predict_head(
1152        &self,
1153        relation: RelationId,
1154        tail: EntityId,
1155        top_k: usize,
1156    ) -> KgResult<Vec<(EntityId, f64)>> {
1157        self.predict_head(relation, tail, top_k)
1158    }
1159}
1160
1161// ---------------------------------------------------------------------------
1162// Link prediction evaluation metrics
1163// ---------------------------------------------------------------------------
1164
1165/// Standard link-prediction evaluation metrics.
1166pub struct LinkPredictionEvaluator;
1167
1168impl LinkPredictionEvaluator {
1169    /// Hits@K: fraction of test triples for which the true tail entity
1170    /// appears in the top-`k` predictions.
1171    pub fn hits_at_k(model: &dyn KgModel, test_triples: &[KgTriple], k: usize) -> f64 {
1172        if test_triples.is_empty() || k == 0 {
1173            return 0.0;
1174        }
1175        let hits: usize = test_triples
1176            .iter()
1177            .filter(|t| {
1178                model
1179                    .predict_tail(t.head, t.relation, k)
1180                    .map(|preds| preds.iter().any(|(eid, _)| *eid == t.tail))
1181                    .unwrap_or(false)
1182            })
1183            .count();
1184        hits as f64 / test_triples.len() as f64
1185    }
1186
1187    /// Mean Rank: average rank of the true tail entity across all test triples
1188    /// (lower is better).
1189    pub fn mean_rank(model: &dyn KgModel, test_triples: &[KgTriple], num_entities: usize) -> f64 {
1190        if test_triples.is_empty() {
1191            return 0.0;
1192        }
1193        let total: usize = test_triples
1194            .iter()
1195            .map(|t| {
1196                model
1197                    .predict_tail(t.head, t.relation, num_entities)
1198                    .map(|preds| {
1199                        preds
1200                            .iter()
1201                            .position(|(eid, _)| *eid == t.tail)
1202                            .map(|p| p + 1)
1203                            .unwrap_or(num_entities + 1)
1204                    })
1205                    .unwrap_or(num_entities + 1)
1206            })
1207            .sum();
1208        total as f64 / test_triples.len() as f64
1209    }
1210
1211    /// Mean Reciprocal Rank: average of 1/rank for the true tail entity
1212    /// (higher is better; max = 1.0).
1213    pub fn mrr(model: &dyn KgModel, test_triples: &[KgTriple], num_entities: usize) -> f64 {
1214        if test_triples.is_empty() {
1215            return 0.0;
1216        }
1217        let sum: f64 = test_triples
1218            .iter()
1219            .map(|t| {
1220                model
1221                    .predict_tail(t.head, t.relation, num_entities)
1222                    .map(|preds| {
1223                        preds
1224                            .iter()
1225                            .position(|(eid, _)| *eid == t.tail)
1226                            .map(|p| 1.0 / (p as f64 + 1.0))
1227                            .unwrap_or(0.0)
1228                    })
1229                    .unwrap_or(0.0)
1230            })
1231            .sum();
1232        sum / test_triples.len() as f64
1233    }
1234}
1235
1236// ---------------------------------------------------------------------------
1237// Utilities
1238// ---------------------------------------------------------------------------
1239
1240/// Serialise `KgEmbeddings` to a simple CSV-like byte string.
1241///
1242/// Format:
1243/// ```text
1244/// ENTITIES <n>
1245/// <dim values per line>
1246/// RELATIONS <m>
1247/// <dim values per line>
1248/// ```
1249pub fn serialize_embeddings(emb: &KgEmbeddings) -> Vec<u8> {
1250    let mut out = String::new();
1251    out.push_str(&format!("ENTITIES {}\n", emb.entity_embeddings.len()));
1252    for row in &emb.entity_embeddings {
1253        let line: Vec<String> = row.iter().map(|x| format!("{x:.8}")).collect();
1254        out.push_str(&line.join(","));
1255        out.push('\n');
1256    }
1257    out.push_str(&format!("RELATIONS {}\n", emb.relation_embeddings.len()));
1258    for row in &emb.relation_embeddings {
1259        let line: Vec<String> = row.iter().map(|x| format!("{x:.8}")).collect();
1260        out.push_str(&line.join(","));
1261        out.push('\n');
1262    }
1263    out.into_bytes()
1264}
1265
1266/// Deserialise `KgEmbeddings` from the format produced by
1267/// [`serialize_embeddings`].
1268pub fn deserialize_embeddings(data: &[u8]) -> Result<KgEmbeddings, KgError> {
1269    let text = std::str::from_utf8(data)
1270        .map_err(|e| KgError::NumericalError(format!("utf8 error: {e}")))?;
1271    let mut lines = text.lines();
1272
1273    let parse_section_header = |line: &str, prefix: &str| -> Result<usize, KgError> {
1274        let rest = line
1275            .strip_prefix(prefix)
1276            .ok_or_else(|| KgError::NumericalError(format!("expected '{prefix}', got '{line}'")))?;
1277        rest.trim()
1278            .parse::<usize>()
1279            .map_err(|e| KgError::NumericalError(e.to_string()))
1280    };
1281
1282    let parse_row = |line: &str| -> Result<Vec<f64>, KgError> {
1283        line.split(',')
1284            .map(|s| {
1285                s.trim()
1286                    .parse::<f64>()
1287                    .map_err(|e| KgError::NumericalError(e.to_string()))
1288            })
1289            .collect()
1290    };
1291
1292    let ent_header = lines
1293        .next()
1294        .ok_or(KgError::NumericalError("empty data".into()))?;
1295    let num_ent = parse_section_header(ent_header, "ENTITIES ")?;
1296    let mut entity_embeddings = Vec::with_capacity(num_ent);
1297    for _ in 0..num_ent {
1298        let line = lines
1299            .next()
1300            .ok_or(KgError::NumericalError("truncated entity data".into()))?;
1301        entity_embeddings.push(parse_row(line)?);
1302    }
1303
1304    let rel_header = lines
1305        .next()
1306        .ok_or(KgError::NumericalError("missing RELATIONS header".into()))?;
1307    let num_rel = parse_section_header(rel_header, "RELATIONS ")?;
1308    let mut relation_embeddings = Vec::with_capacity(num_rel);
1309    for _ in 0..num_rel {
1310        let line = lines
1311            .next()
1312            .ok_or(KgError::NumericalError("truncated relation data".into()))?;
1313        relation_embeddings.push(parse_row(line)?);
1314    }
1315
1316    Ok(KgEmbeddings {
1317        entity_embeddings,
1318        relation_embeddings,
1319        entity_to_id: HashMap::new(),
1320        relation_to_id: HashMap::new(),
1321    })
1322}
1323
1324// ---------------------------------------------------------------------------
1325// Private sigmoid
1326// ---------------------------------------------------------------------------
1327
1328#[inline]
1329fn sigmoid(x: f64) -> f64 {
1330    1.0 / (1.0 + (-x).exp())
1331}
1332
1333// ===========================================================================
1334// Tests
1335// ===========================================================================
1336
1337// ===========================================================================
1338// Tests (in separate file to keep this file under 2000 lines)
1339// ===========================================================================
1340
1341#[cfg(test)]
1342#[path = "kg_embeddings_tests.rs"]
1343mod tests;