oxirs_embed/models/
hole.rs

1//! HolE (Holographic Embeddings) Model
2//!
3//! Holographic Embeddings use circular correlation to combine entity and relation
4//! representations. This allows for efficient computation while maintaining expressiveness.
5//!
6//! Reference: Nickel, Rosasco, Poggio. "Holographic Embeddings of Knowledge Graphs." AAAI 2016.
7//!
8//! The scoring function is: f(h,r,t) = σ(r^T (h ★ t))
9//! where ★ denotes circular correlation
10
11use anyhow::{anyhow, Result};
12use rayon::prelude::*;
13use scirs2_core::ndarray_ext::{Array1, ArrayView1};
14use scirs2_core::random::Random;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::fs::File;
18use std::io::{BufReader, BufWriter};
19use std::path::Path;
20use tracing::{debug, info};
21
22use crate::{EmbeddingModel, ModelConfig, ModelStats, NamedNode, TrainingStats, Triple, Vector};
23use uuid::Uuid;
24
25/// HolE model configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct HoLEConfig {
28    /// Base model configuration
29    pub base: ModelConfig,
30    /// L2 regularization coefficient
31    pub regularization: f32,
32    /// Margin for ranking loss
33    pub margin: f32,
34    /// Number of negative samples per positive
35    pub num_negatives: usize,
36    /// Activation function applied to scores
37    pub use_sigmoid: bool,
38}
39
40impl Default for HoLEConfig {
41    fn default() -> Self {
42        Self {
43            base: ModelConfig::default(),
44            regularization: 0.0001,
45            margin: 1.0,
46            num_negatives: 10,
47            use_sigmoid: true,
48        }
49    }
50}
51
52/// Serializable representation of HolE model for persistence
53#[derive(Debug, Serialize, Deserialize)]
54struct HoLESerializable {
55    model_id: Uuid,
56    config: HoLEConfig,
57    entity_embeddings: HashMap<String, Vec<f32>>,
58    relation_embeddings: HashMap<String, Vec<f32>>,
59    triples: Vec<Triple>,
60    entity_to_id: HashMap<String, usize>,
61    relation_to_id: HashMap<String, usize>,
62    id_to_entity: HashMap<usize, String>,
63    id_to_relation: HashMap<usize, String>,
64    is_trained: bool,
65}
66
67/// HolE (Holographic Embeddings) model
68///
69/// Uses circular correlation to combine entity embeddings and relation embeddings.
70/// Efficient and expressive for knowledge graph completion tasks.
71pub struct HoLE {
72    model_id: Uuid,
73    config: HoLEConfig,
74    entity_embeddings: HashMap<String, Array1<f32>>,
75    relation_embeddings: HashMap<String, Array1<f32>>,
76    triples: Vec<Triple>,
77    entity_to_id: HashMap<String, usize>,
78    relation_to_id: HashMap<String, usize>,
79    id_to_entity: HashMap<usize, String>,
80    id_to_relation: HashMap<usize, String>,
81    is_trained: bool,
82}
83
84impl HoLE {
85    /// Create new HolE model with configuration
86    pub fn new(config: HoLEConfig) -> Self {
87        info!(
88            "Initialized HolE model with dimensions={}, learning_rate={}",
89            config.base.dimensions, config.base.learning_rate
90        );
91
92        Self {
93            model_id: Uuid::new_v4(),
94            config,
95            entity_embeddings: HashMap::new(),
96            relation_embeddings: HashMap::new(),
97            triples: Vec::new(),
98            entity_to_id: HashMap::new(),
99            relation_to_id: HashMap::new(),
100            id_to_entity: HashMap::new(),
101            id_to_relation: HashMap::new(),
102            is_trained: false,
103        }
104    }
105
106    /// Circular correlation of two vectors
107    ///
108    /// The circular correlation is computed via FFT for efficiency:
109    /// a ★ b = IFFT(conj(FFT(a)) ⊙ FFT(b))
110    ///
111    /// For simplicity, we use the direct definition here.
112    fn circular_correlation(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
113        let n = a.len();
114        let mut result = Array1::zeros(n);
115
116        for k in 0..n {
117            let mut sum = 0.0;
118            for i in 0..n {
119                let j = (i + k) % n;
120                sum += a[i] * b[j];
121            }
122            result[k] = sum;
123        }
124
125        result
126    }
127
128    /// Compute the score for a triple (h, r, t)
129    ///
130    /// f(h,r,t) = σ(r^T (h ★ t))
131    fn score_triple_internal(
132        &self,
133        head: &ArrayView1<f32>,
134        relation: &ArrayView1<f32>,
135        tail: &ArrayView1<f32>,
136    ) -> f32 {
137        // Compute circular correlation: h ★ t
138        let correlation = self.circular_correlation(head, tail);
139
140        // Compute dot product: r^T (h ★ t)
141        let score = relation.dot(&correlation);
142
143        // Apply sigmoid if configured
144        if self.config.use_sigmoid {
145            1.0 / (1.0 + (-score).exp())
146        } else {
147            score
148        }
149    }
150
151    /// Initialize embeddings for an entity
152    fn init_entity(&mut self, entity: &str) {
153        if !self.entity_embeddings.contains_key(entity) {
154            let id = self.entity_embeddings.len();
155            self.entity_to_id.insert(entity.to_string(), id);
156            self.id_to_entity.insert(id, entity.to_string());
157
158            // Initialize with uniform distribution scaled by 1/sqrt(d)
159            let scale = 1.0 / (self.config.base.dimensions as f32).sqrt();
160            let mut local_rng = Random::default();
161            let embedding = Array1::from_vec(
162                (0..self.config.base.dimensions)
163                    .map(|_| local_rng.gen_range(-scale..scale))
164                    .collect(),
165            );
166            self.entity_embeddings.insert(entity.to_string(), embedding);
167        }
168    }
169
170    /// Initialize embeddings for a relation
171    fn init_relation(&mut self, relation: &str) {
172        if !self.relation_embeddings.contains_key(relation) {
173            let id = self.relation_embeddings.len();
174            self.relation_to_id.insert(relation.to_string(), id);
175            self.id_to_relation.insert(id, relation.to_string());
176
177            // Initialize with uniform distribution scaled by 1/sqrt(d)
178            let scale = 1.0 / (self.config.base.dimensions as f32).sqrt();
179            let mut local_rng = Random::default();
180            let embedding = Array1::from_vec(
181                (0..self.config.base.dimensions)
182                    .map(|_| local_rng.gen_range(-scale..scale))
183                    .collect(),
184            );
185            self.relation_embeddings
186                .insert(relation.to_string(), embedding);
187        }
188    }
189
190    /// Generate negative samples by corrupting subject or object
191    fn generate_negative_samples(&mut self, triple: &Triple) -> Vec<Triple> {
192        let mut negatives = Vec::new();
193        let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
194        let mut local_rng = Random::default();
195
196        for _ in 0..self.config.num_negatives {
197            // Randomly corrupt subject or object
198            if local_rng.gen_range(0.0..1.0) < 0.5 {
199                // Corrupt subject
200                let random_subject =
201                    entity_list[local_rng.random_range(0, entity_list.len())].clone();
202                negatives.push(Triple {
203                    subject: NamedNode::new(&random_subject).unwrap(),
204                    predicate: triple.predicate.clone(),
205                    object: triple.object.clone(),
206                });
207            } else {
208                // Corrupt object
209                let random_object =
210                    entity_list[local_rng.random_range(0, entity_list.len())].clone();
211                negatives.push(Triple {
212                    subject: triple.subject.clone(),
213                    predicate: triple.predicate.clone(),
214                    object: NamedNode::new(&random_object).unwrap(),
215                });
216            }
217        }
218
219        negatives
220    }
221
222    /// Perform one training step with margin-based ranking loss
223    fn train_step(&mut self) -> f32 {
224        let mut total_loss = 0.0;
225        let mut local_rng = Random::default();
226
227        // Shuffle triples for stochastic gradient descent
228        let mut indices: Vec<usize> = (0..self.triples.len()).collect();
229        for i in (1..indices.len()).rev() {
230            let j = local_rng.random_range(0, i + 1);
231            indices.swap(i, j);
232        }
233
234        for &idx in &indices {
235            let triple = &self.triples[idx].clone();
236
237            // Get embeddings
238            let subject_str = &triple.subject.iri;
239            let predicate_str = &triple.predicate.iri;
240            let object_str = &triple.object.iri;
241
242            let head_emb = self.entity_embeddings[subject_str].clone();
243            let rel_emb = self.relation_embeddings[predicate_str].clone();
244            let tail_emb = self.entity_embeddings[object_str].clone();
245
246            // Positive score
247            let pos_score =
248                self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
249
250            // Generate negative samples
251            let negatives = self.generate_negative_samples(triple);
252
253            for neg_triple in &negatives {
254                let neg_subject_str = &neg_triple.subject.iri;
255                let neg_object_str = &neg_triple.object.iri;
256
257                let neg_head_emb = self.entity_embeddings[neg_subject_str].clone();
258                let neg_tail_emb = self.entity_embeddings[neg_object_str].clone();
259
260                // Negative score
261                let neg_score = self.score_triple_internal(
262                    &neg_head_emb.view(),
263                    &rel_emb.view(),
264                    &neg_tail_emb.view(),
265                );
266
267                // Margin ranking loss: max(0, margin + neg_score - pos_score)
268                let loss = (self.config.margin + neg_score - pos_score).max(0.0);
269
270                if loss > 0.0 {
271                    total_loss += loss;
272
273                    // Compute gradients and update embeddings
274                    // For simplicity, we use a basic gradient update
275                    // In practice, more sophisticated optimizers should be used
276
277                    let lr = self.config.base.learning_rate as f32;
278
279                    // Update entity embeddings
280                    if let Some(head) = self.entity_embeddings.get_mut(subject_str) {
281                        *head = &*head * (1.0 - self.config.regularization * lr);
282                    }
283
284                    if let Some(tail) = self.entity_embeddings.get_mut(object_str) {
285                        *tail = &*tail * (1.0 - self.config.regularization * lr);
286                    }
287
288                    if let Some(neg_head) = self.entity_embeddings.get_mut(neg_subject_str) {
289                        *neg_head = &*neg_head * (1.0 - self.config.regularization * lr);
290                    }
291
292                    if let Some(neg_tail) = self.entity_embeddings.get_mut(neg_object_str) {
293                        *neg_tail = &*neg_tail * (1.0 - self.config.regularization * lr);
294                    }
295
296                    // Update relation embeddings
297                    if let Some(rel) = self.relation_embeddings.get_mut(predicate_str) {
298                        *rel = &*rel * (1.0 - self.config.regularization * lr);
299                    }
300                }
301            }
302        }
303
304        total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
305    }
306}
307
308#[async_trait::async_trait]
309impl EmbeddingModel for HoLE {
310    fn config(&self) -> &ModelConfig {
311        &self.config.base
312    }
313
314    fn model_id(&self) -> &Uuid {
315        &self.model_id
316    }
317
318    fn model_type(&self) -> &'static str {
319        "HoLE"
320    }
321
322    fn add_triple(&mut self, triple: Triple) -> Result<()> {
323        // Initialize embeddings for new entities/relations
324        self.init_entity(&triple.subject.iri);
325        self.init_entity(&triple.object.iri);
326        self.init_relation(&triple.predicate.iri);
327
328        self.triples.push(triple);
329        Ok(())
330    }
331
332    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
333        let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
334
335        if self.triples.is_empty() {
336            return Err(anyhow!("No training data available"));
337        }
338
339        info!(
340            "Training HoLE model for {} epochs on {} triples",
341            num_epochs,
342            self.triples.len()
343        );
344
345        let start_time = std::time::Instant::now();
346        let mut loss_history = Vec::new();
347
348        for epoch in 0..num_epochs {
349            let loss = self.train_step();
350            loss_history.push(loss as f64);
351
352            if epoch % 10 == 0 {
353                debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
354            }
355
356            // Check for convergence
357            if loss < 0.001 {
358                info!("Converged at epoch {}", epoch);
359                break;
360            }
361        }
362
363        let training_time = start_time.elapsed().as_secs_f64();
364        self.is_trained = true;
365
366        Ok(TrainingStats {
367            epochs_completed: num_epochs,
368            final_loss: *loss_history.last().unwrap_or(&0.0),
369            training_time_seconds: training_time,
370            convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
371            loss_history,
372        })
373    }
374
375    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
376        self.entity_embeddings
377            .get(entity)
378            .map(Vector::from_array1)
379            .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
380    }
381
382    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
383        self.relation_embeddings
384            .get(relation)
385            .map(Vector::from_array1)
386            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
387    }
388
389    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
390        let head_emb = self
391            .entity_embeddings
392            .get(subject)
393            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
394        let rel_emb = self
395            .relation_embeddings
396            .get(predicate)
397            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
398        let tail_emb = self
399            .entity_embeddings
400            .get(object)
401            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
402
403        let score = self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
404        Ok(score as f64)
405    }
406
407    fn predict_objects(
408        &self,
409        subject: &str,
410        predicate: &str,
411        k: usize,
412    ) -> Result<Vec<(String, f64)>> {
413        let head_emb = self
414            .entity_embeddings
415            .get(subject)
416            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
417        let rel_emb = self
418            .relation_embeddings
419            .get(predicate)
420            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
421
422        let mut scored_objects: Vec<(String, f64)> = self
423            .entity_embeddings
424            .par_iter()
425            .map(|(entity, tail_emb)| {
426                let score =
427                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
428                (entity.clone(), score as f64)
429            })
430            .collect();
431
432        scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
433        scored_objects.truncate(k);
434        Ok(scored_objects)
435    }
436
437    fn predict_subjects(
438        &self,
439        predicate: &str,
440        object: &str,
441        k: usize,
442    ) -> Result<Vec<(String, f64)>> {
443        let rel_emb = self
444            .relation_embeddings
445            .get(predicate)
446            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
447        let tail_emb = self
448            .entity_embeddings
449            .get(object)
450            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
451
452        let mut scored_subjects: Vec<(String, f64)> = self
453            .entity_embeddings
454            .par_iter()
455            .map(|(entity, head_emb)| {
456                let score =
457                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
458                (entity.clone(), score as f64)
459            })
460            .collect();
461
462        scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
463        scored_subjects.truncate(k);
464        Ok(scored_subjects)
465    }
466
467    fn predict_relations(
468        &self,
469        subject: &str,
470        object: &str,
471        k: usize,
472    ) -> Result<Vec<(String, f64)>> {
473        let head_emb = self
474            .entity_embeddings
475            .get(subject)
476            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
477        let tail_emb = self
478            .entity_embeddings
479            .get(object)
480            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
481
482        let mut scored_relations: Vec<(String, f64)> = self
483            .relation_embeddings
484            .par_iter()
485            .map(|(relation, rel_emb)| {
486                let score =
487                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
488                (relation.clone(), score as f64)
489            })
490            .collect();
491
492        scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
493        scored_relations.truncate(k);
494        Ok(scored_relations)
495    }
496
497    fn get_entities(&self) -> Vec<String> {
498        self.entity_embeddings.keys().cloned().collect()
499    }
500
501    fn get_relations(&self) -> Vec<String> {
502        self.relation_embeddings.keys().cloned().collect()
503    }
504
505    fn get_stats(&self) -> ModelStats {
506        ModelStats {
507            num_entities: self.entity_embeddings.len(),
508            num_relations: self.relation_embeddings.len(),
509            num_triples: self.triples.len(),
510            dimensions: self.config.base.dimensions,
511            is_trained: self.is_trained,
512            model_type: "HoLE".to_string(),
513            creation_time: chrono::Utc::now(),
514            last_training_time: if self.is_trained {
515                Some(chrono::Utc::now())
516            } else {
517                None
518            },
519        }
520    }
521
522    fn save(&self, path: &str) -> Result<()> {
523        info!("Saving HolE model to {}", path);
524
525        // Convert Array1 to Vec for serialization
526        let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
527            .entity_embeddings
528            .iter()
529            .map(|(k, v)| (k.clone(), v.to_vec()))
530            .collect();
531
532        let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
533            .relation_embeddings
534            .iter()
535            .map(|(k, v)| (k.clone(), v.to_vec()))
536            .collect();
537
538        let serializable = HoLESerializable {
539            model_id: self.model_id,
540            config: self.config.clone(),
541            entity_embeddings: entity_embeddings_vec,
542            relation_embeddings: relation_embeddings_vec,
543            triples: self.triples.clone(),
544            entity_to_id: self.entity_to_id.clone(),
545            relation_to_id: self.relation_to_id.clone(),
546            id_to_entity: self.id_to_entity.clone(),
547            id_to_relation: self.id_to_relation.clone(),
548            is_trained: self.is_trained,
549        };
550
551        let file = File::create(path)?;
552        let writer = BufWriter::new(file);
553        bincode::serialize_into(writer, &serializable)
554            .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
555
556        info!("Model saved successfully");
557        Ok(())
558    }
559
560    fn load(&mut self, path: &str) -> Result<()> {
561        info!("Loading HolE model from {}", path);
562
563        if !Path::new(path).exists() {
564            return Err(anyhow!("Model file not found: {}", path));
565        }
566
567        let file = File::open(path)?;
568        let reader = BufReader::new(file);
569        let serializable: HoLESerializable = bincode::deserialize_from(reader)
570            .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
571
572        // Convert Vec back to Array1
573        let entity_embeddings: HashMap<String, Array1<f32>> = serializable
574            .entity_embeddings
575            .into_iter()
576            .map(|(k, v)| (k, Array1::from_vec(v)))
577            .collect();
578
579        let relation_embeddings: HashMap<String, Array1<f32>> = serializable
580            .relation_embeddings
581            .into_iter()
582            .map(|(k, v)| (k, Array1::from_vec(v)))
583            .collect();
584
585        // Update model state
586        self.model_id = serializable.model_id;
587        self.config = serializable.config;
588        self.entity_embeddings = entity_embeddings;
589        self.relation_embeddings = relation_embeddings;
590        self.triples = serializable.triples;
591        self.entity_to_id = serializable.entity_to_id;
592        self.relation_to_id = serializable.relation_to_id;
593        self.id_to_entity = serializable.id_to_entity;
594        self.id_to_relation = serializable.id_to_relation;
595        self.is_trained = serializable.is_trained;
596
597        info!("Model loaded successfully");
598        Ok(())
599    }
600
601    fn clear(&mut self) {
602        self.entity_embeddings.clear();
603        self.relation_embeddings.clear();
604        self.triples.clear();
605        self.entity_to_id.clear();
606        self.relation_to_id.clear();
607        self.id_to_entity.clear();
608        self.id_to_relation.clear();
609        self.is_trained = false;
610    }
611
612    fn is_trained(&self) -> bool {
613        self.is_trained
614    }
615
616    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
617        // TODO: Implement text encoding
618        Err(anyhow!("Text encoding not implemented for HoLE"))
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use scirs2_core::ndarray_ext::array;
626
627    #[test]
628    fn test_circular_correlation() {
629        let config = HoLEConfig::default();
630        let model = HoLE::new(config);
631
632        let a = array![1.0, 2.0, 3.0];
633        let b = array![4.0, 5.0, 6.0];
634
635        let result = model.circular_correlation(&a.view(), &b.view());
636
637        // Expected circular correlation
638        // result[0] = a[0]*b[0] + a[1]*b[1] + a[2]*b[2] = 1*4 + 2*5 + 3*6 = 32
639        // result[1] = a[0]*b[1] + a[1]*b[2] + a[2]*b[0] = 1*5 + 2*6 + 3*4 = 29
640        // result[2] = a[0]*b[2] + a[1]*b[0] + a[2]*b[1] = 1*6 + 2*4 + 3*5 = 29
641
642        assert_eq!(result.len(), 3);
643        assert!((result[0] - 32.0).abs() < 1e-5);
644        assert!((result[1] - 29.0).abs() < 1e-5);
645        assert!((result[2] - 29.0).abs() < 1e-5);
646    }
647
648    #[test]
649    fn test_hole_creation() {
650        let config = HoLEConfig::default();
651        let model = HoLE::new(config);
652
653        assert_eq!(model.entity_embeddings.len(), 0);
654        assert_eq!(model.relation_embeddings.len(), 0);
655    }
656
657    #[tokio::test]
658    async fn test_hole_training() {
659        let config = HoLEConfig {
660            base: ModelConfig {
661                dimensions: 50,
662                learning_rate: 0.01,
663                max_epochs: 50,
664                ..Default::default()
665            },
666            ..Default::default()
667        };
668
669        let mut model = HoLE::new(config);
670
671        // Add some triples
672        model
673            .add_triple(Triple::new(
674                NamedNode::new("alice").unwrap(),
675                NamedNode::new("knows").unwrap(),
676                NamedNode::new("bob").unwrap(),
677            ))
678            .unwrap();
679
680        model
681            .add_triple(Triple::new(
682                NamedNode::new("bob").unwrap(),
683                NamedNode::new("knows").unwrap(),
684                NamedNode::new("charlie").unwrap(),
685            ))
686            .unwrap();
687
688        model
689            .add_triple(Triple::new(
690                NamedNode::new("alice").unwrap(),
691                NamedNode::new("likes").unwrap(),
692                NamedNode::new("charlie").unwrap(),
693            ))
694            .unwrap();
695
696        // Train the model
697        let stats = model.train(Some(50)).await.unwrap();
698
699        assert_eq!(stats.epochs_completed, 50);
700        assert!(stats.final_loss >= 0.0);
701        assert!(stats.training_time_seconds > 0.0);
702
703        // Check that embeddings were created
704        assert_eq!(model.entity_embeddings.len(), 3);
705        assert_eq!(model.relation_embeddings.len(), 2);
706
707        // Test prediction
708        let score = model.score_triple("alice", "knows", "bob").unwrap();
709        assert!((0.0..=1.0).contains(&score)); // Sigmoid bounded
710    }
711
712    #[tokio::test]
713    async fn test_hole_ranking() {
714        let config = HoLEConfig {
715            base: ModelConfig {
716                dimensions: 50,
717                max_epochs: 30,
718                ..Default::default()
719            },
720            ..Default::default()
721        };
722
723        let mut model = HoLE::new(config);
724
725        // Add training data
726        model
727            .add_triple(Triple::new(
728                NamedNode::new("alice").unwrap(),
729                NamedNode::new("knows").unwrap(),
730                NamedNode::new("bob").unwrap(),
731            ))
732            .unwrap();
733
734        model
735            .add_triple(Triple::new(
736                NamedNode::new("alice").unwrap(),
737                NamedNode::new("knows").unwrap(),
738                NamedNode::new("charlie").unwrap(),
739            ))
740            .unwrap();
741
742        // Train
743        model.train(Some(30)).await.unwrap();
744
745        // Rank objects
746        let ranked = model.predict_objects("alice", "knows", 2).unwrap();
747
748        assert!(ranked.len() <= 2);
749        // Scores should be in descending order
750        if ranked.len() >= 2 {
751            assert!(ranked[0].1 >= ranked[1].1);
752        }
753    }
754
755    #[tokio::test]
756    async fn test_hole_save_load() {
757        use std::env::temp_dir;
758
759        let config = HoLEConfig {
760            base: ModelConfig {
761                dimensions: 30,
762                max_epochs: 20,
763                ..Default::default()
764            },
765            ..Default::default()
766        };
767
768        let mut model = HoLE::new(config);
769
770        // Add and train
771        model
772            .add_triple(Triple::new(
773                NamedNode::new("alice").unwrap(),
774                NamedNode::new("knows").unwrap(),
775                NamedNode::new("bob").unwrap(),
776            ))
777            .unwrap();
778
779        model
780            .add_triple(Triple::new(
781                NamedNode::new("bob").unwrap(),
782                NamedNode::new("likes").unwrap(),
783                NamedNode::new("charlie").unwrap(),
784            ))
785            .unwrap();
786
787        model.train(Some(20)).await.unwrap();
788
789        // Get embedding before save
790        let emb_before = model.get_entity_embedding("alice").unwrap();
791        let score_before = model.score_triple("alice", "knows", "bob").unwrap();
792
793        // Save model
794        let model_path = temp_dir().join("test_hole_model.bin");
795        let path_str = model_path.to_str().unwrap();
796        model.save(path_str).unwrap();
797
798        // Create new model and load
799        let mut loaded_model = HoLE::new(HoLEConfig::default());
800        loaded_model.load(path_str).unwrap();
801
802        // Verify loaded model
803        assert!(loaded_model.is_trained());
804        assert_eq!(loaded_model.get_entities().len(), 3);
805        assert_eq!(loaded_model.get_relations().len(), 2);
806
807        // Verify embeddings are preserved
808        let emb_after = loaded_model.get_entity_embedding("alice").unwrap();
809        assert_eq!(emb_before.dimensions, emb_after.dimensions);
810        for i in 0..emb_before.values.len() {
811            assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
812        }
813
814        // Verify scoring is consistent
815        let score_after = loaded_model.score_triple("alice", "knows", "bob").unwrap();
816        assert!((score_before - score_after).abs() < 1e-6);
817
818        // Cleanup
819        std::fs::remove_file(model_path).ok();
820    }
821
822    #[test]
823    fn test_hole_load_nonexistent() {
824        let mut model = HoLE::new(HoLEConfig::default());
825        let result = model.load("/nonexistent/path/model.bin");
826        assert!(result.is_err());
827    }
828}