1use crate::models::BaseModel;
9use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
10use anyhow::{anyhow, Result};
11use async_trait::async_trait;
12use scirs2_core::ndarray_ext::Array2;
13use scirs2_core::random::{Random, SliceRandom};
14use std::time::Instant;
15use tracing::{debug, info};
16use uuid::Uuid;
17
18#[derive(Debug, Clone, Copy)]
20pub struct Quaternion {
21 pub w: f64,
23 pub x: f64,
25 pub y: f64,
27 pub z: f64,
29}
30
31impl Quaternion {
32 pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
34 Self { w, x, y, z }
35 }
36
37 pub fn from_array(arr: &[f64]) -> Self {
39 assert_eq!(arr.len(), 4);
40 Self::new(arr[0], arr[1], arr[2], arr[3])
41 }
42
43 pub fn to_array(&self) -> [f64; 4] {
45 [self.w, self.x, self.y, self.z]
46 }
47
48 pub fn multiply(&self, other: &Quaternion) -> Quaternion {
50 Quaternion {
51 w: self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z,
52 x: self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y,
53 y: self.w * other.y - self.x * other.z + self.y * other.w + self.z * other.x,
54 z: self.w * other.z + self.x * other.y - self.y * other.x + self.z * other.w,
55 }
56 }
57
58 pub fn conjugate(&self) -> Quaternion {
60 Quaternion {
61 w: self.w,
62 x: -self.x,
63 y: -self.y,
64 z: -self.z,
65 }
66 }
67
68 pub fn norm(&self) -> f64 {
70 (self.w * self.w + self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
71 }
72
73 pub fn normalize(&mut self) {
75 let norm = self.norm();
76 if norm > 1e-12 {
77 self.w /= norm;
78 self.x /= norm;
79 self.y /= norm;
80 self.z /= norm;
81 }
82 }
83
84 pub fn dot(&self, other: &Quaternion) -> f64 {
86 self.w * other.w + self.x * other.x + self.y * other.y + self.z * other.z
87 }
88
89 pub fn add(&self, other: &Quaternion) -> Quaternion {
91 Quaternion {
92 w: self.w + other.w,
93 x: self.x + other.x,
94 y: self.y + other.y,
95 z: self.z + other.z,
96 }
97 }
98
99 pub fn subtract(&self, other: &Quaternion) -> Quaternion {
101 Quaternion {
102 w: self.w - other.w,
103 x: self.x - other.x,
104 y: self.y - other.y,
105 z: self.z - other.z,
106 }
107 }
108
109 pub fn scale(&self, scalar: f64) -> Quaternion {
111 Quaternion {
112 w: self.w * scalar,
113 x: self.x * scalar,
114 y: self.y * scalar,
115 z: self.z * scalar,
116 }
117 }
118}
119
120#[derive(Debug)]
122pub struct QuatD {
123 base: BaseModel,
125 entity_embeddings: Array2<f64>,
127 relation_embeddings: Array2<f64>,
129 embeddings_initialized: bool,
131 scoring_function: QuatDScoringFunction,
133 quaternion_regularization: f64,
135}
136
137#[derive(Debug, Clone, Copy)]
139pub enum QuatDScoringFunction {
140 Standard,
142 L2Distance,
144 CosineSimilarity,
146}
147
148impl QuatD {
149 pub fn new(config: ModelConfig) -> Self {
151 let base = BaseModel::new(config.clone());
152
153 let scoring_function = match config.model_params.get("scoring_function") {
155 Some(0.0) => QuatDScoringFunction::Standard,
156 Some(1.0) => QuatDScoringFunction::L2Distance,
157 Some(2.0) => QuatDScoringFunction::CosineSimilarity,
158 _ => QuatDScoringFunction::Standard,
159 };
160
161 let quaternion_regularization = config
162 .model_params
163 .get("quaternion_regularization")
164 .copied()
165 .unwrap_or(0.05);
166
167 Self {
168 base,
169 entity_embeddings: Array2::zeros((0, 4)), relation_embeddings: Array2::zeros((0, 4)), embeddings_initialized: false,
172 scoring_function,
173 quaternion_regularization,
174 }
175 }
176
177 fn initialize_embeddings(&mut self) {
179 if self.embeddings_initialized {
180 return;
181 }
182
183 let num_entities = self.base.num_entities();
184 let num_relations = self.base.num_relations();
185
186 if num_entities == 0 || num_relations == 0 {
187 return;
188 }
189
190 let mut rng = Random::seed(self.base.config.seed.unwrap_or_else(|| {
191 use std::time::{SystemTime, UNIX_EPOCH};
192 SystemTime::now()
193 .duration_since(UNIX_EPOCH)
194 .expect("system time should be after UNIX_EPOCH")
195 .as_secs()
196 }));
197
198 self.entity_embeddings =
200 Array2::from_shape_fn((num_entities, 4), |_| rng.gen_range(-0.1..0.1));
201
202 self.relation_embeddings =
204 Array2::from_shape_fn((num_relations, 4), |_| rng.gen_range(-0.1..0.1));
205
206 self.normalize_all_quaternions();
208
209 self.embeddings_initialized = true;
210 debug!(
211 "Initialized QuatD embeddings: {} entities, {} relations (4D quaternions)",
212 num_entities, num_relations
213 );
214 }
215
216 fn normalize_all_quaternions(&mut self) {
218 for mut row in self.entity_embeddings.rows_mut() {
220 let mut quat =
221 Quaternion::from_array(row.as_slice().expect("row should be contiguous"));
222 quat.normalize();
223 let normalized = quat.to_array();
224 for (i, &val) in normalized.iter().enumerate() {
225 row[i] = val;
226 }
227 }
228
229 for mut row in self.relation_embeddings.rows_mut() {
231 let mut quat =
232 Quaternion::from_array(row.as_slice().expect("row should be contiguous"));
233 quat.normalize();
234 let normalized = quat.to_array();
235 for (i, &val) in normalized.iter().enumerate() {
236 row[i] = val;
237 }
238 }
239 }
240
241 fn get_entity_quaternion(&self, entity_id: usize) -> Quaternion {
243 let row = self.entity_embeddings.row(entity_id);
244 Quaternion::from_array(row.as_slice().expect("row should be contiguous"))
245 }
246
247 fn get_relation_quaternion(&self, relation_id: usize) -> Quaternion {
249 let row = self.relation_embeddings.row(relation_id);
250 Quaternion::from_array(row.as_slice().expect("row should be contiguous"))
251 }
252
253 fn score_triple_ids(
255 &self,
256 subject_id: usize,
257 predicate_id: usize,
258 object_id: usize,
259 ) -> Result<f64> {
260 if !self.embeddings_initialized {
261 return Err(anyhow!("Model not trained"));
262 }
263
264 let h = self.get_entity_quaternion(subject_id);
265 let r = self.get_relation_quaternion(predicate_id);
266 let t = self.get_entity_quaternion(object_id);
267
268 match self.scoring_function {
269 QuatDScoringFunction::Standard => {
270 let hr = h.multiply(&r);
272 Ok(hr.dot(&t))
273 }
274 QuatDScoringFunction::L2Distance => {
275 let hr = h.multiply(&r);
277 let diff = hr.subtract(&t);
278 Ok(-diff.norm())
279 }
280 QuatDScoringFunction::CosineSimilarity => {
281 let hr = h.multiply(&r);
283 let dot_product = hr.dot(&t);
284 let magnitude_product = hr.norm() * t.norm();
285 if magnitude_product > 1e-12 {
286 Ok(dot_product / magnitude_product)
287 } else {
288 Ok(0.0)
289 }
290 }
291 }
292 }
293
294 fn compute_gradients(
296 &self,
297 pos_triple: (usize, usize, usize),
298 neg_triple: (usize, usize, usize),
299 ) -> Result<(Array2<f64>, Array2<f64>)> {
300 let (pos_s, pos_p, pos_o) = pos_triple;
301 let (neg_s, neg_p, neg_o) = neg_triple;
302
303 let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
304 let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
305
306 let pos_score = self.score_triple_ids(pos_s, pos_p, pos_o)?;
308 let neg_score = self.score_triple_ids(neg_s, neg_p, neg_o)?;
309
310 let pos_sigmoid = 1.0 / (1.0 + (-pos_score).exp());
312 let neg_sigmoid = 1.0 / (1.0 + (-neg_score).exp());
313
314 let pos_grad = pos_sigmoid - 1.0;
315 let neg_grad = neg_sigmoid;
316
317 self.compute_triple_gradients(pos_triple, pos_grad, &mut entity_grads, &mut relation_grads);
319
320 self.compute_triple_gradients(neg_triple, neg_grad, &mut entity_grads, &mut relation_grads);
322
323 Ok((entity_grads, relation_grads))
324 }
325
326 fn compute_triple_gradients(
328 &self,
329 triple: (usize, usize, usize),
330 loss_grad: f64,
331 entity_grads: &mut Array2<f64>,
332 relation_grads: &mut Array2<f64>,
333 ) {
334 let (s, p, o) = triple;
335
336 let h = self.get_entity_quaternion(s);
337 let r = self.get_relation_quaternion(p);
338 let t = self.get_entity_quaternion(o);
339
340 match self.scoring_function {
341 QuatDScoringFunction::Standard => {
342 let hr = h.multiply(&r);
344
345 let r_conj = r.conjugate();
347 let grad_h = r_conj.multiply(&t).scale(loss_grad);
348
349 let h_conj = h.conjugate();
351 let grad_r = h_conj.multiply(&t).scale(loss_grad);
352
353 let grad_t = hr.scale(loss_grad);
355
356 let grad_h_arr = grad_h.to_array();
358 let grad_r_arr = grad_r.to_array();
359 let grad_t_arr = grad_t.to_array();
360
361 for i in 0..4 {
362 entity_grads[[s, i]] += grad_h_arr[i];
363 relation_grads[[p, i]] += grad_r_arr[i];
364 entity_grads[[o, i]] += grad_t_arr[i];
365 }
366 }
367 QuatDScoringFunction::L2Distance => {
368 let hr = h.multiply(&r);
370 let diff = hr.subtract(&t);
371 let norm = diff.norm();
372
373 if norm > 1e-12 {
374 let scale = -loss_grad / norm;
375
376 let r_conj = r.conjugate();
378 let grad_h = r_conj.scale(scale);
379
380 let h_conj = h.conjugate();
381 let grad_r = h_conj.scale(scale);
382
383 let grad_t = diff.scale(-scale);
384
385 let grad_h_arr = grad_h.to_array();
386 let grad_r_arr = grad_r.to_array();
387 let grad_t_arr = grad_t.to_array();
388
389 for i in 0..4 {
390 entity_grads[[s, i]] += grad_h_arr[i];
391 relation_grads[[p, i]] += grad_r_arr[i];
392 entity_grads[[o, i]] += grad_t_arr[i];
393 }
394 }
395 }
396 QuatDScoringFunction::CosineSimilarity => {
397 let hr = h.multiply(&r);
399 let dot_product = hr.dot(&t);
400 let hr_norm = hr.norm();
401 let t_norm = t.norm();
402 let magnitude_product = hr_norm * t_norm;
403
404 if magnitude_product > 1e-12 {
405 let cos_sim = dot_product / magnitude_product;
406
407 let scale = loss_grad / magnitude_product;
409
410 let grad_hr = t
411 .subtract(&hr.scale(cos_sim / (hr_norm * hr_norm)))
412 .scale(scale);
413 let grad_t = hr
414 .subtract(&t.scale(cos_sim / (t_norm * t_norm)))
415 .scale(scale);
416
417 let r_conj = r.conjugate();
419 let grad_h = r_conj.multiply(&grad_hr);
420
421 let h_conj = h.conjugate();
422 let grad_r = h_conj.multiply(&grad_hr);
423
424 let grad_h_arr = grad_h.to_array();
425 let grad_r_arr = grad_r.to_array();
426 let grad_t_arr = grad_t.to_array();
427
428 for i in 0..4 {
429 entity_grads[[s, i]] += grad_h_arr[i];
430 relation_grads[[p, i]] += grad_r_arr[i];
431 entity_grads[[o, i]] += grad_t_arr[i];
432 }
433 }
434 }
435 }
436 }
437
438 async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
440 let mut rng = Random::seed(self.base.config.seed.unwrap_or_else(|| {
441 use std::time::{SystemTime, UNIX_EPOCH};
442 SystemTime::now()
443 .duration_since(UNIX_EPOCH)
444 .expect("system time should be after UNIX_EPOCH")
445 .as_secs()
446 }));
447
448 let mut total_loss = 0.0;
449 let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
450 / self.base.config.batch_size;
451
452 let mut shuffled_triples = self.base.triples.clone();
454 shuffled_triples.shuffle(&mut rng);
455
456 for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
457 let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
458 let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
459 let mut batch_loss = 0.0;
460
461 for &pos_triple in batch_triples {
462 let neg_samples = self
464 .base
465 .generate_negative_samples(self.base.config.negative_samples, &mut rng);
466
467 for neg_triple in neg_samples {
468 let pos_score =
470 self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
471 let neg_score =
472 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
473
474 let pos_loss = -(1.0 / (1.0 + (-pos_score).exp())).ln();
476 let neg_loss = -(1.0 / (1.0 + neg_score.exp())).ln();
477 let loss = pos_loss + neg_loss;
478 batch_loss += loss;
479
480 let (entity_grads, relation_grads) =
482 self.compute_gradients(pos_triple, neg_triple)?;
483
484 batch_entity_grads += &entity_grads;
485 batch_relation_grads += &relation_grads;
486 }
487 }
488
489 if batch_loss > 0.0 {
491 for (((_i, _j), embedding_val), grad_val) in self
493 .entity_embeddings
494 .indexed_iter_mut()
495 .zip(batch_entity_grads.iter())
496 {
497 let reg_term = self.quaternion_regularization * *embedding_val;
498 *embedding_val -= learning_rate * (grad_val + reg_term);
499 }
500
501 for (((_i, _j), embedding_val), grad_val) in self
503 .relation_embeddings
504 .indexed_iter_mut()
505 .zip(batch_relation_grads.iter())
506 {
507 let reg_term = self.quaternion_regularization * *embedding_val;
508 *embedding_val -= learning_rate * (grad_val + reg_term);
509 }
510
511 self.normalize_all_quaternions();
513 }
514
515 total_loss += batch_loss;
516 }
517
518 Ok(total_loss / num_batches as f64)
519 }
520}
521
522#[async_trait]
523impl EmbeddingModel for QuatD {
524 fn config(&self) -> &ModelConfig {
525 &self.base.config
526 }
527
528 fn model_id(&self) -> &Uuid {
529 &self.base.model_id
530 }
531
532 fn model_type(&self) -> &'static str {
533 "QuatD"
534 }
535
536 fn add_triple(&mut self, triple: Triple) -> Result<()> {
537 self.base.add_triple(triple)
538 }
539
540 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
541 let start_time = Instant::now();
542 let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
543
544 self.initialize_embeddings();
546
547 if !self.embeddings_initialized {
548 return Err(anyhow!("No training data available"));
549 }
550
551 let mut loss_history = Vec::new();
552 let learning_rate = self.base.config.learning_rate;
553
554 info!("Starting QuatD training for {} epochs", max_epochs);
555
556 for epoch in 0..max_epochs {
557 let epoch_loss = self.train_epoch(learning_rate).await?;
558 loss_history.push(epoch_loss);
559
560 if epoch % 100 == 0 {
561 debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
562 }
563
564 if epoch > 10 && epoch_loss < 1e-6 {
566 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
567 break;
568 }
569 }
570
571 self.base.mark_trained();
572 let training_time = start_time.elapsed().as_secs_f64();
573
574 Ok(TrainingStats {
575 epochs_completed: loss_history.len(),
576 final_loss: loss_history.last().copied().unwrap_or(0.0),
577 training_time_seconds: training_time,
578 convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
579 loss_history,
580 })
581 }
582
583 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
584 if !self.embeddings_initialized {
585 return Err(anyhow!("Model not trained"));
586 }
587
588 let entity_id = self
589 .base
590 .get_entity_id(entity)
591 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
592
593 let embedding = self.entity_embeddings.row(entity_id).to_owned();
594 Ok(Vector::new(
595 embedding.to_vec().into_iter().map(|x| x as f32).collect(),
596 ))
597 }
598
599 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
600 if !self.embeddings_initialized {
601 return Err(anyhow!("Model not trained"));
602 }
603
604 let relation_id = self
605 .base
606 .get_relation_id(relation)
607 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
608
609 let embedding = self.relation_embeddings.row(relation_id).to_owned();
610 Ok(Vector::new(
611 embedding.to_vec().into_iter().map(|x| x as f32).collect(),
612 ))
613 }
614
615 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
616 let subject_id = self
617 .base
618 .get_entity_id(subject)
619 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
620 let predicate_id = self
621 .base
622 .get_relation_id(predicate)
623 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
624 let object_id = self
625 .base
626 .get_entity_id(object)
627 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
628
629 self.score_triple_ids(subject_id, predicate_id, object_id)
630 }
631
632 fn predict_objects(
633 &self,
634 subject: &str,
635 predicate: &str,
636 k: usize,
637 ) -> Result<Vec<(String, f64)>> {
638 if !self.embeddings_initialized {
639 return Err(anyhow!("Model not trained"));
640 }
641
642 let subject_id = self
643 .base
644 .get_entity_id(subject)
645 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
646 let predicate_id = self
647 .base
648 .get_relation_id(predicate)
649 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
650
651 let mut scores = Vec::new();
652
653 for object_id in 0..self.base.num_entities() {
654 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
655 let object_name = self
656 .base
657 .get_entity(object_id)
658 .expect("entity should exist in index")
659 .clone();
660 scores.push((object_name, score));
661 }
662
663 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("scores should be comparable"));
664 scores.truncate(k);
665
666 Ok(scores)
667 }
668
669 fn predict_subjects(
670 &self,
671 predicate: &str,
672 object: &str,
673 k: usize,
674 ) -> Result<Vec<(String, f64)>> {
675 if !self.embeddings_initialized {
676 return Err(anyhow!("Model not trained"));
677 }
678
679 let predicate_id = self
680 .base
681 .get_relation_id(predicate)
682 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
683 let object_id = self
684 .base
685 .get_entity_id(object)
686 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
687
688 let mut scores = Vec::new();
689
690 for subject_id in 0..self.base.num_entities() {
691 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
692 let subject_name = self
693 .base
694 .get_entity(subject_id)
695 .expect("entity should exist in index")
696 .clone();
697 scores.push((subject_name, score));
698 }
699
700 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("scores should be comparable"));
701 scores.truncate(k);
702
703 Ok(scores)
704 }
705
706 fn predict_relations(
707 &self,
708 subject: &str,
709 object: &str,
710 k: usize,
711 ) -> Result<Vec<(String, f64)>> {
712 if !self.embeddings_initialized {
713 return Err(anyhow!("Model not trained"));
714 }
715
716 let subject_id = self
717 .base
718 .get_entity_id(subject)
719 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
720 let object_id = self
721 .base
722 .get_entity_id(object)
723 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
724
725 let mut scores = Vec::new();
726
727 for predicate_id in 0..self.base.num_relations() {
728 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
729 let predicate_name = self
730 .base
731 .get_relation(predicate_id)
732 .expect("relation should exist in index")
733 .clone();
734 scores.push((predicate_name, score));
735 }
736
737 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("scores should be comparable"));
738 scores.truncate(k);
739
740 Ok(scores)
741 }
742
743 fn get_entities(&self) -> Vec<String> {
744 self.base.get_entities()
745 }
746
747 fn get_relations(&self) -> Vec<String> {
748 self.base.get_relations()
749 }
750
751 fn get_stats(&self) -> ModelStats {
752 self.base.get_stats("QuatD")
753 }
754
755 fn save(&self, path: &str) -> Result<()> {
756 info!("Saving QuatD model to {}", path);
757 Ok(())
758 }
759
760 fn load(&mut self, path: &str) -> Result<()> {
761 info!("Loading QuatD model from {}", path);
762 Ok(())
763 }
764
765 fn clear(&mut self) {
766 self.base.clear();
767 self.entity_embeddings = Array2::zeros((0, 4));
768 self.relation_embeddings = Array2::zeros((0, 4));
769 self.embeddings_initialized = false;
770 }
771
772 fn is_trained(&self) -> bool {
773 self.base.is_trained
774 }
775
776 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
777 Err(anyhow!(
778 "Knowledge graph embedding model does not support text encoding"
779 ))
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use super::*;
786 use crate::NamedNode;
787
788 #[test]
789 fn test_quaternion_operations() {
790 let q1 = Quaternion::new(1.0, 2.0, 3.0, 4.0);
791 let q2 = Quaternion::new(2.0, 3.0, 4.0, 5.0);
792
793 let product = q1.multiply(&q2);
795 assert!(product.w.is_finite());
796
797 let conj = q1.conjugate();
799 assert_eq!(conj.w, q1.w);
800 assert_eq!(conj.x, -q1.x);
801
802 let mut q3 = q1;
804 q3.normalize();
805 assert!((q3.norm() - 1.0).abs() < 1e-10);
806 }
807
808 #[tokio::test]
809 async fn test_quatd_basic() -> Result<()> {
810 let config = ModelConfig::default()
811 .with_dimensions(4) .with_max_epochs(10)
813 .with_seed(42);
814
815 let mut model = QuatD::new(config);
816
817 let alice = NamedNode::new("http://example.org/alice")?;
819 let knows = NamedNode::new("http://example.org/knows")?;
820 let bob = NamedNode::new("http://example.org/bob")?;
821
822 model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
823 model.add_triple(Triple::new(bob.clone(), knows.clone(), alice.clone()))?;
824
825 let stats = model.train(Some(5)).await?;
827 assert!(stats.epochs_completed > 0);
828
829 let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
831 assert_eq!(alice_emb.dimensions, 4); let score = model.score_triple(
835 "http://example.org/alice",
836 "http://example.org/knows",
837 "http://example.org/bob",
838 )?;
839
840 assert!(score.is_finite());
842
843 Ok(())
844 }
845
846 #[test]
847 fn test_quatd_creation() {
848 let config = ModelConfig::default();
849 let quatd = QuatD::new(config);
850 assert!(!quatd.embeddings_initialized);
851 assert_eq!(quatd.model_type(), "QuatD");
852 }
853}