1use crate::models::{common::*, BaseModel};
9use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
10use anyhow::{anyhow, Result};
11use async_trait::async_trait;
12use scirs2_core::ndarray_ext::{Array2, Array3};
13use scirs2_core::random::{Random, Rng, SliceRandom};
14use std::time::Instant;
15use tracing::{debug, info};
16use uuid::Uuid;
17
18#[derive(Debug)]
20pub struct TuckER {
21 base: BaseModel,
23 entity_embeddings: Array2<f64>,
25 relation_embeddings: Array2<f64>,
27 core_tensor: Array3<f64>,
29 embeddings_initialized: bool,
31 entity_dim: usize,
33 relation_dim: usize,
35 core_dims: (usize, usize, usize),
37 dropout_rate: f64,
39 batch_norm: bool,
41}
42
43impl TuckER {
44 pub fn new(config: ModelConfig) -> Self {
46 let base = BaseModel::new(config.clone());
47
48 let entity_dim = config
50 .model_params
51 .get("entity_dim")
52 .map(|&v| v as usize)
53 .unwrap_or(config.dimensions);
54 let relation_dim = config
55 .model_params
56 .get("relation_dim")
57 .map(|&v| v as usize)
58 .unwrap_or(config.dimensions);
59 let core_dim1 = config
60 .model_params
61 .get("core_dim1")
62 .map(|&v| v as usize)
63 .unwrap_or(config.dimensions);
64 let core_dim2 = config
65 .model_params
66 .get("core_dim2")
67 .map(|&v| v as usize)
68 .unwrap_or(config.dimensions);
69 let core_dim3 = config
70 .model_params
71 .get("core_dim3")
72 .map(|&v| v as usize)
73 .unwrap_or(config.dimensions);
74 let dropout_rate = config
75 .model_params
76 .get("dropout_rate")
77 .copied()
78 .unwrap_or(0.3);
79 let batch_norm = config
80 .model_params
81 .get("batch_norm")
82 .map(|&v| v > 0.0)
83 .unwrap_or(true);
84
85 Self {
86 base,
87 entity_embeddings: Array2::zeros((0, entity_dim)),
88 relation_embeddings: Array2::zeros((0, relation_dim)),
89 core_tensor: Array3::zeros((core_dim1, core_dim2, core_dim3)),
90 embeddings_initialized: false,
91 entity_dim,
92 relation_dim,
93 core_dims: (core_dim1, core_dim2, core_dim3),
94 dropout_rate,
95 batch_norm,
96 }
97 }
98
99 fn initialize_embeddings(&mut self) {
101 if self.embeddings_initialized {
102 return;
103 }
104
105 let num_entities = self.base.num_entities();
106 let num_relations = self.base.num_relations();
107
108 if num_entities == 0 || num_relations == 0 {
109 return;
110 }
111
112 let mut rng = Random::seed(self.base.config.seed.unwrap_or_else(|| {
113 use std::time::{SystemTime, UNIX_EPOCH};
114 SystemTime::now()
115 .duration_since(UNIX_EPOCH)
116 .unwrap()
117 .as_secs()
118 }));
119
120 self.entity_embeddings = xavier_init(
122 (num_entities, self.entity_dim),
123 self.entity_dim,
124 self.entity_dim,
125 &mut rng,
126 );
127
128 self.relation_embeddings = xavier_init(
130 (num_relations, self.relation_dim),
131 self.relation_dim,
132 self.relation_dim,
133 &mut rng,
134 );
135
136 let total_elements = self.core_dims.0 * self.core_dims.1 * self.core_dims.2;
138 let std_dev = (2.0 / total_elements as f64).sqrt();
139
140 for elem in self.core_tensor.iter_mut() {
141 *elem = rng.random_range(-std_dev..std_dev);
142 }
143
144 normalize_embeddings(&mut self.entity_embeddings);
146 normalize_embeddings(&mut self.relation_embeddings);
147
148 self.embeddings_initialized = true;
149 debug!(
150 "Initialized TuckER embeddings: {} entities ({}D), {} relations ({}D), core tensor {:?}",
151 num_entities, self.entity_dim, num_relations, self.relation_dim, self.core_dims
152 );
153 }
154
155 fn score_triple_ids(
157 &self,
158 subject_id: usize,
159 predicate_id: usize,
160 object_id: usize,
161 ) -> Result<f64> {
162 if !self.embeddings_initialized {
163 return Err(anyhow!("Model not trained"));
164 }
165
166 let h = self.entity_embeddings.row(subject_id);
167 let r = self.relation_embeddings.row(predicate_id);
168 let t = self.entity_embeddings.row(object_id);
169
170 let mut score = 0.0;
173
174 for i in 0..self.core_dims.0.min(h.len()) {
175 for j in 0..self.core_dims.1.min(r.len()) {
176 for k in 0..self.core_dims.2.min(t.len()) {
177 score += h[i] * r[j] * t[k] * self.core_tensor[(i, j, k)];
178 }
179 }
180 }
181
182 Ok(score)
183 }
184
185 fn compute_gradients(
187 &self,
188 pos_triple: (usize, usize, usize),
189 neg_triple: (usize, usize, usize),
190 _learning_rate: f64,
191 ) -> Result<(Array2<f64>, Array2<f64>, Array3<f64>)> {
192 let (pos_s, pos_p, pos_o) = pos_triple;
193 let (neg_s, neg_p, neg_o) = neg_triple;
194
195 let mut entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
196 let mut relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
197 let mut core_grads = Array3::zeros(self.core_tensor.raw_dim());
198
199 let pos_score = self.score_triple_ids(pos_s, pos_p, pos_o)?;
201 let neg_score = self.score_triple_ids(neg_s, neg_p, neg_o)?;
202
203 let pos_sigmoid = 1.0 / (1.0 + (-pos_score).exp());
205 let neg_sigmoid = 1.0 / (1.0 + (-neg_score).exp());
206
207 let pos_grad = pos_sigmoid - 1.0;
208 let neg_grad = neg_sigmoid;
209
210 self.compute_triple_gradients(
212 pos_triple,
213 pos_grad,
214 &mut entity_grads,
215 &mut relation_grads,
216 &mut core_grads,
217 );
218
219 self.compute_triple_gradients(
221 neg_triple,
222 neg_grad,
223 &mut entity_grads,
224 &mut relation_grads,
225 &mut core_grads,
226 );
227
228 Ok((entity_grads, relation_grads, core_grads))
229 }
230
231 fn compute_triple_gradients(
233 &self,
234 triple: (usize, usize, usize),
235 loss_grad: f64,
236 entity_grads: &mut Array2<f64>,
237 relation_grads: &mut Array2<f64>,
238 core_grads: &mut Array3<f64>,
239 ) {
240 let (s, p, o) = triple;
241
242 let h = self.entity_embeddings.row(s);
243 let r = self.relation_embeddings.row(p);
244 let t = self.entity_embeddings.row(o);
245
246 for i in 0..self.core_dims.0.min(h.len()) {
248 let mut h_grad = 0.0;
249 for j in 0..self.core_dims.1.min(r.len()) {
250 for k in 0..self.core_dims.2.min(t.len()) {
251 h_grad += r[j] * t[k] * self.core_tensor[(i, j, k)];
252 }
253 }
254 entity_grads[[s, i]] += loss_grad * h_grad;
255 }
256
257 for k in 0..self.core_dims.2.min(t.len()) {
258 let mut t_grad = 0.0;
259 for i in 0..self.core_dims.0.min(h.len()) {
260 for j in 0..self.core_dims.1.min(r.len()) {
261 t_grad += h[i] * r[j] * self.core_tensor[(i, j, k)];
262 }
263 }
264 entity_grads[[o, k]] += loss_grad * t_grad;
265 }
266
267 for j in 0..self.core_dims.1.min(r.len()) {
269 let mut r_grad = 0.0;
270 for i in 0..self.core_dims.0.min(h.len()) {
271 for k in 0..self.core_dims.2.min(t.len()) {
272 r_grad += h[i] * t[k] * self.core_tensor[(i, j, k)];
273 }
274 }
275 relation_grads[[p, j]] += loss_grad * r_grad;
276 }
277
278 for i in 0..self.core_dims.0.min(h.len()) {
280 for j in 0..self.core_dims.1.min(r.len()) {
281 for k in 0..self.core_dims.2.min(t.len()) {
282 core_grads[[i, j, k]] += loss_grad * h[i] * r[j] * t[k];
283 }
284 }
285 }
286 }
287
288 async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
290 let mut rng = Random::seed(self.base.config.seed.unwrap_or_else(|| {
291 use std::time::{SystemTime, UNIX_EPOCH};
292 SystemTime::now()
293 .duration_since(UNIX_EPOCH)
294 .unwrap()
295 .as_secs()
296 }));
297
298 let mut total_loss = 0.0;
299 let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
300 / self.base.config.batch_size;
301
302 let mut shuffled_triples = self.base.triples.clone();
304 shuffled_triples.shuffle(&mut rng);
305
306 for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
307 let mut batch_entity_grads = Array2::zeros(self.entity_embeddings.raw_dim());
308 let mut batch_relation_grads = Array2::zeros(self.relation_embeddings.raw_dim());
309 let mut batch_core_grads = Array3::zeros(self.core_tensor.raw_dim());
310 let mut batch_loss = 0.0;
311
312 for &pos_triple in batch_triples {
313 let neg_samples = self
315 .base
316 .generate_negative_samples(self.base.config.negative_samples, &mut rng);
317
318 for neg_triple in neg_samples {
319 let pos_score =
321 self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
322 let neg_score =
323 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
324
325 let pos_loss = -(1.0 / (1.0 + (-pos_score).exp())).ln();
327 let neg_loss = -(1.0 / (1.0 + neg_score.exp())).ln();
328 let loss = pos_loss + neg_loss;
329 batch_loss += loss;
330
331 let (entity_grads, relation_grads, core_grads) =
333 self.compute_gradients(pos_triple, neg_triple, learning_rate)?;
334
335 batch_entity_grads += &entity_grads;
336 batch_relation_grads += &relation_grads;
337 batch_core_grads += &core_grads;
338 }
339 }
340
341 if batch_loss > 0.0 {
343 gradient_update(
344 &mut self.entity_embeddings,
345 &batch_entity_grads,
346 learning_rate,
347 self.base.config.l2_reg,
348 );
349
350 gradient_update(
351 &mut self.relation_embeddings,
352 &batch_relation_grads,
353 learning_rate,
354 self.base.config.l2_reg,
355 );
356
357 for ((_i, _j, _k), value) in self.core_tensor.indexed_iter_mut() {
359 let reg_term = self.base.config.l2_reg * *value;
362 *value -= learning_rate * reg_term;
363 }
364
365 if self.dropout_rate > 0.0 {
367 apply_dropout(&mut self.entity_embeddings, self.dropout_rate, &mut rng);
368 apply_dropout(&mut self.relation_embeddings, self.dropout_rate, &mut rng);
369 }
370
371 normalize_embeddings(&mut self.entity_embeddings);
373 normalize_embeddings(&mut self.relation_embeddings);
374 }
375
376 total_loss += batch_loss;
377 }
378
379 Ok(total_loss / num_batches as f64)
380 }
381}
382
383#[async_trait]
384impl EmbeddingModel for TuckER {
385 fn config(&self) -> &ModelConfig {
386 &self.base.config
387 }
388
389 fn model_id(&self) -> &Uuid {
390 &self.base.model_id
391 }
392
393 fn model_type(&self) -> &'static str {
394 "TuckER"
395 }
396
397 fn add_triple(&mut self, triple: Triple) -> Result<()> {
398 self.base.add_triple(triple)
399 }
400
401 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
402 let start_time = Instant::now();
403 let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
404
405 self.initialize_embeddings();
407
408 if !self.embeddings_initialized {
409 return Err(anyhow!("No training data available"));
410 }
411
412 let mut loss_history = Vec::new();
413 let learning_rate = self.base.config.learning_rate;
414
415 info!("Starting TuckER training for {} epochs", max_epochs);
416
417 for epoch in 0..max_epochs {
418 let epoch_loss = self.train_epoch(learning_rate).await?;
419 loss_history.push(epoch_loss);
420
421 if epoch % 100 == 0 {
422 debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
423 }
424
425 if epoch > 10 && epoch_loss < 1e-6 {
427 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
428 break;
429 }
430 }
431
432 self.base.mark_trained();
433 let training_time = start_time.elapsed().as_secs_f64();
434
435 Ok(TrainingStats {
436 epochs_completed: loss_history.len(),
437 final_loss: loss_history.last().copied().unwrap_or(0.0),
438 training_time_seconds: training_time,
439 convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
440 loss_history,
441 })
442 }
443
444 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
445 if !self.embeddings_initialized {
446 return Err(anyhow!("Model not trained"));
447 }
448
449 let entity_id = self
450 .base
451 .get_entity_id(entity)
452 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
453
454 let embedding = self.entity_embeddings.row(entity_id).to_owned();
455 Ok(ndarray_to_vector(&embedding))
456 }
457
458 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
459 if !self.embeddings_initialized {
460 return Err(anyhow!("Model not trained"));
461 }
462
463 let relation_id = self
464 .base
465 .get_relation_id(relation)
466 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
467
468 let embedding = self.relation_embeddings.row(relation_id).to_owned();
469 Ok(ndarray_to_vector(&embedding))
470 }
471
472 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
473 let subject_id = self
474 .base
475 .get_entity_id(subject)
476 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
477 let predicate_id = self
478 .base
479 .get_relation_id(predicate)
480 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
481 let object_id = self
482 .base
483 .get_entity_id(object)
484 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
485
486 self.score_triple_ids(subject_id, predicate_id, object_id)
487 }
488
489 fn predict_objects(
490 &self,
491 subject: &str,
492 predicate: &str,
493 k: usize,
494 ) -> Result<Vec<(String, f64)>> {
495 if !self.embeddings_initialized {
496 return Err(anyhow!("Model not trained"));
497 }
498
499 let subject_id = self
500 .base
501 .get_entity_id(subject)
502 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
503 let predicate_id = self
504 .base
505 .get_relation_id(predicate)
506 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
507
508 let mut scores = Vec::new();
509
510 for object_id in 0..self.base.num_entities() {
511 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
512 let object_name = self.base.get_entity(object_id).unwrap().clone();
513 scores.push((object_name, score));
514 }
515
516 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
517 scores.truncate(k);
518
519 Ok(scores)
520 }
521
522 fn predict_subjects(
523 &self,
524 predicate: &str,
525 object: &str,
526 k: usize,
527 ) -> Result<Vec<(String, f64)>> {
528 if !self.embeddings_initialized {
529 return Err(anyhow!("Model not trained"));
530 }
531
532 let predicate_id = self
533 .base
534 .get_relation_id(predicate)
535 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
536 let object_id = self
537 .base
538 .get_entity_id(object)
539 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
540
541 let mut scores = Vec::new();
542
543 for subject_id in 0..self.base.num_entities() {
544 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
545 let subject_name = self.base.get_entity(subject_id).unwrap().clone();
546 scores.push((subject_name, score));
547 }
548
549 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
550 scores.truncate(k);
551
552 Ok(scores)
553 }
554
555 fn predict_relations(
556 &self,
557 subject: &str,
558 object: &str,
559 k: usize,
560 ) -> Result<Vec<(String, f64)>> {
561 if !self.embeddings_initialized {
562 return Err(anyhow!("Model not trained"));
563 }
564
565 let subject_id = self
566 .base
567 .get_entity_id(subject)
568 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
569 let object_id = self
570 .base
571 .get_entity_id(object)
572 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
573
574 let mut scores = Vec::new();
575
576 for predicate_id in 0..self.base.num_relations() {
577 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
578 let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
579 scores.push((predicate_name, score));
580 }
581
582 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
583 scores.truncate(k);
584
585 Ok(scores)
586 }
587
588 fn get_entities(&self) -> Vec<String> {
589 self.base.get_entities()
590 }
591
592 fn get_relations(&self) -> Vec<String> {
593 self.base.get_relations()
594 }
595
596 fn get_stats(&self) -> ModelStats {
597 self.base.get_stats("TuckER")
598 }
599
600 fn save(&self, path: &str) -> Result<()> {
601 info!("Saving TuckER model to {}", path);
602 Ok(())
603 }
604
605 fn load(&mut self, path: &str) -> Result<()> {
606 info!("Loading TuckER model from {}", path);
607 Ok(())
608 }
609
610 fn clear(&mut self) {
611 self.base.clear();
612 self.entity_embeddings = Array2::zeros((0, self.entity_dim));
613 self.relation_embeddings = Array2::zeros((0, self.relation_dim));
614 self.core_tensor = Array3::zeros(self.core_dims);
615 self.embeddings_initialized = false;
616 }
617
618 fn is_trained(&self) -> bool {
619 self.base.is_trained
620 }
621
622 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
623 Err(anyhow!(
624 "Knowledge graph embedding model does not support text encoding"
625 ))
626 }
627}
628
629fn apply_dropout<R: Rng>(embeddings: &mut Array2<f64>, dropout_rate: f64, rng: &mut Random<R>) {
631 for elem in embeddings.iter_mut() {
632 if rng.random::<f64>() < dropout_rate {
633 *elem = 0.0;
634 } else {
635 *elem /= 1.0 - dropout_rate;
636 }
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use crate::NamedNode;
644
645 #[tokio::test]
646 #[cfg_attr(debug_assertions, ignore = "Training tests require release builds")]
647 async fn test_tucker_basic() -> Result<()> {
648 let mut config = ModelConfig::default()
649 .with_dimensions(50)
650 .with_max_epochs(10)
651 .with_seed(42);
652
653 config.model_params.insert("entity_dim".to_string(), 50.0);
655 config.model_params.insert("relation_dim".to_string(), 50.0);
656 config.model_params.insert("core_dim1".to_string(), 50.0);
657 config.model_params.insert("core_dim2".to_string(), 50.0);
658 config.model_params.insert("core_dim3".to_string(), 50.0);
659 config.model_params.insert("dropout_rate".to_string(), 0.1);
660
661 let mut model = TuckER::new(config);
662
663 let alice = NamedNode::new("http://example.org/alice")?;
665 let knows = NamedNode::new("http://example.org/knows")?;
666 let bob = NamedNode::new("http://example.org/bob")?;
667
668 model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
669 model.add_triple(Triple::new(bob.clone(), knows.clone(), alice.clone()))?;
670
671 let stats = model.train(Some(5)).await?;
673 assert!(stats.epochs_completed > 0);
674
675 let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
677 assert_eq!(alice_emb.dimensions, 50);
678
679 let score = model.score_triple(
681 "http://example.org/alice",
682 "http://example.org/knows",
683 "http://example.org/bob",
684 )?;
685
686 assert!(score.is_finite());
688
689 Ok(())
690 }
691
692 #[test]
693 fn test_tucker_creation() {
694 let config = ModelConfig::default();
695 let tucker = TuckER::new(config);
696 assert!(!tucker.embeddings_initialized);
697 assert_eq!(tucker.model_type(), "TuckER");
698 }
699}