1use crate::models::{common::*, BaseModel};
9use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
10use anyhow::{anyhow, Result};
11use async_trait::async_trait;
12use scirs2_core::ndarray_ext::{Array1, Array2};
13#[allow(unused_imports)]
14use scirs2_core::random::{Random, Rng};
15use serde::{Deserialize, Serialize};
16use std::ops::{AddAssign, SubAssign};
17use std::time::Instant;
18use tracing::{debug, info};
19use uuid::Uuid;
20
21#[derive(Debug, Clone)]
23pub struct TransE {
24 base: BaseModel,
26 entity_embeddings: Array2<f64>,
28 relation_embeddings: Array2<f64>,
30 embeddings_initialized: bool,
32 distance_metric: DistanceMetric,
34 margin: f64,
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
40pub enum DistanceMetric {
41 L1,
43 L2,
45 Cosine,
47}
48
49impl TransE {
50 pub fn new(config: ModelConfig) -> Self {
52 let base = BaseModel::new(config.clone());
53
54 let distance_metric = match config.model_params.get("distance_metric") {
56 Some(0.0) => DistanceMetric::L1,
57 Some(1.0) => DistanceMetric::L2,
58 Some(2.0) => DistanceMetric::Cosine,
59 _ => DistanceMetric::L2, };
61
62 let margin = config.model_params.get("margin").copied().unwrap_or(1.0);
63
64 Self {
65 base,
66 entity_embeddings: Array2::zeros((0, config.dimensions)),
67 relation_embeddings: Array2::zeros((0, config.dimensions)),
68 embeddings_initialized: false,
69 distance_metric,
70 margin,
71 }
72 }
73
74 pub fn with_l1_distance(mut config: ModelConfig) -> Self {
76 config
77 .model_params
78 .insert("distance_metric".to_string(), 0.0);
79 Self::new(config)
80 }
81
82 pub fn with_l2_distance(mut config: ModelConfig) -> Self {
84 config
85 .model_params
86 .insert("distance_metric".to_string(), 1.0);
87 Self::new(config)
88 }
89
90 pub fn with_cosine_distance(mut config: ModelConfig) -> Self {
92 config
93 .model_params
94 .insert("distance_metric".to_string(), 2.0);
95 Self::new(config)
96 }
97
98 pub fn with_margin(mut config: ModelConfig, margin: f64) -> Self {
100 config.model_params.insert("margin".to_string(), margin);
101 Self::new(config)
102 }
103
104 pub fn distance_metric(&self) -> DistanceMetric {
106 self.distance_metric
107 }
108
109 pub fn margin(&self) -> f64 {
111 self.margin
112 }
113
114 fn initialize_embeddings(&mut self) {
116 if self.embeddings_initialized {
117 return;
118 }
119
120 let num_entities = self.base.num_entities();
121 let num_relations = self.base.num_relations();
122 let dimensions = self.base.config.dimensions;
123
124 if num_entities == 0 || num_relations == 0 {
125 return;
126 }
127
128 let mut rng = Random::default();
129
130 self.entity_embeddings =
132 xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
133
134 self.relation_embeddings = xavier_init(
136 (num_relations, dimensions),
137 dimensions,
138 dimensions,
139 &mut rng,
140 );
141
142 normalize_embeddings(&mut self.entity_embeddings);
144
145 self.embeddings_initialized = true;
146 debug!(
147 "Initialized TransE embeddings: {} entities, {} relations, {} dimensions",
148 num_entities, num_relations, dimensions
149 );
150 }
151
152 fn score_triple_ids(
154 &self,
155 subject_id: usize,
156 predicate_id: usize,
157 object_id: usize,
158 ) -> Result<f64> {
159 if !self.embeddings_initialized {
160 return Err(anyhow!("Model not trained"));
161 }
162
163 let h = self.entity_embeddings.row(subject_id);
164 let r = self.relation_embeddings.row(predicate_id);
165 let t = self.entity_embeddings.row(object_id);
166
167 let diff = &h + &r - t;
169
170 let distance = match self.distance_metric {
172 DistanceMetric::L1 => diff.mapv(|x| x.abs()).sum(),
173 DistanceMetric::L2 => diff.mapv(|x| x * x).sum().sqrt(),
174 DistanceMetric::Cosine => {
175 let h_plus_r = &h + &r;
177 let dot_product = (&h_plus_r * &t).sum();
178 let norm_h_plus_r = h_plus_r.mapv(|x| x * x).sum().sqrt();
179 let norm_t = t.mapv(|x| x * x).sum().sqrt();
180
181 if norm_h_plus_r == 0.0 || norm_t == 0.0 {
182 1.0 } else {
184 let cosine_sim = dot_product / (norm_h_plus_r * norm_t);
185 1.0 - cosine_sim.clamp(-1.0, 1.0) }
187 }
188 };
189
190 Ok(-distance)
192 }
193
194 fn compute_gradients(
196 &self,
197 pos_triple: (usize, usize, usize),
198 neg_triple: (usize, usize, usize),
199 ) -> Result<(Array2<f64>, Array2<f64>)> {
200 let (pos_s, pos_p, pos_o) = pos_triple;
201 let (neg_s, neg_p, neg_o) = neg_triple;
202
203 let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
204 let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
205
206 let pos_h = self.entity_embeddings.row(pos_s);
208 let pos_r = self.relation_embeddings.row(pos_p);
209 let pos_t = self.entity_embeddings.row(pos_o);
210
211 let neg_h = self.entity_embeddings.row(neg_s);
212 let neg_r = self.relation_embeddings.row(neg_p);
213 let neg_t = self.entity_embeddings.row(neg_o);
214
215 let pos_diff = &pos_h + &pos_r - pos_t;
217 let neg_diff = &neg_h + &neg_r - neg_t;
218
219 let pos_distance = match self.distance_metric {
221 DistanceMetric::L1 => pos_diff.mapv(|x| x.abs()).sum(),
222 DistanceMetric::L2 => pos_diff.mapv(|x| x * x).sum().sqrt(),
223 DistanceMetric::Cosine => {
224 let norm = pos_diff.mapv(|x| x * x).sum().sqrt();
225 if norm > 1e-10 {
226 1.0 - (pos_diff.dot(&pos_diff) / (norm * norm)).clamp(-1.0, 1.0)
227 } else {
228 0.0
229 }
230 }
231 };
232
233 let neg_distance = match self.distance_metric {
234 DistanceMetric::L1 => neg_diff.mapv(|x| x.abs()).sum(),
235 DistanceMetric::L2 => neg_diff.mapv(|x| x * x).sum().sqrt(),
236 DistanceMetric::Cosine => {
237 let norm = neg_diff.mapv(|x| x * x).sum().sqrt();
238 if norm > 1e-10 {
239 1.0 - (neg_diff.dot(&neg_diff) / (norm * norm)).clamp(-1.0, 1.0)
240 } else {
241 0.0
242 }
243 }
244 };
245
246 let loss = self.margin + pos_distance - neg_distance;
248 if loss > 0.0 {
249 let pos_grad_direction = match self.distance_metric {
251 DistanceMetric::L1 => pos_diff.mapv(|x| {
252 if x > 0.0 {
253 1.0
254 } else if x < 0.0 {
255 -1.0
256 } else {
257 0.0
258 }
259 }),
260 DistanceMetric::L2 => {
261 if pos_distance > 1e-10 {
262 &pos_diff / pos_distance
263 } else {
264 Array1::zeros(pos_diff.len())
265 }
266 }
267 DistanceMetric::Cosine => {
268 let norm_sq = pos_diff.mapv(|x| x * x).sum();
269 if norm_sq > 1e-10 {
270 &pos_diff / norm_sq.sqrt()
271 } else {
272 Array1::zeros(pos_diff.len())
273 }
274 }
275 };
276
277 let neg_grad_direction = match self.distance_metric {
278 DistanceMetric::L1 => neg_diff.mapv(|x| {
279 if x > 0.0 {
280 1.0
281 } else if x < 0.0 {
282 -1.0
283 } else {
284 0.0
285 }
286 }),
287 DistanceMetric::L2 => {
288 if neg_distance > 1e-10 {
289 &neg_diff / neg_distance
290 } else {
291 Array1::zeros(neg_diff.len())
292 }
293 }
294 DistanceMetric::Cosine => {
295 let norm_sq = neg_diff.mapv(|x| x * x).sum();
296 if norm_sq > 1e-10 {
297 &neg_diff / norm_sq.sqrt()
298 } else {
299 Array1::zeros(neg_diff.len())
300 }
301 }
302 };
303
304 entity_grads.row_mut(pos_s).add_assign(&pos_grad_direction);
306 relation_grads
307 .row_mut(pos_p)
308 .add_assign(&pos_grad_direction);
309 entity_grads.row_mut(pos_o).sub_assign(&pos_grad_direction);
310
311 entity_grads.row_mut(neg_s).sub_assign(&neg_grad_direction);
313 relation_grads
314 .row_mut(neg_p)
315 .sub_assign(&neg_grad_direction);
316 entity_grads.row_mut(neg_o).add_assign(&neg_grad_direction);
317 }
318
319 Ok((entity_grads, relation_grads))
320 }
321
322 async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
324 let mut rng = Random::default();
325
326 let mut total_loss = 0.0;
327 let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
328 / self.base.config.batch_size;
329
330 let mut shuffled_triples = self.base.triples.clone();
332 for i in (1..shuffled_triples.len()).rev() {
334 let j = rng.random_range(0..i + 1);
335 shuffled_triples.swap(i, j);
336 }
337
338 for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
339 let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
340 let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
341 let mut batch_loss = 0.0;
342
343 for &pos_triple in batch_triples {
344 let neg_samples = self
346 .base
347 .generate_negative_samples(self.base.config.negative_samples, &mut rng);
348
349 for neg_triple in neg_samples {
350 let pos_score =
352 self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
353 let neg_score =
354 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
355
356 let pos_distance = -pos_score;
358 let neg_distance = -neg_score;
359
360 let loss = margin_loss(pos_distance, neg_distance, self.margin);
362 batch_loss += loss;
363
364 if loss > 0.0 {
365 let (entity_grads, relation_grads) =
367 self.compute_gradients(pos_triple, neg_triple)?;
368 batch_entity_grads += &entity_grads;
369 batch_relation_grads += &relation_grads;
370 }
371 }
372 }
373
374 if batch_loss > 0.0 {
376 gradient_update(
377 &mut self.entity_embeddings,
378 &batch_entity_grads,
379 learning_rate,
380 self.base.config.l2_reg,
381 );
382
383 gradient_update(
384 &mut self.relation_embeddings,
385 &batch_relation_grads,
386 learning_rate,
387 self.base.config.l2_reg,
388 );
389
390 normalize_embeddings(&mut self.entity_embeddings);
392 }
393
394 total_loss += batch_loss;
395 }
396
397 Ok(total_loss / num_batches as f64)
398 }
399}
400
401#[async_trait]
402impl EmbeddingModel for TransE {
403 fn config(&self) -> &ModelConfig {
404 &self.base.config
405 }
406
407 fn model_id(&self) -> &Uuid {
408 &self.base.model_id
409 }
410
411 fn model_type(&self) -> &'static str {
412 "TransE"
413 }
414
415 fn add_triple(&mut self, triple: Triple) -> Result<()> {
416 self.base.add_triple(triple)
417 }
418
419 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
420 let start_time = Instant::now();
421 let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
422
423 self.initialize_embeddings();
425
426 if !self.embeddings_initialized {
427 return Err(anyhow!("No training data available"));
428 }
429
430 let mut loss_history = Vec::new();
431 let learning_rate = self.base.config.learning_rate;
432
433 info!("Starting TransE training for {} epochs", max_epochs);
434
435 for epoch in 0..max_epochs {
436 let epoch_loss = self.train_epoch(learning_rate).await?;
437 loss_history.push(epoch_loss);
438
439 if epoch % 100 == 0 {
440 debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
441 }
442
443 if epoch > 10 && epoch_loss < 1e-6 {
445 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
446 break;
447 }
448 }
449
450 self.base.mark_trained();
451 let training_time = start_time.elapsed().as_secs_f64();
452
453 Ok(TrainingStats {
454 epochs_completed: loss_history.len(),
455 final_loss: loss_history.last().copied().unwrap_or(0.0),
456 training_time_seconds: training_time,
457 convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
458 loss_history,
459 })
460 }
461
462 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
463 if !self.embeddings_initialized {
464 return Err(anyhow!("Model not trained"));
465 }
466
467 let entity_id = self
468 .base
469 .get_entity_id(entity)
470 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
471
472 let embedding = self.entity_embeddings.row(entity_id).to_owned();
473 Ok(ndarray_to_vector(&embedding))
474 }
475
476 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
477 if !self.embeddings_initialized {
478 return Err(anyhow!("Model not trained"));
479 }
480
481 let relation_id = self
482 .base
483 .get_relation_id(relation)
484 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
485
486 let embedding = self.relation_embeddings.row(relation_id).to_owned();
487 Ok(ndarray_to_vector(&embedding))
488 }
489
490 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
491 let subject_id = self
492 .base
493 .get_entity_id(subject)
494 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
495 let predicate_id = self
496 .base
497 .get_relation_id(predicate)
498 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
499 let object_id = self
500 .base
501 .get_entity_id(object)
502 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
503
504 self.score_triple_ids(subject_id, predicate_id, object_id)
505 }
506
507 fn predict_objects(
508 &self,
509 subject: &str,
510 predicate: &str,
511 k: usize,
512 ) -> Result<Vec<(String, f64)>> {
513 if !self.embeddings_initialized {
514 return Err(anyhow!("Model not trained"));
515 }
516
517 let subject_id = self
518 .base
519 .get_entity_id(subject)
520 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
521 let predicate_id = self
522 .base
523 .get_relation_id(predicate)
524 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
525
526 let mut scores = Vec::new();
527
528 for object_id in 0..self.base.num_entities() {
529 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
530 let object_name = self.base.get_entity(object_id).unwrap().clone();
531 scores.push((object_name, score));
532 }
533
534 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
535 scores.truncate(k);
536
537 Ok(scores)
538 }
539
540 fn predict_subjects(
541 &self,
542 predicate: &str,
543 object: &str,
544 k: usize,
545 ) -> Result<Vec<(String, f64)>> {
546 if !self.embeddings_initialized {
547 return Err(anyhow!("Model not trained"));
548 }
549
550 let predicate_id = self
551 .base
552 .get_relation_id(predicate)
553 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
554 let object_id = self
555 .base
556 .get_entity_id(object)
557 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
558
559 let mut scores = Vec::new();
560
561 for subject_id in 0..self.base.num_entities() {
562 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
563 let subject_name = self.base.get_entity(subject_id).unwrap().clone();
564 scores.push((subject_name, score));
565 }
566
567 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
568 scores.truncate(k);
569
570 Ok(scores)
571 }
572
573 fn predict_relations(
574 &self,
575 subject: &str,
576 object: &str,
577 k: usize,
578 ) -> Result<Vec<(String, f64)>> {
579 if !self.embeddings_initialized {
580 return Err(anyhow!("Model not trained"));
581 }
582
583 let subject_id = self
584 .base
585 .get_entity_id(subject)
586 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
587 let object_id = self
588 .base
589 .get_entity_id(object)
590 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
591
592 let mut scores = Vec::new();
593
594 for predicate_id in 0..self.base.num_relations() {
595 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
596 let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
597 scores.push((predicate_name, score));
598 }
599
600 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
601 scores.truncate(k);
602
603 Ok(scores)
604 }
605
606 fn get_entities(&self) -> Vec<String> {
607 self.base.get_entities()
608 }
609
610 fn get_relations(&self) -> Vec<String> {
611 self.base.get_relations()
612 }
613
614 fn get_stats(&self) -> ModelStats {
615 self.base.get_stats("TransE")
616 }
617
618 fn save(&self, path: &str) -> Result<()> {
619 info!("Saving TransE model to {}", path);
622 Ok(())
623 }
624
625 fn load(&mut self, path: &str) -> Result<()> {
626 info!("Loading TransE model from {}", path);
629 Ok(())
630 }
631
632 fn clear(&mut self) {
633 self.base.clear();
634 self.entity_embeddings = Array2::zeros((0, self.base.config.dimensions));
635 self.relation_embeddings = Array2::zeros((0, self.base.config.dimensions));
636 self.embeddings_initialized = false;
637 }
638
639 fn is_trained(&self) -> bool {
640 self.base.is_trained
641 }
642
643 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
644 Err(anyhow!(
645 "TransE is a knowledge graph embedding model and does not support text encoding"
646 ))
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use crate::NamedNode;
654
655 #[tokio::test]
656 async fn test_transe_basic() -> Result<()> {
657 let config = ModelConfig::default()
658 .with_dimensions(50)
659 .with_max_epochs(10)
660 .with_seed(42);
661
662 let mut model = TransE::new(config);
663
664 let alice = NamedNode::new("http://example.org/alice")?;
666 let knows = NamedNode::new("http://example.org/knows")?;
667 let bob = NamedNode::new("http://example.org/bob")?;
668
669 model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
670 model.add_triple(Triple::new(bob.clone(), knows.clone(), alice.clone()))?;
671
672 let stats = model.train(Some(5)).await?;
674 assert!(stats.epochs_completed > 0);
675
676 let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
678 assert_eq!(alice_emb.dimensions, 50);
679
680 let score = model.score_triple(
682 "http://example.org/alice",
683 "http://example.org/knows",
684 "http://example.org/bob",
685 )?;
686
687 assert!(score.is_finite());
689
690 Ok(())
691 }
692
693 #[tokio::test]
694 async fn test_transe_distance_metrics() -> Result<()> {
695 let base_config = ModelConfig::default()
696 .with_dimensions(10)
697 .with_max_epochs(5)
698 .with_seed(42);
699
700 let mut model_l1 = TransE::with_l1_distance(base_config.clone());
702 assert!(matches!(model_l1.distance_metric(), DistanceMetric::L1));
703
704 let mut model_l2 = TransE::with_l2_distance(base_config.clone());
706 assert!(matches!(model_l2.distance_metric(), DistanceMetric::L2));
707
708 let mut model_cosine = TransE::with_cosine_distance(base_config.clone());
710 assert!(matches!(
711 model_cosine.distance_metric(),
712 DistanceMetric::Cosine
713 ));
714
715 let model_margin = TransE::with_margin(base_config.clone(), 2.0);
717 assert_eq!(model_margin.margin(), 2.0);
718
719 let alice = NamedNode::new("http://example.org/alice")?;
721 let knows = NamedNode::new("http://example.org/knows")?;
722 let bob = NamedNode::new("http://example.org/bob")?;
723 let triple = Triple::new(alice, knows, bob);
724
725 model_l1.add_triple(triple.clone())?;
726 model_l2.add_triple(triple.clone())?;
727 model_cosine.add_triple(triple.clone())?;
728
729 model_l1.train(Some(3)).await?;
731 model_l2.train(Some(3)).await?;
732 model_cosine.train(Some(3)).await?;
733
734 let score_l1 = model_l1.score_triple(
736 "http://example.org/alice",
737 "http://example.org/knows",
738 "http://example.org/bob",
739 )?;
740 let score_l2 = model_l2.score_triple(
741 "http://example.org/alice",
742 "http://example.org/knows",
743 "http://example.org/bob",
744 )?;
745 let score_cosine = model_cosine.score_triple(
746 "http://example.org/alice",
747 "http://example.org/knows",
748 "http://example.org/bob",
749 )?;
750
751 assert!(score_l1.is_finite());
752 assert!(score_l2.is_finite());
753 assert!(score_cosine.is_finite());
754
755 println!("L1 score: {score_l1}, L2 score: {score_l2}, Cosine score: {score_cosine}");
758
759 Ok(())
760 }
761}