1use crate::models::{common::*, BaseModel};
10use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
11use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use scirs2_core::ndarray_ext::Array2;
14#[allow(unused_imports)]
15use scirs2_core::random::{Random, Rng};
16use serde::{Deserialize, Serialize};
17use std::ops::AddAssign;
18use std::time::Instant;
19use tracing::{debug, info};
20use uuid::Uuid;
21
22type GradientTuple = (Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>);
24
25#[derive(Debug)]
27pub struct ComplEx {
28 base: BaseModel,
30 entity_embeddings_real: Array2<f64>,
32 entity_embeddings_imag: Array2<f64>,
34 relation_embeddings_real: Array2<f64>,
36 relation_embeddings_imag: Array2<f64>,
38 embeddings_initialized: bool,
40 regularization: RegularizationType,
42}
43
44#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
46pub enum RegularizationType {
47 L2,
49 N3,
51 None,
53}
54
55impl ComplEx {
56 pub fn new(config: ModelConfig) -> Self {
58 let base = BaseModel::new(config.clone());
59
60 let regularization = match config.model_params.get("regularization") {
62 Some(0.0) => RegularizationType::None,
63 Some(1.0) => RegularizationType::L2,
64 Some(2.0) => RegularizationType::N3,
65 _ => RegularizationType::N3, };
67
68 Self {
69 base,
70 entity_embeddings_real: Array2::zeros((0, config.dimensions)),
71 entity_embeddings_imag: Array2::zeros((0, config.dimensions)),
72 relation_embeddings_real: Array2::zeros((0, config.dimensions)),
73 relation_embeddings_imag: Array2::zeros((0, config.dimensions)),
74 embeddings_initialized: false,
75 regularization,
76 }
77 }
78
79 fn initialize_embeddings(&mut self) {
81 if self.embeddings_initialized {
82 return;
83 }
84
85 let num_entities = self.base.num_entities();
86 let num_relations = self.base.num_relations();
87 let dimensions = self.base.config.dimensions;
88
89 if num_entities == 0 || num_relations == 0 {
90 return;
91 }
92
93 let mut rng = Random::default();
94
95 self.entity_embeddings_real =
97 xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
98
99 self.entity_embeddings_imag =
100 xavier_init((num_entities, dimensions), dimensions, dimensions, &mut rng);
101
102 self.relation_embeddings_real = xavier_init(
103 (num_relations, dimensions),
104 dimensions,
105 dimensions,
106 &mut rng,
107 );
108
109 self.relation_embeddings_imag = xavier_init(
110 (num_relations, dimensions),
111 dimensions,
112 dimensions,
113 &mut rng,
114 );
115
116 self.embeddings_initialized = true;
117 debug!(
118 "Initialized ComplEx embeddings: {} entities, {} relations, {} dimensions",
119 num_entities, num_relations, dimensions
120 );
121 }
122
123 fn score_triple_ids(
127 &self,
128 subject_id: usize,
129 predicate_id: usize,
130 object_id: usize,
131 ) -> Result<f64> {
132 if !self.embeddings_initialized {
133 return Err(anyhow!("Model not trained"));
134 }
135
136 let h_real = self.entity_embeddings_real.row(subject_id);
137 let h_imag = self.entity_embeddings_imag.row(subject_id);
138 let r_real = self.relation_embeddings_real.row(predicate_id);
139 let r_imag = self.relation_embeddings_imag.row(predicate_id);
140 let t_real = self.entity_embeddings_real.row(object_id);
141 let t_imag = self.entity_embeddings_imag.row(object_id);
142
143 let score = (&h_real * &r_real * t_real).sum()
146 + (&h_real * &r_imag * t_imag).sum()
147 + (&h_imag * &r_real * t_imag).sum()
148 - (&h_imag * &r_imag * t_real).sum();
149
150 Ok(score)
151 }
152
153 fn compute_gradients(
155 &self,
156 pos_triple: (usize, usize, usize),
157 neg_triple: (usize, usize, usize),
158 pos_score: f64,
159 neg_score: f64,
160 ) -> Result<GradientTuple> {
161 let mut entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
162 let mut entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
163 let mut relation_grads_real = Array2::zeros(self.relation_embeddings_real.raw_dim());
164 let mut relation_grads_imag = Array2::zeros(self.relation_embeddings_imag.raw_dim());
165
166 let pos_sigmoid = sigmoid(pos_score);
168 let neg_sigmoid = sigmoid(neg_score);
169
170 let pos_grad_coeff = pos_sigmoid - 1.0; let neg_grad_coeff = neg_sigmoid; self.add_triple_gradients(
175 pos_triple,
176 pos_grad_coeff,
177 &mut entity_grads_real,
178 &mut entity_grads_imag,
179 &mut relation_grads_real,
180 &mut relation_grads_imag,
181 );
182
183 self.add_triple_gradients(
185 neg_triple,
186 neg_grad_coeff,
187 &mut entity_grads_real,
188 &mut entity_grads_imag,
189 &mut relation_grads_real,
190 &mut relation_grads_imag,
191 );
192
193 Ok((
194 entity_grads_real,
195 entity_grads_imag,
196 relation_grads_real,
197 relation_grads_imag,
198 ))
199 }
200
201 fn add_triple_gradients(
203 &self,
204 triple: (usize, usize, usize),
205 grad_coeff: f64,
206 entity_grads_real: &mut Array2<f64>,
207 entity_grads_imag: &mut Array2<f64>,
208 relation_grads_real: &mut Array2<f64>,
209 relation_grads_imag: &mut Array2<f64>,
210 ) {
211 let (s, p, o) = triple;
212
213 let h_real = self.entity_embeddings_real.row(s);
214 let h_imag = self.entity_embeddings_imag.row(s);
215 let r_real = self.relation_embeddings_real.row(p);
216 let r_imag = self.relation_embeddings_imag.row(p);
217 let t_real = self.entity_embeddings_real.row(o);
218 let t_imag = self.entity_embeddings_imag.row(o);
219
220 let h_real_grad = (&r_real * &t_real + &r_imag * &t_imag) * grad_coeff;
224 let h_imag_grad = (&r_real * &t_imag - &r_imag * &t_real) * grad_coeff;
225
226 entity_grads_real.row_mut(s).add_assign(&h_real_grad);
227 entity_grads_imag.row_mut(s).add_assign(&h_imag_grad);
228
229 let r_real_grad = (&h_real * &t_real + &h_imag * &t_imag) * grad_coeff;
233 let r_imag_grad = (&h_real * &t_imag - &h_imag * &t_real) * grad_coeff;
234
235 relation_grads_real.row_mut(p).add_assign(&r_real_grad);
236 relation_grads_imag.row_mut(p).add_assign(&r_imag_grad);
237
238 let t_real_grad = (&h_real * &r_real - &h_imag * &r_imag) * grad_coeff;
242 let t_imag_grad = -(&h_real * &r_imag + &h_imag * &r_real) * grad_coeff;
243
244 entity_grads_real.row_mut(o).add_assign(&t_real_grad);
245 entity_grads_imag.row_mut(o).add_assign(&t_imag_grad);
246 }
247
248 fn apply_n3_regularization(
250 &self,
251 entity_grads_real: &mut Array2<f64>,
252 entity_grads_imag: &mut Array2<f64>,
253 relation_grads_real: &mut Array2<f64>,
254 relation_grads_imag: &mut Array2<f64>,
255 regularization_weight: f64,
256 ) {
257 *entity_grads_real += &(&self.entity_embeddings_real * regularization_weight);
263 *entity_grads_imag += &(&self.entity_embeddings_imag * regularization_weight);
264 *relation_grads_real += &(&self.relation_embeddings_real * regularization_weight);
265 *relation_grads_imag += &(&self.relation_embeddings_imag * regularization_weight);
266 }
267
268 async fn train_epoch(&mut self, learning_rate: f64) -> Result<f64> {
270 let mut rng = Random::default();
271
272 let mut total_loss = 0.0;
273 let num_batches = (self.base.triples.len() + self.base.config.batch_size - 1)
274 / self.base.config.batch_size;
275
276 let mut shuffled_triples = self.base.triples.clone();
278 for i in (1..shuffled_triples.len()).rev() {
280 let j = rng.random_range(0..i + 1);
281 shuffled_triples.swap(i, j);
282 }
283
284 for batch_triples in shuffled_triples.chunks(self.base.config.batch_size) {
285 let mut batch_entity_grads_real = Array2::zeros(self.entity_embeddings_real.raw_dim());
286 let mut batch_entity_grads_imag = Array2::zeros(self.entity_embeddings_imag.raw_dim());
287 let mut batch_relation_grads_real =
288 Array2::zeros(self.relation_embeddings_real.raw_dim());
289 let mut batch_relation_grads_imag =
290 Array2::zeros(self.relation_embeddings_imag.raw_dim());
291 let mut batch_loss = 0.0;
292
293 for &pos_triple in batch_triples {
294 let neg_samples = self
296 .base
297 .generate_negative_samples(self.base.config.negative_samples, &mut rng);
298
299 for neg_triple in neg_samples {
300 let pos_score =
302 self.score_triple_ids(pos_triple.0, pos_triple.1, pos_triple.2)?;
303 let neg_score =
304 self.score_triple_ids(neg_triple.0, neg_triple.1, neg_triple.2)?;
305
306 let pos_loss = logistic_loss(pos_score, 1.0);
308 let neg_loss = logistic_loss(neg_score, -1.0);
309 let total_triple_loss = pos_loss + neg_loss;
310
311 batch_loss += total_triple_loss;
312
313 let (
315 entity_grads_real,
316 entity_grads_imag,
317 relation_grads_real,
318 relation_grads_imag,
319 ) = self.compute_gradients(pos_triple, neg_triple, pos_score, neg_score)?;
320
321 batch_entity_grads_real += &entity_grads_real;
322 batch_entity_grads_imag += &entity_grads_imag;
323 batch_relation_grads_real += &relation_grads_real;
324 batch_relation_grads_imag += &relation_grads_imag;
325 }
326 }
327
328 match self.regularization {
330 RegularizationType::L2 => {
331 let reg_weight = self.base.config.l2_reg;
332 batch_entity_grads_real += &(&self.entity_embeddings_real * reg_weight);
333 batch_entity_grads_imag += &(&self.entity_embeddings_imag * reg_weight);
334 batch_relation_grads_real += &(&self.relation_embeddings_real * reg_weight);
335 batch_relation_grads_imag += &(&self.relation_embeddings_imag * reg_weight);
336 }
337 RegularizationType::N3 => {
338 self.apply_n3_regularization(
339 &mut batch_entity_grads_real,
340 &mut batch_entity_grads_imag,
341 &mut batch_relation_grads_real,
342 &mut batch_relation_grads_imag,
343 self.base.config.l2_reg,
344 );
345 }
346 RegularizationType::None => {}
347 }
348
349 self.entity_embeddings_real -= &(&batch_entity_grads_real * learning_rate);
351 self.entity_embeddings_imag -= &(&batch_entity_grads_imag * learning_rate);
352 self.relation_embeddings_real -= &(&batch_relation_grads_real * learning_rate);
353 self.relation_embeddings_imag -= &(&batch_relation_grads_imag * learning_rate);
354
355 total_loss += batch_loss;
356 }
357
358 Ok(total_loss / num_batches as f64)
359 }
360
361 fn get_entity_embedding_vector(&self, entity_id: usize) -> Vector {
363 let real_part = self.entity_embeddings_real.row(entity_id);
364 let imag_part = self.entity_embeddings_imag.row(entity_id);
365
366 let mut values = Vec::with_capacity(real_part.len() * 2);
368 for &val in real_part.iter() {
369 values.push(val as f32);
370 }
371 for &val in imag_part.iter() {
372 values.push(val as f32);
373 }
374
375 Vector::new(values)
376 }
377
378 fn get_relation_embedding_vector(&self, relation_id: usize) -> Vector {
380 let real_part = self.relation_embeddings_real.row(relation_id);
381 let imag_part = self.relation_embeddings_imag.row(relation_id);
382
383 let mut values = Vec::with_capacity(real_part.len() * 2);
385 for &val in real_part.iter() {
386 values.push(val as f32);
387 }
388 for &val in imag_part.iter() {
389 values.push(val as f32);
390 }
391
392 Vector::new(values)
393 }
394}
395
396#[async_trait]
397impl EmbeddingModel for ComplEx {
398 fn config(&self) -> &ModelConfig {
399 &self.base.config
400 }
401
402 fn model_id(&self) -> &Uuid {
403 &self.base.model_id
404 }
405
406 fn model_type(&self) -> &'static str {
407 "ComplEx"
408 }
409
410 fn add_triple(&mut self, triple: Triple) -> Result<()> {
411 self.base.add_triple(triple)
412 }
413
414 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
415 let start_time = Instant::now();
416 let max_epochs = epochs.unwrap_or(self.base.config.max_epochs);
417
418 self.initialize_embeddings();
420
421 if !self.embeddings_initialized {
422 return Err(anyhow!("No training data available"));
423 }
424
425 let mut loss_history = Vec::new();
426 let learning_rate = self.base.config.learning_rate;
427
428 info!("Starting ComplEx training for {} epochs", max_epochs);
429
430 for epoch in 0..max_epochs {
431 let epoch_loss = self.train_epoch(learning_rate).await?;
432 loss_history.push(epoch_loss);
433
434 if epoch % 100 == 0 {
435 debug!("Epoch {}: loss = {:.6}", epoch, epoch_loss);
436 }
437
438 if epoch > 10 && epoch_loss < 1e-6 {
440 info!("Converged at epoch {} with loss {:.6}", epoch, epoch_loss);
441 break;
442 }
443 }
444
445 self.base.mark_trained();
446 let training_time = start_time.elapsed().as_secs_f64();
447
448 Ok(TrainingStats {
449 epochs_completed: loss_history.len(),
450 final_loss: loss_history.last().copied().unwrap_or(0.0),
451 training_time_seconds: training_time,
452 convergence_achieved: loss_history.last().copied().unwrap_or(f64::INFINITY) < 1e-6,
453 loss_history,
454 })
455 }
456
457 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
458 if !self.embeddings_initialized {
459 return Err(anyhow!("Model not trained"));
460 }
461
462 let entity_id = self
463 .base
464 .get_entity_id(entity)
465 .ok_or_else(|| anyhow!("Entity not found: {}", entity))?;
466
467 Ok(self.get_entity_embedding_vector(entity_id))
468 }
469
470 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
471 if !self.embeddings_initialized {
472 return Err(anyhow!("Model not trained"));
473 }
474
475 let relation_id = self
476 .base
477 .get_relation_id(relation)
478 .ok_or_else(|| anyhow!("Relation not found: {}", relation))?;
479
480 Ok(self.get_relation_embedding_vector(relation_id))
481 }
482
483 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
484 let subject_id = self
485 .base
486 .get_entity_id(subject)
487 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
488 let predicate_id = self
489 .base
490 .get_relation_id(predicate)
491 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
492 let object_id = self
493 .base
494 .get_entity_id(object)
495 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
496
497 self.score_triple_ids(subject_id, predicate_id, object_id)
498 }
499
500 fn predict_objects(
501 &self,
502 subject: &str,
503 predicate: &str,
504 k: usize,
505 ) -> Result<Vec<(String, f64)>> {
506 if !self.embeddings_initialized {
507 return Err(anyhow!("Model not trained"));
508 }
509
510 let subject_id = self
511 .base
512 .get_entity_id(subject)
513 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
514 let predicate_id = self
515 .base
516 .get_relation_id(predicate)
517 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
518
519 let mut scores = Vec::new();
520
521 for object_id in 0..self.base.num_entities() {
522 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
523 let object_name = self.base.get_entity(object_id).unwrap().clone();
524 scores.push((object_name, score));
525 }
526
527 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
528 scores.truncate(k);
529
530 Ok(scores)
531 }
532
533 fn predict_subjects(
534 &self,
535 predicate: &str,
536 object: &str,
537 k: usize,
538 ) -> Result<Vec<(String, f64)>> {
539 if !self.embeddings_initialized {
540 return Err(anyhow!("Model not trained"));
541 }
542
543 let predicate_id = self
544 .base
545 .get_relation_id(predicate)
546 .ok_or_else(|| anyhow!("Predicate not found: {}", predicate))?;
547 let object_id = self
548 .base
549 .get_entity_id(object)
550 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
551
552 let mut scores = Vec::new();
553
554 for subject_id in 0..self.base.num_entities() {
555 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
556 let subject_name = self.base.get_entity(subject_id).unwrap().clone();
557 scores.push((subject_name, score));
558 }
559
560 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
561 scores.truncate(k);
562
563 Ok(scores)
564 }
565
566 fn predict_relations(
567 &self,
568 subject: &str,
569 object: &str,
570 k: usize,
571 ) -> Result<Vec<(String, f64)>> {
572 if !self.embeddings_initialized {
573 return Err(anyhow!("Model not trained"));
574 }
575
576 let subject_id = self
577 .base
578 .get_entity_id(subject)
579 .ok_or_else(|| anyhow!("Subject not found: {}", subject))?;
580 let object_id = self
581 .base
582 .get_entity_id(object)
583 .ok_or_else(|| anyhow!("Object not found: {}", object))?;
584
585 let mut scores = Vec::new();
586
587 for predicate_id in 0..self.base.num_relations() {
588 let score = self.score_triple_ids(subject_id, predicate_id, object_id)?;
589 let predicate_name = self.base.get_relation(predicate_id).unwrap().clone();
590 scores.push((predicate_name, score));
591 }
592
593 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
594 scores.truncate(k);
595
596 Ok(scores)
597 }
598
599 fn get_entities(&self) -> Vec<String> {
600 self.base.get_entities()
601 }
602
603 fn get_relations(&self) -> Vec<String> {
604 self.base.get_relations()
605 }
606
607 fn get_stats(&self) -> ModelStats {
608 self.base.get_stats("ComplEx")
609 }
610
611 fn save(&self, path: &str) -> Result<()> {
612 info!("Saving ComplEx model to {}", path);
613 Ok(())
614 }
615
616 fn load(&mut self, path: &str) -> Result<()> {
617 info!("Loading ComplEx model from {}", path);
618 Ok(())
619 }
620
621 fn clear(&mut self) {
622 self.base.clear();
623 self.entity_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
624 self.entity_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
625 self.relation_embeddings_real = Array2::zeros((0, self.base.config.dimensions));
626 self.relation_embeddings_imag = Array2::zeros((0, self.base.config.dimensions));
627 self.embeddings_initialized = false;
628 }
629
630 fn is_trained(&self) -> bool {
631 self.base.is_trained
632 }
633
634 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
635 Err(anyhow!(
636 "Knowledge graph embedding model does not support text encoding"
637 ))
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use crate::NamedNode;
645
646 #[tokio::test]
647 async fn test_complex_basic() -> Result<()> {
648 let config = ModelConfig::default()
649 .with_dimensions(50)
650 .with_max_epochs(10)
651 .with_seed(42);
652
653 let mut model = ComplEx::new(config);
654
655 let alice = NamedNode::new("http://example.org/alice")?;
657 let knows = NamedNode::new("http://example.org/knows")?;
658 let bob = NamedNode::new("http://example.org/bob")?;
659
660 model.add_triple(Triple::new(alice.clone(), knows.clone(), bob.clone()))?;
661
662 let stats = model.train(Some(5)).await?;
664 assert!(stats.epochs_completed > 0);
665
666 let alice_emb = model.get_entity_embedding("http://example.org/alice")?;
668 assert_eq!(alice_emb.dimensions, 100); let score = model.score_triple(
672 "http://example.org/alice",
673 "http://example.org/knows",
674 "http://example.org/bob",
675 )?;
676
677 assert!(score.is_finite());
679
680 Ok(())
681 }
682}