1use anyhow::{anyhow, Result};
13use rayon::prelude::*;
14use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufReader, BufWriter};
20use std::path::Path;
21use tracing::{debug, info};
22
23#[cfg(test)]
24use crate::NamedNode;
25use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
26use uuid::Uuid;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ConvEConfig {
31 pub base: ModelConfig,
33 pub reshape_width: usize,
35 pub num_filters: usize,
37 pub kernel_size: usize,
39 pub dropout_rate: f32,
41 pub regularization: f32,
43 pub margin: f32,
45 pub num_negatives: usize,
47 pub use_batch_norm: bool,
49}
50
51impl Default for ConvEConfig {
52 fn default() -> Self {
53 Self {
54 base: ModelConfig::default().with_dimensions(200),
55 reshape_width: 20, num_filters: 32,
57 kernel_size: 3,
58 dropout_rate: 0.3,
59 regularization: 0.0001,
60 margin: 1.0,
61 num_negatives: 10,
62 use_batch_norm: true,
63 }
64 }
65}
66
67#[derive(Debug, Serialize, Deserialize)]
69struct ConvLayerSerializable {
70 filters: Vec<Vec<Vec<f32>>>, biases: Vec<f32>,
72}
73
74struct ConvLayer {
76 filters: Vec<Array2<f32>>,
78 biases: Array1<f32>,
80}
81
82impl ConvLayer {
83 fn new(num_filters: usize, kernel_size: usize, rng: &mut Random) -> Self {
84 let scale = (2.0 / (kernel_size * kernel_size) as f32).sqrt();
85 let mut filters = Vec::new();
86
87 for _ in 0..num_filters {
88 let filter = Array2::from_shape_fn((kernel_size, kernel_size), |_| {
89 rng.random_range(-scale..scale)
90 });
91 filters.push(filter);
92 }
93
94 let biases = Array1::zeros(num_filters);
95
96 Self { filters, biases }
97 }
98
99 fn forward(&self, input: &Array2<f32>) -> Array3<f32> {
101 let kernel_size = self.filters[0].nrows();
102 let input_height = input.nrows();
103 let input_width = input.ncols();
104
105 let out_height = input_height.saturating_sub(kernel_size - 1);
106 let out_width = input_width.saturating_sub(kernel_size - 1);
107
108 if out_height == 0 || out_width == 0 {
109 return Array3::zeros((self.filters.len(), 1, 1));
111 }
112
113 let mut output = Array3::zeros((self.filters.len(), out_height, out_width));
114
115 for (f_idx, filter) in self.filters.iter().enumerate() {
116 for i in 0..out_height {
117 for j in 0..out_width {
118 let mut sum = 0.0;
119
120 for ki in 0..kernel_size {
121 for kj in 0..kernel_size {
122 sum += input[[i + ki, j + kj]] * filter[[ki, kj]];
123 }
124 }
125
126 output[[f_idx, i, j]] = sum + self.biases[f_idx];
127 }
128 }
129 }
130
131 output
132 }
133}
134
135#[derive(Debug, Serialize, Deserialize)]
137struct FCLayerSerializable {
138 weights: Vec<Vec<f32>>, bias: Vec<f32>,
140}
141
142struct FCLayer {
144 weights: Array2<f32>,
145 bias: Array1<f32>,
146}
147
148impl FCLayer {
149 fn new(input_size: usize, output_size: usize, rng: &mut Random) -> Self {
150 let scale = (2.0 / input_size as f32).sqrt();
151 let weights = Array2::from_shape_fn((input_size, output_size), |_| {
152 rng.random_range(-scale..scale)
153 });
154 let bias = Array1::zeros(output_size);
155
156 Self { weights, bias }
157 }
158
159 fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
160 let mut output = self.bias.clone();
161 for i in 0..output.len() {
162 for j in 0..input.len() {
163 output[i] += input[j] * self.weights[[j, i]];
164 }
165 }
166 output
167 }
168}
169
170#[derive(Debug, Serialize, Deserialize)]
172struct ConvESerializable {
173 model_id: Uuid,
174 config: ConvEConfig,
175 entity_embeddings: HashMap<String, Vec<f32>>,
176 relation_embeddings: HashMap<String, Vec<f32>>,
177 conv_layer: ConvLayerSerializable,
178 fc_layer: FCLayerSerializable,
179 triples: Vec<Triple>,
180 entity_to_id: HashMap<String, usize>,
181 relation_to_id: HashMap<String, usize>,
182 id_to_entity: HashMap<usize, String>,
183 id_to_relation: HashMap<usize, String>,
184 is_trained: bool,
185}
186
187pub struct ConvE {
189 model_id: Uuid,
190 config: ConvEConfig,
191 entity_embeddings: HashMap<String, Array1<f32>>,
192 relation_embeddings: HashMap<String, Array1<f32>>,
193 conv_layer: ConvLayer,
194 fc_layer: FCLayer,
195 triples: Vec<Triple>,
196 entity_to_id: HashMap<String, usize>,
197 relation_to_id: HashMap<String, usize>,
198 id_to_entity: HashMap<usize, String>,
199 id_to_relation: HashMap<usize, String>,
200 is_trained: bool,
201}
202
203impl ConvE {
204 pub fn new(config: ConvEConfig) -> Self {
206 let mut rng = Random::default();
207
208 let reshape_height = config.base.dimensions / config.reshape_width;
210 let conv_out_height = reshape_height.saturating_sub(config.kernel_size - 1);
211 let conv_out_width = (config.reshape_width * 2).saturating_sub(config.kernel_size - 1);
212 let fc_input_size = config.num_filters * conv_out_height * conv_out_width;
213
214 let conv_layer = ConvLayer::new(config.num_filters, config.kernel_size, &mut rng);
215 let fc_layer = FCLayer::new(fc_input_size, config.base.dimensions, &mut rng);
216
217 info!(
218 "Initialized ConvE model: dim={}, filters={}, kernel={}, fc_input={}",
219 config.base.dimensions, config.num_filters, config.kernel_size, fc_input_size
220 );
221
222 Self {
223 model_id: Uuid::new_v4(),
224 config,
225 entity_embeddings: HashMap::new(),
226 relation_embeddings: HashMap::new(),
227 conv_layer,
228 fc_layer,
229 triples: Vec::new(),
230 entity_to_id: HashMap::new(),
231 relation_to_id: HashMap::new(),
232 id_to_entity: HashMap::new(),
233 id_to_relation: HashMap::new(),
234 is_trained: false,
235 }
236 }
237
238 fn reshape_embedding(&self, embedding: &Array1<f32>) -> Array2<f32> {
240 let height = self.config.base.dimensions / self.config.reshape_width;
241 let width = self.config.reshape_width;
242
243 Array2::from_shape_fn((height, width), |(i, j)| embedding[i * width + j])
244 }
245
246 fn relu(&self, x: f32) -> f32 {
248 x.max(0.0)
249 }
250
251 fn dropout(&mut self, values: &mut Array1<f32>, training: bool) {
253 if !training || self.config.dropout_rate == 0.0 {
254 return;
255 }
256
257 let mut local_rng = Random::default();
258 let keep_prob = 1.0 - self.config.dropout_rate;
259 for val in values.iter_mut() {
260 if local_rng.random_range(0.0..1.0) > keep_prob {
261 *val = 0.0;
262 } else {
263 *val /= keep_prob; }
265 }
266 }
267
268 fn forward(
270 &mut self,
271 head: &Array1<f32>,
272 relation: &Array1<f32>,
273 training: bool,
274 ) -> Array1<f32> {
275 let head_2d = self.reshape_embedding(head);
277 let rel_2d = self.reshape_embedding(relation);
278
279 let height = head_2d.nrows();
281 let width = head_2d.ncols() * 2;
282 let mut concat = Array2::zeros((height, width));
283
284 for i in 0..height {
285 for j in 0..head_2d.ncols() {
286 concat[[i, j]] = head_2d[[i, j]];
287 }
288 for j in 0..rel_2d.ncols() {
289 concat[[i, head_2d.ncols() + j]] = rel_2d[[i, j]];
290 }
291 }
292
293 let conv_out = self.conv_layer.forward(&concat);
295
296 let conv_out_relu = conv_out.mapv(|x| self.relu(x));
298
299 let flattened_size = conv_out_relu.len();
301 let mut flattened = Array1::zeros(flattened_size);
302 for (idx, &val) in conv_out_relu.iter().enumerate() {
303 flattened[idx] = val;
304 }
305
306 self.dropout(&mut flattened, training);
308
309 let mut output = self.fc_layer.forward(&flattened);
311
312 self.dropout(&mut output, training);
314
315 output
316 }
317
318 fn score_triple_internal(
320 &mut self,
321 head: &Array1<f32>,
322 relation: &Array1<f32>,
323 tail: &Array1<f32>,
324 ) -> f32 {
325 let projected = self.forward(head, relation, false);
326 projected.dot(tail)
328 }
329
330 fn init_entity(&mut self, entity: &str) {
332 if !self.entity_embeddings.contains_key(entity) {
333 let id = self.entity_embeddings.len();
334 self.entity_to_id.insert(entity.to_string(), id);
335 self.id_to_entity.insert(id, entity.to_string());
336
337 let mut local_rng = Random::default();
338 let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
339 let embedding = Array1::from_vec(
340 (0..self.config.base.dimensions)
341 .map(|_| local_rng.random_range(-scale..scale))
342 .collect(),
343 );
344 self.entity_embeddings.insert(entity.to_string(), embedding);
345 }
346 }
347
348 fn init_relation(&mut self, relation: &str) {
350 if !self.relation_embeddings.contains_key(relation) {
351 let id = self.relation_embeddings.len();
352 self.relation_to_id.insert(relation.to_string(), id);
353 self.id_to_relation.insert(id, relation.to_string());
354
355 let mut local_rng = Random::default();
356 let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
357 let embedding = Array1::from_vec(
358 (0..self.config.base.dimensions)
359 .map(|_| local_rng.random_range(-scale..scale))
360 .collect(),
361 );
362 self.relation_embeddings
363 .insert(relation.to_string(), embedding);
364 }
365 }
366
367 fn train_step(&mut self) -> f32 {
369 let mut total_loss = 0.0;
370 let mut local_rng = Random::default();
371
372 let mut indices: Vec<usize> = (0..self.triples.len()).collect();
374 for i in (1..indices.len()).rev() {
375 let j = local_rng.random_range(0..i + 1);
376 indices.swap(i, j);
377 }
378
379 for &idx in &indices {
380 let triple = &self.triples[idx].clone();
381
382 let subject_str = &triple.subject.iri;
383 let predicate_str = &triple.predicate.iri;
384 let object_str = &triple.object.iri;
385
386 let head_emb = self.entity_embeddings[subject_str].clone();
387 let rel_emb = self.relation_embeddings[predicate_str].clone();
388 let tail_emb = self.entity_embeddings[object_str].clone();
389
390 let pos_score = self.score_triple_internal(&head_emb, &rel_emb, &tail_emb);
392
393 let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
395 for _ in 0..self.config.num_negatives {
396 let neg_tail_id = entity_list[local_rng.random_range(0..entity_list.len())].clone();
397 let neg_tail_emb = self.entity_embeddings[&neg_tail_id].clone();
398
399 let neg_score = self.score_triple_internal(&head_emb, &rel_emb, &neg_tail_emb);
400
401 let loss = (self.config.margin + neg_score - pos_score).max(0.0);
403 total_loss += loss;
404
405 if loss > 0.0 {
407 let lr = self.config.base.learning_rate as f32;
408 for emb in self.entity_embeddings.values_mut() {
410 *emb = &*emb * (1.0 - self.config.regularization * lr);
411 }
412 for emb in self.relation_embeddings.values_mut() {
413 *emb = &*emb * (1.0 - self.config.regularization * lr);
414 }
415 }
416 }
417 }
418
419 total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
420 }
421}
422
423#[async_trait::async_trait]
424impl EmbeddingModel for ConvE {
425 fn config(&self) -> &ModelConfig {
426 &self.config.base
427 }
428
429 fn model_id(&self) -> &Uuid {
430 &self.model_id
431 }
432
433 fn model_type(&self) -> &'static str {
434 "ConvE"
435 }
436
437 fn add_triple(&mut self, triple: Triple) -> Result<()> {
438 self.init_entity(&triple.subject.iri);
439 self.init_entity(&triple.object.iri);
440 self.init_relation(&triple.predicate.iri);
441 self.triples.push(triple);
442 Ok(())
443 }
444
445 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
446 let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
447
448 if self.triples.is_empty() {
449 return Err(anyhow!("No training data available"));
450 }
451
452 info!(
453 "Training ConvE model for {} epochs on {} triples",
454 num_epochs,
455 self.triples.len()
456 );
457
458 let start_time = std::time::Instant::now();
459 let mut loss_history = Vec::new();
460
461 for epoch in 0..num_epochs {
462 let loss = self.train_step();
463 loss_history.push(loss as f64);
464
465 if epoch % 10 == 0 {
466 debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
467 }
468
469 if loss < 0.001 {
470 info!("Converged at epoch {}", epoch);
471 break;
472 }
473 }
474
475 let training_time = start_time.elapsed().as_secs_f64();
476 self.is_trained = true;
477
478 Ok(TrainingStats {
479 epochs_completed: num_epochs,
480 final_loss: *loss_history.last().unwrap_or(&0.0),
481 training_time_seconds: training_time,
482 convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
483 loss_history,
484 })
485 }
486
487 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
488 self.entity_embeddings
489 .get(entity)
490 .map(Vector::from_array1)
491 .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
492 }
493
494 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
495 self.relation_embeddings
496 .get(relation)
497 .map(Vector::from_array1)
498 .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
499 }
500
501 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
502 let head_emb = self
503 .entity_embeddings
504 .get(subject)
505 .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
506 let rel_emb = self
507 .relation_embeddings
508 .get(predicate)
509 .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
510 let tail_emb = self
511 .entity_embeddings
512 .get(object)
513 .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
514
515 let score = (head_emb + rel_emb).dot(tail_emb);
518 Ok(score as f64)
519 }
520
521 fn predict_objects(
522 &self,
523 subject: &str,
524 predicate: &str,
525 k: usize,
526 ) -> Result<Vec<(String, f64)>> {
527 let head_emb = self
528 .entity_embeddings
529 .get(subject)
530 .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
531 let rel_emb = self
532 .relation_embeddings
533 .get(predicate)
534 .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
535
536 let combined = head_emb + rel_emb;
537 let mut scored_objects: Vec<(String, f64)> = self
538 .entity_embeddings
539 .par_iter()
540 .map(|(entity, tail_emb)| {
541 let score = combined.dot(tail_emb);
542 (entity.clone(), score as f64)
543 })
544 .collect();
545
546 scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
547 scored_objects.truncate(k);
548 Ok(scored_objects)
549 }
550
551 fn predict_subjects(
552 &self,
553 predicate: &str,
554 object: &str,
555 k: usize,
556 ) -> Result<Vec<(String, f64)>> {
557 let rel_emb = self
558 .relation_embeddings
559 .get(predicate)
560 .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
561 let tail_emb = self
562 .entity_embeddings
563 .get(object)
564 .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
565
566 let mut scored_subjects: Vec<(String, f64)> = self
567 .entity_embeddings
568 .par_iter()
569 .map(|(entity, head_emb)| {
570 let score = (head_emb + rel_emb).dot(tail_emb);
571 (entity.clone(), score as f64)
572 })
573 .collect();
574
575 scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
576 scored_subjects.truncate(k);
577 Ok(scored_subjects)
578 }
579
580 fn predict_relations(
581 &self,
582 subject: &str,
583 object: &str,
584 k: usize,
585 ) -> Result<Vec<(String, f64)>> {
586 let head_emb = self
587 .entity_embeddings
588 .get(subject)
589 .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
590 let tail_emb = self
591 .entity_embeddings
592 .get(object)
593 .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
594
595 let mut scored_relations: Vec<(String, f64)> = self
596 .relation_embeddings
597 .par_iter()
598 .map(|(relation, rel_emb)| {
599 let score = (head_emb + rel_emb).dot(tail_emb);
600 (relation.clone(), score as f64)
601 })
602 .collect();
603
604 scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
605 scored_relations.truncate(k);
606 Ok(scored_relations)
607 }
608
609 fn get_entities(&self) -> Vec<String> {
610 self.entity_embeddings.keys().cloned().collect()
611 }
612
613 fn get_relations(&self) -> Vec<String> {
614 self.relation_embeddings.keys().cloned().collect()
615 }
616
617 fn get_stats(&self) -> ModelStats {
618 ModelStats {
619 num_entities: self.entity_embeddings.len(),
620 num_relations: self.relation_embeddings.len(),
621 num_triples: self.triples.len(),
622 dimensions: self.config.base.dimensions,
623 is_trained: self.is_trained,
624 model_type: "ConvE".to_string(),
625 creation_time: chrono::Utc::now(),
626 last_training_time: if self.is_trained {
627 Some(chrono::Utc::now())
628 } else {
629 None
630 },
631 }
632 }
633
634 fn save(&self, path: &str) -> Result<()> {
635 info!("Saving ConvE model to {}", path);
636
637 let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
639 .entity_embeddings
640 .iter()
641 .map(|(k, v)| (k.clone(), v.to_vec()))
642 .collect();
643
644 let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
645 .relation_embeddings
646 .iter()
647 .map(|(k, v)| (k.clone(), v.to_vec()))
648 .collect();
649
650 let conv_filters: Vec<Vec<Vec<f32>>> = self
652 .conv_layer
653 .filters
654 .iter()
655 .map(|filter| {
656 let mut rows = Vec::new();
657 for i in 0..filter.nrows() {
658 let mut row = Vec::new();
659 for j in 0..filter.ncols() {
660 row.push(filter[[i, j]]);
661 }
662 rows.push(row);
663 }
664 rows
665 })
666 .collect();
667
668 let conv_layer_ser = ConvLayerSerializable {
669 filters: conv_filters,
670 biases: self.conv_layer.biases.to_vec(),
671 };
672
673 let mut fc_weights = Vec::new();
675 for i in 0..self.fc_layer.weights.nrows() {
676 let mut row = Vec::new();
677 for j in 0..self.fc_layer.weights.ncols() {
678 row.push(self.fc_layer.weights[[i, j]]);
679 }
680 fc_weights.push(row);
681 }
682
683 let fc_layer_ser = FCLayerSerializable {
684 weights: fc_weights,
685 bias: self.fc_layer.bias.to_vec(),
686 };
687
688 let serializable = ConvESerializable {
689 model_id: self.model_id,
690 config: self.config.clone(),
691 entity_embeddings: entity_embeddings_vec,
692 relation_embeddings: relation_embeddings_vec,
693 conv_layer: conv_layer_ser,
694 fc_layer: fc_layer_ser,
695 triples: self.triples.clone(),
696 entity_to_id: self.entity_to_id.clone(),
697 relation_to_id: self.relation_to_id.clone(),
698 id_to_entity: self.id_to_entity.clone(),
699 id_to_relation: self.id_to_relation.clone(),
700 is_trained: self.is_trained,
701 };
702
703 let file = File::create(path)?;
704 let writer = BufWriter::new(file);
705 oxicode::serde::encode_into_std_write(&serializable, writer, oxicode::config::standard())
706 .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
707
708 info!("Model saved successfully");
709 Ok(())
710 }
711
712 fn load(&mut self, path: &str) -> Result<()> {
713 info!("Loading ConvE model from {}", path);
714
715 if !Path::new(path).exists() {
716 return Err(anyhow!("Model file not found: {}", path));
717 }
718
719 let file = File::open(path)?;
720 let reader = BufReader::new(file);
721 let (serializable, _): (ConvESerializable, _) =
722 oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
723 .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
724
725 let entity_embeddings: HashMap<String, Array1<f32>> = serializable
727 .entity_embeddings
728 .into_iter()
729 .map(|(k, v)| (k, Array1::from_vec(v)))
730 .collect();
731
732 let relation_embeddings: HashMap<String, Array1<f32>> = serializable
733 .relation_embeddings
734 .into_iter()
735 .map(|(k, v)| (k, Array1::from_vec(v)))
736 .collect();
737
738 let conv_filters: Vec<Array2<f32>> = serializable
740 .conv_layer
741 .filters
742 .into_iter()
743 .map(|filter_vec| {
744 let kernel_size = filter_vec.len();
745 Array2::from_shape_fn((kernel_size, kernel_size), |(i, j)| filter_vec[i][j])
746 })
747 .collect();
748
749 let conv_layer = ConvLayer {
750 filters: conv_filters,
751 biases: Array1::from_vec(serializable.conv_layer.biases),
752 };
753
754 let fc_weights_vec = serializable.fc_layer.weights;
756 let input_size = fc_weights_vec.len();
757 let output_size = if input_size > 0 {
758 fc_weights_vec[0].len()
759 } else {
760 0
761 };
762
763 let fc_weights =
764 Array2::from_shape_fn((input_size, output_size), |(i, j)| fc_weights_vec[i][j]);
765
766 let fc_layer = FCLayer {
767 weights: fc_weights,
768 bias: Array1::from_vec(serializable.fc_layer.bias),
769 };
770
771 self.model_id = serializable.model_id;
773 self.config = serializable.config;
774 self.entity_embeddings = entity_embeddings;
775 self.relation_embeddings = relation_embeddings;
776 self.conv_layer = conv_layer;
777 self.fc_layer = fc_layer;
778 self.triples = serializable.triples;
779 self.entity_to_id = serializable.entity_to_id;
780 self.relation_to_id = serializable.relation_to_id;
781 self.id_to_entity = serializable.id_to_entity;
782 self.id_to_relation = serializable.id_to_relation;
783 self.is_trained = serializable.is_trained;
784
785 info!("Model loaded successfully");
786 Ok(())
787 }
788
789 fn clear(&mut self) {
790 self.entity_embeddings.clear();
791 self.relation_embeddings.clear();
792 self.triples.clear();
793 self.entity_to_id.clear();
794 self.relation_to_id.clear();
795 self.id_to_entity.clear();
796 self.id_to_relation.clear();
797 self.is_trained = false;
798 }
799
800 fn is_trained(&self) -> bool {
801 self.is_trained
802 }
803
804 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
805 Err(anyhow!("Text encoding not implemented for ConvE"))
807 }
808}
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813
814 #[test]
815 fn test_conve_creation() {
816 let config = ConvEConfig::default();
817 let model = ConvE::new(config);
818
819 assert_eq!(model.entity_embeddings.len(), 0);
820 assert_eq!(model.relation_embeddings.len(), 0);
821 }
822
823 #[tokio::test]
824 async fn test_conve_training() {
825 let config = ConvEConfig {
826 base: ModelConfig {
827 dimensions: 50, learning_rate: 0.001,
829 max_epochs: 5, ..Default::default()
831 },
832 reshape_width: 10,
833 num_filters: 8, ..Default::default()
835 };
836
837 let mut model = ConvE::new(config);
838
839 model
840 .add_triple(Triple::new(
841 NamedNode::new("alice").expect("should succeed"),
842 NamedNode::new("knows").expect("should succeed"),
843 NamedNode::new("bob").expect("should succeed"),
844 ))
845 .expect("should succeed");
846
847 model
848 .add_triple(Triple::new(
849 NamedNode::new("bob").expect("should succeed"),
850 NamedNode::new("likes").expect("should succeed"),
851 NamedNode::new("charlie").expect("should succeed"),
852 ))
853 .expect("should succeed");
854
855 let stats = model.train(Some(5)).await.expect("should succeed"); assert_eq!(stats.epochs_completed, 5);
858 assert!(stats.final_loss >= 0.0);
859 assert_eq!(model.entity_embeddings.len(), 3);
860 assert_eq!(model.relation_embeddings.len(), 2);
861 }
862
863 #[tokio::test]
864 async fn test_conve_save_load() {
865 use std::env::temp_dir;
866
867 let config = ConvEConfig {
868 base: ModelConfig {
869 dimensions: 50,
870 learning_rate: 0.001,
871 max_epochs: 15,
872 ..Default::default()
873 },
874 reshape_width: 10,
875 num_filters: 8,
876 kernel_size: 2,
877 ..Default::default()
878 };
879
880 let mut model = ConvE::new(config);
881
882 model
884 .add_triple(Triple::new(
885 NamedNode::new("alice").expect("should succeed"),
886 NamedNode::new("knows").expect("should succeed"),
887 NamedNode::new("bob").expect("should succeed"),
888 ))
889 .expect("should succeed");
890
891 model
892 .add_triple(Triple::new(
893 NamedNode::new("bob").expect("should succeed"),
894 NamedNode::new("likes").expect("should succeed"),
895 NamedNode::new("charlie").expect("should succeed"),
896 ))
897 .expect("should succeed");
898
899 model.train(Some(15)).await.expect("should succeed");
900
901 let emb_before = model.get_entity_embedding("alice").expect("should succeed");
903 let score_before = model
904 .score_triple("alice", "knows", "bob")
905 .expect("should succeed");
906
907 let model_path = temp_dir().join("test_conve_model.bin");
909 let path_str = model_path.to_str().expect("should succeed");
910 model.save(path_str).expect("should succeed");
911
912 let mut loaded_model = ConvE::new(ConvEConfig::default());
914 loaded_model.load(path_str).expect("should succeed");
915
916 assert!(loaded_model.is_trained());
918 assert_eq!(loaded_model.get_entities().len(), 3);
919 assert_eq!(loaded_model.get_relations().len(), 2);
920
921 let emb_after = loaded_model
923 .get_entity_embedding("alice")
924 .expect("should succeed");
925 assert_eq!(emb_before.dimensions, emb_after.dimensions);
926 for i in 0..emb_before.values.len() {
927 assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
928 }
929
930 let score_after = loaded_model
932 .score_triple("alice", "knows", "bob")
933 .expect("should succeed");
934 assert!((score_before - score_after).abs() < 1e-5);
935
936 std::fs::remove_file(model_path).ok();
938 }
939
940 #[test]
941 fn test_conve_load_nonexistent() {
942 let mut model = ConvE::new(ConvEConfig::default());
943 let result = model.load("/nonexistent/path/model.bin");
944 assert!(result.is_err());
945 }
946}