Skip to main content

oxirs_embed/models/
hole.rs

1//! HolE (Holographic Embeddings) Model
2//!
3//! Holographic Embeddings use circular correlation to combine entity and relation
4//! representations. This allows for efficient computation while maintaining expressiveness.
5//!
6//! Reference: Nickel, Rosasco, Poggio. "Holographic Embeddings of Knowledge Graphs." AAAI 2016.
7//!
8//! The scoring function is: f(h,r,t) = σ(r^T (h ★ t))
9//! where ★ denotes circular correlation
10
11use anyhow::{anyhow, Result};
12use rayon::prelude::*;
13use scirs2_core::ndarray_ext::{Array1, ArrayView1};
14use scirs2_core::random::Random;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::fs::File;
18use std::io::{BufReader, BufWriter};
19use std::path::Path;
20use tracing::{debug, info};
21
22use crate::{EmbeddingModel, ModelConfig, ModelStats, NamedNode, TrainingStats, Triple, Vector};
23use uuid::Uuid;
24
25/// HolE model configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct HoLEConfig {
28    /// Base model configuration
29    pub base: ModelConfig,
30    /// L2 regularization coefficient
31    pub regularization: f32,
32    /// Margin for ranking loss
33    pub margin: f32,
34    /// Number of negative samples per positive
35    pub num_negatives: usize,
36    /// Activation function applied to scores
37    pub use_sigmoid: bool,
38}
39
40impl Default for HoLEConfig {
41    fn default() -> Self {
42        Self {
43            base: ModelConfig::default(),
44            regularization: 0.0001,
45            margin: 1.0,
46            num_negatives: 10,
47            use_sigmoid: true,
48        }
49    }
50}
51
52/// Serializable representation of HolE model for persistence
53#[derive(Debug, Serialize, Deserialize)]
54struct HoLESerializable {
55    model_id: Uuid,
56    config: HoLEConfig,
57    entity_embeddings: HashMap<String, Vec<f32>>,
58    relation_embeddings: HashMap<String, Vec<f32>>,
59    triples: Vec<Triple>,
60    entity_to_id: HashMap<String, usize>,
61    relation_to_id: HashMap<String, usize>,
62    id_to_entity: HashMap<usize, String>,
63    id_to_relation: HashMap<usize, String>,
64    is_trained: bool,
65}
66
67/// HolE (Holographic Embeddings) model
68///
69/// Uses circular correlation to combine entity embeddings and relation embeddings.
70/// Efficient and expressive for knowledge graph completion tasks.
71pub struct HoLE {
72    model_id: Uuid,
73    config: HoLEConfig,
74    entity_embeddings: HashMap<String, Array1<f32>>,
75    relation_embeddings: HashMap<String, Array1<f32>>,
76    triples: Vec<Triple>,
77    entity_to_id: HashMap<String, usize>,
78    relation_to_id: HashMap<String, usize>,
79    id_to_entity: HashMap<usize, String>,
80    id_to_relation: HashMap<usize, String>,
81    is_trained: bool,
82}
83
84impl HoLE {
85    /// Create new HolE model with configuration
86    pub fn new(config: HoLEConfig) -> Self {
87        info!(
88            "Initialized HolE model with dimensions={}, learning_rate={}",
89            config.base.dimensions, config.base.learning_rate
90        );
91
92        Self {
93            model_id: Uuid::new_v4(),
94            config,
95            entity_embeddings: HashMap::new(),
96            relation_embeddings: HashMap::new(),
97            triples: Vec::new(),
98            entity_to_id: HashMap::new(),
99            relation_to_id: HashMap::new(),
100            id_to_entity: HashMap::new(),
101            id_to_relation: HashMap::new(),
102            is_trained: false,
103        }
104    }
105
106    /// Circular correlation of two vectors
107    ///
108    /// The circular correlation is computed via FFT for efficiency:
109    /// a ★ b = IFFT(conj(FFT(a)) ⊙ FFT(b))
110    ///
111    /// For simplicity, we use the direct definition here.
112    fn circular_correlation(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
113        let n = a.len();
114        let mut result = Array1::zeros(n);
115
116        for k in 0..n {
117            let mut sum = 0.0;
118            for i in 0..n {
119                let j = (i + k) % n;
120                sum += a[i] * b[j];
121            }
122            result[k] = sum;
123        }
124
125        result
126    }
127
128    /// Compute the score for a triple (h, r, t)
129    ///
130    /// f(h,r,t) = σ(r^T (h ★ t))
131    fn score_triple_internal(
132        &self,
133        head: &ArrayView1<f32>,
134        relation: &ArrayView1<f32>,
135        tail: &ArrayView1<f32>,
136    ) -> f32 {
137        // Compute circular correlation: h ★ t
138        let correlation = self.circular_correlation(head, tail);
139
140        // Compute dot product: r^T (h ★ t)
141        let score = relation.dot(&correlation);
142
143        // Apply sigmoid if configured
144        if self.config.use_sigmoid {
145            1.0 / (1.0 + (-score).exp())
146        } else {
147            score
148        }
149    }
150
151    /// Initialize embeddings for an entity
152    fn init_entity(&mut self, entity: &str) {
153        if !self.entity_embeddings.contains_key(entity) {
154            let id = self.entity_embeddings.len();
155            self.entity_to_id.insert(entity.to_string(), id);
156            self.id_to_entity.insert(id, entity.to_string());
157
158            // Initialize with uniform distribution scaled by 1/sqrt(d)
159            let scale = 1.0 / (self.config.base.dimensions as f32).sqrt();
160            let mut local_rng = Random::default();
161            let embedding = Array1::from_vec(
162                (0..self.config.base.dimensions)
163                    .map(|_| local_rng.gen_range(-scale..scale))
164                    .collect(),
165            );
166            self.entity_embeddings.insert(entity.to_string(), embedding);
167        }
168    }
169
170    /// Initialize embeddings for a relation
171    fn init_relation(&mut self, relation: &str) {
172        if !self.relation_embeddings.contains_key(relation) {
173            let id = self.relation_embeddings.len();
174            self.relation_to_id.insert(relation.to_string(), id);
175            self.id_to_relation.insert(id, relation.to_string());
176
177            // Initialize with uniform distribution scaled by 1/sqrt(d)
178            let scale = 1.0 / (self.config.base.dimensions as f32).sqrt();
179            let mut local_rng = Random::default();
180            let embedding = Array1::from_vec(
181                (0..self.config.base.dimensions)
182                    .map(|_| local_rng.gen_range(-scale..scale))
183                    .collect(),
184            );
185            self.relation_embeddings
186                .insert(relation.to_string(), embedding);
187        }
188    }
189
190    /// Generate negative samples by corrupting subject or object
191    fn generate_negative_samples(&mut self, triple: &Triple) -> Vec<Triple> {
192        let mut negatives = Vec::new();
193        let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
194        let mut local_rng = Random::default();
195
196        for _ in 0..self.config.num_negatives {
197            // Randomly corrupt subject or object
198            if local_rng.gen_range(0.0..1.0) < 0.5 {
199                // Corrupt subject
200                let random_subject =
201                    entity_list[local_rng.random_range(0..entity_list.len())].clone();
202                negatives.push(Triple {
203                    subject: NamedNode::new(&random_subject)
204                        .expect("NamedNode creation should succeed for valid entity"),
205                    predicate: triple.predicate.clone(),
206                    object: triple.object.clone(),
207                });
208            } else {
209                // Corrupt object
210                let random_object =
211                    entity_list[local_rng.random_range(0..entity_list.len())].clone();
212                negatives.push(Triple {
213                    subject: triple.subject.clone(),
214                    predicate: triple.predicate.clone(),
215                    object: NamedNode::new(&random_object)
216                        .expect("NamedNode creation should succeed for valid entity"),
217                });
218            }
219        }
220
221        negatives
222    }
223
224    /// Perform one training step with margin-based ranking loss
225    fn train_step(&mut self) -> f32 {
226        let mut total_loss = 0.0;
227        let mut local_rng = Random::default();
228
229        // Shuffle triples for stochastic gradient descent
230        let mut indices: Vec<usize> = (0..self.triples.len()).collect();
231        for i in (1..indices.len()).rev() {
232            let j = local_rng.random_range(0..i + 1);
233            indices.swap(i, j);
234        }
235
236        for &idx in &indices {
237            let triple = &self.triples[idx].clone();
238
239            // Get embeddings
240            let subject_str = &triple.subject.iri;
241            let predicate_str = &triple.predicate.iri;
242            let object_str = &triple.object.iri;
243
244            let head_emb = self.entity_embeddings[subject_str].clone();
245            let rel_emb = self.relation_embeddings[predicate_str].clone();
246            let tail_emb = self.entity_embeddings[object_str].clone();
247
248            // Positive score
249            let pos_score =
250                self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
251
252            // Generate negative samples
253            let negatives = self.generate_negative_samples(triple);
254
255            for neg_triple in &negatives {
256                let neg_subject_str = &neg_triple.subject.iri;
257                let neg_object_str = &neg_triple.object.iri;
258
259                let neg_head_emb = self.entity_embeddings[neg_subject_str].clone();
260                let neg_tail_emb = self.entity_embeddings[neg_object_str].clone();
261
262                // Negative score
263                let neg_score = self.score_triple_internal(
264                    &neg_head_emb.view(),
265                    &rel_emb.view(),
266                    &neg_tail_emb.view(),
267                );
268
269                // Margin ranking loss: max(0, margin + neg_score - pos_score)
270                let loss = (self.config.margin + neg_score - pos_score).max(0.0);
271
272                if loss > 0.0 {
273                    total_loss += loss;
274
275                    // Compute gradients and update embeddings
276                    // For simplicity, we use a basic gradient update
277                    // In practice, more sophisticated optimizers should be used
278
279                    let lr = self.config.base.learning_rate as f32;
280
281                    // Update entity embeddings
282                    if let Some(head) = self.entity_embeddings.get_mut(subject_str) {
283                        *head = &*head * (1.0 - self.config.regularization * lr);
284                    }
285
286                    if let Some(tail) = self.entity_embeddings.get_mut(object_str) {
287                        *tail = &*tail * (1.0 - self.config.regularization * lr);
288                    }
289
290                    if let Some(neg_head) = self.entity_embeddings.get_mut(neg_subject_str) {
291                        *neg_head = &*neg_head * (1.0 - self.config.regularization * lr);
292                    }
293
294                    if let Some(neg_tail) = self.entity_embeddings.get_mut(neg_object_str) {
295                        *neg_tail = &*neg_tail * (1.0 - self.config.regularization * lr);
296                    }
297
298                    // Update relation embeddings
299                    if let Some(rel) = self.relation_embeddings.get_mut(predicate_str) {
300                        *rel = &*rel * (1.0 - self.config.regularization * lr);
301                    }
302                }
303            }
304        }
305
306        total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
307    }
308}
309
310#[async_trait::async_trait]
311impl EmbeddingModel for HoLE {
312    fn config(&self) -> &ModelConfig {
313        &self.config.base
314    }
315
316    fn model_id(&self) -> &Uuid {
317        &self.model_id
318    }
319
320    fn model_type(&self) -> &'static str {
321        "HoLE"
322    }
323
324    fn add_triple(&mut self, triple: Triple) -> Result<()> {
325        // Initialize embeddings for new entities/relations
326        self.init_entity(&triple.subject.iri);
327        self.init_entity(&triple.object.iri);
328        self.init_relation(&triple.predicate.iri);
329
330        self.triples.push(triple);
331        Ok(())
332    }
333
334    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
335        let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
336
337        if self.triples.is_empty() {
338            return Err(anyhow!("No training data available"));
339        }
340
341        info!(
342            "Training HoLE model for {} epochs on {} triples",
343            num_epochs,
344            self.triples.len()
345        );
346
347        let start_time = std::time::Instant::now();
348        let mut loss_history = Vec::new();
349
350        for epoch in 0..num_epochs {
351            let loss = self.train_step();
352            loss_history.push(loss as f64);
353
354            if epoch % 10 == 0 {
355                debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
356            }
357
358            // Check for convergence
359            if loss < 0.001 {
360                info!("Converged at epoch {}", epoch);
361                break;
362            }
363        }
364
365        let training_time = start_time.elapsed().as_secs_f64();
366        self.is_trained = true;
367
368        Ok(TrainingStats {
369            epochs_completed: num_epochs,
370            final_loss: *loss_history.last().unwrap_or(&0.0),
371            training_time_seconds: training_time,
372            convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
373            loss_history,
374        })
375    }
376
377    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
378        self.entity_embeddings
379            .get(entity)
380            .map(Vector::from_array1)
381            .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
382    }
383
384    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
385        self.relation_embeddings
386            .get(relation)
387            .map(Vector::from_array1)
388            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
389    }
390
391    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
392        let head_emb = self
393            .entity_embeddings
394            .get(subject)
395            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
396        let rel_emb = self
397            .relation_embeddings
398            .get(predicate)
399            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
400        let tail_emb = self
401            .entity_embeddings
402            .get(object)
403            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
404
405        let score = self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
406        Ok(score as f64)
407    }
408
409    fn predict_objects(
410        &self,
411        subject: &str,
412        predicate: &str,
413        k: usize,
414    ) -> Result<Vec<(String, f64)>> {
415        let head_emb = self
416            .entity_embeddings
417            .get(subject)
418            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
419        let rel_emb = self
420            .relation_embeddings
421            .get(predicate)
422            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
423
424        let mut scored_objects: Vec<(String, f64)> = self
425            .entity_embeddings
426            .par_iter()
427            .map(|(entity, tail_emb)| {
428                let score =
429                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
430                (entity.clone(), score as f64)
431            })
432            .collect();
433
434        scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
435        scored_objects.truncate(k);
436        Ok(scored_objects)
437    }
438
439    fn predict_subjects(
440        &self,
441        predicate: &str,
442        object: &str,
443        k: usize,
444    ) -> Result<Vec<(String, f64)>> {
445        let rel_emb = self
446            .relation_embeddings
447            .get(predicate)
448            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
449        let tail_emb = self
450            .entity_embeddings
451            .get(object)
452            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
453
454        let mut scored_subjects: Vec<(String, f64)> = self
455            .entity_embeddings
456            .par_iter()
457            .map(|(entity, head_emb)| {
458                let score =
459                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
460                (entity.clone(), score as f64)
461            })
462            .collect();
463
464        scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
465        scored_subjects.truncate(k);
466        Ok(scored_subjects)
467    }
468
469    fn predict_relations(
470        &self,
471        subject: &str,
472        object: &str,
473        k: usize,
474    ) -> Result<Vec<(String, f64)>> {
475        let head_emb = self
476            .entity_embeddings
477            .get(subject)
478            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
479        let tail_emb = self
480            .entity_embeddings
481            .get(object)
482            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
483
484        let mut scored_relations: Vec<(String, f64)> = self
485            .relation_embeddings
486            .par_iter()
487            .map(|(relation, rel_emb)| {
488                let score =
489                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
490                (relation.clone(), score as f64)
491            })
492            .collect();
493
494        scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
495        scored_relations.truncate(k);
496        Ok(scored_relations)
497    }
498
499    fn get_entities(&self) -> Vec<String> {
500        self.entity_embeddings.keys().cloned().collect()
501    }
502
503    fn get_relations(&self) -> Vec<String> {
504        self.relation_embeddings.keys().cloned().collect()
505    }
506
507    fn get_stats(&self) -> ModelStats {
508        ModelStats {
509            num_entities: self.entity_embeddings.len(),
510            num_relations: self.relation_embeddings.len(),
511            num_triples: self.triples.len(),
512            dimensions: self.config.base.dimensions,
513            is_trained: self.is_trained,
514            model_type: "HoLE".to_string(),
515            creation_time: chrono::Utc::now(),
516            last_training_time: if self.is_trained {
517                Some(chrono::Utc::now())
518            } else {
519                None
520            },
521        }
522    }
523
524    fn save(&self, path: &str) -> Result<()> {
525        info!("Saving HolE model to {}", path);
526
527        // Convert Array1 to Vec for serialization
528        let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
529            .entity_embeddings
530            .iter()
531            .map(|(k, v)| (k.clone(), v.to_vec()))
532            .collect();
533
534        let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
535            .relation_embeddings
536            .iter()
537            .map(|(k, v)| (k.clone(), v.to_vec()))
538            .collect();
539
540        let serializable = HoLESerializable {
541            model_id: self.model_id,
542            config: self.config.clone(),
543            entity_embeddings: entity_embeddings_vec,
544            relation_embeddings: relation_embeddings_vec,
545            triples: self.triples.clone(),
546            entity_to_id: self.entity_to_id.clone(),
547            relation_to_id: self.relation_to_id.clone(),
548            id_to_entity: self.id_to_entity.clone(),
549            id_to_relation: self.id_to_relation.clone(),
550            is_trained: self.is_trained,
551        };
552
553        let file = File::create(path)?;
554        let writer = BufWriter::new(file);
555        oxicode::serde::encode_into_std_write(&serializable, writer, oxicode::config::standard())
556            .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
557
558        info!("Model saved successfully");
559        Ok(())
560    }
561
562    fn load(&mut self, path: &str) -> Result<()> {
563        info!("Loading HolE model from {}", path);
564
565        if !Path::new(path).exists() {
566            return Err(anyhow!("Model file not found: {}", path));
567        }
568
569        let file = File::open(path)?;
570        let reader = BufReader::new(file);
571        let (serializable, _): (HoLESerializable, _) =
572            oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
573                .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
574
575        // Convert Vec back to Array1
576        let entity_embeddings: HashMap<String, Array1<f32>> = serializable
577            .entity_embeddings
578            .into_iter()
579            .map(|(k, v)| (k, Array1::from_vec(v)))
580            .collect();
581
582        let relation_embeddings: HashMap<String, Array1<f32>> = serializable
583            .relation_embeddings
584            .into_iter()
585            .map(|(k, v)| (k, Array1::from_vec(v)))
586            .collect();
587
588        // Update model state
589        self.model_id = serializable.model_id;
590        self.config = serializable.config;
591        self.entity_embeddings = entity_embeddings;
592        self.relation_embeddings = relation_embeddings;
593        self.triples = serializable.triples;
594        self.entity_to_id = serializable.entity_to_id;
595        self.relation_to_id = serializable.relation_to_id;
596        self.id_to_entity = serializable.id_to_entity;
597        self.id_to_relation = serializable.id_to_relation;
598        self.is_trained = serializable.is_trained;
599
600        info!("Model loaded successfully");
601        Ok(())
602    }
603
604    fn clear(&mut self) {
605        self.entity_embeddings.clear();
606        self.relation_embeddings.clear();
607        self.triples.clear();
608        self.entity_to_id.clear();
609        self.relation_to_id.clear();
610        self.id_to_entity.clear();
611        self.id_to_relation.clear();
612        self.is_trained = false;
613    }
614
615    fn is_trained(&self) -> bool {
616        self.is_trained
617    }
618
619    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
620        // TODO: Implement text encoding
621        Err(anyhow!("Text encoding not implemented for HoLE"))
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use scirs2_core::ndarray_ext::array;
629
630    #[test]
631    fn test_circular_correlation() {
632        let config = HoLEConfig::default();
633        let model = HoLE::new(config);
634
635        let a = array![1.0, 2.0, 3.0];
636        let b = array![4.0, 5.0, 6.0];
637
638        let result = model.circular_correlation(&a.view(), &b.view());
639
640        // Expected circular correlation
641        // result[0] = a[0]*b[0] + a[1]*b[1] + a[2]*b[2] = 1*4 + 2*5 + 3*6 = 32
642        // result[1] = a[0]*b[1] + a[1]*b[2] + a[2]*b[0] = 1*5 + 2*6 + 3*4 = 29
643        // result[2] = a[0]*b[2] + a[1]*b[0] + a[2]*b[1] = 1*6 + 2*4 + 3*5 = 29
644
645        assert_eq!(result.len(), 3);
646        assert!((result[0] - 32.0).abs() < 1e-5);
647        assert!((result[1] - 29.0).abs() < 1e-5);
648        assert!((result[2] - 29.0).abs() < 1e-5);
649    }
650
651    #[test]
652    fn test_hole_creation() {
653        let config = HoLEConfig::default();
654        let model = HoLE::new(config);
655
656        assert_eq!(model.entity_embeddings.len(), 0);
657        assert_eq!(model.relation_embeddings.len(), 0);
658    }
659
660    #[tokio::test]
661    async fn test_hole_training() {
662        let config = HoLEConfig {
663            base: ModelConfig {
664                dimensions: 50,
665                learning_rate: 0.01,
666                max_epochs: 50,
667                ..Default::default()
668            },
669            ..Default::default()
670        };
671
672        let mut model = HoLE::new(config);
673
674        // Add some triples
675        model
676            .add_triple(Triple::new(
677                NamedNode::new("alice").unwrap(),
678                NamedNode::new("knows").unwrap(),
679                NamedNode::new("bob").unwrap(),
680            ))
681            .unwrap();
682
683        model
684            .add_triple(Triple::new(
685                NamedNode::new("bob").unwrap(),
686                NamedNode::new("knows").unwrap(),
687                NamedNode::new("charlie").unwrap(),
688            ))
689            .unwrap();
690
691        model
692            .add_triple(Triple::new(
693                NamedNode::new("alice").unwrap(),
694                NamedNode::new("likes").unwrap(),
695                NamedNode::new("charlie").unwrap(),
696            ))
697            .unwrap();
698
699        // Train the model
700        let stats = model.train(Some(50)).await.unwrap();
701
702        assert_eq!(stats.epochs_completed, 50);
703        assert!(stats.final_loss >= 0.0);
704        assert!(stats.training_time_seconds > 0.0);
705
706        // Check that embeddings were created
707        assert_eq!(model.entity_embeddings.len(), 3);
708        assert_eq!(model.relation_embeddings.len(), 2);
709
710        // Test prediction
711        let score = model.score_triple("alice", "knows", "bob").unwrap();
712        assert!((0.0..=1.0).contains(&score)); // Sigmoid bounded
713    }
714
715    #[tokio::test]
716    async fn test_hole_ranking() {
717        let config = HoLEConfig {
718            base: ModelConfig {
719                dimensions: 50,
720                max_epochs: 30,
721                ..Default::default()
722            },
723            ..Default::default()
724        };
725
726        let mut model = HoLE::new(config);
727
728        // Add training data
729        model
730            .add_triple(Triple::new(
731                NamedNode::new("alice").unwrap(),
732                NamedNode::new("knows").unwrap(),
733                NamedNode::new("bob").unwrap(),
734            ))
735            .unwrap();
736
737        model
738            .add_triple(Triple::new(
739                NamedNode::new("alice").unwrap(),
740                NamedNode::new("knows").unwrap(),
741                NamedNode::new("charlie").unwrap(),
742            ))
743            .unwrap();
744
745        // Train
746        model.train(Some(30)).await.unwrap();
747
748        // Rank objects
749        let ranked = model.predict_objects("alice", "knows", 2).unwrap();
750
751        assert!(ranked.len() <= 2);
752        // Scores should be in descending order
753        if ranked.len() >= 2 {
754            assert!(ranked[0].1 >= ranked[1].1);
755        }
756    }
757
758    #[tokio::test]
759    async fn test_hole_save_load() {
760        use std::env::temp_dir;
761
762        let config = HoLEConfig {
763            base: ModelConfig {
764                dimensions: 30,
765                max_epochs: 20,
766                ..Default::default()
767            },
768            ..Default::default()
769        };
770
771        let mut model = HoLE::new(config);
772
773        // Add and train
774        model
775            .add_triple(Triple::new(
776                NamedNode::new("alice").unwrap(),
777                NamedNode::new("knows").unwrap(),
778                NamedNode::new("bob").unwrap(),
779            ))
780            .unwrap();
781
782        model
783            .add_triple(Triple::new(
784                NamedNode::new("bob").unwrap(),
785                NamedNode::new("likes").unwrap(),
786                NamedNode::new("charlie").unwrap(),
787            ))
788            .unwrap();
789
790        model.train(Some(20)).await.unwrap();
791
792        // Get embedding before save
793        let emb_before = model.get_entity_embedding("alice").unwrap();
794        let score_before = model.score_triple("alice", "knows", "bob").unwrap();
795
796        // Save model
797        let model_path = temp_dir().join("test_hole_model.bin");
798        let path_str = model_path.to_str().unwrap();
799        model.save(path_str).unwrap();
800
801        // Create new model and load
802        let mut loaded_model = HoLE::new(HoLEConfig::default());
803        loaded_model.load(path_str).unwrap();
804
805        // Verify loaded model
806        assert!(loaded_model.is_trained());
807        assert_eq!(loaded_model.get_entities().len(), 3);
808        assert_eq!(loaded_model.get_relations().len(), 2);
809
810        // Verify embeddings are preserved
811        let emb_after = loaded_model.get_entity_embedding("alice").unwrap();
812        assert_eq!(emb_before.dimensions, emb_after.dimensions);
813        for i in 0..emb_before.values.len() {
814            assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
815        }
816
817        // Verify scoring is consistent
818        let score_after = loaded_model.score_triple("alice", "knows", "bob").unwrap();
819        assert!((score_before - score_after).abs() < 1e-6);
820
821        // Cleanup
822        std::fs::remove_file(model_path).ok();
823    }
824
825    #[test]
826    fn test_hole_load_nonexistent() {
827        let mut model = HoLE::new(HoLEConfig::default());
828        let result = model.load("/nonexistent/path/model.bin");
829        assert!(result.is_err());
830    }
831}