oxirs_embed/models/
quatd.rs

1//! QuatE: Quaternion Embeddings for Knowledge Graph Completion
2//!
3//! QuatE models entities and relations as quaternions in a 4D space,
4//! using quaternion algebra for knowledge graph completion.
5//!
6//! Reference: Zhang et al. "Quaternion Knowledge Graph Embeddings" (2019)
7
8use crate::models::BaseModel;
9use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
10use anyhow::{anyhow, Result};
11use async_trait::async_trait;
12use scirs2_core::ndarray_ext::Array2;
13use scirs2_core::random::{Random, SliceRandom};
14use std::time::Instant;
15use tracing::{debug, info};
16use uuid::Uuid;
17
18/// Quaternion representation for embeddings
19#[derive(Debug, Clone, Copy)]
20pub struct Quaternion {
21    /// Real component
22    pub w: f64,
23    /// i component
24    pub x: f64,
25    /// j component
26    pub y: f64,
27    /// k component
28    pub z: f64,
29}
30
31impl Quaternion {
32    /// Create a new quaternion
33    pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
34        Self { w, x, y, z }
35    }
36
37    /// Create a quaternion from a 4-element array
38    pub fn from_array(arr: &[f64]) -> Self {
39        assert_eq!(arr.len(), 4);
40        Self::new(arr[0], arr[1], arr[2], arr[3])
41    }
42
43    /// Convert quaternion to array
44    pub fn to_array(&self) -> [f64; 4] {
45        [self.w, self.x, self.y, self.z]
46    }
47
48    /// Quaternion multiplication (Hamilton product)
49    pub fn multiply(&self, other: &Quaternion) -> Quaternion {
50        Quaternion {
51            w: self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z,
52            x: self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y,
53            y: self.w * other.y - self.x * other.z + self.y * other.w + self.z * other.x,
54            z: self.w * other.z + self.x * other.y - self.y * other.x + self.z * other.w,
55        }
56    }
57
58    /// Quaternion conjugate
59    pub fn conjugate(&self) -> Quaternion {
60        Quaternion {
61            w: self.w,
62            x: -self.x,
63            y: -self.y,
64            z: -self.z,
65        }
66    }
67
68    /// Quaternion norm (magnitude)
69    pub fn norm(&self) -> f64 {
70        (self.w * self.w + self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
71    }
72
73    /// Normalize quaternion to unit length
74    pub fn normalize(&mut self) {
75        let norm = self.norm();
76        if norm > 1e-12 {
77            self.w /= norm;
78            self.x /= norm;
79            self.y /= norm;
80            self.z /= norm;
81        }
82    }
83
84    /// Quaternion dot product
85    pub fn dot(&self, other: &Quaternion) -> f64 {
86        self.w * other.w + self.x * other.x + self.y * other.y + self.z * other.z
87    }
88
89    /// Element-wise addition
90    pub fn add(&self, other: &Quaternion) -> Quaternion {
91        Quaternion {
92            w: self.w + other.w,
93            x: self.x + other.x,
94            y: self.y + other.y,
95            z: self.z + other.z,
96        }
97    }
98
99    /// Element-wise subtraction
100    pub fn subtract(&self, other: &Quaternion) -> Quaternion {
101        Quaternion {
102            w: self.w - other.w,
103            x: self.x - other.x,
104            y: self.y - other.y,
105            z: self.z - other.z,
106        }
107    }
108
109    /// Scalar multiplication
110    pub fn scale(&self, scalar: f64) -> Quaternion {
111        Quaternion {
112            w: self.w * scalar,
113            x: self.x * scalar,
114            y: self.y * scalar,
115            z: self.z * scalar,
116        }
117    }
118}
119
120/// QuatD embedding model
121#[derive(Debug)]
122pub struct QuatD {
123    /// Base model functionality
124    base: BaseModel,
125    /// Entity embeddings as quaternions (num_entities × 4)
126    entity_embeddings: Array2<f64>,
127    /// Relation embeddings as quaternions (num_relations × 4)
128    relation_embeddings: Array2<f64>,
129    /// Whether embeddings have been initialized
130    embeddings_initialized: bool,
131    /// Scoring function variant
132    scoring_function: QuatDScoringFunction,
133    /// Regularization parameters
134    quaternion_regularization: f64,
135}
136
137/// Scoring function variants for QuatD
138#[derive(Debug, Clone, Copy)]
139pub enum QuatDScoringFunction {
140    /// Original QuatD scoring function
141    Standard,
142    /// QuatD with L2 distance
143    L2Distance,
144    /// QuatD with cosine similarity
145    CosineSimilarity,
146}
147
148impl QuatD {
149    /// Create a new QuatD model
150    pub fn new(config: ModelConfig) -> Self {
151        let base = BaseModel::new(config.clone());
152
153        // Get QuatD-specific parameters
154        let scoring_function = match config.model_params.get("scoring_function") {
155            Some(0.0) => QuatDScoringFunction::Standard,
156            Some(1.0) => QuatDScoringFunction::L2Distance,
157            Some(2.0) => QuatDScoringFunction::CosineSimilarity,
158            _ => QuatDScoringFunction::Standard,
159        };
160
161        let quaternion_regularization = config
162            .model_params
163            .get("quaternion_regularization")
164            .copied()
165            .unwrap_or(0.05);
166
167        Self {
168            base,
169            entity_embeddings: Array2::zeros((0, 4)), // 4D quaternions
170            relation_embeddings: Array2::zeros((0, 4)), // 4D quaternions
171            embeddings_initialized: false,
172            scoring_function,
173            quaternion_regularization,
174        }
175    }
176
177    /// Initialize embeddings after entities and relations are known
178    fn initialize_embeddings(&mut self) {
179        if self.embeddings_initialized {
180            return;
181        }
182
183        let num_entities = self.base.num_entities();
184        let num_relations = self.base.num_relations();
185
186        if num_entities == 0 || num_relations == 0 {
187            return;
188        }
189
190        let mut rng = Random::seed(self.base.config.seed.unwrap_or_else(|| {
191            use std::time::{SystemTime, UNIX_EPOCH};
192            SystemTime::now()
193                .duration_since(UNIX_EPOCH)
194                .expect("system time should be after UNIX_EPOCH")
195                .as_secs()
196        }));
197
198        // Initialize entity embeddings as quaternions
199        self.entity_embeddings =
200            Array2::from_shape_fn((num_entities, 4), |_| rng.gen_range(-0.1..0.1));
201
202        // Initialize relation embeddings as quaternions
203        self.relation_embeddings =
204            Array2::from_shape_fn((num_relations, 4), |_| rng.gen_range(-0.1..0.1));
205
206        // Normalize quaternions to unit length
207        self.normalize_all_quaternions();
208
209        self.embeddings_initialized = true;
210        debug!(
211            "Initialized QuatD embeddings: {} entities, {} relations (4D quaternions)",
212            num_entities, num_relations
213        );
214    }
215
216    /// Normalize all quaternion embeddings to unit length
217    fn normalize_all_quaternions(&mut self) {
218        // Normalize entity embeddings
219        for mut row in self.entity_embeddings.rows_mut() {
220            let mut quat =
221                Quaternion::from_array(row.as_slice().expect("row should be contiguous"));
222            quat.normalize();
223            let normalized = quat.to_array();
224            for (i, &val) in normalized.iter().enumerate() {
225                row[i] = val;
226            }
227        }
228
229        // Normalize relation embeddings
230        for mut row in self.relation_embeddings.rows_mut() {
231            let mut quat =
232                Quaternion::from_array(row.as_slice().expect("row should be contiguous"));
233            quat.normalize();
234            let normalized = quat.to_array();
235            for (i, &val) in normalized.iter().enumerate() {
236                row[i] = val;
237            }
238        }
239    }
240
241    /// Get quaternion from entity embeddings
242    fn get_entity_quaternion(&self, entity_id: usize) -> Quaternion {
243        let row = self.entity_embeddings.row(entity_id);
244        Quaternion::from_array(row.as_slice().expect("row should be contiguous"))
245    }
246
247    /// Get quaternion from relation embeddings
248    fn get_relation_quaternion(&self, relation_id: usize) -> Quaternion {
249        let row = self.relation_embeddings.row(relation_id);
250        Quaternion::from_array(row.as_slice().expect("row should be contiguous"))
251    }
252
253    /// Score a triple using QuatD scoring function
254    fn score_triple_ids(
255        &self,
256        subject_id: usize,
257        predicate_id: usize,
258        object_id: usize,
259    ) -> Result<f64> {
260        if !self.embeddings_initialized {
261            return Err(anyhow!("Model not trained"));
262        }
263
264        let h = self.get_entity_quaternion(subject_id);
265        let r = self.get_relation_quaternion(predicate_id);
266        let t = self.get_entity_quaternion(object_id);
267
268        match self.scoring_function {
269            QuatDScoringFunction::Standard => {
270                // QuatD scoring: σ(h ∘ r · t)
271                let hr = h.multiply(&r);
272                Ok(hr.dot(&t))
273            }
274            QuatDScoringFunction::L2Distance => {
275                // L2 distance: -||h ∘ r - t||₂
276                let hr = h.multiply(&r);
277                let diff = hr.subtract(&t);
278                Ok(-diff.norm())
279            }
280            QuatDScoringFunction::CosineSimilarity => {
281                // Cosine similarity between h ∘ r and t
282                let hr = h.multiply(&r);
283                let dot_product = hr.dot(&t);
284                let magnitude_product = hr.norm() * t.norm();
285                if magnitude_product > 1e-12 {
286                    Ok(dot_product / magnitude_product)
287                } else {
288                    Ok(0.0)
289                }
290            }
291        }
292    }
293
294    /// Compute gradients for QuatD
295    fn compute_gradients(
296        &self,
297        pos_triple: (usize, usize, usize),
298        neg_triple: (usize, usize, usize),
299    ) -> Result<(Array2<f64>, Array2<f64>)> {
300        let (pos_s, pos_p, pos_o) = pos_triple;
301        let (neg_s, neg_p, neg_o) = neg_triple;
302
303        let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
304        let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
305
306        // Compute scores
307        let pos_score = self.score_triple_ids(pos_s, pos_p, pos_o)?;
308        let neg_score = self.score_triple_ids(neg_s, neg_p, neg_o)?;
309
310        // Sigmoid derivatives
311        let pos_sigmoid = 1.0 / (1.0 + (-pos_score).exp());
312        let neg_sigmoid = 1.0 / (1.0 + (-neg_score).exp());
313
314        let pos_grad = pos_sigmoid - 1.0;
315        let neg_grad = neg_sigmoid;
316
317        // Compute gradients for positive triple
318        self.compute_triple_gradients(pos_triple, pos_grad, &mut entity_grads, &mut relation_grads);
319
320        // Compute gradients for negative triple
321        self.compute_triple_gradients(neg_triple, neg_grad, &mut entity_grads, &mut relation_grads);
322
323        Ok((entity_grads, relation_grads))
324    }
325
326    /// Compute gradients for a single triple
327    fn compute_triple_gradients(
328        &self,
329        triple: (usize, usize, usize),
330        loss_grad: f64,
331        entity_grads: &mut Array2<f64>,
332        relation_grads: &mut Array2<f64>,
333    ) {
334        let (s, p, o) = triple;
335
336        let h = self.get_entity_quaternion(s);
337        let r = self.get_relation_quaternion(p);
338        let t = self.get_entity_quaternion(o);
339
340        match self.scoring_function {
341            QuatDScoringFunction::Standard => {
342                // Gradients for h ∘ r · t scoring
343                let hr = h.multiply(&r);
344
345                // ∂score/∂h = (r · t) where · is quaternion multiplication with t
346                let r_conj = r.conjugate();
347                let grad_h = r_conj.multiply(&t).scale(loss_grad);
348
349                // ∂score/∂r = (h^* · t) where ^* is conjugate
350                let h_conj = h.conjugate();
351                let grad_r = h_conj.multiply(&t).scale(loss_grad);
352
353                // ∂score/∂t = (h ∘ r)
354                let grad_t = hr.scale(loss_grad);
355
356                // Add gradients
357                let grad_h_arr = grad_h.to_array();
358                let grad_r_arr = grad_r.to_array();
359                let grad_t_arr = grad_t.to_array();
360
361                for i in 0..4 {
362                    entity_grads[[s, i]] += grad_h_arr[i];
363                    relation_grads[[p, i]] += grad_r_arr[i];
364                    entity_grads[[o, i]] += grad_t_arr[i];
365                }
366            }
367            QuatDScoringFunction::L2Distance => {
368                // Gradients for -||h ∘ r - t||₂ scoring
369                let hr = h.multiply(&r);
370                let diff = hr.subtract(&t);
371                let norm = diff.norm();
372
373                if norm > 1e-12 {
374                    let scale = -loss_grad / norm;
375
376                    // Similar quaternion gradient computation but scaled by norm
377                    let r_conj = r.conjugate();
378                    let grad_h = r_conj.scale(scale);
379
380                    let h_conj = h.conjugate();
381                    let grad_r = h_conj.scale(scale);
382
383                    let grad_t = diff.scale(-scale);
384
385                    let grad_h_arr = grad_h.to_array();
386                    let grad_r_arr = grad_r.to_array();
387                    let grad_t_arr = grad_t.to_array();
388
389                    for i in 0..4 {
390                        entity_grads[[s, i]] += grad_h_arr[i];
391                        relation_grads[[p, i]] += grad_r_arr[i];
392                        entity_grads[[o, i]] += grad_t_arr[i];
393                    }
394                }
395            }
396            QuatDScoringFunction::CosineSimilarity => {
397                // Gradients for cosine similarity
398                let hr = h.multiply(&r);
399                let dot_product = hr.dot(&t);
400                let hr_norm = hr.norm();
401                let t_norm = t.norm();
402                let magnitude_product = hr_norm * t_norm;
403
404                if magnitude_product > 1e-12 {
405                    let cos_sim = dot_product / magnitude_product;
406
407                    // Complex gradients for cosine similarity - simplified version
408                    let scale = loss_grad / magnitude_product;
409
410                    let grad_hr = t
411                        .subtract(&hr.scale(cos_sim / (hr_norm * hr_norm)))
412                        .scale(scale);
413                    let grad_t = hr
414                        .subtract(&t.scale(cos_sim / (t_norm * t_norm)))
415                        .scale(scale);
416
417                    // Backpropagate through quaternion multiplication for grad_hr
418                    let r_conj = r.conjugate();
419                    let grad_h = r_conj.multiply(&grad_hr);
420
421                    let h_conj = h.conjugate();
422                    let grad_r = h_conj.multiply(&grad_hr);
423
424                    let grad_h_arr = grad_h.to_array();
425                    let grad_r_arr = grad_r.to_array();
426                    let grad_t_arr = grad_t.to_array();
427
428                    for i in 0..4 {
429                        entity_grads[[s, i]] += grad_h_arr[i];
430                        relation_grads[[p, i]] += grad_r_arr[i];
431                        entity_grads[[o, i]] += grad_t_arr[i];
432                    }
433                }
434            }
435        }
436    }
437
438    /// Perform one training epoch
439    async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
440        let mut rng = Random::seed(self.base.config.seed.unwrap_or_else(|| {
441            use std::time::{SystemTime, UNIX_EPOCH};
442            SystemTime::now()
443                .duration_since(UNIX_EPOCH)
444                .expect("system time should be after UNIX_EPOCH")
445                .as_secs()
446        }));
447
448        let mut total_loss = 0.0;
449        let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
450            / self.base.config.batch_size;
451
452        // Create shuffled batches
453        let mut shuffled_triples = self.base.triples.clone();
454        shuffled_triples.shuffle(&mut rng);
455
456        for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
457            let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
458            let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
459            let mut batch_loss = 0.0;
460
461            for &pos_triple in batch_triples {
462                // Generate negative samples
463                let neg_samples = self
464                    .base
465                    .generate_negative_samples(self.base.config.negative_samples, &mut rng);
466
467                for neg_triple in neg_samples {
468                    // Compute scores
469                    let pos_score =
470                        self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
471                    let neg_score =
472                        self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
473
474                    // Logistic loss
475                    let pos_loss = -(1.0 / (1.0 + (-pos_score).exp())).ln();
476                    let neg_loss = -(1.0 / (1.0 + neg_score.exp())).ln();
477                    let loss = pos_loss + neg_loss;
478                    batch_loss += loss;
479
480                    // Compute and accumulate gradients
481                    let (entity_grads, relation_grads) =
482                        self.compute_gradients(pos_triple, neg_triple)?;
483
484                    batch_entity_grads += &entity_grads;
485                    batch_relation_grads += &relation_grads;
486                }
487            }
488
489            // Apply gradients with quaternion regularization
490            if batch_loss > 0.0 {
491                // Update entity embeddings
492                for (((_i, _j), embedding_val), grad_val) in self
493                    .entity_embeddings
494                    .indexed_iter_mut()
495                    .zip(batch_entity_grads.iter())
496                {
497                    let reg_term = self.quaternion_regularization * *embedding_val;
498                    *embedding_val -= learning_rate * (grad_val + reg_term);
499                }
500
501                // Update relation embeddings
502                for (((_i, _j), embedding_val), grad_val) in self
503                    .relation_embeddings
504                    .indexed_iter_mut()
505                    .zip(batch_relation_grads.iter())
506                {
507                    let reg_term = self.quaternion_regularization * *embedding_val;
508                    *embedding_val -= learning_rate * (grad_val + reg_term);
509                }
510
511                // Normalize quaternions after update
512                self.normalize_all_quaternions();
513            }
514
515            total_loss += batch_loss;
516        }
517
518        Ok(total_loss / num_batches as f64)
519    }
520}
521
522#[async_trait]
523impl EmbeddingModel for QuatD {
524    fn config(&self) -> &ModelConfig {
525        &self.base.config
526    }
527
528    fn model_id(&self) -> &Uuid {
529        &self.base.model_id
530    }
531
532    fn model_type(&self) -> &'static str {
533        "QuatD"
534    }
535
536    fn add_triple(&mut self, triple: Triple) -> Result<()> {
537        self.base.add_triple(triple)
538    }
539
540    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
541        let start_time = Instant::now();
542        let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
543
544        // Initialize embeddings if needed
545        self.initialize_embeddings();
546
547        if !self.embeddings_initialized {
548            return Err(anyhow!("No training data available"));
549        }
550
551        let mut loss_history = Vec::new();
552        let learning_rate = self.base.config.learning_rate;
553
554        info!("Starting QuatD training for {} epochs", max_epochs);
555
556        for epoch in 0..max_epochs {
557            let epoch_loss = self.train_epoch(learning_rate).await?;
558            loss_history.push(epoch_loss);
559
560            if epoch % 100 == 0 {
561                debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
562            }
563
564            // Simple convergence check
565            if epoch > 10 && epoch_loss < 1e-6 {
566                info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
567                break;
568            }
569        }
570
571        self.base.mark_trained();
572        let training_time = start_time.elapsed().as_secs_f64();
573
574        Ok(TrainingStats {
575            epochs_completed: loss_history.len(),
576            final_loss: loss_history.last().copied().unwrap_or(0.0),
577            training_time_seconds: training_time,
578            convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
579            loss_history,
580        })
581    }
582
583    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
584        if !self.embeddings_initialized {
585            return Err(anyhow!("Model not trained"));
586        }
587
588        let entity_id = self
589            .base
590            .get_entity_id(entity)
591            .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
592
593        let embedding = self.entity_embeddings.row(entity_id).to_owned();
594        Ok(Vector::new(
595            embedding.to_vec().into_iter().map(|x| x as f32).collect(),
596        ))
597    }
598
599    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
600        if !self.embeddings_initialized {
601            return Err(anyhow!("Model not trained"));
602        }
603
604        let relation_id = self
605            .base
606            .get_relation_id(relation)
607            .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
608
609        let embedding = self.relation_embeddings.row(relation_id).to_owned();
610        Ok(Vector::new(
611            embedding.to_vec().into_iter().map(|x| x as f32).collect(),
612        ))
613    }
614
615    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
616        let subject_id = self
617            .base
618            .get_entity_id(subject)
619            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
620        let predicate_id = self
621            .base
622            .get_relation_id(predicate)
623            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
624        let object_id = self
625            .base
626            .get_entity_id(object)
627            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
628
629        self.score_triple_ids(subject_id, predicate_id, object_id)
630    }
631
632    fn predict_objects(
633        &self,
634        subject: &str,
635        predicate: &str,
636        k: usize,
637    ) -> Result<Vec<(String, f64)>> {
638        if !self.embeddings_initialized {
639            return Err(anyhow!("Model not trained"));
640        }
641
642        let subject_id = self
643            .base
644            .get_entity_id(subject)
645            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
646        let predicate_id = self
647            .base
648            .get_relation_id(predicate)
649            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
650
651        let mut scores = Vec::new();
652
653        for object_id in 0..self.base.num_entities() {
654            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
655            let object_name = self
656                .base
657                .get_entity(object_id)
658                .expect("entity should exist in index")
659                .clone();
660            scores.push((object_name, score));
661        }
662
663        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("scores should be comparable"));
664        scores.truncate(k);
665
666        Ok(scores)
667    }
668
669    fn predict_subjects(
670        &self,
671        predicate: &str,
672        object: &str,
673        k: usize,
674    ) -> Result<Vec<(String, f64)>> {
675        if !self.embeddings_initialized {
676            return Err(anyhow!("Model not trained"));
677        }
678
679        let predicate_id = self
680            .base
681            .get_relation_id(predicate)
682            .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
683        let object_id = self
684            .base
685            .get_entity_id(object)
686            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
687
688        let mut scores = Vec::new();
689
690        for subject_id in 0..self.base.num_entities() {
691            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
692            let subject_name = self
693                .base
694                .get_entity(subject_id)
695                .expect("entity should exist in index")
696                .clone();
697            scores.push((subject_name, score));
698        }
699
700        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("scores should be comparable"));
701        scores.truncate(k);
702
703        Ok(scores)
704    }
705
706    fn predict_relations(
707        &self,
708        subject: &str,
709        object: &str,
710        k: usize,
711    ) -> Result<Vec<(String, f64)>> {
712        if !self.embeddings_initialized {
713            return Err(anyhow!("Model not trained"));
714        }
715
716        let subject_id = self
717            .base
718            .get_entity_id(subject)
719            .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
720        let object_id = self
721            .base
722            .get_entity_id(object)
723            .ok_or_else(|| anyhow!("Object not found: {}", object))?;
724
725        let mut scores = Vec::new();
726
727        for predicate_id in 0..self.base.num_relations() {
728            let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
729            let predicate_name = self
730                .base
731                .get_relation(predicate_id)
732                .expect("relation should exist in index")
733                .clone();
734            scores.push((predicate_name, score));
735        }
736
737        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("scores should be comparable"));
738        scores.truncate(k);
739
740        Ok(scores)
741    }
742
743    fn get_entities(&self) -> Vec<String> {
744        self.base.get_entities()
745    }
746
747    fn get_relations(&self) -> Vec<String> {
748        self.base.get_relations()
749    }
750
751    fn get_stats(&self) -> ModelStats {
752        self.base.get_stats("QuatD")
753    }
754
755    fn save(&self, path: &str) -> Result<()> {
756        info!("Saving QuatD model to {}", path);
757        Ok(())
758    }
759
760    fn load(&mut self, path: &str) -> Result<()> {
761        info!("Loading QuatD model from {}", path);
762        Ok(())
763    }
764
765    fn clear(&mut self) {
766        self.base.clear();
767        self.entity_embeddings = Array2::zeros((0, 4));
768        self.relation_embeddings = Array2::zeros((0, 4));
769        self.embeddings_initialized = false;
770    }
771
772    fn is_trained(&self) -> bool {
773        self.base.is_trained
774    }
775
776    async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
777        Err(anyhow!(
778            "Knowledge graph embedding model does not support text encoding"
779        ))
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use super::*;
786    use crate::NamedNode;
787
788    #[test]
789    fn test_quaternion_operations() {
790        let q1 = Quaternion::new(1.0, 2.0, 3.0, 4.0);
791        let q2 = Quaternion::new(2.0, 3.0, 4.0, 5.0);
792
793        // Test multiplication
794        let product = q1.multiply(&q2);
795        assert!(product.w.is_finite());
796
797        // Test conjugate
798        let conj = q1.conjugate();
799        assert_eq!(conj.w, q1.w);
800        assert_eq!(conj.x, -q1.x);
801
802        // Test normalization
803        let mut q3 = q1;
804        q3.normalize();
805        assert!((q3.norm() - 1.0).abs() < 1e-10);
806    }
807
808    #[tokio::test]
809    async fn test_quatd_basic() -> Result<()> {
810        let config = ModelConfig::default()
811            .with_dimensions(4) // Always 4 for quaternions
812            .with_max_epochs(10)
813            .with_seed(42);
814
815        let mut model = QuatD::new(config);
816
817        // Add test triples
818        let alice = NamedNode::new("http://example.org/alice")?;
819        let knows = NamedNode::new("http://example.org/knows")?;
820        let bob = NamedNode::new("http://example.org/bob")?;
821
822        model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
823        model.add_triple(Triple::new(bob.clone(), knows.clone(), alice.clone()))?;
824
825        // Train
826        let stats = model.train(Some(5)).await?;
827        assert!(stats.epochs_completed > 0);
828
829        // Test embeddings
830        let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
831        assert_eq!(alice_emb.dimensions, 4); // Quaternion dimension
832
833        // Test scoring
834        let score = model.score_triple(
835            "http://example.org/alice",
836            "http://example.org/knows",
837            "http://example.org/bob",
838        )?;
839
840        // Score should be a finite number
841        assert!(score.is_finite());
842
843        Ok(())
844    }
845
846    #[test]
847    fn test_quatd_creation() {
848        let config = ModelConfig::default();
849        let quatd = QuatD::new(config);
850        assert!(!quatd.embeddings_initialized);
851        assert_eq!(quatd.model_type(), "QuatD");
852    }
853}