oxirs_embed/models/
hole.rs

1//! HolE (Holographic Embeddings) Model
2//!
3//! Holographic Embeddings use circular correlation to combine entity and relation
4//! representations. This allows for efficient computation while maintaining expressiveness.
5//!
6//! Reference: Nickel, Rosasco, Poggio. "Holographic Embeddings of Knowledge Graphs." AAAI 2016.
7//!
8//! The scoring function is: f(h,r,t) = σ(r^T (h ★ t))
9//! where ★ denotes circular correlation
10
11use anyhow::{anyhow, Result};
12use rayon::prelude::*;
13use scirs2_core::ndarray_ext::{Array1, ArrayView1};
14use scirs2_core::random::Random;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::fs::File;
18use std::io::{BufReader, BufWriter};
19use std::path::Path;
20use tracing::{debug, info};
21
22use crate::{EmbeddingModel, ModelConfig, ModelStats, NamedNode, TrainingStats, Triple, Vector};
23use uuid::Uuid;
24
25/// HolE model configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct HoLEConfig {
28    /// Base model configuration
29    pub base: ModelConfig,
30    /// L2 regularization coefficient
31    pub regularization: f32,
32    /// Margin for ranking loss
33    pub margin: f32,
34    /// Number of negative samples per positive
35    pub num_negatives: usize,
36    /// Activation function applied to scores
37    pub use_sigmoid: bool,
38}
39
40impl Default for HoLEConfig {
41    fn default() -> Self {
42        Self {
43            base: ModelConfig::default(),
44            regularization: 0.0001,
45            margin: 1.0,
46            num_negatives: 10,
47            use_sigmoid: true,
48        }
49    }
50}
51
52/// Serializable representation of HolE model for persistence
53#[derive(Debug, Serialize, Deserialize)]
54struct HoLESerializable {
55    model_id: Uuid,
56    config: HoLEConfig,
57    entity_embeddings: HashMap<String, Vec<f32>>,
58    relation_embeddings: HashMap<String, Vec<f32>>,
59    triples: Vec<Triple>,
60    entity_to_id: HashMap<String, usize>,
61    relation_to_id: HashMap<String, usize>,
62    id_to_entity: HashMap<usize, String>,
63    id_to_relation: HashMap<usize, String>,
64    is_trained: bool,
65}
66
67/// HolE (Holographic Embeddings) model
68///
69/// Uses circular correlation to combine entity embeddings and relation embeddings.
70/// Efficient and expressive for knowledge graph completion tasks.
71pub struct HoLE {
72    model_id: Uuid,
73    config: HoLEConfig,
74    entity_embeddings: HashMap<String, Array1<f32>>,
75    relation_embeddings: HashMap<String, Array1<f32>>,
76    triples: Vec<Triple>,
77    entity_to_id: HashMap<String, usize>,
78    relation_to_id: HashMap<String, usize>,
79    id_to_entity: HashMap<usize, String>,
80    id_to_relation: HashMap<usize, String>,
81    is_trained: bool,
82}
83
84impl HoLE {
85    /// Create new HolE model with configuration
86    pub fn new(config: HoLEConfig) -> Self {
87        info!(
88            "Initialized HolE model with dimensions={}, learning_rate={}",
89            config.base.dimensions, config.base.learning_rate
90        );
91
92        Self {
93            model_id: Uuid::new_v4(),
94            config,
95            entity_embeddings: HashMap::new(),
96            relation_embeddings: HashMap::new(),
97            triples: Vec::new(),
98            entity_to_id: HashMap::new(),
99            relation_to_id: HashMap::new(),
100            id_to_entity: HashMap::new(),
101            id_to_relation: HashMap::new(),
102            is_trained: false,
103        }
104    }
105
106    /// Circular correlation of two vectors
107    ///
108    /// The circular correlation is computed via FFT for efficiency:
109    /// a ★ b = IFFT(conj(FFT(a)) ⊙ FFT(b))
110    ///
111    /// For simplicity, we use the direct definition here.
112    fn circular_correlation(&self, a: &ArrayView1<f32>, b: &ArrayView1<f32>) -> Array1<f32> {
113        let n = a.len();
114        let mut result = Array1::zeros(n);
115
116        for k in 0..n {
117            let mut sum = 0.0;
118            for i in 0..n {
119                let j = (i + k) % n;
120                sum += a[i] * b[j];
121            }
122            result[k] = sum;
123        }
124
125        result
126    }
127
128    /// Compute the score for a triple (h, r, t)
129    ///
130    /// f(h,r,t) = σ(r^T (h ★ t))
131    fn score_triple_internal(
132        &self,
133        head: &ArrayView1<f32>,
134        relation: &ArrayView1<f32>,
135        tail: &ArrayView1<f32>,
136    ) -> f32 {
137        // Compute circular correlation: h ★ t
138        let correlation = self.circular_correlation(head, tail);
139
140        // Compute dot product: r^T (h ★ t)
141        let score = relation.dot(&correlation);
142
143        // Apply sigmoid if configured
144        if self.config.use_sigmoid {
145            1.0 / (1.0 + (-score).exp())
146        } else {
147            score
148        }
149    }
150
151    /// Initialize embeddings for an entity
152    fn init_entity(&mut self, entity: &str) {
153        if !self.entity_embeddings.contains_key(entity) {
154            let id = self.entity_embeddings.len();
155            self.entity_to_id.insert(entity.to_string(), id);
156            self.id_to_entity.insert(id, entity.to_string());
157
158            // Initialize with uniform distribution scaled by 1/sqrt(d)
159            let scale = 1.0 / (self.config.base.dimensions as f32).sqrt();
160            let mut local_rng = Random::default();
161            let embedding = Array1::from_vec(
162                (0..self.config.base.dimensions)
163                    .map(|_| local_rng.gen_range(-scale..scale))
164                    .collect(),
165            );
166            self.entity_embeddings.insert(entity.to_string(), embedding);
167        }
168    }
169
170    /// Initialize embeddings for a relation
171    fn init_relation(&mut self, relation: &str) {
172        if !self.relation_embeddings.contains_key(relation) {
173            let id = self.relation_embeddings.len();
174            self.relation_to_id.insert(relation.to_string(), id);
175            self.id_to_relation.insert(id, relation.to_string());
176
177            // Initialize with uniform distribution scaled by 1/sqrt(d)
178            let scale = 1.0 / (self.config.base.dimensions as f32).sqrt();
179            let mut local_rng = Random::default();
180            let embedding = Array1::from_vec(
181                (0..self.config.base.dimensions)
182                    .map(|_| local_rng.gen_range(-scale..scale))
183                    .collect(),
184            );
185            self.relation_embeddings
186                .insert(relation.to_string(), embedding);
187        }
188    }
189
190    /// Generate negative samples by corrupting subject or object
191    fn generate_negative_samples(&mut self, triple: &Triple) -> Vec<Triple> {
192        let mut negatives = Vec::new();
193        let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
194        let mut local_rng = Random::default();
195
196        for _ in 0..self.config.num_negatives {
197            // Randomly corrupt subject or object
198            if local_rng.gen_range(0.0..1.0) < 0.5 {
199                // Corrupt subject
200                let random_subject =
201                    entity_list[local_rng.random_range(0..entity_list.len())].clone();
202                negatives.push(Triple {
203                    subject: NamedNode::new(&random_subject).unwrap(),
204                    predicate: triple.predicate.clone(),
205                    object: triple.object.clone(),
206                });
207            } else {
208                // Corrupt object
209                let random_object =
210                    entity_list[local_rng.random_range(0..entity_list.len())].clone();
211                negatives.push(Triple {
212                    subject: triple.subject.clone(),
213                    predicate: triple.predicate.clone(),
214                    object: NamedNode::new(&random_object).unwrap(),
215                });
216            }
217        }
218
219        negatives
220    }
221
222    /// Perform one training step with margin-based ranking loss
223    fn train_step(&mut self) -> f32 {
224        let mut total_loss = 0.0;
225        let mut local_rng = Random::default();
226
227        // Shuffle triples for stochastic gradient descent
228        let mut indices: Vec<usize> = (0..self.triples.len()).collect();
229        for i in (1..indices.len()).rev() {
230            let j = local_rng.random_range(0..i + 1);
231            indices.swap(i, j);
232        }
233
234        for &idx in &indices {
235            let triple = &self.triples[idx].clone();
236
237            // Get embeddings
238            let subject_str = &triple.subject.iri;
239            let predicate_str = &triple.predicate.iri;
240            let object_str = &triple.object.iri;
241
242            let head_emb = self.entity_embeddings[subject_str].clone();
243            let rel_emb = self.relation_embeddings[predicate_str].clone();
244            let tail_emb = self.entity_embeddings[object_str].clone();
245
246            // Positive score
247            let pos_score =
248                self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
249
250            // Generate negative samples
251            let negatives = self.generate_negative_samples(triple);
252
253            for neg_triple in &negatives {
254                let neg_subject_str = &neg_triple.subject.iri;
255                let neg_object_str = &neg_triple.object.iri;
256
257                let neg_head_emb = self.entity_embeddings[neg_subject_str].clone();
258                let neg_tail_emb = self.entity_embeddings[neg_object_str].clone();
259
260                // Negative score
261                let neg_score = self.score_triple_internal(
262                    &neg_head_emb.view(),
263                    &rel_emb.view(),
264                    &neg_tail_emb.view(),
265                );
266
267                // Margin ranking loss: max(0, margin + neg_score - pos_score)
268                let loss = (self.config.margin + neg_score - pos_score).max(0.0);
269
270                if loss > 0.0 {
271                    total_loss += loss;
272
273                    // Compute gradients and update embeddings
274                    // For simplicity, we use a basic gradient update
275                    // In practice, more sophisticated optimizers should be used
276
277                    let lr = self.config.base.learning_rate as f32;
278
279                    // Update entity embeddings
280                    if let Some(head) = self.entity_embeddings.get_mut(subject_str) {
281                        *head = &*head * (1.0 - self.config.regularization * lr);
282                    }
283
284                    if let Some(tail) = self.entity_embeddings.get_mut(object_str) {
285                        *tail = &*tail * (1.0 - self.config.regularization * lr);
286                    }
287
288                    if let Some(neg_head) = self.entity_embeddings.get_mut(neg_subject_str) {
289                        *neg_head = &*neg_head * (1.0 - self.config.regularization * lr);
290                    }
291
292                    if let Some(neg_tail) = self.entity_embeddings.get_mut(neg_object_str) {
293                        *neg_tail = &*neg_tail * (1.0 - self.config.regularization * lr);
294                    }
295
296                    // Update relation embeddings
297                    if let Some(rel) = self.relation_embeddings.get_mut(predicate_str) {
298                        *rel = &*rel * (1.0 - self.config.regularization * lr);
299                    }
300                }
301            }
302        }
303
304        total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
305    }
306}
307
308#[async_trait::async_trait]
309impl EmbeddingModel for HoLE {
310    fn config(&self) -> &ModelConfig {
311        &self.config.base
312    }
313
314    fn model_id(&self) -> &Uuid {
315        &self.model_id
316    }
317
318    fn model_type(&self) -> &'static str {
319        "HoLE"
320    }
321
322    fn add_triple(&mut self, triple: Triple) -> Result<()> {
323        // Initialize embeddings for new entities/relations
324        self.init_entity(&triple.subject.iri);
325        self.init_entity(&triple.object.iri);
326        self.init_relation(&triple.predicate.iri);
327
328        self.triples.push(triple);
329        Ok(())
330    }
331
332    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
333        let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
334
335        if self.triples.is_empty() {
336            return Err(anyhow!("No training data available"));
337        }
338
339        info!(
340            "Training HoLE model for {} epochs on {} triples",
341            num_epochs,
342            self.triples.len()
343        );
344
345        let start_time = std::time::Instant::now();
346        let mut loss_history = Vec::new();
347
348        for epoch in 0..num_epochs {
349            let loss = self.train_step();
350            loss_history.push(loss as f64);
351
352            if epoch % 10 == 0 {
353                debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
354            }
355
356            // Check for convergence
357            if loss < 0.001 {
358                info!("Converged at epoch {}", epoch);
359                break;
360            }
361        }
362
363        let training_time = start_time.elapsed().as_secs_f64();
364        self.is_trained = true;
365
366        Ok(TrainingStats {
367            epochs_completed: num_epochs,
368            final_loss: *loss_history.last().unwrap_or(&0.0),
369            training_time_seconds: training_time,
370            convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
371            loss_history,
372        })
373    }
374
375    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
376        self.entity_embeddings
377            .get(entity)
378            .map(Vector::from_array1)
379            .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
380    }
381
382    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
383        self.relation_embeddings
384            .get(relation)
385            .map(Vector::from_array1)
386            .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
387    }
388
389    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
390        let head_emb = self
391            .entity_embeddings
392            .get(subject)
393            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
394        let rel_emb = self
395            .relation_embeddings
396            .get(predicate)
397            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
398        let tail_emb = self
399            .entity_embeddings
400            .get(object)
401            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
402
403        let score = self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
404        Ok(score as f64)
405    }
406
407    fn predict_objects(
408        &self,
409        subject: &str,
410        predicate: &str,
411        k: usize,
412    ) -> Result<Vec<(String, f64)>> {
413        let head_emb = self
414            .entity_embeddings
415            .get(subject)
416            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
417        let rel_emb = self
418            .relation_embeddings
419            .get(predicate)
420            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
421
422        let mut scored_objects: Vec<(String, f64)> = self
423            .entity_embeddings
424            .par_iter()
425            .map(|(entity, tail_emb)| {
426                let score =
427                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
428                (entity.clone(), score as f64)
429            })
430            .collect();
431
432        scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
433        scored_objects.truncate(k);
434        Ok(scored_objects)
435    }
436
437    fn predict_subjects(
438        &self,
439        predicate: &str,
440        object: &str,
441        k: usize,
442    ) -> Result<Vec<(String, f64)>> {
443        let rel_emb = self
444            .relation_embeddings
445            .get(predicate)
446            .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
447        let tail_emb = self
448            .entity_embeddings
449            .get(object)
450            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
451
452        let mut scored_subjects: Vec<(String, f64)> = self
453            .entity_embeddings
454            .par_iter()
455            .map(|(entity, head_emb)| {
456                let score =
457                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
458                (entity.clone(), score as f64)
459            })
460            .collect();
461
462        scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
463        scored_subjects.truncate(k);
464        Ok(scored_subjects)
465    }
466
467    fn predict_relations(
468        &self,
469        subject: &str,
470        object: &str,
471        k: usize,
472    ) -> Result<Vec<(String, f64)>> {
473        let head_emb = self
474            .entity_embeddings
475            .get(subject)
476            .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
477        let tail_emb = self
478            .entity_embeddings
479            .get(object)
480            .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
481
482        let mut scored_relations: Vec<(String, f64)> = self
483            .relation_embeddings
484            .par_iter()
485            .map(|(relation, rel_emb)| {
486                let score =
487                    self.score_triple_internal(&head_emb.view(), &rel_emb.view(), &tail_emb.view());
488                (relation.clone(), score as f64)
489            })
490            .collect();
491
492        scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
493        scored_relations.truncate(k);
494        Ok(scored_relations)
495    }
496
497    fn get_entities(&self) -> Vec<String> {
498        self.entity_embeddings.keys().cloned().collect()
499    }
500
501    fn get_relations(&self) -> Vec<String> {
502        self.relation_embeddings.keys().cloned().collect()
503    }
504
505    fn get_stats(&self) -> ModelStats {
506        ModelStats {
507            num_entities: self.entity_embeddings.len(),
508            num_relations: self.relation_embeddings.len(),
509            num_triples: self.triples.len(),
510            dimensions: self.config.base.dimensions,
511            is_trained: self.is_trained,
512            model_type: "HoLE".to_string(),
513            creation_time: chrono::Utc::now(),
514            last_training_time: if self.is_trained {
515                Some(chrono::Utc::now())
516            } else {
517                None
518            },
519        }
520    }
521
522    fn save(&self, path: &str) -> Result<()> {
523        info!("Saving HolE model to {}", path);
524
525        // Convert Array1 to Vec for serialization
526        let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
527            .entity_embeddings
528            .iter()
529            .map(|(k, v)| (k.clone(), v.to_vec()))
530            .collect();
531
532        let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
533            .relation_embeddings
534            .iter()
535            .map(|(k, v)| (k.clone(), v.to_vec()))
536            .collect();
537
538        let serializable = HoLESerializable {
539            model_id: self.model_id,
540            config: self.config.clone(),
541            entity_embeddings: entity_embeddings_vec,
542            relation_embeddings: relation_embeddings_vec,
543            triples: self.triples.clone(),
544            entity_to_id: self.entity_to_id.clone(),
545            relation_to_id: self.relation_to_id.clone(),
546            id_to_entity: self.id_to_entity.clone(),
547            id_to_relation: self.id_to_relation.clone(),
548            is_trained: self.is_trained,
549        };
550
551        let file = File::create(path)?;
552        let writer = BufWriter::new(file);
553        oxicode::serde::encode_into_std_write(&serializable, writer, oxicode::config::standard())
554            .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
555
556        info!("Model saved successfully");
557        Ok(())
558    }
559
560    fn load(&mut self, path: &str) -> Result<()> {
561        info!("Loading HolE model from {}", path);
562
563        if !Path::new(path).exists() {
564            return Err(anyhow!("Model file not found: {}", path));
565        }
566
567        let file = File::open(path)?;
568        let reader = BufReader::new(file);
569        let (serializable, _): (HoLESerializable, _) =
570            oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
571                .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
572
573        // Convert Vec back to Array1
574        let entity_embeddings: HashMap<String, Array1<f32>> = serializable
575            .entity_embeddings
576            .into_iter()
577            .map(|(k, v)| (k, Array1::from_vec(v)))
578            .collect();
579
580        let relation_embeddings: HashMap<String, Array1<f32>> = serializable
581            .relation_embeddings
582            .into_iter()
583            .map(|(k, v)| (k, Array1::from_vec(v)))
584            .collect();
585
586        // Update model state
587        self.model_id = serializable.model_id;
588        self.config = serializable.config;
589        self.entity_embeddings = entity_embeddings;
590        self.relation_embeddings = relation_embeddings;
591        self.triples = serializable.triples;
592        self.entity_to_id = serializable.entity_to_id;
593        self.relation_to_id = serializable.relation_to_id;
594        self.id_to_entity = serializable.id_to_entity;
595        self.id_to_relation = serializable.id_to_relation;
596        self.is_trained = serializable.is_trained;
597
598        info!("Model loaded successfully");
599        Ok(())
600    }
601
602    fn clear(&mut self) {
603        self.entity_embeddings.clear();
604        self.relation_embeddings.clear();
605        self.triples.clear();
606        self.entity_to_id.clear();
607        self.relation_to_id.clear();
608        self.id_to_entity.clear();
609        self.id_to_relation.clear();
610        self.is_trained = false;
611    }
612
613    fn is_trained(&self) -> bool {
614        self.is_trained
615    }
616
617    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
618        // TODO: Implement text encoding
619        Err(anyhow!("Text encoding not implemented for HoLE"))
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use scirs2_core::ndarray_ext::array;
627
628    #[test]
629    fn test_circular_correlation() {
630        let config = HoLEConfig::default();
631        let model = HoLE::new(config);
632
633        let a = array![1.0, 2.0, 3.0];
634        let b = array![4.0, 5.0, 6.0];
635
636        let result = model.circular_correlation(&a.view(), &b.view());
637
638        // Expected circular correlation
639        // result[0] = a[0]*b[0] + a[1]*b[1] + a[2]*b[2] = 1*4 + 2*5 + 3*6 = 32
640        // result[1] = a[0]*b[1] + a[1]*b[2] + a[2]*b[0] = 1*5 + 2*6 + 3*4 = 29
641        // result[2] = a[0]*b[2] + a[1]*b[0] + a[2]*b[1] = 1*6 + 2*4 + 3*5 = 29
642
643        assert_eq!(result.len(), 3);
644        assert!((result[0] - 32.0).abs() < 1e-5);
645        assert!((result[1] - 29.0).abs() < 1e-5);
646        assert!((result[2] - 29.0).abs() < 1e-5);
647    }
648
649    #[test]
650    fn test_hole_creation() {
651        let config = HoLEConfig::default();
652        let model = HoLE::new(config);
653
654        assert_eq!(model.entity_embeddings.len(), 0);
655        assert_eq!(model.relation_embeddings.len(), 0);
656    }
657
658    #[tokio::test]
659    async fn test_hole_training() {
660        let config = HoLEConfig {
661            base: ModelConfig {
662                dimensions: 50,
663                learning_rate: 0.01,
664                max_epochs: 50,
665                ..Default::default()
666            },
667            ..Default::default()
668        };
669
670        let mut model = HoLE::new(config);
671
672        // Add some triples
673        model
674            .add_triple(Triple::new(
675                NamedNode::new("alice").unwrap(),
676                NamedNode::new("knows").unwrap(),
677                NamedNode::new("bob").unwrap(),
678            ))
679            .unwrap();
680
681        model
682            .add_triple(Triple::new(
683                NamedNode::new("bob").unwrap(),
684                NamedNode::new("knows").unwrap(),
685                NamedNode::new("charlie").unwrap(),
686            ))
687            .unwrap();
688
689        model
690            .add_triple(Triple::new(
691                NamedNode::new("alice").unwrap(),
692                NamedNode::new("likes").unwrap(),
693                NamedNode::new("charlie").unwrap(),
694            ))
695            .unwrap();
696
697        // Train the model
698        let stats = model.train(Some(50)).await.unwrap();
699
700        assert_eq!(stats.epochs_completed, 50);
701        assert!(stats.final_loss >= 0.0);
702        assert!(stats.training_time_seconds > 0.0);
703
704        // Check that embeddings were created
705        assert_eq!(model.entity_embeddings.len(), 3);
706        assert_eq!(model.relation_embeddings.len(), 2);
707
708        // Test prediction
709        let score = model.score_triple("alice", "knows", "bob").unwrap();
710        assert!((0.0..=1.0).contains(&score)); // Sigmoid bounded
711    }
712
713    #[tokio::test]
714    async fn test_hole_ranking() {
715        let config = HoLEConfig {
716            base: ModelConfig {
717                dimensions: 50,
718                max_epochs: 30,
719                ..Default::default()
720            },
721            ..Default::default()
722        };
723
724        let mut model = HoLE::new(config);
725
726        // Add training data
727        model
728            .add_triple(Triple::new(
729                NamedNode::new("alice").unwrap(),
730                NamedNode::new("knows").unwrap(),
731                NamedNode::new("bob").unwrap(),
732            ))
733            .unwrap();
734
735        model
736            .add_triple(Triple::new(
737                NamedNode::new("alice").unwrap(),
738                NamedNode::new("knows").unwrap(),
739                NamedNode::new("charlie").unwrap(),
740            ))
741            .unwrap();
742
743        // Train
744        model.train(Some(30)).await.unwrap();
745
746        // Rank objects
747        let ranked = model.predict_objects("alice", "knows", 2).unwrap();
748
749        assert!(ranked.len() <= 2);
750        // Scores should be in descending order
751        if ranked.len() >= 2 {
752            assert!(ranked[0].1 >= ranked[1].1);
753        }
754    }
755
756    #[tokio::test]
757    async fn test_hole_save_load() {
758        use std::env::temp_dir;
759
760        let config = HoLEConfig {
761            base: ModelConfig {
762                dimensions: 30,
763                max_epochs: 20,
764                ..Default::default()
765            },
766            ..Default::default()
767        };
768
769        let mut model = HoLE::new(config);
770
771        // Add and train
772        model
773            .add_triple(Triple::new(
774                NamedNode::new("alice").unwrap(),
775                NamedNode::new("knows").unwrap(),
776                NamedNode::new("bob").unwrap(),
777            ))
778            .unwrap();
779
780        model
781            .add_triple(Triple::new(
782                NamedNode::new("bob").unwrap(),
783                NamedNode::new("likes").unwrap(),
784                NamedNode::new("charlie").unwrap(),
785            ))
786            .unwrap();
787
788        model.train(Some(20)).await.unwrap();
789
790        // Get embedding before save
791        let emb_before = model.get_entity_embedding("alice").unwrap();
792        let score_before = model.score_triple("alice", "knows", "bob").unwrap();
793
794        // Save model
795        let model_path = temp_dir().join("test_hole_model.bin");
796        let path_str = model_path.to_str().unwrap();
797        model.save(path_str).unwrap();
798
799        // Create new model and load
800        let mut loaded_model = HoLE::new(HoLEConfig::default());
801        loaded_model.load(path_str).unwrap();
802
803        // Verify loaded model
804        assert!(loaded_model.is_trained());
805        assert_eq!(loaded_model.get_entities().len(), 3);
806        assert_eq!(loaded_model.get_relations().len(), 2);
807
808        // Verify embeddings are preserved
809        let emb_after = loaded_model.get_entity_embedding("alice").unwrap();
810        assert_eq!(emb_before.dimensions, emb_after.dimensions);
811        for i in 0..emb_before.values.len() {
812            assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
813        }
814
815        // Verify scoring is consistent
816        let score_after = loaded_model.score_triple("alice", "knows", "bob").unwrap();
817        assert!((score_before - score_after).abs() < 1e-6);
818
819        // Cleanup
820        std::fs::remove_file(model_path).ok();
821    }
822
823    #[test]
824    fn test_hole_load_nonexistent() {
825        let mut model = HoLE::new(HoLEConfig::default());
826        let result = model.load("/nonexistent/path/model.bin");
827        assert!(result.is_err());
828    }
829}