1use anyhow::{anyhow, Result};
13use rayon::prelude::*;
14use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufReader, BufWriter};
20use std::path::Path;
21use tracing::{debug, info};
22
23#[cfg(test)]
24use crate::NamedNode;
25use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
26use uuid::Uuid;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ConvEConfig {
31 pub base: ModelConfig,
33 pub reshape_width: usize,
35 pub num_filters: usize,
37 pub kernel_size: usize,
39 pub dropout_rate: f32,
41 pub regularization: f32,
43 pub margin: f32,
45 pub num_negatives: usize,
47 pub use_batch_norm: bool,
49}
50
51impl Default for ConvEConfig {
52 fn default() -> Self {
53 Self {
54 base: ModelConfig::default().with_dimensions(200),
55 reshape_width: 20, num_filters: 32,
57 kernel_size: 3,
58 dropout_rate: 0.3,
59 regularization: 0.0001,
60 margin: 1.0,
61 num_negatives: 10,
62 use_batch_norm: true,
63 }
64 }
65}
66
67#[derive(Debug, Serialize, Deserialize)]
69struct ConvLayerSerializable {
70 filters: Vec<Vec<Vec<f32>>>, biases: Vec<f32>,
72}
73
74struct ConvLayer {
76 filters: Vec<Array2<f32>>,
78 biases: Array1<f32>,
80}
81
82impl ConvLayer {
83 fn new(num_filters: usize, kernel_size: usize, rng: &mut Random) -> Self {
84 let scale = (2.0 / (kernel_size * kernel_size) as f32).sqrt();
85 let mut filters = Vec::new();
86
87 for _ in 0..num_filters {
88 let filter =
89 Array2::from_shape_fn((kernel_size, kernel_size), |_| rng.gen_range(-scale..scale));
90 filters.push(filter);
91 }
92
93 let biases = Array1::zeros(num_filters);
94
95 Self { filters, biases }
96 }
97
98 fn forward(&self, input: &Array2<f32>) -> Array3<f32> {
100 let kernel_size = self.filters[0].nrows();
101 let input_height = input.nrows();
102 let input_width = input.ncols();
103
104 let out_height = input_height.saturating_sub(kernel_size - 1);
105 let out_width = input_width.saturating_sub(kernel_size - 1);
106
107 if out_height == 0 || out_width == 0 {
108 return Array3::zeros((self.filters.len(), 1, 1));
110 }
111
112 let mut output = Array3::zeros((self.filters.len(), out_height, out_width));
113
114 for (f_idx, filter) in self.filters.iter().enumerate() {
115 for i in 0..out_height {
116 for j in 0..out_width {
117 let mut sum = 0.0;
118
119 for ki in 0..kernel_size {
120 for kj in 0..kernel_size {
121 sum += input[[i + ki, j + kj]] * filter[[ki, kj]];
122 }
123 }
124
125 output[[f_idx, i, j]] = sum + self.biases[f_idx];
126 }
127 }
128 }
129
130 output
131 }
132}
133
134#[derive(Debug, Serialize, Deserialize)]
136struct FCLayerSerializable {
137 weights: Vec<Vec<f32>>, bias: Vec<f32>,
139}
140
141struct FCLayer {
143 weights: Array2<f32>,
144 bias: Array1<f32>,
145}
146
147impl FCLayer {
148 fn new(input_size: usize, output_size: usize, rng: &mut Random) -> Self {
149 let scale = (2.0 / input_size as f32).sqrt();
150 let weights =
151 Array2::from_shape_fn((input_size, output_size), |_| rng.gen_range(-scale..scale));
152 let bias = Array1::zeros(output_size);
153
154 Self { weights, bias }
155 }
156
157 fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
158 let mut output = self.bias.clone();
159 for i in 0..output.len() {
160 for j in 0..input.len() {
161 output[i] += input[j] * self.weights[[j, i]];
162 }
163 }
164 output
165 }
166}
167
168#[derive(Debug, Serialize, Deserialize)]
170struct ConvESerializable {
171 model_id: Uuid,
172 config: ConvEConfig,
173 entity_embeddings: HashMap<String, Vec<f32>>,
174 relation_embeddings: HashMap<String, Vec<f32>>,
175 conv_layer: ConvLayerSerializable,
176 fc_layer: FCLayerSerializable,
177 triples: Vec<Triple>,
178 entity_to_id: HashMap<String, usize>,
179 relation_to_id: HashMap<String, usize>,
180 id_to_entity: HashMap<usize, String>,
181 id_to_relation: HashMap<usize, String>,
182 is_trained: bool,
183}
184
185pub struct ConvE {
187 model_id: Uuid,
188 config: ConvEConfig,
189 entity_embeddings: HashMap<String, Array1<f32>>,
190 relation_embeddings: HashMap<String, Array1<f32>>,
191 conv_layer: ConvLayer,
192 fc_layer: FCLayer,
193 triples: Vec<Triple>,
194 entity_to_id: HashMap<String, usize>,
195 relation_to_id: HashMap<String, usize>,
196 id_to_entity: HashMap<usize, String>,
197 id_to_relation: HashMap<usize, String>,
198 is_trained: bool,
199}
200
201impl ConvE {
202 pub fn new(config: ConvEConfig) -> Self {
204 let mut rng = Random::default();
205
206 let reshape_height = config.base.dimensions / config.reshape_width;
208 let conv_out_height = reshape_height.saturating_sub(config.kernel_size - 1);
209 let conv_out_width = (config.reshape_width * 2).saturating_sub(config.kernel_size - 1);
210 let fc_input_size = config.num_filters * conv_out_height * conv_out_width;
211
212 let conv_layer = ConvLayer::new(config.num_filters, config.kernel_size, &mut rng);
213 let fc_layer = FCLayer::new(fc_input_size, config.base.dimensions, &mut rng);
214
215 info!(
216 "Initialized ConvE model: dim={}, filters={}, kernel={}, fc_input={}",
217 config.base.dimensions, config.num_filters, config.kernel_size, fc_input_size
218 );
219
220 Self {
221 model_id: Uuid::new_v4(),
222 config,
223 entity_embeddings: HashMap::new(),
224 relation_embeddings: HashMap::new(),
225 conv_layer,
226 fc_layer,
227 triples: Vec::new(),
228 entity_to_id: HashMap::new(),
229 relation_to_id: HashMap::new(),
230 id_to_entity: HashMap::new(),
231 id_to_relation: HashMap::new(),
232 is_trained: false,
233 }
234 }
235
236 fn reshape_embedding(&self, embedding: &Array1<f32>) -> Array2<f32> {
238 let height = self.config.base.dimensions / self.config.reshape_width;
239 let width = self.config.reshape_width;
240
241 Array2::from_shape_fn((height, width), |(i, j)| embedding[i * width + j])
242 }
243
244 fn relu(&self, x: f32) -> f32 {
246 x.max(0.0)
247 }
248
249 fn dropout(&mut self, values: &mut Array1<f32>, training: bool) {
251 if !training || self.config.dropout_rate == 0.0 {
252 return;
253 }
254
255 let mut local_rng = Random::default();
256 let keep_prob = 1.0 - self.config.dropout_rate;
257 for val in values.iter_mut() {
258 if local_rng.gen_range(0.0..1.0) > keep_prob {
259 *val = 0.0;
260 } else {
261 *val /= keep_prob; }
263 }
264 }
265
266 fn forward(
268 &mut self,
269 head: &Array1<f32>,
270 relation: &Array1<f32>,
271 training: bool,
272 ) -> Array1<f32> {
273 let head_2d = self.reshape_embedding(head);
275 let rel_2d = self.reshape_embedding(relation);
276
277 let height = head_2d.nrows();
279 let width = head_2d.ncols() * 2;
280 let mut concat = Array2::zeros((height, width));
281
282 for i in 0..height {
283 for j in 0..head_2d.ncols() {
284 concat[[i, j]] = head_2d[[i, j]];
285 }
286 for j in 0..rel_2d.ncols() {
287 concat[[i, head_2d.ncols() + j]] = rel_2d[[i, j]];
288 }
289 }
290
291 let conv_out = self.conv_layer.forward(&concat);
293
294 let conv_out_relu = conv_out.mapv(|x| self.relu(x));
296
297 let flattened_size = conv_out_relu.len();
299 let mut flattened = Array1::zeros(flattened_size);
300 for (idx, &val) in conv_out_relu.iter().enumerate() {
301 flattened[idx] = val;
302 }
303
304 self.dropout(&mut flattened, training);
306
307 let mut output = self.fc_layer.forward(&flattened);
309
310 self.dropout(&mut output, training);
312
313 output
314 }
315
316 fn score_triple_internal(
318 &mut self,
319 head: &Array1<f32>,
320 relation: &Array1<f32>,
321 tail: &Array1<f32>,
322 ) -> f32 {
323 let projected = self.forward(head, relation, false);
324 projected.dot(tail)
326 }
327
328 fn init_entity(&mut self, entity: &str) {
330 if !self.entity_embeddings.contains_key(entity) {
331 let id = self.entity_embeddings.len();
332 self.entity_to_id.insert(entity.to_string(), id);
333 self.id_to_entity.insert(id, entity.to_string());
334
335 let mut local_rng = Random::default();
336 let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
337 let embedding = Array1::from_vec(
338 (0..self.config.base.dimensions)
339 .map(|_| local_rng.gen_range(-scale..scale))
340 .collect(),
341 );
342 self.entity_embeddings.insert(entity.to_string(), embedding);
343 }
344 }
345
346 fn init_relation(&mut self, relation: &str) {
348 if !self.relation_embeddings.contains_key(relation) {
349 let id = self.relation_embeddings.len();
350 self.relation_to_id.insert(relation.to_string(), id);
351 self.id_to_relation.insert(id, relation.to_string());
352
353 let mut local_rng = Random::default();
354 let scale = (6.0 / self.config.base.dimensions as f32).sqrt();
355 let embedding = Array1::from_vec(
356 (0..self.config.base.dimensions)
357 .map(|_| local_rng.gen_range(-scale..scale))
358 .collect(),
359 );
360 self.relation_embeddings
361 .insert(relation.to_string(), embedding);
362 }
363 }
364
365 fn train_step(&mut self) -> f32 {
367 let mut total_loss = 0.0;
368 let mut local_rng = Random::default();
369
370 let mut indices: Vec<usize> = (0..self.triples.len()).collect();
372 for i in (1..indices.len()).rev() {
373 let j = local_rng.random_range(0..i + 1);
374 indices.swap(i, j);
375 }
376
377 for &idx in &indices {
378 let triple = &self.triples[idx].clone();
379
380 let subject_str = &triple.subject.iri;
381 let predicate_str = &triple.predicate.iri;
382 let object_str = &triple.object.iri;
383
384 let head_emb = self.entity_embeddings[subject_str].clone();
385 let rel_emb = self.relation_embeddings[predicate_str].clone();
386 let tail_emb = self.entity_embeddings[object_str].clone();
387
388 let pos_score = self.score_triple_internal(&head_emb, &rel_emb, &tail_emb);
390
391 let entity_list: Vec<String> = self.entity_embeddings.keys().cloned().collect();
393 for _ in 0..self.config.num_negatives {
394 let neg_tail_id = entity_list[local_rng.random_range(0..entity_list.len())].clone();
395 let neg_tail_emb = self.entity_embeddings[&neg_tail_id].clone();
396
397 let neg_score = self.score_triple_internal(&head_emb, &rel_emb, &neg_tail_emb);
398
399 let loss = (self.config.margin + neg_score - pos_score).max(0.0);
401 total_loss += loss;
402
403 if loss > 0.0 {
405 let lr = self.config.base.learning_rate as f32;
406 for emb in self.entity_embeddings.values_mut() {
408 *emb = &*emb * (1.0 - self.config.regularization * lr);
409 }
410 for emb in self.relation_embeddings.values_mut() {
411 *emb = &*emb * (1.0 - self.config.regularization * lr);
412 }
413 }
414 }
415 }
416
417 total_loss / (self.triples.len() as f32 * self.config.num_negatives as f32)
418 }
419}
420
421#[async_trait::async_trait]
422impl EmbeddingModel for ConvE {
423 fn config(&self) -> &ModelConfig {
424 &self.config.base
425 }
426
427 fn model_id(&self) -> &Uuid {
428 &self.model_id
429 }
430
431 fn model_type(&self) -> &'static str {
432 "ConvE"
433 }
434
435 fn add_triple(&mut self, triple: Triple) -> Result<()> {
436 self.init_entity(&triple.subject.iri);
437 self.init_entity(&triple.object.iri);
438 self.init_relation(&triple.predicate.iri);
439 self.triples.push(triple);
440 Ok(())
441 }
442
443 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
444 let num_epochs = epochs.unwrap_or(self.config.base.max_epochs);
445
446 if self.triples.is_empty() {
447 return Err(anyhow!("No training data available"));
448 }
449
450 info!(
451 "Training ConvE model for {} epochs on {} triples",
452 num_epochs,
453 self.triples.len()
454 );
455
456 let start_time = std::time::Instant::now();
457 let mut loss_history = Vec::new();
458
459 for epoch in 0..num_epochs {
460 let loss = self.train_step();
461 loss_history.push(loss as f64);
462
463 if epoch % 10 == 0 {
464 debug!("Epoch {}/{}: loss = {:.6}", epoch + 1, num_epochs, loss);
465 }
466
467 if loss < 0.001 {
468 info!("Converged at epoch {}", epoch);
469 break;
470 }
471 }
472
473 let training_time = start_time.elapsed().as_secs_f64();
474 self.is_trained = true;
475
476 Ok(TrainingStats {
477 epochs_completed: num_epochs,
478 final_loss: *loss_history.last().unwrap_or(&0.0),
479 training_time_seconds: training_time,
480 convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
481 loss_history,
482 })
483 }
484
485 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
486 self.entity_embeddings
487 .get(entity)
488 .map(Vector::from_array1)
489 .ok_or_else(|| anyhow!("Unknown entity: {}", entity))
490 }
491
492 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
493 self.relation_embeddings
494 .get(relation)
495 .map(Vector::from_array1)
496 .ok_or_else(|| anyhow!("Unknown relation: {}", relation))
497 }
498
499 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
500 let head_emb = self
501 .entity_embeddings
502 .get(subject)
503 .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
504 let rel_emb = self
505 .relation_embeddings
506 .get(predicate)
507 .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
508 let tail_emb = self
509 .entity_embeddings
510 .get(object)
511 .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
512
513 let score = (head_emb + rel_emb).dot(tail_emb);
516 Ok(score as f64)
517 }
518
519 fn predict_objects(
520 &self,
521 subject: &str,
522 predicate: &str,
523 k: usize,
524 ) -> Result<Vec<(String, f64)>> {
525 let head_emb = self
526 .entity_embeddings
527 .get(subject)
528 .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
529 let rel_emb = self
530 .relation_embeddings
531 .get(predicate)
532 .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
533
534 let combined = head_emb + rel_emb;
535 let mut scored_objects: Vec<(String, f64)> = self
536 .entity_embeddings
537 .par_iter()
538 .map(|(entity, tail_emb)| {
539 let score = combined.dot(tail_emb);
540 (entity.clone(), score as f64)
541 })
542 .collect();
543
544 scored_objects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
545 scored_objects.truncate(k);
546 Ok(scored_objects)
547 }
548
549 fn predict_subjects(
550 &self,
551 predicate: &str,
552 object: &str,
553 k: usize,
554 ) -> Result<Vec<(String, f64)>> {
555 let rel_emb = self
556 .relation_embeddings
557 .get(predicate)
558 .ok_or_else(|| anyhow!("Unknown predicate: {}", predicate))?;
559 let tail_emb = self
560 .entity_embeddings
561 .get(object)
562 .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
563
564 let mut scored_subjects: Vec<(String, f64)> = self
565 .entity_embeddings
566 .par_iter()
567 .map(|(entity, head_emb)| {
568 let score = (head_emb + rel_emb).dot(tail_emb);
569 (entity.clone(), score as f64)
570 })
571 .collect();
572
573 scored_subjects.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
574 scored_subjects.truncate(k);
575 Ok(scored_subjects)
576 }
577
578 fn predict_relations(
579 &self,
580 subject: &str,
581 object: &str,
582 k: usize,
583 ) -> Result<Vec<(String, f64)>> {
584 let head_emb = self
585 .entity_embeddings
586 .get(subject)
587 .ok_or_else(|| anyhow!("Unknown subject: {}", subject))?;
588 let tail_emb = self
589 .entity_embeddings
590 .get(object)
591 .ok_or_else(|| anyhow!("Unknown object: {}", object))?;
592
593 let mut scored_relations: Vec<(String, f64)> = self
594 .relation_embeddings
595 .par_iter()
596 .map(|(relation, rel_emb)| {
597 let score = (head_emb + rel_emb).dot(tail_emb);
598 (relation.clone(), score as f64)
599 })
600 .collect();
601
602 scored_relations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
603 scored_relations.truncate(k);
604 Ok(scored_relations)
605 }
606
607 fn get_entities(&self) -> Vec<String> {
608 self.entity_embeddings.keys().cloned().collect()
609 }
610
611 fn get_relations(&self) -> Vec<String> {
612 self.relation_embeddings.keys().cloned().collect()
613 }
614
615 fn get_stats(&self) -> ModelStats {
616 ModelStats {
617 num_entities: self.entity_embeddings.len(),
618 num_relations: self.relation_embeddings.len(),
619 num_triples: self.triples.len(),
620 dimensions: self.config.base.dimensions,
621 is_trained: self.is_trained,
622 model_type: "ConvE".to_string(),
623 creation_time: chrono::Utc::now(),
624 last_training_time: if self.is_trained {
625 Some(chrono::Utc::now())
626 } else {
627 None
628 },
629 }
630 }
631
632 fn save(&self, path: &str) -> Result<()> {
633 info!("Saving ConvE model to {}", path);
634
635 let entity_embeddings_vec: HashMap<String, Vec<f32>> = self
637 .entity_embeddings
638 .iter()
639 .map(|(k, v)| (k.clone(), v.to_vec()))
640 .collect();
641
642 let relation_embeddings_vec: HashMap<String, Vec<f32>> = self
643 .relation_embeddings
644 .iter()
645 .map(|(k, v)| (k.clone(), v.to_vec()))
646 .collect();
647
648 let conv_filters: Vec<Vec<Vec<f32>>> = self
650 .conv_layer
651 .filters
652 .iter()
653 .map(|filter| {
654 let mut rows = Vec::new();
655 for i in 0..filter.nrows() {
656 let mut row = Vec::new();
657 for j in 0..filter.ncols() {
658 row.push(filter[[i, j]]);
659 }
660 rows.push(row);
661 }
662 rows
663 })
664 .collect();
665
666 let conv_layer_ser = ConvLayerSerializable {
667 filters: conv_filters,
668 biases: self.conv_layer.biases.to_vec(),
669 };
670
671 let mut fc_weights = Vec::new();
673 for i in 0..self.fc_layer.weights.nrows() {
674 let mut row = Vec::new();
675 for j in 0..self.fc_layer.weights.ncols() {
676 row.push(self.fc_layer.weights[[i, j]]);
677 }
678 fc_weights.push(row);
679 }
680
681 let fc_layer_ser = FCLayerSerializable {
682 weights: fc_weights,
683 bias: self.fc_layer.bias.to_vec(),
684 };
685
686 let serializable = ConvESerializable {
687 model_id: self.model_id,
688 config: self.config.clone(),
689 entity_embeddings: entity_embeddings_vec,
690 relation_embeddings: relation_embeddings_vec,
691 conv_layer: conv_layer_ser,
692 fc_layer: fc_layer_ser,
693 triples: self.triples.clone(),
694 entity_to_id: self.entity_to_id.clone(),
695 relation_to_id: self.relation_to_id.clone(),
696 id_to_entity: self.id_to_entity.clone(),
697 id_to_relation: self.id_to_relation.clone(),
698 is_trained: self.is_trained,
699 };
700
701 let file = File::create(path)?;
702 let writer = BufWriter::new(file);
703 oxicode::serde::encode_into_std_write(&serializable, writer, oxicode::config::standard())
704 .map_err(|e| anyhow!("Failed to serialize model: {}", e))?;
705
706 info!("Model saved successfully");
707 Ok(())
708 }
709
710 fn load(&mut self, path: &str) -> Result<()> {
711 info!("Loading ConvE model from {}", path);
712
713 if !Path::new(path).exists() {
714 return Err(anyhow!("Model file not found: {}", path));
715 }
716
717 let file = File::open(path)?;
718 let reader = BufReader::new(file);
719 let (serializable, _): (ConvESerializable, _) =
720 oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())
721 .map_err(|e| anyhow!("Failed to deserialize model: {}", e))?;
722
723 let entity_embeddings: HashMap<String, Array1<f32>> = serializable
725 .entity_embeddings
726 .into_iter()
727 .map(|(k, v)| (k, Array1::from_vec(v)))
728 .collect();
729
730 let relation_embeddings: HashMap<String, Array1<f32>> = serializable
731 .relation_embeddings
732 .into_iter()
733 .map(|(k, v)| (k, Array1::from_vec(v)))
734 .collect();
735
736 let conv_filters: Vec<Array2<f32>> = serializable
738 .conv_layer
739 .filters
740 .into_iter()
741 .map(|filter_vec| {
742 let kernel_size = filter_vec.len();
743 Array2::from_shape_fn((kernel_size, kernel_size), |(i, j)| filter_vec[i][j])
744 })
745 .collect();
746
747 let conv_layer = ConvLayer {
748 filters: conv_filters,
749 biases: Array1::from_vec(serializable.conv_layer.biases),
750 };
751
752 let fc_weights_vec = serializable.fc_layer.weights;
754 let input_size = fc_weights_vec.len();
755 let output_size = if input_size > 0 {
756 fc_weights_vec[0].len()
757 } else {
758 0
759 };
760
761 let fc_weights =
762 Array2::from_shape_fn((input_size, output_size), |(i, j)| fc_weights_vec[i][j]);
763
764 let fc_layer = FCLayer {
765 weights: fc_weights,
766 bias: Array1::from_vec(serializable.fc_layer.bias),
767 };
768
769 self.model_id = serializable.model_id;
771 self.config = serializable.config;
772 self.entity_embeddings = entity_embeddings;
773 self.relation_embeddings = relation_embeddings;
774 self.conv_layer = conv_layer;
775 self.fc_layer = fc_layer;
776 self.triples = serializable.triples;
777 self.entity_to_id = serializable.entity_to_id;
778 self.relation_to_id = serializable.relation_to_id;
779 self.id_to_entity = serializable.id_to_entity;
780 self.id_to_relation = serializable.id_to_relation;
781 self.is_trained = serializable.is_trained;
782
783 info!("Model loaded successfully");
784 Ok(())
785 }
786
787 fn clear(&mut self) {
788 self.entity_embeddings.clear();
789 self.relation_embeddings.clear();
790 self.triples.clear();
791 self.entity_to_id.clear();
792 self.relation_to_id.clear();
793 self.id_to_entity.clear();
794 self.id_to_relation.clear();
795 self.is_trained = false;
796 }
797
798 fn is_trained(&self) -> bool {
799 self.is_trained
800 }
801
802 async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
803 Err(anyhow!("Text encoding not implemented for ConvE"))
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn test_conve_creation() {
814 let config = ConvEConfig::default();
815 let model = ConvE::new(config);
816
817 assert_eq!(model.entity_embeddings.len(), 0);
818 assert_eq!(model.relation_embeddings.len(), 0);
819 }
820
821 #[tokio::test]
822 async fn test_conve_training() {
823 let config = ConvEConfig {
824 base: ModelConfig {
825 dimensions: 50, learning_rate: 0.001,
827 max_epochs: 5, ..Default::default()
829 },
830 reshape_width: 10,
831 num_filters: 8, ..Default::default()
833 };
834
835 let mut model = ConvE::new(config);
836
837 model
838 .add_triple(Triple::new(
839 NamedNode::new("alice").unwrap(),
840 NamedNode::new("knows").unwrap(),
841 NamedNode::new("bob").unwrap(),
842 ))
843 .unwrap();
844
845 model
846 .add_triple(Triple::new(
847 NamedNode::new("bob").unwrap(),
848 NamedNode::new("likes").unwrap(),
849 NamedNode::new("charlie").unwrap(),
850 ))
851 .unwrap();
852
853 let stats = model.train(Some(5)).await.unwrap(); assert_eq!(stats.epochs_completed, 5);
856 assert!(stats.final_loss >= 0.0);
857 assert_eq!(model.entity_embeddings.len(), 3);
858 assert_eq!(model.relation_embeddings.len(), 2);
859 }
860
861 #[tokio::test]
862 async fn test_conve_save_load() {
863 use std::env::temp_dir;
864
865 let config = ConvEConfig {
866 base: ModelConfig {
867 dimensions: 50,
868 learning_rate: 0.001,
869 max_epochs: 15,
870 ..Default::default()
871 },
872 reshape_width: 10,
873 num_filters: 8,
874 kernel_size: 2,
875 ..Default::default()
876 };
877
878 let mut model = ConvE::new(config);
879
880 model
882 .add_triple(Triple::new(
883 NamedNode::new("alice").unwrap(),
884 NamedNode::new("knows").unwrap(),
885 NamedNode::new("bob").unwrap(),
886 ))
887 .unwrap();
888
889 model
890 .add_triple(Triple::new(
891 NamedNode::new("bob").unwrap(),
892 NamedNode::new("likes").unwrap(),
893 NamedNode::new("charlie").unwrap(),
894 ))
895 .unwrap();
896
897 model.train(Some(15)).await.unwrap();
898
899 let emb_before = model.get_entity_embedding("alice").unwrap();
901 let score_before = model.score_triple("alice", "knows", "bob").unwrap();
902
903 let model_path = temp_dir().join("test_conve_model.bin");
905 let path_str = model_path.to_str().unwrap();
906 model.save(path_str).unwrap();
907
908 let mut loaded_model = ConvE::new(ConvEConfig::default());
910 loaded_model.load(path_str).unwrap();
911
912 assert!(loaded_model.is_trained());
914 assert_eq!(loaded_model.get_entities().len(), 3);
915 assert_eq!(loaded_model.get_relations().len(), 2);
916
917 let emb_after = loaded_model.get_entity_embedding("alice").unwrap();
919 assert_eq!(emb_before.dimensions, emb_after.dimensions);
920 for i in 0..emb_before.values.len() {
921 assert!((emb_before.values[i] - emb_after.values[i]).abs() < 1e-6);
922 }
923
924 let score_after = loaded_model.score_triple("alice", "knows", "bob").unwrap();
926 assert!((score_before - score_after).abs() < 1e-5);
927
928 std::fs::remove_file(model_path).ok();
930 }
931
932 #[test]
933 fn test_conve_load_nonexistent() {
934 let mut model = ConvE::new(ConvEConfig::default());
935 let result = model.load("/nonexistent/path/model.bin");
936 assert!(result.is_err());
937 }
938}